Example #1
0
def train_model(df: pd.DataFrame, tokenizer: transformers.PreTrainedTokenizer,
                model: transformers.PreTrainedModel, steps: int, batch_size: int, save_path:str) -> None:
    
    
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    #device = torch.device('cpu')
    model.to(device)
    model.train()

    optim = AdamW(model.parameters(), lr=5e-5)

    losses = []
    for step in trange(steps):
    
        optim.zero_grad()

        sample = df.sample(batch_size)

        X = sample['articles'].tolist()
        y = sample['labels'].tolist()
        
        inputs = tokenizer(X, return_tensors='pt', padding=True, truncation=True)
        
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        labels = torch.tensor(y).unsqueeze(1).to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)

        loss = outputs.loss
        losses.append(loss)
        
        if (step + 1) % 100 == 0:
            print(f'Step: {step + 1} Loss: {sum(losses)/len(losses)}')
            send_message(f'Step: {step + 1} Loss: {sum(losses)/len(losses)}')
            losses = []
            
        loss.backward()
        optim.step()

    model.save_pretrained(save_path)
Example #2
0
def train_step(model: PreTrainedModel, batched_input: Dict,
               device: torch.device):
    model.train()
    for k, v in batched_input.items():
        batched_input[k] = v.to(device)

    # check device
    if batched_input["input_ids"].device.type != "cuda":
        logging.info(
            "Warning: inputs not on cuda. Is on {batched_input['input_ids'].device.type}"
        )

    # INPUT to model
    # GPT2 has labels param and BERT has masked_lm_labels param
    # both have input_ids, attention_mask
    output = model(**batched_input)

    # OUTPUT
    # output[0] is loss
    # output[1] is prediction scores (logits)

    return (output[0], output[1])
