Exemple #1
0
def _resize_embeddings_if_needed(model: transformers.PreTrainedModel,
                                 vocab_size: int) -> None:
    if vocab_size is not None:
        mean_emb = model.base_model.wte.weight.data.mean(0)
        old_size = model.base_model.wte.weight.data.size()[0]
        n_new = vocab_size - old_size

        if n_new < 0:
            raise ModelLoadingError(
                "Can't resize embeddings: new vocab size can not be less than "
                "the old embeddings number (old vocab size).")

        model.resize_token_embeddings(vocab_size)
        idx = vocab_size - n_new
        model.base_model.wte.weight.data[idx:] = mean_emb.unsqueeze(0)
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    model = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    # logger.info("***** Running training *****")
    # logger.info("  Num examples = %d", len(train_dataset))
    # logger.info("  Num Epochs = %d", args.num_train_epochs)
    # logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    # logger.info(
    #     "  Total train batch size (w. parallel, distributed & accumulation) = %d",
    #     args.train_batch_size
    #     * args.gradient_accumulation_steps
    #     * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    # )
    # logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    # logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            # logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            # logger.info("  Continuing training from epoch %d", epochs_trained)
            # logger.info("  Continuing training from global step %d", global_step)
            # logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            xxxxxx = 0
            # logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer,
                                         args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs,
                            masked_lm_labels=labels) if args.mlm else model(
                                inputs, labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    # logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    # logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if 0 < args.max_steps < global_step:
                epoch_iterator.close()
                break
        if 0 < args.max_steps < global_step:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemple #3
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    model = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    non_multi_model = model
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(non_multi_model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility
    best_perplexity = float('inf')
    for i, epoch in enumerate(train_iterator):
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])

        if args.local_rank != -1:
            train_sampler.set_epoch(epoch)

        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer,
                                         args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs,
                            masked_lm_labels=labels) if args.mlm else model(
                                inputs, labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

        if args.do_eval:
            file_path = Path(args.data_dir, args.eval_data_file)

            out_file_path = Path(args.data_dir,
                                 "output_" + args.eval_data_file)
            id_to_json_map = {}
            with open(file_path, encoding="utf-8") as f:
                lines = []
                i = 0

                eval_loss = 0.0
                nb_eval_steps = 0
                for line in tqdm(f, desc="Evaluating"):
                    out_json = {}
                    line = json.loads(line)
                    example_id = line.get("example_id")
                    question_text = line.get("question_text")

                    prompt_text = question_text + " " + args.sep_token + " "
                    encoded_prompt = tokenizer.encode(prompt_text,
                                                      add_special_tokens=False,
                                                      return_tensors="pt")
                    encoded_prompt = encoded_prompt.to(args.device)

                    output_sequences = non_multi_model.generate(
                        input_ids=encoded_prompt,
                        max_length=args.length + len(encoded_prompt[0]),
                        temperature=args.temperature,
                        top_k=args.k,
                        top_p=args.p,
                        repetition_penalty=args.repetition_penalty,
                        do_sample=True,
                        num_return_sequences=args.num_return_sequences,
                    )
                    if len(output_sequences.shape) > 2:
                        output_sequences.squeeze_()

                    generated_sequences = []

                    for generated_sequence_idx, generated_sequence in enumerate(
                            output_sequences):
                        # print("=== GENERATED SEQUENCE {} ===".format(generated_sequence_idx + 1))
                        # generated_sequence = output_sequences[0]

                        generated_sequence = generated_sequence.tolist()

                        # Decode text
                        text = tokenizer.decode(
                            generated_sequence,
                            clean_up_tokenization_spaces=True)

                        # Remove all text after the stop token
                        if args.stop_token:
                            text = text[:text.find(args.stop_token)]

                        # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
                        total_sequence = (prompt_text + text[len(
                            tokenizer.decode(encoded_prompt[0],
                                             clean_up_tokenization_spaces=True)
                        ):])

                        # print(total_sequence)

                        out_json["journaling_input"], out_json[
                            "reflection_output"] = total_sequence.split(
                                args.sep_token)[:2]

                        sample_dataset = GenerateTextDataset(
                            tokenizer, total_sequence, args.block_size)

                        def collate(examples: List[torch.Tensor]):
                            if tokenizer._pad_token is None:
                                return pad_sequence(examples, batch_first=True)
                            return pad_sequence(
                                examples,
                                batch_first=True,
                                padding_value=tokenizer.pad_token_id)

                        eval_sampler = SequentialSampler(sample_dataset)
                        eval_dataloader = DataLoader(sample_dataset,
                                                     sampler=eval_sampler,
                                                     batch_size=1,
                                                     collate_fn=collate)

                        model_lm = model
                        if args.n_gpu > 1:
                            model_lm = torch.nn.DataParallel(model_lm)

                        model_lm.eval()

                        for batch in eval_dataloader:
                            inputs, labels = mask_tokens(
                                batch, tokenizer,
                                args) if args.mlm else (batch, batch)
                            inputs = inputs.to(args.device)
                            labels = labels.to(args.device)

                            with torch.no_grad():
                                outputs = model_lm(inputs,
                                                   masked_lm_labels=labels
                                                   ) if args.mlm else model_lm(
                                                       inputs, labels=labels)
                                lm_loss = outputs[0]
                                example_loss = lm_loss.mean().item()
                                eval_loss += example_loss
                            nb_eval_steps += 1

                        perplexity = torch.exp(
                            torch.tensor(example_loss)).item()
                        # print(perplexity)
                        out_json["perplexity"] = perplexity

                        example_id += "-" + str(generated_sequence_idx)
                        id_to_json_map[example_id] = json.dumps(
                            out_json, ensure_ascii=False)

                    # result = {"perplexity": perplexity}

                eval_loss = eval_loss / nb_eval_steps
                total_perplexity = torch.exp(torch.tensor(eval_loss))
                logger.info(f"total_loss:: {eval_loss}")
                logger.info(
                    f"total_perplexity:: {torch.exp(torch.tensor(eval_loss))}")
                if total_perplexity < best_perplexity:
                    logger.info(
                        f"Current best epoch::: {i}, with perplexity:: {total_perplexity}"
                    )
                    best_perplexity = total_perplexity

                    with open(out_file_path, "w+",
                              encoding="utf-8") as out_file:
                        for _, out_json in id_to_json_map.items():
                            out_file.write(out_json + "\n")

                    model_to_save = model.module if hasattr(
                        model,
                        'module') else model  # Only save the model it-self

                    # If we save using the predefined names, we can load using `from_pretrained`
                    output_model_file = os.path.join(args.output_dir,
                                                     WEIGHTS_NAME)
                    output_config_file = os.path.join(args.output_dir,
                                                      CONFIG_NAME)
                    torch.save(model_to_save.state_dict(), output_model_file)
                    model_to_save.config.to_json_file(output_config_file)
                    tokenizer.save_vocabulary(args.output_dir)

    return global_step, tr_loss / global_step
def _resize_embeddings_if_needed(model: transformers.PreTrainedModel,
                                 vocab_size: int) -> None:
    if vocab_size is not None:
        model.resize_token_embeddings(vocab_size)
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)

    def collate(examples: List[Dict]):
        inputs, labels = [], []
        for sample in examples:
            inputs.append(sample['inputs'])
            labels.append(sample['label'])
        labels = torch.LongTensor(labels)

        if tokenizer._pad_token is None:
            return {
                'inputs': pad_sequence(inputs, batch_first=True),
                'label': labels
            }
        return {
            'inputs':
            pad_sequence(inputs,
                         batch_first=True,
                         padding_value=tokenizer.pad_token_id),
            'label':
            labels
        }

    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    model = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = batch['inputs'], batch['label']
            inputs, labels = inputs.to(args.device), labels.to(args.device)
            model.train()

            outputs = model(inputs, labels=labels)
            logits = outputs[1]

            loss_fct = nn.CrossEntropyLoss(reduce=False)
            loss_raw = loss_fct(logits, labels)

            # Get the probability predicted for gold labels
            gt_pred_prb = torch.gather(F.softmax(logits, dim=-1),
                                       dim=-1,
                                       index=labels.unsqueeze(-1)).squeeze(1)
            # Focal Loss
            alpha_matrix = torch.Tensor(
                [4067, 2632, 1583])  # magic number - label distribution

            alpha_matrix = torch.gather(
                (alpha_matrix / alpha_matrix.max()).to(logits.device),
                dim=0,
                index=labels)

            loss = (loss_raw * torch.pow(
                (1 - gt_pred_prb), 2) * alpha_matrix).mean()

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:
                        # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(
                            args, model.module,
                            tokenizer) if args.n_gpu > 1 else evaluate(
                                args, model, tokenizer)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size,
                                      collate_fn=collate)

    # if args.local_rank in [-1, 0]:
    #     tb_writer.close()

    return global_step, tr_loss / global_step