Example #3
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model_to_resize = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer,
                                         args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs,
                            masked_lm_labels=labels) if args.mlm else model(
                                inputs, labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Example #4
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    model = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    non_multi_model = model
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(non_multi_model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility
    best_perplexity = float('inf')
    for i, epoch in enumerate(train_iterator):
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])

        if args.local_rank != -1:
            train_sampler.set_epoch(epoch)

        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer,
                                         args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs,
                            masked_lm_labels=labels) if args.mlm else model(
                                inputs, labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

        if args.do_eval:
            file_path = Path(args.data_dir, args.eval_data_file)

            out_file_path = Path(args.data_dir,
                                 "output_" + args.eval_data_file)
            id_to_json_map = {}
            with open(file_path, encoding="utf-8") as f:
                lines = []
                i = 0

                eval_loss = 0.0
                nb_eval_steps = 0
                for line in tqdm(f, desc="Evaluating"):
                    out_json = {}
                    line = json.loads(line)
                    example_id = line.get("example_id")
                    question_text = line.get("question_text")

                    prompt_text = question_text + " " + args.sep_token + " "
                    encoded_prompt = tokenizer.encode(prompt_text,
                                                      add_special_tokens=False,
                                                      return_tensors="pt")
                    encoded_prompt = encoded_prompt.to(args.device)

                    output_sequences = non_multi_model.generate(
                        input_ids=encoded_prompt,
                        max_length=args.length + len(encoded_prompt[0]),
                        temperature=args.temperature,
                        top_k=args.k,
                        top_p=args.p,
                        repetition_penalty=args.repetition_penalty,
                        do_sample=True,
                        num_return_sequences=args.num_return_sequences,
                    )
                    if len(output_sequences.shape) > 2:
                        output_sequences.squeeze_()

                    generated_sequences = []

                    for generated_sequence_idx, generated_sequence in enumerate(
                            output_sequences):
                        # print("=== GENERATED SEQUENCE {} ===".format(generated_sequence_idx + 1))
                        # generated_sequence = output_sequences[0]

                        generated_sequence = generated_sequence.tolist()

                        # Decode text
                        text = tokenizer.decode(
                            generated_sequence,
                            clean_up_tokenization_spaces=True)

                        # Remove all text after the stop token
                        if args.stop_token:
                            text = text[:text.find(args.stop_token)]

                        # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
                        total_sequence = (prompt_text + text[len(
                            tokenizer.decode(encoded_prompt[0],
                                             clean_up_tokenization_spaces=True)
                        ):])

                        # print(total_sequence)

                        out_json["journaling_input"], out_json[
                            "reflection_output"] = total_sequence.split(
                                args.sep_token)[:2]

                        sample_dataset = GenerateTextDataset(
                            tokenizer, total_sequence, args.block_size)

                        def collate(examples: List[torch.Tensor]):
                            if tokenizer._pad_token is None:
                                return pad_sequence(examples, batch_first=True)
                            return pad_sequence(
                                examples,
                                batch_first=True,
                                padding_value=tokenizer.pad_token_id)

                        eval_sampler = SequentialSampler(sample_dataset)
                        eval_dataloader = DataLoader(sample_dataset,
                                                     sampler=eval_sampler,
                                                     batch_size=1,
                                                     collate_fn=collate)

                        model_lm = model
                        if args.n_gpu > 1:
                            model_lm = torch.nn.DataParallel(model_lm)

                        model_lm.eval()

                        for batch in eval_dataloader:
                            inputs, labels = mask_tokens(
                                batch, tokenizer,
                                args) if args.mlm else (batch, batch)
                            inputs = inputs.to(args.device)
                            labels = labels.to(args.device)

                            with torch.no_grad():
                                outputs = model_lm(inputs,
                                                   masked_lm_labels=labels
                                                   ) if args.mlm else model_lm(
                                                       inputs, labels=labels)
                                lm_loss = outputs[0]
                                example_loss = lm_loss.mean().item()
                                eval_loss += example_loss
                            nb_eval_steps += 1

                        perplexity = torch.exp(
                            torch.tensor(example_loss)).item()
                        # print(perplexity)
                        out_json["perplexity"] = perplexity

                        example_id += "-" + str(generated_sequence_idx)
                        id_to_json_map[example_id] = json.dumps(
                            out_json, ensure_ascii=False)

                    # result = {"perplexity": perplexity}

                eval_loss = eval_loss / nb_eval_steps
                total_perplexity = torch.exp(torch.tensor(eval_loss))
                logger.info(f"total_loss:: {eval_loss}")
                logger.info(
                    f"total_perplexity:: {torch.exp(torch.tensor(eval_loss))}")
                if total_perplexity < best_perplexity:
                    logger.info(
                        f"Current best epoch::: {i}, with perplexity:: {total_perplexity}"
                    )
                    best_perplexity = total_perplexity

                    with open(out_file_path, "w+",
                              encoding="utf-8") as out_file:
                        for _, out_json in id_to_json_map.items():
                            out_file.write(out_json + "\n")

                    model_to_save = model.module if hasattr(
                        model,
                        'module') else model  # Only save the model it-self

                    # If we save using the predefined names, we can load using `from_pretrained`
                    output_model_file = os.path.join(args.output_dir,
                                                     WEIGHTS_NAME)
                    output_config_file = os.path.join(args.output_dir,
                                                      CONFIG_NAME)
                    torch.save(model_to_save.state_dict(), output_model_file)
                    model_to_save.config.to_json_file(output_config_file)
                    tokenizer.save_vocabulary(args.output_dir)

    return global_step, tr_loss / global_step
Example #5
0
def train(args, train_dataset, eval_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer, run_batch_fn_train,
          run_batch_fn_eval) -> Tuple[int, float]:
    if args.local_rank in [-1, 0]:
        log_dir = os.path.join("runs",
                               args.exp_name) if args.exp_name else None
        tb_writer = SummaryWriter(log_dir)
        args.output_dir = log_dir

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=train_dataset.collate_fn)

    t_total = len(train_dataloader
                  ) // args.gradient_accumulation_steps * args.num_train_epochs

    optimizer = AdamW(model.parameters(),
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    global_step = 0
    model.zero_grad()
    train_iterator = trange(0,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # for reproducibility

    for _ in train_iterator:
        local_steps = 0
        tr_loss = 0.0
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            loss, _, _, _ = run_batch_fn_train(args, model, batch)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1
                local_steps += 1
                epoch_iterator.set_postfix(Loss=tr_loss / local_steps)

        results = evaluate(args,
                           eval_dataset,
                           model,
                           tokenizer,
                           run_batch_fn_eval,
                           desc=str(global_step))
        if args.local_rank in [-1, 0]:
            for key, value in results.items():
                tb_writer.add_scalar("eval_{}".format(key), value, global_step)
            tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
            tb_writer.add_scalar("loss", tr_loss / local_steps, global_step)

            checkpoint_prefix = "checkpoint"
            # Save model checkpoint
            output_dir = os.path.join(
                args.output_dir, "{}-{}".format(checkpoint_prefix,
                                                global_step))
            os.makedirs(output_dir, exist_ok=True)
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Take care of distributed/parallel training

            logger.info("Saving model checkpoint to %s", output_dir)
            model_to_save.save_pretrained(output_dir)
            tokenizer.save_pretrained(output_dir)

            torch.save(args, os.path.join(output_dir, "training_args.bin"))
            with open(os.path.join(output_dir, "params.json"),
                      "w") as jsonfile:
                json.dump(args.params,
                          jsonfile,
                          indent=2,
                          default=lambda x: str(x))
            logger.info("Saving model checkpoint to %s", output_dir)

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / local_steps
Example #6
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def pad_examples(examples, padding_value=tokenizer.pad_token_id):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=padding_value)

    def collate(examples):
        text_examples = [None] * len(examples)
        text_labels = [None] * len(examples)
        text_type_ids = [None] * len(examples)

        video_examples = [None] * len(examples)
        video_labels = [None] * len(examples)
        video_type_ids = [None] * len(examples)

        joint_examples = [None] * len(examples)
        joint_labels = [None] * len(examples)
        joint_type_ids = [None] * len(examples)

        for i, (te, tl, tti, ve, vl, vti, je, jl, jti) in enumerate(examples):
            text_examples[i] = te
            video_examples[i] = ve
            text_labels[i] = tl
            video_labels[i] = vl

            text_type_ids[i] = tti
            video_type_ids[i] = vti

            joint_examples[i] = je
            joint_labels[i] = jl
            joint_type_ids[i] = jti

        padded_text_ids = pad_examples(text_examples)
        text_attention_mask = torch.ones(padded_text_ids.shape,
                                         dtype=torch.int64)
        text_attention_mask[(padded_text_ids == 0)] = 0

        padded_video_ids = pad_examples(video_examples)
        video_attention_mask = torch.ones(padded_video_ids.shape,
                                          dtype=torch.int64)
        video_attention_mask[(padded_video_ids == 0)] = 0

        padded_joint_ids = pad_examples(joint_examples)
        joint_attention_mask = torch.ones(padded_joint_ids.shape,
                                          dtype=torch.int64)
        joint_attention_mask[(padded_joint_ids == 0)] = 0

        return padded_text_ids, \
               torch.tensor(text_labels, dtype=torch.int64), \
               pad_examples(text_type_ids, padding_value=0), \
               text_attention_mask, \
               padded_video_ids, \
               torch.tensor(video_labels, dtype=torch.int64), \
               pad_examples(video_type_ids, padding_value=0), \
               video_attention_mask, \
               padded_joint_ids, \
               torch.tensor(joint_labels, dtype=torch.int64), \
               pad_examples(joint_type_ids, padding_value=0), \
               joint_attention_mask

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])

        if args.local_rank != -1:
            train_sampler.set_epoch(epoch)

        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            text_ids = batch[0]
            text_seq_labels = batch[1]
            text_token_type_ids = batch[2]
            text_attention_masks = batch[3]

            video_ids = batch[4]
            video_seq_labels = batch[5]
            video_token_type_ids = batch[6]
            video_attention_masks = batch[7]

            joint_ids = batch[8]
            joint_labels = batch[9]
            joint_token_type_ids = batch[10]
            joint_attention_masks = batch[11]

            text_inputs, text_mask_labels = mask_tokens(
                text_ids, tokenizer, args) if args.mlm else (text_ids,
                                                             text_ids)
            video_inputs, video_mask_labels = mask_tokens(
                video_ids, tokenizer, args) if args.mlm else (video_ids,
                                                              video_ids)
            joint_inputs, joint_mask_labels = mask_tokens(
                joint_ids, tokenizer, args) if args.mlm else (joint_ids,
                                                              joint_ids)

            text_inputs = text_inputs.to(args.device)
            text_mask_labels = text_mask_labels.to(args.device)
            text_seq_labels = text_seq_labels.to(args.device)

            text_token_type_ids = text_token_type_ids.to(args.device)
            video_token_type_ids = video_token_type_ids.to(args.device)
            joint_token_type_ids = joint_token_type_ids.to(args.device)

            text_attention_masks = text_attention_masks.to(args.device)
            video_attention_masks = video_attention_masks.to(args.device)
            joint_attention_masks = joint_attention_masks.to(args.device)

            video_inputs = video_inputs.to(args.device)
            video_mask_labels = video_mask_labels.to(args.device)
            video_seq_labels = video_seq_labels.to(args.device)

            joint_inputs = joint_inputs.to(args.device)
            joint_mask_labels = joint_mask_labels.to(args.device)
            joint_labels = joint_labels.to(args.device)

            model.train()

            outputs = model(
                text_input_ids=text_inputs,
                video_input_ids=video_inputs,
                joint_input_ids=joint_inputs,
                text_token_type_ids=text_token_type_ids,
                video_token_type_ids=video_token_type_ids,
                joint_token_type_ids=joint_token_type_ids,
                text_attention_mask=text_attention_masks,
                video_attention_mask=video_attention_masks,
                joint_attention_mask=joint_attention_masks,
                text_masked_lm_labels=text_mask_labels,
                video_masked_lm_labels=video_mask_labels,
                joint_masked_lm_labels=joint_mask_labels,
                text_next_sentence_label=text_seq_labels,
                video_next_sentence_label=video_seq_labels,
                joint_vis_lin_label=joint_labels,
            )

            loss = outputs[0]
            text_loss = outputs[1]
            video_loss = outputs[2]
            joint_loss = outputs[3]

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
                # text_loss = text_loss / args.gradient_accumulation_steps
                # video_loss = video_loss / args.gradient_accumulation_steps
                # joint_loss = joint_loss / args.gradient_accumulation_steps

            loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

                print('loss:', loss.item(), 'text loss:', text_loss.item(),
                      'video loss:', video_loss.item(), 'joint loss:',
                      joint_loss.item())

                # keep BERT embeddings frozen
                model.bert.embeddings.word_embeddings.weight.grad[
                    globals.frozen_indices] = 0

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    print('writing tf logs...')
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    tb_writer.add_scalar("text_loss", text_loss.item(),
                                         global_step)
                    tb_writer.add_scalar("video_loss", video_loss.item(),
                                         global_step)
                    tb_writer.add_scalar("joint_loss", joint_loss.item(),
                                         global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Example #7
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    tb_writer = SummaryWriter(log_dir=str(
        Path(args.output_dir).parent / "tensorboard" /
        Path(args.output_dir).name))
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=None)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0

    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:

            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model_to_resize = model.module if hasattr(model, "module") else model
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch")
    set_seed(args)
    lm_losses = []
    mc_losses = []
    mc_logitsp, mc_logitsn = [], []
    total_samples = 0
    pos_samples = 0
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              dynamic_ncols=True)
        for step, (batch_lm, mc_labels) in enumerate(epoch_iterator):
            inputs, lm_labels = batch_lm, batch_lm
            inputs = inputs.to(args.device)
            lm_labels = lm_labels.to(args.device)
            mc_labels = mc_labels.to(args.device)
            model.train()
            outputs = model(inputs, lm_labels=lm_labels, mc_labels=mc_labels)
            lm_loss = torch.where(mc_labels == 1, outputs[0],
                                  torch.zeros_like(outputs[0]))

            mc_loss = outputs[1]
            if torch.any(mc_labels == 1):
                if len(mc_logitsp) > 100:
                    mc_logitsp.pop(0)
                mc_logitsp.append(outputs[3][0, 1].squeeze().item())
                pos_samples += 1
            else:
                if len(mc_logitsn) > 100:
                    mc_logitsn.pop(0)
                mc_logitsn.append(outputs[3][0, 1].squeeze().item())
            total_samples += 1
            loss = lm_loss + mc_loss

            loss.backward()

            if lm_loss.item() != 0.0:
                if len(lm_losses) > 100:
                    lm_losses.pop(0)
                lm_losses.append(lm_loss.item())
            if len(mc_losses) > 100:
                mc_losses.pop(0)
            mc_losses.append(mc_loss.item())

            mean_lm_loss = np.mean(lm_losses) if lm_losses else 0
            train_results = {
                "lm_loss": mean_lm_loss,
                "mc_loss": np.mean(mc_losses),
                "ppl": 2**mean_lm_loss
            }
            for key, value in train_results.items():
                tb_writer.add_scalar("train_{}".format(key), value,
                                     global_step)

            train_info = f"#{step:3d} lm_loss: {train_results['lm_loss']:6.4f} " \
                         f"mc_loss: {train_results['mc_loss']:6.4f}" \
                         f" ppl: {train_results['ppl']:6.2f}" \
                         f" mc+/-: {np.mean(mc_logitsp):3.2f}/{np.mean(mc_logitsn):3.2f}" \
                         f" {100 * pos_samples / total_samples:.1f}"
            epoch_iterator.set_description(train_info)
            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # results = evaluate(args, model, tokenizer)
                    # for key, value in results.items():
                    #     tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss
                    tb_writer.flush()

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    save_checkpoint(args, global_step, model, optimizer,
                                    scheduler, tokenizer)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    tb_writer.close()

    return global_step, tr_loss / global_step