def train(args, train_dataset, tb_writer, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer,
          saved_model_cur: str) -> Tuple[int, float]:
    """ Train the model """

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    model = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (saved_model_cur
            and os.path.isfile(os.path.join(saved_model_cur, "optimizer.pt"))
            and os.path.isfile(os.path.join(saved_model_cur, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(saved_model_cur, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(saved_model_cur, "scheduler.pt")))

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if saved_model_cur and os.path.exists(saved_model_cur):
        try:
            args_dict = vars(
                torch.load(os.path.join(saved_model_cur, "training_args.bin")))
            global_step = args_dict['global_step']
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except:
            logger.info("  Starting fine-tuning.")

    # all steps accumulated loss, loss per epoch, loss per gradient update step.
    tr_loss, epoch_loss, logging_loss, val_loss_previous_best = 0.0, 0.0, 0.0, 1e8

    model.zero_grad()
    # trange(i) is a special optimised instance of tqdm(range(i))
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility

    val_loss, preplexity_val = evaluate(args,
                                        model,
                                        tokenizer,
                                        str(global_step),
                                        debug=args.debug)
    tb_writer.add_scalar("val/preplexity", preplexity_val, global_step)
    tb_writer.add_scalar("val/loss", val_loss, global_step)

    print(f"the val loss is {val_loss} and perplexity is {preplexity_val}.")
    print("one iteration has been finished.")
Exemple #7
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """

    # GX: you can make changes here.
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter(log_dir='tb_logs_' + args.descr_string)
        log_string_config = '  '.join(
            [k + ':' + str(v) for k, v in vars(args).items()])
        tb_writer.add_text('experiment args', log_string_config, 0)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    model = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            args_dict = vars(
                torch.load(
                    os.path.join(args.model_name_or_path,
                                 "training_args.bin")))
            global_step = args_dict['global_step']
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except:
            logger.info("  Starting fine-tuning.")

    # all steps accumulated loss, loss per epoch, loss per gradient update step.
    tr_loss, epoch_loss, logging_loss, val_loss_previous_best = 0.0, 0.0, 0.0, 1e8

    model.zero_grad()
    # trange(i) is a special optimised instance of tqdm(range(i))
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer,
                                         args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs,
                            masked_lm_labels=labels) if args.mlm else model(
                                inputs, labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                ## ??should I modify the changes in gradients here?
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()  # backpropagate on gradient_accumulation_step
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    # doing the things on the perplexity per batch.
                    cur_gas_loss = (
                        tr_loss - logging_loss
                    ) / args.logging_steps  # the loss during the last gradient_accumulation_step
                    cur_gas_perplexit = np.exp(cur_gas_loss)
                    tb_writer.add_scalar("train/iter_loss", cur_gas_loss,
                                         global_step)
                    tb_writer.add_scalar("train/iter_preplexity",
                                         cur_gas_perplexit, global_step)
                    tb_writer.add_scalar("progress/lr",
                                         scheduler.get_lr()[0], global_step)
                    logging_loss = tr_loss

                    ## save and evaluation.
                    current_prefix = "checkpoint-current"
                    best_prefix = "checkpoint-best"
                    args.global_step = global_step

                    # in case either torch.save or tokenizer.save does not overwrite
                    current_output_dir = os.path.join(args.output_dir,
                                                      current_prefix)
                    if os.path.isdir(current_output_dir):
                        shutil.rmtree(current_output_dir)
                    os.makedirs(current_output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(
                        current_output_dir)  # build-in save function.
                    tokenizer.save_pretrained(current_output_dir)
                    torch.save(
                        args,
                        os.path.join(current_output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s",
                                current_output_dir)
                    torch.save(
                        optimizer.state_dict(),
                        os.path.join(current_output_dir, "optimizer.pt"))
                    torch.save(
                        scheduler.state_dict(),
                        os.path.join(current_output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                current_output_dir)

                    # saving the best model.
                    val_loss, preplexity_val = evaluate(args,
                                                        model,
                                                        tokenizer,
                                                        str(global_step),
                                                        debug=args.debug)
                    tb_writer.add_scalar("val/preplexity", preplexity_val,
                                         global_step)
                    tb_writer.add_scalar("val/loss", val_loss, global_step)

                    if val_loss < val_loss_previous_best:
                        logger.info(
                            "The current model surpass the previous best.")
                        best_output_dir = os.path.join(args.output_dir,
                                                       best_prefix)
                        if os.path.isdir(best_output_dir):
                            shutil.rmtree(best_output_dir)
                        os.makedirs(best_output_dir)
                        copy_tree(current_output_dir, best_output_dir)
                        logger.info(
                            "Saving the best model by copying  %s to %s",
                            current_output_dir, best_output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

            if args.debug and step == 20:
                print(global_step, step)
                break

        ## record the average loss each step per epoch
        cur_epoch_loss = (tr_loss - epoch_loss) / step
        cur_epoch_prelexity = np.exp(cur_epoch_loss)
        epoch_loss = tr_loss
        tb_writer.add_scalar("train/epoch_loss", cur_epoch_loss, global_step)
        tb_writer.add_scalar("train/epoch_preplexity", cur_epoch_prelexity,
                             global_step)

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemple #8
0
def train(args, data, datasets, model: PreTrainedModel, original_model,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)

    train_datasets = datasets['train']
    dev_datasets = datasets['dev']

    train_dataloaders, train_example_num, train_distribution = create_dataloader(
        args, train_datasets, tokenizer, train=True)
    dev_dataloaders, dev_example_num, dev_distribution = create_dataloader(
        args, dev_datasets, tokenizer, train=False)

    train_iter_num = sum(
        [len(dataloader) for dataloader in train_dataloaders.values()])
    dev_iter_num = sum(
        [len(dataloader) for dataloader in dev_dataloaders.values()])

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            train_iter_num // args.gradient_accumulation_steps) + 1
    else:
        t_total = train_iter_num // args.gradient_accumulation_steps * args.num_train_epochs

    model = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))

    original_model = original_model.module if hasattr(
        original_model, "module"
    ) else original_model  # Take care of distributed/parallel training
    original_model.resize_token_embeddings(len(tokenizer))

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
        original_model = torch.nn.DataParallel(original_model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
        original_model = torch.nn.parallel.DistributedDataParallel(
            original_model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", train_example_num)
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    best_loss = float('inf')
    best_step = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (train_iter_num //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                train_iter_num // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    model.zero_grad()
    original_model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])

    def inner_product(x, y):
        return torch.mean(torch.sum(y * x, 3))

    def mean_square(x, y, idx):
        return torch.mean(torch.mean((y - x)**2, idx))
        #return torch.mean(torch.sum((y - x) ** 2, 3))

    def save_best_model(best_loss, best_step, dev_dataloaders):
        if (
                args.local_rank == -1 and args.evaluate_during_training
        ):  # Only evaluate when single GPU otherwise metrics may not average well
            eval_loss = evaluate(model, attributes_hiddens, dev_dataloaders)
            #eval_loss = evaluate(args, model, original_model, dev_dataloaders, dev_example_num, dev_distribution, criterion_mse, criterion_ip, feminine_hiddens, masculine_hiddens, gender_hiddens)
            logger.info(" global_step = %s, evaluate loss = %s", global_step,
                        eval_loss)
            tb_writer.add_scalar("eval_loss", eval_loss, global_step)
        tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)

        if eval_loss < best_loss:
            best_loss = eval_loss
            best_step = global_step
            checkpoint_prefix = "checkpoint"
            # Save model checkpoint
            output_dir = os.path.join(args.output_dir, "checkpoint-best")
            os.makedirs(output_dir, exist_ok=True)
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Take care of distributed/parallel training
            model_to_save.save_pretrained(output_dir)
            tokenizer.save_pretrained(output_dir)

            torch.save(args, os.path.join(output_dir, "training_args.bin"))
            logger.info("Saving model checkpoint to %s", output_dir)

            #_rotate_checkpoints(args, checkpoint_prefix)

            torch.save(optimizer.state_dict(),
                       os.path.join(output_dir, "optimizer.pt"))
            torch.save(scheduler.state_dict(),
                       os.path.join(output_dir, "scheduler.pt"))
            logger.info("Saving optimizer and scheduler states to %s",
                        output_dir)
        logger.info(" best_step = %s, best loss = %s", best_step, best_loss)

        return best_loss, best_step

    def get_hiddens_of_model(input):
        model.zero_grad()
        if args.model_type == 'roberta':
            _, _, hiddens = model.roberta(input)
        elif args.model_type == 'bert':
            _, _, hiddens = model.bert(input)
        elif args.model_type == 'albert':
            _, _, hiddens = model.albert(input)
        elif args.model_type == 'dbert':
            _, hiddens = model.distilbert(input)
        elif args.model_type == 'electra':
            _, hiddens = model.electra(input)
        elif args.model_type == 'gpt2':
            _, _, hiddens = model.transformer(input)
        elif args.model_type == 'gpt':
            _, hiddens = model.transformer(input)

        return hiddens

    def attribute_vector_example():
        attributes_hiddens = {f'attribute{i}': [] for i in range(2)}

        dataloaders, _, distribution = create_dataloader(args,
                                                         train_datasets,
                                                         tokenizer,
                                                         train=True)
        for key in distribution:
            if key != 'neutral':
                inputs, labels = next(dataloaders[key])
                inputs = inputs.to(args.device)
                hiddens = get_hiddens_of_model(inputs)
                hiddens = torch.stack(hiddens, 2)
                if labels.size(1) > 1:
                    onehot = torch.eye(hiddens.size(1))
                    zeros = torch.zeros(1, onehot.size(0))
                    onehot = torch.cat((zeros, onehot), 0)
                    onehot = onehot[labels]
                    onehot = torch.sum(onehot, 1)
                    onehot = onehot.view(hiddens.size(0), -1, 1, 1)
                else:
                    onehot = torch.eye(hiddens.size(1))[labels].view(
                        hiddens.size(0), -1, 1, 1)
                onehot = onehot.to(args.device)
                attributes_hiddens[key].append(
                    torch.sum(hiddens * onehot, 1) / labels.size(1))

        # neutralも含まれている
        attribute_size = len(data['train']['example'])
        for i in range(attribute_size - 1):
            attributes_hiddens[f'attribute{i}'] = torch.mean(
                torch.cat(attributes_hiddens[f'attribute{i}'], 0),
                0).detach().unsqueeze(0)

        return attributes_hiddens

    def forward(attributes_hiddens, dataloaders, key):
        inputs = next(dataloaders[key])
        if len(inputs) == 2:
            inputs, labels = inputs
            labels = labels.to(args.device)
        else:
            labels = None
        inputs = inputs.to(args.device)
        if args.model_type == 'roberta':
            final_layer_hiddens, first_token_hidden, all_layer_hiddens = model.roberta(
                inputs)
            if 'neutral' != key:
                with torch.no_grad():
                    final_layer_original_hiddens, _, all_layer_original_hiddens = original_model.roberta(
                        inputs)
                if args.token_loss:
                    token_predicts = model.lm_head(final_layer_hiddens)
                    token_original = original_model.lm_head(
                        final_layer_original_hiddens)
        elif args.model_type == 'bert':
            final_layer_hiddens, first_token_hidden, all_layer_hiddens = model.bert(
                inputs)
            if 'neutral' != key:
                with torch.no_grad():
                    final_layer_original_hiddens, _, all_layer_original_hiddens = original_model.bert(
                        inputs)
                if args.token_loss:
                    token_predicts = model.cls(final_layer_hiddens)
                    token_original = original_model.cls(
                        final_layer_original_hiddens)
        elif args.model_type == 'albert':
            final_layer_hiddens, first_token_hidden, all_layer_hiddens = model.albert(
                inputs)
            if 'neutral' != key:
                with torch.no_grad():
                    final_layer_original_hiddens, _, all_layer_original_hiddens = original_model.albert(
                        inputs)
                if args.token_loss:
                    token_predicts = model.classifier(final_layer_hiddens)
                    token_original = original_model.classifier(
                        final_layer_original_hiddens)
        elif args.model_type == 'dbert':
            final_layer_hiddens, all_layer_hiddens = model.distilbert(inputs)
            if 'neutral' != key:
                with torch.no_grad():
                    final_layer_original_hiddens, all_layer_original_hiddens = original_model.distilbert(
                        inputs)
                if args.token_loss:
                    token_predicts = model.classifier(final_layer_hiddens)
                    token_original = original_model.classifier(
                        final_layer_original_hiddens)
        elif args.model_type == 'electra':
            final_layer_hiddens, all_layer_hiddens = model.electra(inputs)
            if 'neutral' != key:
                with torch.no_grad():
                    final_layer_original_hiddens, all_layer_original_hiddens = original_model.electra(
                        inputs)
                if args.token_loss:
                    hiddens = model.generator_predictions(final_layer_hiddens)
                    token_predicts = model.generator_lm_head(hiddens)
                    original_hiddens = original_model.generator_predictions(
                        final_layer_original_hiddens)
                    token_original = original_model.generator_lm_head(
                        original_hiddens)
        elif args.model_type == 'gpt2':
            final_layer_hiddens, first_token_hidden, all_layer_hiddens = model.transformer(
                inputs)
            if 'neutral' != key:
                with torch.no_grad():
                    final_layer_original_hiddens, _, all_layer_original_hiddens = original_model.transformer(
                        inputs)
                if args.token_loss:
                    token_predicts = model.lm_head(final_layer_hiddens)
                    token_original = original_model.lm_head(
                        final_layer_original_hiddens)
        elif args.model_type == 'gpt':
            final_layer_hiddens, all_layer_hiddens = model.transformer(inputs)
            if 'neutral' != key:
                with torch.no_grad():
                    final_layer_original_hiddens, all_layer_original_hiddens = original_model.transformer(
                        inputs)
                if args.token_loss:
                    token_predicts = model.lm_head(final_layer_hiddens)
                    token_original = original_model.lm_head(
                        final_layer_original_hiddens)

        all_layer_hiddens = torch.stack(all_layer_hiddens, 2)
        if 'neutral' != key:
            all_original_hiddens = torch.stack(all_layer_original_hiddens, 2)
            all_original_hiddens = all_original_hiddens.detach()
            if args.token_loss:
                original_hiddens - original_hiddens.detach()
                token_original = token_original.detach()
        if args.debias_layer == 'all':
            target_layer_hiddens = all_layer_hiddens
            target_original_hiddens = all_layer_hiddens
        else:
            if args.debias_layer == 'first':
                idx = 0
            elif args.debias_layer == 'last':
                idx = -1
            target_layer_hiddens = all_layer_hiddens[:, :, idx]
            target_layer_hiddens = target_layer_hiddens.unsqueeze(2)
            if 'neutral' != key:
                target_original_hiddens = all_original_hiddens[:, :, idx]
                target_original_hiddens = target_original_hiddens.unsqueeze(2)
            else:
                attributes_hiddens = {
                    key: value[:, idx, :].unsqueeze(1)
                    for key, value in attributes_hiddens.items()
                }

        if args.loss_target == 'sentence' or labels is None:
            attributes_hiddens = {
                key: value.unsqueeze(1)
                for key, value in attributes_hiddens.items()
            }
        #elif args.loss_target == 'token' and key == 'neutral':
        elif args.loss_target == 'token':
            if labels.size(1) > 1:
                onehot = torch.eye(target_layer_hiddens.size(1))
                zeros = torch.zeros(1, onehot.size(0))
                onehot = torch.cat((zeros, onehot), 0)
                onehot = onehot[labels]
                onehot = torch.sum(onehot, 1)
                onehot = onehot.view(target_layer_hiddens.size(0), -1, 1, 1)
            else:
                onehot = torch.eye(target_layer_hiddens.size(1))[labels].view(
                    target_layer_hiddens.size(0), -1, 1, 1)
            onehot = onehot.to(args.device)
            target_layer_hiddens = torch.sum(target_layer_hiddens * onehot,
                                             1).unsqueeze(1) / labels.size(1)
            if 'neutral' != key:
                target_original_hiddens = torch.sum(
                    target_original_hiddens * onehot,
                    1).unsqueeze(1) / labels.size(1)
            else:
                attributes_hiddens = {
                    key: value.expand(target_layer_hiddens.size(0), 1,
                                      value.size(1), value.size(2))
                    for key, value in attributes_hiddens.items()
                }

        if 'neutral' == key:
            loss = 0
            for attribute_hiddens in attributes_hiddens.values():
                tmp_loss = criterion_ip(target_layer_hiddens,
                                        attribute_hiddens)
                if args.square_loss:
                    tmp_loss = tmp_loss**2
                tmp_loss *= alpha
                loss += tmp_loss
        else:
            #loss = criterion_ms(target_layer_hiddens, target_original_hiddens)
            loss = criterion_ms(all_layer_hiddens, all_original_hiddens, 3)
            if args.token_loss:
                loss += criterion_ms(token_predicts, token_original, 2)
                #loss += criterion_ms(hiddens, original_hiddens, 2)
            loss *= beta

        return loss

    #def evaluate(args, model: PreTrainedModel, original_model, dev_dataloaders, dev_example_num, dev_distribution, criterion_mse, criterion_ip, feminine_hiddens, masculine_hiddens, gender_hiddens, prefix="") -> Dict:
    def evaluate(model, attributes_hiddens, dev_dataloaders, prefix=""):
        # Loop to handle MNLI double evaluation (matched, mis-matched)
        eval_output_dir = args.output_dir

        if args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir, exist_ok=True)

        args.eval_batch_size = args.per_gpu_eval_batch_size * max(
            1, args.n_gpu)
        # Note that DistributedSampler samples randomly

        # multi-gpu evaluate
        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Eval!
        logger.info("***** Running evaluation {} *****".format(prefix))
        logger.info("  Num examples = %d", dev_example_num)
        logger.info("  Batch size = %d", args.eval_batch_size)
        eval_loss = 0.0
        model.eval()
        #criterion.eval()

        for key in tqdm(dev_distribution):
            with torch.no_grad():
                loss = forward(attributes_hiddens, dev_dataloaders, key)

                eval_loss += loss.item()

                model.zero_grad()
                original_model.zero_grad()

        output_eval_file = os.path.join(eval_output_dir, prefix,
                                        "eval_results.txt")
        '''
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results {} *****".format(prefix))
            logger.info("  Loss = %s", eval_loss)
            writer.write("Loss = %s\n" % (eval_loss))
        '''

        return eval_loss

    #criterion_ms = torch.nn.MSELoss()
    criterion_ms = mean_square
    #criterion.train()
    criterion_ip = inner_product
    original_model.eval()

    alpha, beta = args.weighted_loss
    alpha = float(alpha)
    beta = float(beta)

    train_loss = 0.0

    for _ in train_iterator:

        random.shuffle(train_distribution)
        epoch_iterator = tqdm(train_distribution,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])

        model.eval()
        with torch.no_grad():
            attributes_hiddens = attribute_vector_example()

        for step, key in enumerate(epoch_iterator):
            model.train()

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            loss = forward(attributes_hiddens, train_dataloaders, key)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            train_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                original_model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logger.info(" global_step = %s, train loss = %s",
                                global_step, train_loss)
                    train_loss = 0.0
                    # Log metrics
                    best_loss, best_step = save_best_model(
                        best_loss, best_step, dev_dataloaders)
                    dev_dataloaders, dev_example_num, dev_distribution = create_dataloader(
                        args, dev_datasets, tokenizer, train=False)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
            train_dataloaders, train_example_num, train_distribution = create_dataloader(
                args, train_datasets, tokenizer, train=True)

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    dev_dataloaders, dev_example_num, dev_distribution = create_dataloader(
        args, dev_datasets, tokenizer, train=False)
    best_loss, best_step = save_best_model(best_loss, best_step,
                                           dev_dataloaders)

    if args.local_rank in [-1, 0]:
        tb_writer.close()