def train(args, train_dataset: CoLDataset, valid_dataset: CoLDataset,
          model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    set_seed(args)  # Added here for reproducibility

    """ Train the model """
    if args.gpu == 0:
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        tb_writer = SummaryWriter(args.output_dir + '/runs/' + current_time)

    args.train_batch_size = args.per_gpu_train_batch_size

    def col_collate(examples):
        tokens, vokens = zip(*examples)
        if tokenizer._pad_token is None:
            tokens = pad_sequence(tokens, batch_first=True)
        else:
            tokens = pad_sequence(tokens, batch_first=True, padding_value=tokenizer.pad_token_id)
        vokens = pad_sequence(vokens, batch_first=True, padding_value=-100)
        return tokens, vokens

    if args.shuffle:
        logger.info(f"Shuffle the dataset in training,"
                       f"GPU: {args.gpu},"
                       f"Rank: {args.rank},"
                       f"Total: {args.world_size}")
    train_sampler = DistributedSampler(
        train_dataset,
        num_replicas=args.world_size,
        rank=args.rank,
        shuffle=args.shuffle,
    )
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, shuffle=False, num_workers=0,
        batch_size=args.train_batch_size, collate_fn=col_collate, pin_memory=True
    )

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
        # args.num_train_epochs = 9595
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    if args.lamb:
        no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
    else:
        no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    if args.lamb:
        logger.info(f"Using LAMB Optimizer with max grad norm {args.max_grad_norm}")
        import apex
        optimizer = apex.optimizers.FusedLAMB(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            eps=args.adam_epsilon,
            max_grad_norm=args.max_grad_norm
        )
    else:
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          #betas=(0.9, 0.98),
                          eps=args.adam_epsilon)
    if args.gpu == 0:
        print(f"Optimized with lr: {optimizer.defaults['lr']}, total steps: {t_total},"
              f" warmup steps: {args.warmup_steps}, epsilon {optimizer.defaults['eps']},"
              f" beta: {optimizer.defaults['betas']}, weight decay {args.weight_decay}.")
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    # Check if saved optimizer or scheduler states exist
    if (
        args.model_name_or_path
        and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
        and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
        from apex.parallel import DistributedDataParallel as DDP
        model = DDP(model)
    else:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], find_unused_parameters=True
        )

    # Allow not calculating the lm heads.
    if args.mlm_ratio == 0.:
        model.lm_head = None


    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * args.world_size
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    # Check if continuing training from a checkpoint
    # if args.model_name_or_path and os.path.exists(args.model_name_or_path):
    #     try:
    #         # set global_step to gobal_step of last saved checkpoint from model path
    #         checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
    #         epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
    #         steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
    #         logger.info("  Continuing training from checkpoint, will skip to saved global_step")
    #         logger.info("  Continuing training from epoch %d", epochs_trained)
    #     except ValueError:
    #         logger.info("  Do not load model from %s, restart training" % args.model_name_or_path)

    model_to_resize = model.module if hasattr(model, "module") else model  # Take care of distributed/parallel training
    assert model_to_resize.config.vocab_size == len(tokenizer)
    # model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.gpu != 0
    )
    set_seed(args)  # Added here for reproducibility
    LOSS_NAMES = ['token_loss', 'voken_loss', 'total_loss']
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.gpu != 0)
        tr_loss, logging_loss = np.zeros(len(LOSS_NAMES)), 0.0
        model.zero_grad()
        for step, (tokens, vokens) in enumerate(epoch_iterator):
            token_inputs, token_labels, voken_labels = mask_tokens(tokens, vokens, tokenizer, args)
            token_inputs = token_inputs.to(args.device)
            token_labels = token_labels.to(args.device) if args.mlm_ratio != 0. else None
            voken_labels = voken_labels.to(args.device)
            # If some of the input is padded, then the attention mask is needed
            attention_mask = (token_inputs != tokenizer.pad_token_id)         # word_tokens --> 1, pad_token --> 0
            if attention_mask.all():
                attention_mask = None

            if epoch == 0 and step < 3 and args.gpu == 0:
                print()
                print("Token inputs:", token_inputs.shape, token_inputs[0])
                print("Token inputs (in str): ", tokenizer.convert_ids_to_tokens(token_inputs[0].cpu().numpy()))
                print("Attention Mask:", attention_mask)
                print("Token Labels: ", token_labels[0] if token_labels is not None else token_labels)
                print("Token Labels (in str): ", tokenizer.convert_ids_to_tokens(token_labels[0].cpu().numpy()) if token_labels is not None else token_labels)
                print("Voken Labels: ", voken_labels[0])
                print()

            model.train()
            outputs = model(token_inputs,
                            attention_mask=attention_mask,
                            masked_lm_labels=token_labels,
                            voken_labels=voken_labels)
            voken_loss = outputs[0]
            token_loss = outputs[1]

            if args.mlm_ratio == 0.:
                loss = voken_loss
            else:
                loss = voken_loss + args.mlm_ratio * token_loss

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            # print(f"GPU: {args.gpu}, Global Step: {global_step + 1}, "
            #       f"Step: {step}, "
            #       f"Range: {train_dataset.get_item_info(step * args.world_size + args.gpu)}, "
            #       f"Loss: {loss.item()}, "
            #       f"Scaled Loss: {scaled_loss.item()}")

            tr_loss += np.array((token_loss.item() / args.gradient_accumulation_steps,
                                 voken_loss.item() / args.gradient_accumulation_steps,
                                 loss.item()))

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.max_grad_norm > 0. and not args.lamb:
                    # Only clip the grad when it is valid and not using LAMB optimizer,
                    # because the LAMB optimizer already apply grad clipping
                    if args.fp16:
                        total_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                elif args.max_grad_norm <= 0. and step <= args.gradient_accumulation_steps:
                    logger.warning("Have not clipped the gradient because "
                                   "the max_grad_norm is set to %0.2f" % args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.gpu == 0 and args.logging_steps > 0 and (step + 1) % args.logging_steps == 0:
                    # Log metrics
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    if args.fp16:
                        try:
                            from apex.amp import _amp_state
                            tb_writer.add_scalar("loss_scale", _amp_state.loss_scalers[0]._loss_scale, global_step)
                            tb_writer.add_scalar("scaled_loss", scaled_loss.item(), global_step)
                        except ImportError:
                            logger.warning("Cannot import apex.amp._amp_state, "
                                           "would not state the loss_scale in the log")
                    if args.max_grad_norm > 0. and not args.lamb:  # Only clip the grad when it is valid
                        tb_writer.add_scalar("grad_norm", total_norm, global_step)
                    interval_loss = (tr_loss - logging_loss) / args.logging_steps
                    for loss_idx, loss_name in enumerate(LOSS_NAMES):
                        tb_writer.add_scalar(loss_name, interval_loss[loss_idx], global_step)
                    logging_loss = tr_loss.copy()

            if args.max_steps > 0 and global_step >= args.max_steps:
                break

            # if step == 200:
            #     break
            #
        # Save it each epoch
        if args.gpu == 0:
            # Save checkpoints
            checkpoint_name = "checkpoint-epoch%04d" % epoch
            save_model(args, checkpoint_name, model, tokenizer, optimizer, scheduler)

            # last_path = os.path.join(args.output_dir, 'checkpoint-last')
            # if os.path.exists(last_path):
            #     os.remove(last_path)
            # os.symlink(os.path.join(args.output_dir, checkpoint_name), last_path)

            # Evaluate the model
            for loss_idx, loss_name in enumerate(LOSS_NAMES):
                logger.info(" Training %s of Epoch %d: %0.4f" % (
                    loss_name, epoch, tr_loss[loss_idx] / len(train_dataloader)))

            if args.do_eval:
                logger.info(" Evaluation Results of Epoch %d: " % epoch)
                old_eval_batch_size = args.per_gpu_eval_batch_size
                while args.per_gpu_eval_batch_size > 0:
                    try:
                        results = evaluate(args, valid_dataset, model, tokenizer)
                        break
                    except RuntimeError as e:
                        args.per_gpu_eval_batch_size = int(args.per_gpu_eval_batch_size / 2)
                        print("HALVE THE BATCH SIZE in EVAL.")
                        if args.per_gpu_eval_batch_size == 0:
                            raise e
                        time.sleep(5)
                args.per_gpu_eval_batch_size = old_eval_batch_size

                for key, value in results.items():
                    tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    logger.info("\t %s: %0.4f" % (key, value))
                tb_writer.add_scalar("epoch", epoch, global_step)
                output_eval_file = os.path.join(args.output_dir, checkpoint_name, "eval_results.json")
                json.dump(results, open(output_eval_file, 'w'), sort_keys=True, indent=4)
            # Currently, only GPU 0 is responsible for the evaluation.
            # torch.cuda.empty_cache()
            # torch.distributed.barrier()
        else:
            pass
            # torch.cuda.empty_cache()
            # torch.distributed.barrier()

        if args.max_steps > 0 and global_step >= args.max_steps:
            epoch_iterator.close()
            train_iterator.close()
            break

    if args.gpu == 0:
        tb_writer.close()
Example #9
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """

    # GX: you can make changes here.
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter(log_dir='tb_logs_' + args.descr_string)
        log_string_config = '  '.join(
            [k + ':' + str(v) for k, v in vars(args).items()])
        tb_writer.add_text('experiment args', log_string_config, 0)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    model = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            args_dict = vars(
                torch.load(
                    os.path.join(args.model_name_or_path,
                                 "training_args.bin")))
            global_step = args_dict['global_step']
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except:
            logger.info("  Starting fine-tuning.")

    # all steps accumulated loss, loss per epoch, loss per gradient update step.
    tr_loss, epoch_loss, logging_loss, val_loss_previous_best = 0.0, 0.0, 0.0, 1e8

    model.zero_grad()
    # trange(i) is a special optimised instance of tqdm(range(i))
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer,
                                         args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs,
                            masked_lm_labels=labels) if args.mlm else model(
                                inputs, labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                ## ??should I modify the changes in gradients here?
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()  # backpropagate on gradient_accumulation_step
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    # doing the things on the perplexity per batch.
                    cur_gas_loss = (
                        tr_loss - logging_loss
                    ) / args.logging_steps  # the loss during the last gradient_accumulation_step
                    cur_gas_perplexit = np.exp(cur_gas_loss)
                    tb_writer.add_scalar("train/iter_loss", cur_gas_loss,
                                         global_step)
                    tb_writer.add_scalar("train/iter_preplexity",
                                         cur_gas_perplexit, global_step)
                    tb_writer.add_scalar("progress/lr",
                                         scheduler.get_lr()[0], global_step)
                    logging_loss = tr_loss

                    ## save and evaluation.
                    current_prefix = "checkpoint-current"
                    best_prefix = "checkpoint-best"
                    args.global_step = global_step

                    # in case either torch.save or tokenizer.save does not overwrite
                    current_output_dir = os.path.join(args.output_dir,
                                                      current_prefix)
                    if os.path.isdir(current_output_dir):
                        shutil.rmtree(current_output_dir)
                    os.makedirs(current_output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(
                        current_output_dir)  # build-in save function.
                    tokenizer.save_pretrained(current_output_dir)
                    torch.save(
                        args,
                        os.path.join(current_output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s",
                                current_output_dir)
                    torch.save(
                        optimizer.state_dict(),
                        os.path.join(current_output_dir, "optimizer.pt"))
                    torch.save(
                        scheduler.state_dict(),
                        os.path.join(current_output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                current_output_dir)

                    # saving the best model.
                    val_loss, preplexity_val = evaluate(args,
                                                        model,
                                                        tokenizer,
                                                        str(global_step),
                                                        debug=args.debug)
                    tb_writer.add_scalar("val/preplexity", preplexity_val,
                                         global_step)
                    tb_writer.add_scalar("val/loss", val_loss, global_step)

                    if val_loss < val_loss_previous_best:
                        logger.info(
                            "The current model surpass the previous best.")
                        best_output_dir = os.path.join(args.output_dir,
                                                       best_prefix)
                        if os.path.isdir(best_output_dir):
                            shutil.rmtree(best_output_dir)
                        os.makedirs(best_output_dir)
                        copy_tree(current_output_dir, best_output_dir)
                        logger.info(
                            "Saving the best model by copying  %s to %s",
                            current_output_dir, best_output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

            if args.debug and step == 20:
                print(global_step, step)
                break

        ## record the average loss each step per epoch
        cur_epoch_loss = (tr_loss - epoch_loss) / step
        cur_epoch_prelexity = np.exp(cur_epoch_loss)
        epoch_loss = tr_loss
        tb_writer.add_scalar("train/epoch_loss", cur_epoch_loss, global_step)
        tb_writer.add_scalar("train/epoch_preplexity", cur_epoch_prelexity,
                             global_step)

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer):

    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())

    writer = SummaryWriter()

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  args.train_batch_size,
                                  sampler=train_sampler,
                                  collate_fn=collate,
                                  num_workers=args.num_workers)

    t_total = len(train_dataloader) \
        // args.gradient_accumulation_steps \
        * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                args.warmup_steps,
                                                num_training_steps=t_total)

    if args.fp16:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=params.fp16_opt_level)

    if args.n_gpu > 1:
        model = nn.DataParallel(model)

    # TODO: Loading checkpoint for AMP
    # Train!
    logger.info('***** Running training *****')
    logger.info(f'  Num examples = {len(train_dataset)}')
    logger.info(f'  Num Epochs = {args.num_train_epochs}')
    logger.info(
        f'  Instantaneous batch size per GPU = {args.per_gpu_train_batch_size}'
    )
    logger.info('  Total train batch size (w. parallel, & accumulation) = %d',
                args.train_batch_size * args.gradient_accumulation_steps)
    logger.info(
        f'  Gradient Accumulation steps = {args.gradient_accumulation_steps}')
    logger.info(f'  Total optimization steps = {t_total}')

    global_step = 0
    best_perplexity = 0.0
    training_loss, running_loss = 0.0, 0.0

    # Take care of distributed/parallel training
    model_to_resize = model.module if hasattr(model, "module") else model
    model_to_resize.resize_token_embeddings(len(tokenizer))
    model.train()

    for epoch in range(args.num_train_epochs):
        print(f'Epoch {epoch}/{args.num_train_epochs - 1}')
        print('-' * 10)

        for step, batch in enumerate(tqdm(train_dataloader)):
            inputs, labels = mask_tokens(batch, tokenizer, args) \
                if args.mlm else (batch, batch)
            inputs, labels = inputs.to(args.device), labels.to(args.device)

            # zero the parameter gradients
            optimizer.zero_grad()

            outputs = model(inputs, masked_lm_labels=labels) \
                if args.mlm else model(inputs, labels=labels)
            loss = outputs[0]

            if args.n_gpu > 1:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            training_loss += loss.item()
            running_loss += loss.item() * inputs.size(0)
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                             args.max_grad_norm)
                else:
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             args.max_grad_norm)
                optimizer.step()
                scheduler.step()
                global_step += 1

                # TODO: args.evaluate_during_training
                writer.add_scalar('learning_rate',
                                  scheduler.get_lr()[0], global_step)
                writer.add_scalar('loss/training', training_loss, global_step)
                training_loss = 0.0

        epoch_loss = running_loss / len(train_dataset)
        # TODO: Evaluates and saves checkpoint after every epoch
        result = evaluate(args, model, tokenizer)
        epoch_perplexity = result.get('perplexity')

        if step == 0:
            best_perplexity = epoch_perplexity
        else:
            if epoch_perplexity < best_perplexity:
                best_perplexity = epoch_perplexity

        writer.add_scalar('perplexity per epoch', epoch_perplexity, epoch)
        print(f'Loss: {epoch_loss:.4f} perplexity:{epoch_perplexity}')

    writer.close()

    time_elapsed = time.time() - since
    print('Training completed in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print(f'Perplexity: {best_perplexity}')

    return model
Example #11
0
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    set_seed(args)  # Added here for reproducibility

    """ Train the model """
    if args.gpu == 0:
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        tb_writer = SummaryWriter(args.output_dir + '/runs/' + current_time)

    args.train_batch_size = args.per_gpu_train_batch_size

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)

    if args.shuffle:
        logger.info(f"Shuffle the dataset in training,"
                       f"GPU: {args.gpu},"
                       f"Rank: {args.rank},"
                       f"Total: {args.world_size}")
    train_sampler = DistributedSampler(
        train_dataset,
        num_replicas=args.world_size,
        rank=args.rank,
        shuffle=args.shuffle,
    )
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, shuffle=False, num_workers=0,
        batch_size=args.train_batch_size, collate_fn=collate, pin_memory=True
    )

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      # betas=(0.9, 0.98),
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    if args.warmup_ratio > 0.:
        assert args.warmup_steps == 0
        args.warmup_steps = int(t_total * args.warmup_ratio)
    if args.gpu == 0:
        print("Optimized with lr %f, steps %d, warmup steps %d, and use beta, epsilon %0.8f." % (
            args.learning_rate, t_total, args.warmup_steps, optimizer.defaults['eps']
        ), optimizer.defaults['betas'])
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    # Check if saved optimizer or scheduler states exist
    if (
        args.model_name_or_path
        and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
        and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level,
                                          verbosity=0)
        from apex.parallel import DistributedDataParallel as DDP
        model = DDP(model)
    else:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], find_unused_parameters=True
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * args.world_size
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    # Check if continuing training from a checkpoint
    # if args.model_name_or_path and os.path.exists(args.model_name_or_path):
    #     try:
    #         # set global_step to gobal_step of last saved checkpoint from model path
    #         checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
    #         epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
    #         steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
    #         logger.info("  Continuing training from checkpoint, will skip to saved global_step")
    #         logger.info("  Continuing training from epoch %d", epochs_trained)
    #     except ValueError:
    #         logger.info("  Do not load model from %s, restart training" % args.model_name_or_path)

    # model_to_resize = model.module if hasattr(model, "module") else model  # Take care of distributed/parallel training
    # model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.gpu != 0
    )
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.gpu != 0)
        tr_loss, logging_loss = 0.0, 0.0
        model.zero_grad()       # Support of accumulating gradients
        for step, batch in enumerate(epoch_iterator):
            inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            # If some of the input is padded, then the attention mask is needed
            attention_mask = (inputs != tokenizer.pad_token_id)         # word_tokens --> 1, pad_token --> 0
            if attention_mask.all():
                attention_mask = None

            if epoch == 0 and step < 3 and args.gpu == 0:
                print(inputs.shape)
                print(inputs[0])
                print(tokenizer.convert_ids_to_tokens(inputs[0].cpu().numpy()))
                print(labels[0])
                print(attention_mask)

            model.train()
            outputs = model(inputs,
                            attention_mask=attention_mask,
                            masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.max_grad_norm > 0.:
                    if args.fp16:
                        total_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        total_norm =torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.gpu == 0 and args.logging_steps > 0 and (step + 1) % args.logging_steps == 0:
                    # Log metrics
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    if args.fp16:
                        try:
                            from apex.amp import _amp_state
                            tb_writer.add_scalar("loss_scale", _amp_state.loss_scalers[0]._loss_scale, global_step)
                            tb_writer.add_scalar("scaled_loss", scaled_loss.item(), global_step)
                        except ImportError:
                            logger.warning("Cannot import apex.amp._amp_state, "
                                           "would not state the loss_scale in the log")
                    if args.max_grad_norm > 0.:  # Only clip the grad when it is valid
                        tb_writer.add_scalar("grad_norm", total_norm, global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
                    logging_loss = tr_loss

            if args.max_steps > 0 and global_step >= args.max_steps:
                break

        # Save it each epoch
        if args.gpu == 0:
            # Save checkpoints
            checkpoint_name = "checkpoint-epoch%04d" % epoch
            save_model(args, checkpoint_name, model, tokenizer, optimizer, scheduler)
            last_path = os.path.join(args.output_dir, 'checkpoint-last')
            # if os.path.exists(last_path):
            #     print(last_path)
            #     os.remove(last_path)
            # os.symlink(os.path.join(args.output_dir, checkpoint_name), last_path)

            # Evaluate the model
            logger.info(" Training loss of Epoch %d: %0.4f" % (epoch, tr_loss / step))
            logger.info(" Evaluation Results of Epoch %d: " % epoch)
            results = evaluate(args, model, tokenizer)
            for key, value in results.items():
                tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                logger.info("\t %s: %0.4f" % (key, value))
            output_eval_file = os.path.join(args.output_dir, checkpoint_name, "eval_results.json")
            json.dump(results, open(output_eval_file, 'w'), sort_keys=True, indent=4)

        if args.max_steps > 0 and global_step >= args.max_steps:
            epoch_iterator.close()
            train_iterator.close()
            break

    if args.gpu == 0:
        tb_writer.close()
Example #12
0
def train(args, data, datasets, model: PreTrainedModel, original_model,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)

    train_datasets = datasets['train']
    dev_datasets = datasets['dev']

    train_dataloaders, train_example_num, train_distribution = create_dataloader(
        args, train_datasets, tokenizer, train=True)
    dev_dataloaders, dev_example_num, dev_distribution = create_dataloader(
        args, dev_datasets, tokenizer, train=False)

    train_iter_num = sum(
        [len(dataloader) for dataloader in train_dataloaders.values()])
    dev_iter_num = sum(
        [len(dataloader) for dataloader in dev_dataloaders.values()])

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            train_iter_num // args.gradient_accumulation_steps) + 1
    else:
        t_total = train_iter_num // args.gradient_accumulation_steps * args.num_train_epochs

    model = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))

    original_model = original_model.module if hasattr(
        original_model, "module"
    ) else original_model  # Take care of distributed/parallel training
    original_model.resize_token_embeddings(len(tokenizer))

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
        original_model = torch.nn.DataParallel(original_model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
        original_model = torch.nn.parallel.DistributedDataParallel(
            original_model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", train_example_num)
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    best_loss = float('inf')
    best_step = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (train_iter_num //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                train_iter_num // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    model.zero_grad()
    original_model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])

    def inner_product(x, y):
        return torch.mean(torch.sum(y * x, 3))

    def mean_square(x, y, idx):
        return torch.mean(torch.mean((y - x)**2, idx))
        #return torch.mean(torch.sum((y - x) ** 2, 3))

    def save_best_model(best_loss, best_step, dev_dataloaders):
        if (
                args.local_rank == -1 and args.evaluate_during_training
        ):  # Only evaluate when single GPU otherwise metrics may not average well
            eval_loss = evaluate(model, attributes_hiddens, dev_dataloaders)
            #eval_loss = evaluate(args, model, original_model, dev_dataloaders, dev_example_num, dev_distribution, criterion_mse, criterion_ip, feminine_hiddens, masculine_hiddens, gender_hiddens)
            logger.info(" global_step = %s, evaluate loss = %s", global_step,
                        eval_loss)
            tb_writer.add_scalar("eval_loss", eval_loss, global_step)
        tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)

        if eval_loss < best_loss:
            best_loss = eval_loss
            best_step = global_step
            checkpoint_prefix = "checkpoint"
            # Save model checkpoint
            output_dir = os.path.join(args.output_dir, "checkpoint-best")
            os.makedirs(output_dir, exist_ok=True)
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Take care of distributed/parallel training
            model_to_save.save_pretrained(output_dir)
            tokenizer.save_pretrained(output_dir)

            torch.save(args, os.path.join(output_dir, "training_args.bin"))
            logger.info("Saving model checkpoint to %s", output_dir)

            #_rotate_checkpoints(args, checkpoint_prefix)

            torch.save(optimizer.state_dict(),
                       os.path.join(output_dir, "optimizer.pt"))
            torch.save(scheduler.state_dict(),
                       os.path.join(output_dir, "scheduler.pt"))
            logger.info("Saving optimizer and scheduler states to %s",
                        output_dir)
        logger.info(" best_step = %s, best loss = %s", best_step, best_loss)

        return best_loss, best_step

    def get_hiddens_of_model(input):
        model.zero_grad()
        if args.model_type == 'roberta':
            _, _, hiddens = model.roberta(input)
        elif args.model_type == 'bert':
            _, _, hiddens = model.bert(input)
        elif args.model_type == 'albert':
            _, _, hiddens = model.albert(input)
        elif args.model_type == 'dbert':
            _, hiddens = model.distilbert(input)
        elif args.model_type == 'electra':
            _, hiddens = model.electra(input)
        elif args.model_type == 'gpt2':
            _, _, hiddens = model.transformer(input)
        elif args.model_type == 'gpt':
            _, hiddens = model.transformer(input)

        return hiddens

    def attribute_vector_example():
        attributes_hiddens = {f'attribute{i}': [] for i in range(2)}

        dataloaders, _, distribution = create_dataloader(args,
                                                         train_datasets,
                                                         tokenizer,
                                                         train=True)
        for key in distribution:
            if key != 'neutral':
                inputs, labels = next(dataloaders[key])
                inputs = inputs.to(args.device)
                hiddens = get_hiddens_of_model(inputs)
                hiddens = torch.stack(hiddens, 2)
                if labels.size(1) > 1:
                    onehot = torch.eye(hiddens.size(1))
                    zeros = torch.zeros(1, onehot.size(0))
                    onehot = torch.cat((zeros, onehot), 0)
                    onehot = onehot[labels]
                    onehot = torch.sum(onehot, 1)
                    onehot = onehot.view(hiddens.size(0), -1, 1, 1)
                else:
                    onehot = torch.eye(hiddens.size(1))[labels].view(
                        hiddens.size(0), -1, 1, 1)
                onehot = onehot.to(args.device)
                attributes_hiddens[key].append(
                    torch.sum(hiddens * onehot, 1) / labels.size(1))

        # neutralも含まれている
        attribute_size = len(data['train']['example'])
        for i in range(attribute_size - 1):
            attributes_hiddens[f'attribute{i}'] = torch.mean(
                torch.cat(attributes_hiddens[f'attribute{i}'], 0),
                0).detach().unsqueeze(0)

        return attributes_hiddens

    def forward(attributes_hiddens, dataloaders, key):
        inputs = next(dataloaders[key])
        if len(inputs) == 2:
            inputs, labels = inputs
            labels = labels.to(args.device)
        else:
            labels = None
        inputs = inputs.to(args.device)
        if args.model_type == 'roberta':
            final_layer_hiddens, first_token_hidden, all_layer_hiddens = model.roberta(
                inputs)
            if 'neutral' != key:
                with torch.no_grad():
                    final_layer_original_hiddens, _, all_layer_original_hiddens = original_model.roberta(
                        inputs)
                if args.token_loss:
                    token_predicts = model.lm_head(final_layer_hiddens)
                    token_original = original_model.lm_head(
                        final_layer_original_hiddens)
        elif args.model_type == 'bert':
            final_layer_hiddens, first_token_hidden, all_layer_hiddens = model.bert(
                inputs)
            if 'neutral' != key:
                with torch.no_grad():
                    final_layer_original_hiddens, _, all_layer_original_hiddens = original_model.bert(
                        inputs)
                if args.token_loss:
                    token_predicts = model.cls(final_layer_hiddens)
                    token_original = original_model.cls(
                        final_layer_original_hiddens)
        elif args.model_type == 'albert':
            final_layer_hiddens, first_token_hidden, all_layer_hiddens = model.albert(
                inputs)
            if 'neutral' != key:
                with torch.no_grad():
                    final_layer_original_hiddens, _, all_layer_original_hiddens = original_model.albert(
                        inputs)
                if args.token_loss:
                    token_predicts = model.classifier(final_layer_hiddens)
                    token_original = original_model.classifier(
                        final_layer_original_hiddens)
        elif args.model_type == 'dbert':
            final_layer_hiddens, all_layer_hiddens = model.distilbert(inputs)
            if 'neutral' != key:
                with torch.no_grad():
                    final_layer_original_hiddens, all_layer_original_hiddens = original_model.distilbert(
                        inputs)
                if args.token_loss:
                    token_predicts = model.classifier(final_layer_hiddens)
                    token_original = original_model.classifier(
                        final_layer_original_hiddens)
        elif args.model_type == 'electra':
            final_layer_hiddens, all_layer_hiddens = model.electra(inputs)
            if 'neutral' != key:
                with torch.no_grad():
                    final_layer_original_hiddens, all_layer_original_hiddens = original_model.electra(
                        inputs)
                if args.token_loss:
                    hiddens = model.generator_predictions(final_layer_hiddens)
                    token_predicts = model.generator_lm_head(hiddens)
                    original_hiddens = original_model.generator_predictions(
                        final_layer_original_hiddens)
                    token_original = original_model.generator_lm_head(
                        original_hiddens)
        elif args.model_type == 'gpt2':
            final_layer_hiddens, first_token_hidden, all_layer_hiddens = model.transformer(
                inputs)
            if 'neutral' != key:
                with torch.no_grad():
                    final_layer_original_hiddens, _, all_layer_original_hiddens = original_model.transformer(
                        inputs)
                if args.token_loss:
                    token_predicts = model.lm_head(final_layer_hiddens)
                    token_original = original_model.lm_head(
                        final_layer_original_hiddens)
        elif args.model_type == 'gpt':
            final_layer_hiddens, all_layer_hiddens = model.transformer(inputs)
            if 'neutral' != key:
                with torch.no_grad():
                    final_layer_original_hiddens, all_layer_original_hiddens = original_model.transformer(
                        inputs)
                if args.token_loss:
                    token_predicts = model.lm_head(final_layer_hiddens)
                    token_original = original_model.lm_head(
                        final_layer_original_hiddens)

        all_layer_hiddens = torch.stack(all_layer_hiddens, 2)
        if 'neutral' != key:
            all_original_hiddens = torch.stack(all_layer_original_hiddens, 2)
            all_original_hiddens = all_original_hiddens.detach()
            if args.token_loss:
                original_hiddens - original_hiddens.detach()
                token_original = token_original.detach()
        if args.debias_layer == 'all':
            target_layer_hiddens = all_layer_hiddens
            target_original_hiddens = all_layer_hiddens
        else:
            if args.debias_layer == 'first':
                idx = 0
            elif args.debias_layer == 'last':
                idx = -1
            target_layer_hiddens = all_layer_hiddens[:, :, idx]
            target_layer_hiddens = target_layer_hiddens.unsqueeze(2)
            if 'neutral' != key:
                target_original_hiddens = all_original_hiddens[:, :, idx]
                target_original_hiddens = target_original_hiddens.unsqueeze(2)
            else:
                attributes_hiddens = {
                    key: value[:, idx, :].unsqueeze(1)
                    for key, value in attributes_hiddens.items()
                }

        if args.loss_target == 'sentence' or labels is None:
            attributes_hiddens = {
                key: value.unsqueeze(1)
                for key, value in attributes_hiddens.items()
            }
        #elif args.loss_target == 'token' and key == 'neutral':
        elif args.loss_target == 'token':
            if labels.size(1) > 1:
                onehot = torch.eye(target_layer_hiddens.size(1))
                zeros = torch.zeros(1, onehot.size(0))
                onehot = torch.cat((zeros, onehot), 0)
                onehot = onehot[labels]
                onehot = torch.sum(onehot, 1)
                onehot = onehot.view(target_layer_hiddens.size(0), -1, 1, 1)
            else:
                onehot = torch.eye(target_layer_hiddens.size(1))[labels].view(
                    target_layer_hiddens.size(0), -1, 1, 1)
            onehot = onehot.to(args.device)
            target_layer_hiddens = torch.sum(target_layer_hiddens * onehot,
                                             1).unsqueeze(1) / labels.size(1)
            if 'neutral' != key:
                target_original_hiddens = torch.sum(
                    target_original_hiddens * onehot,
                    1).unsqueeze(1) / labels.size(1)
            else:
                attributes_hiddens = {
                    key: value.expand(target_layer_hiddens.size(0), 1,
                                      value.size(1), value.size(2))
                    for key, value in attributes_hiddens.items()
                }

        if 'neutral' == key:
            loss = 0
            for attribute_hiddens in attributes_hiddens.values():
                tmp_loss = criterion_ip(target_layer_hiddens,
                                        attribute_hiddens)
                if args.square_loss:
                    tmp_loss = tmp_loss**2
                tmp_loss *= alpha
                loss += tmp_loss
        else:
            #loss = criterion_ms(target_layer_hiddens, target_original_hiddens)
            loss = criterion_ms(all_layer_hiddens, all_original_hiddens, 3)
            if args.token_loss:
                loss += criterion_ms(token_predicts, token_original, 2)
                #loss += criterion_ms(hiddens, original_hiddens, 2)
            loss *= beta

        return loss

    #def evaluate(args, model: PreTrainedModel, original_model, dev_dataloaders, dev_example_num, dev_distribution, criterion_mse, criterion_ip, feminine_hiddens, masculine_hiddens, gender_hiddens, prefix="") -> Dict:
    def evaluate(model, attributes_hiddens, dev_dataloaders, prefix=""):
        # Loop to handle MNLI double evaluation (matched, mis-matched)
        eval_output_dir = args.output_dir

        if args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir, exist_ok=True)

        args.eval_batch_size = args.per_gpu_eval_batch_size * max(
            1, args.n_gpu)
        # Note that DistributedSampler samples randomly

        # multi-gpu evaluate
        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Eval!
        logger.info("***** Running evaluation {} *****".format(prefix))
        logger.info("  Num examples = %d", dev_example_num)
        logger.info("  Batch size = %d", args.eval_batch_size)
        eval_loss = 0.0
        model.eval()
        #criterion.eval()

        for key in tqdm(dev_distribution):
            with torch.no_grad():
                loss = forward(attributes_hiddens, dev_dataloaders, key)

                eval_loss += loss.item()

                model.zero_grad()
                original_model.zero_grad()

        output_eval_file = os.path.join(eval_output_dir, prefix,
                                        "eval_results.txt")
        '''
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results {} *****".format(prefix))
            logger.info("  Loss = %s", eval_loss)
            writer.write("Loss = %s\n" % (eval_loss))
        '''

        return eval_loss

    #criterion_ms = torch.nn.MSELoss()
    criterion_ms = mean_square
    #criterion.train()
    criterion_ip = inner_product
    original_model.eval()

    alpha, beta = args.weighted_loss
    alpha = float(alpha)
    beta = float(beta)

    train_loss = 0.0

    for _ in train_iterator:

        random.shuffle(train_distribution)
        epoch_iterator = tqdm(train_distribution,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])

        model.eval()
        with torch.no_grad():
            attributes_hiddens = attribute_vector_example()

        for step, key in enumerate(epoch_iterator):
            model.train()

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            loss = forward(attributes_hiddens, train_dataloaders, key)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            train_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                original_model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logger.info(" global_step = %s, train loss = %s",
                                global_step, train_loss)
                    train_loss = 0.0
                    # Log metrics
                    best_loss, best_step = save_best_model(
                        best_loss, best_step, dev_dataloaders)
                    dev_dataloaders, dev_example_num, dev_distribution = create_dataloader(
                        args, dev_datasets, tokenizer, train=False)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
            train_dataloaders, train_example_num, train_distribution = create_dataloader(
                args, train_datasets, tokenizer, train=True)

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    dev_dataloaders, dev_example_num, dev_distribution = create_dataloader(
        args, dev_datasets, tokenizer, train=False)
    best_loss, best_step = save_best_model(best_loss, best_step,
                                           dev_dataloaders)

    if args.local_rank in [-1, 0]:
        tb_writer.close()
Example #13
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    args.train_batch_size = args.per_gpu_train_batch_size

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, tr_loss_sent, logging_loss, logging_loss_sent = 0.0, 0.0, 0.0, 0.0

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch")
    set_seed(args)  # Added here for reproducibility

    results = {}
    acc_prev = 0.
    preds = None
    labels = None

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            batch = tuple(t.to(args.device) for t in batch)
            input_ids, attention, token_ids, child, head = batch[0], batch[
                1], batch[2], batch[3], batch[4]
            dep_labels, num_dependency, arcs, arc_labels = batch[5], batch[
                6], batch[7], batch[8]
            arc_label_lengths, sent_labels = batch[9], batch[10]

            inputs = {
                'input_ids': input_ids,
                'attention': attention,
                'token_ids': token_ids,
                'child': child,
                'head': head,
                'dep_labels': dep_labels,
                'arcs': arc_labels,
                'arc_label_lengths': arc_label_lengths,
                'device': args.device
            }

            model.train()
            outputs = model(**inputs)

            loss = outputs[0]
            logits = outputs[1]

            tr_loss += loss.item()

            loss.backward()

            if preds is None:
                preds = logits.detach().cpu().numpy()
                labels = dep_labels.view(-1).cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                labels = np.append(labels,
                                   dep_labels.view(-1).cpu().numpy(),
                                   axis=0)

            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.save_steps > 0 and global_step % args.save_steps == 0:

                    logs = {}
                    loss_scalar_dep = (tr_loss -
                                       logging_loss) / args.save_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss_dep"] = loss_scalar_dep
                    logging_loss = tr_loss

                    print(json.dumps({**logs, **{"step": global_step}}))
                    logger.info(json.dumps({**logs, **{"step": global_step}}))

                    preds = preds.reshape(-1, 2)
                    preds = softmax(preds)
                    preds = np.argmax(preds, axis=1)
                    res_train = compute_metrics_intermediate(preds, labels)
                    preds = None
                    labels = None

                    print(res_train)
                    # Evaluation
                    result = evaluate(args, model, tokenizer)
                    results.update(result)

                    save_checkpoints(args, args.output_dir, model, tokenizer)

                    if result['acc'] > acc_prev:
                        acc_prev = result['acc']
                        # Save model checkpoint best
                        output_dir = os.path.join(args.output_dir,
                                                  "model-best")
                        save_checkpoints(args, output_dir, model, tokenizer)

            if 0 < args.max_steps < global_step:
                epoch_iterator.close()
                break
        if 0 < args.max_steps < global_step:
            train_iterator.close()
            break

    return global_step, tr_loss / global_step
Example #14
0
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, train_dataset_second, DP_classifier) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
    

    train_sampler = SequentialSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate
    )
    
    correct_sampler = SequentialSampler(train_dataset_second) if args.local_rank == -1 else DistributedSampler(train_dataset_second)
    correct_dataloader = DataLoader(
        train_dataset_second, sampler=correct_sampler, batch_size=args.train_batch_size, collate_fn=collate
    )

    wrong_sampler = RandomSampler(train_dataset_second) if args.local_rank == -1 else DistributedSampler(train_dataset_second)
    wrong_dataloader = DataLoader(
        train_dataset_second, sampler=wrong_sampler, batch_size=args.train_batch_size, collate_fn=collate
    )


    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)] + [p for n, p in DP_classifier.named_parameters()],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    #scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (
        args.model_name_or_path
        and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
        and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model_to_resize = model.module if hasattr(model, "module") else model  # Take care of distributed/parallel training
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    DP_classifier.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
    set_seed(args)  # Added here for reproducibility
    
    zipped_data = zip(train_dataloader, correct_dataloader, wrong_dataloader)
    
    correct_mc_tensor = torch.ones(args.train_batch_size, dtype=torch.float)
    correct_mc_tensor = correct_mc_tensor.to(args.device)
    wrong_mc_tensor = torch.zeros(args.train_batch_size, dtype=torch.float)
    wrong_mc_tensor = wrong_mc_tensor.to(args.device)
    
    print(correct_mc_tensor)
    print(wrong_mc_tensor)
    
    accumulated_lm_loss = 0.0
    accumulated_mc_loss = 0.0
    
    for _ in train_iterator:
        train_sampler = SequentialSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
        train_dataloader = DataLoader(
            train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate
        )
        
        correct_sampler = SequentialSampler(train_dataset_second) if args.local_rank == -1 else DistributedSampler(train_dataset_second)
        correct_dataloader = DataLoader(
            train_dataset_second, sampler=correct_sampler, batch_size=args.train_batch_size, collate_fn=collate
        )
    
        wrong_sampler = RandomSampler(train_dataset_second) if args.local_rank == -1 else DistributedSampler(train_dataset_second)
        wrong_dataloader = DataLoader(
            train_dataset_second, sampler=wrong_sampler, batch_size=args.train_batch_size, collate_fn=collate
        )
        zipped_data = zip(train_dataloader, correct_dataloader, wrong_dataloader)
        epoch_iterator = tqdm(zipped_data, desc="Iteration", disable=args.local_rank not in [-1, 0], total=len(train_dataloader))
        for step, zipped_batch in enumerate(epoch_iterator):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            DP_classifier.train()
            

            # unpack zipped_batch
            batch, correct_batch, wrong_batch = zipped_batch
            
                        
            # First: original sentence
            inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
            labels = inputs.clone()
            
            cls_pos = []
            for curr in labels:
                for idx, tk in enumerate(curr):
                    if tk == tokenizer.cls_token_id:
                        curr[idx] = -100
                        cls_pos.append(idx)
                        break
                
                
            
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            
            outputs = model(inputs, lm_labels=labels)
            loss_lm_1 = outputs[0]
            hidden_1 = outputs[3]
            
            sentence_embed_1_pieces = [hh[cls_pos[idx]].unsqueeze(0) for idx, hh in enumerate(hidden_1)]
            sentence_embed_1 = torch.cat(sentence_embed_1_pieces)
            
            
            
            
            # Second: correct next sentence
            correct_input = correct_batch
            correct_labels = correct_input.clone()
            
            cls_pos = []
            for curr in correct_labels:
                for idx, tk in enumerate(curr):
                    if tk == tokenizer.cls_token_id:
                        curr[idx] = -100
                        cls_pos.append(idx)
                        break
            
            
            
            correct_input = correct_input.to(args.device)
            correct_labels = correct_labels.to(args.device)            
            
            outputs = model(correct_input, lm_labels=correct_labels)

            loss_lm_2 = outputs[0]
            hidden_2 = outputs[3]
            sentence_embed_2_pieces = [hh[cls_pos[idx]].unsqueeze(0) for idx, hh in enumerate(hidden_2)]
            sentence_embed_2 = torch.cat(sentence_embed_2_pieces)

            
            
            
            # Get correct loss
            if random.randint(0, 1) == 1:
                outputs = DP_classifier(sentence_embed_1, sentence_embed_2, correct_mc_tensor)
            else:
                outputs = DP_classifier(sentence_embed_2, sentence_embed_1, correct_mc_tensor)
            loss_mc = outputs[0]
            
            # MC_LOSS SCALING
            SCALING = 0.05
            loss_lm = loss_lm_1 + loss_lm_2
            
            
            #loss = loss_lm
            loss_first = loss_lm + SCALING * loss_mc
            #print("loss_mc: ", loss_mc.item())
            #print("loss_lm: ", loss_lm.item())
            
            accumulated_lm_loss += loss_lm.item() / 2.0
            accumulated_mc_loss += SCALING * loss_mc.item()
            
            # Second loss: wrong next sentence randomly sampled from training set
            wrong_input = wrong_batch
            wrong_labels = wrong_input.clone()
            
            cls_pos = []
            for curr in wrong_labels:
                for idx, tk in enumerate(curr):
                    if tk == tokenizer.cls_token_id:
                        curr[idx] = -100
                        cls_pos.append(idx)
                        break

            
            wrong_input = wrong_input.to(args.device)
            wrong_labels = wrong_labels.to(args.device)
            
            outputs = model(wrong_input, lm_labels=wrong_labels)

            loss_lm_3 = outputs[0]
            hidden_3 = outputs[3]
            sentence_embed_3_pieces = [hh[cls_pos[idx]].unsqueeze(0) for idx, hh in enumerate(hidden_3)]
            sentence_embed_3 = torch.cat(sentence_embed_3_pieces)

            
            if random.randint(0, 1) == 1:
                outputs = DP_classifier(sentence_embed_1, sentence_embed_3, wrong_mc_tensor)
            else:
                outputs = DP_classifier(sentence_embed_3, sentence_embed_1, wrong_mc_tensor)
            loss_mc = outputs[0]
            
            #loss = loss_lm
            loss_second = loss_lm_3 + SCALING * loss_mc
            #print("loss_mc: ", loss_mc.item())
            #print("loss_lm: ", loss_lm.item())
            accumulated_mc_loss += SCALING * loss_mc.item()
            
            # Total loss
            loss = loss_first + loss_second


            SKIP_STEP = 50
            if (step % SKIP_STEP == 0):
                print(' iter %d, avg. lm_loss %.2f, avg. mc_loss %.2f, avg. ppl %.2f ' % (step,
                                                                    accumulated_lm_loss / SKIP_STEP,
                                                                    accumulated_mc_loss / SKIP_STEP,
                                                                    math.exp(loss_lm.item() /2),
                                                                    ), file=sys.stderr)
                tb_writer.add_scalar("training_lm_loss", accumulated_lm_loss / SKIP_STEP, global_step)
                tb_writer.add_scalar("training_mc_loss", accumulated_mc_loss / SKIP_STEP, global_step)
                accumulated_lm_loss = 0.0
                accumulated_mc_loss = 0.0
                

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(DP_classifier.parameters(), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(DP_classifier.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                DP_classifier.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                        args.local_rank == -1 and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer, DP_classifier)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(DP_classifier, os.path.join(output_dir, "DP_classifier.bin"))
                    
                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
def train(args, train_dataset, corrects, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    log_dir = os.path.join(
        config.output_dir, 'runs', args.relation,
        os.path.basename(args.output_dir) + '_' + current_time)
    tb_writer = SummaryWriter(log_dir=log_dir)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_dataloader = DataLoader(train_dataset,
                                  sampler=RandomSampler(train_dataset),
                                  batch_size=args.batch_size,
                                  collate_fn=collate)
    t_total = len(
        train_dataloader) // args.gradient_accumulation_steps * args.epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.batch_size * args.gradient_accumulation_steps,
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0

    tr_loss, logging_loss = 0.0, 0.0

    model_to_resize = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.epochs),
                            desc="Epoch",
                            disable=False)
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=False)
        for step, batch in enumerate(epoch_iterator):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer, args)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs, masked_lm_labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    results = evaluate(args, corrects, model, tokenizer)

                    for key, value in results.items():
                        tb_writer.add_scalar("{}".format(key), value,
                                             global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    # print((tr_loss - logging_loss) / args.logging_steps)

                    logging_loss = tr_loss

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

    tb_writer.close()

    return global_step, tr_loss / global_step
Example #16
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """

    tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model_to_resize = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch")
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, label_ids = batch
            labels = label_ids.repeat((inputs.shape[1], 1)).T
            masks = inputs.eq(tokenizer.mask_token_id)
            labels[~masks] = -100
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs, masked_lm_labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:

                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    tb_writer.close()

    return global_step, tr_loss / global_step