def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = range(int(args.num_train_epochs))
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)

    if args.model_type == 'bert-slice-aware' or args.model_type == 'bert-slice-aware-random-slices':
        if args.model_type == 'bert-slice-aware':
            sfs = slicing_functions[args.task_name]
        elif args.model_type == 'bert-slice-aware-random-slices':
            if args.number_random_slices is None or args.size_random_slices is None:
                sfs = random_slicing_functions[args.task_name]
            else:
                sfs = args.sfs
        processor = slicing_processors[args.task_name]()
        examples_train = processor.get_train_examples(args.data_dir)

        snorkel_sf_applier = SFApplier(sfs)

        if os.path.isfile(args.data_dir + "/snorkel_slices_train.pickle"):
            with open(args.data_dir + "/snorkel_slices_train.pickle",
                      "rb") as f:
                logger.info("loaded cached pickle for sliced train.")
                snorkel_slices_train = pickle.load(f)
        else:
            snorkel_slices_train = snorkel_sf_applier.apply(examples_train)
            with open(args.data_dir + "/snorkel_slices_train.pickle",
                      "wb") as f:
                pickle.dump(snorkel_slices_train, f)
                logger.info("dumped pickle with sliced train.")

        snorkel_slices_with_ns = []
        for i, example in enumerate(examples_train):
            for _ in range(len(example.documents)):
                snorkel_slices_with_ns.append(snorkel_slices_train[i])

        snorkel_slices_with_ns_np = np.array(snorkel_slices_with_ns,
                                             dtype=snorkel_slices_train.dtype)

        slice_model = SliceAwareClassifier(
            task_name='labels',
            input_data_key='input_ids',
            base_architecture=model,
            head_dim=768,  #* args.max_seq_length,
            slice_names=[sf.name for sf in sfs])

        X_dict = {
            'input_ids': train_dataset.tensors[0],
            'attention_mask': train_dataset.tensors[1],
            'token_type_ids': train_dataset.tensors[2]
        }
        Y_dict = {'labels': train_dataset.tensors[3]}

        ds = DictDataset(name='labels',
                         split='train',
                         X_dict=X_dict,
                         Y_dict=Y_dict)
        train_dl_slice = slice_model.make_slice_dataloader(
            ds,
            snorkel_slices_with_ns_np,
            shuffle=True,
            batch_size=args.train_batch_size)

        trainer = Trainer(lr=args.learning_rate,
                          n_epochs=int(args.num_train_epochs),
                          l2=args.weight_decay,
                          optimizer="adamax",
                          max_steps=args.max_steps,
                          seed=args.seed)

        trainer.fit(slice_model, [train_dl_slice])
        model = slice_model
    else:
        for _ in train_iterator:
            epoch_iterator = train_dataloader
            for step, batch in enumerate(epoch_iterator):
                model.train()
                batch = tuple(t.to(args.device) for t in batch)
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'labels': batch[3]
                }
                if args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2] if args.model_type in [
                        'bert', 'xlnet'
                    ] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
                if args.model_type == 'bert-mtl':
                    inputs["clf_head"] = 0
                outputs = model(**inputs)
                loss = outputs[
                    0]  # model outputs are always tuple in transformers (see doc)

                if args.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                tr_loss += loss.item()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    if args.local_rank in [
                            -1, 0
                    ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        # Log metrics
                        if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                            results = evaluate(args,
                                               model,
                                               tokenizer,
                                               sample_percentage=0.01)
                            for key, value in results.items():
                                tb_writer.add_scalar('eval_{}'.format(key),
                                                     value, global_step)
                                ex.log_scalar('eval_{}'.format(key), value,
                                              global_step)
                                logger.info('eval_{}'.format(key) + ": " +
                                            str(value) + ", step: " +
                                            str(global_step))
                        tb_writer.add_scalar('lr',
                                             scheduler.get_lr()[0],
                                             global_step)
                        tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                             args.logging_steps, global_step)
                        ex.log_scalar("lr", scheduler.get_lr()[0], global_step)
                        ex.log_scalar("loss", (tr_loss - logging_loss) /
                                      args.logging_steps, global_step)
                        logging_loss = tr_loss

                    if args.local_rank in [
                            -1, 0
                    ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                        # Save model checkpoint
                        output_dir = os.path.join(
                            args.output_dir,
                            'checkpoint-{}'.format(global_step))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        model_to_save = model.module if hasattr(
                            model, 'module'
                        ) else model  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(output_dir)
                        torch.save(
                            args, os.path.join(output_dir,
                                               'training_args.bin'))
                        logger.info("Saving model checkpoint to %s",
                                    output_dir)

                if args.max_steps > 0 and global_step > args.max_steps:
                    break
                    # epoch_iterator.close()
            if args.max_steps > 0 and global_step > args.max_steps:
                break
                # train_iterator.close()

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return model
def train_with_early_stopping(train_dataset, eval_dataset, model, tokenizer):
    tb_writer = SummaryWriter()

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args['train_batch_size'])
    eval_sampler = RandomSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args['train_batch_size'])


    t_total = len(train_dataloader) // args['gradient_accumulation_steps'] * args['num_train_epochs']

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args['weight_decay']},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args['learning_rate'], eps=args['adam_epsilon'])
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args['warmup_steps'], t_total=t_total)

    if args['fp16']:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args['fp16_opt_level'])

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args['num_train_epochs'])
    logger.info("  Total train batch size  = %d", args['train_batch_size'])
    logger.info("  Gradient Accumulation steps = %d", args['gradient_accumulation_steps'])
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, ev_loss, logging_loss = 0.0, 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args['num_train_epochs']), desc="Epoch")
    count_eval_loss = 0

    for _ in train_iterator:  # loop over the 'num_train_epochs'
        epoch_iterator = tqdm_notebook(train_dataloader, desc="Iteration")
        epoch_iterator_eval = tqdm_notebook(eval_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            model.train()  # start the training
            batch = tuple(t.to(device) for t in batch)
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2] if args['model_type'] in ['bert', 'xlnet'] else None,
                      # XLM don't use segment_ids
                      'labels': batch[3]}
            outputs = model(**inputs)  # feed the input into the model and receive the output
            loss = outputs[0]  # model outputs are always tuple in pytorch-transformers (see doc)
            print("\r%f" % loss, end='')

            if args['gradient_accumulation_steps'] > 1:
                loss = loss / args['gradient_accumulation_steps']

            if args['fp16']:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args['max_grad_norm'])

            else:
                loss.backward()  # back_prop
                torch.nn.utils.clip_grad_norm_(model.parameters(), args['max_grad_norm']) # gradient clipping

            tr_loss += loss.item()  # extracts the loss’s value as a Python float.



            # update learning rate
            if (step + 1) % args['gradient_accumulation_steps'] == 0:
                scheduler.step()  # Update learning rate schedule
                optimizer.step()
                model.zero_grad()
                global_step += 1 # count the global step


                if args['logging_steps'] > 0 and global_step % args['logging_steps'] == 0:
                    # Log metrics
                    if args['evaluate_during_training']:
                        # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)


                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args['logging_steps'], global_step)
                    logging_loss = tr_loss


                # To save what we have the trained the model so far
                if args['save_steps'] > 0 and global_step % args['save_steps'] == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args['output_dir'], 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model,'module') else model
                    # Take care of distributed/parallel training

                    model_to_save.save_pretrained(output_dir)
                    logger.info("Saving model checkpoint to %s", output_dir)


        # add early stopping here
        for step, batch in enumerate(epoch_iterator_eval):
            model.eval()  # start the training
            batch = tuple(t.to(device) for t in batch)
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2] if args['model_type'] in ['bert', 'xlnet'] else None,
                      # XLM don't use segment_ids
                      'labels': batch[3]}
            outputs = model(**inputs)  # feed the input into the model and receive the output
            loss = outputs[0]  # model outputs are always tuple in pytorch-transformers (see doc)

            if loss > 0:
                count_eval_loss += 1
            else: count_eval_loss -= 1

        if count_eval_loss >= 3:
            print("overfitting!!")
            break



    return global_step, tr_loss / global_step
Exemple #3
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
    args.warmup_steps = t_total // 100

    # Prepare optimizer and schedule (linear warmup and decay)
    optimizer_grouped_parameters = get_param_groups(args, model)
    optimizer = RAdam(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelModel(model)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    args.logging_steps = len(train_dataloader) // 1
    args.save_steps = args.logging_steps
    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
    set_seed(args)
    for _ in train_iterator:
        args.current_epoch = _
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids':
                batch[0],
                'attention_mask':
                batch[1],
                'token_type_ids':
                batch[2] if args.model_type in ['bert', 'xlnet'] else None,
            }  # XLM and RoBERTa don't use segment_ids
            #   'labels':         batch[3]}
            outputs = model(**inputs)
            outputs = [outputs[i][0] for i in range(len(outputs))]

            loss_fct = CrossEntropyLoss()
            loss_fct = DataParallelCriterion(loss_fct)

            loss = loss_fct(outputs, batch[3])

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemple #4
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=args.n_gpu)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            input_ids, input_mask, segment_ids, index, is_selected, start_positions, end_positions, is_impossible = batch
            batch_size = input_ids.shape[0]
            inputs = {'input_ids':       input_ids.view(-1, input_ids.shape[2]),
                      'attention_mask':  input_mask.view(-1, input_mask.shape[2]),
                      'token_type_ids':  segment_ids.view(-1, segment_ids.shape[2])}
            outputs = model(**inputs)
            start_logits, end_logits, select_logits = outputs[0], outputs[1], outputs[2] # (batch_size * num_para, max_seq_len)
            start_logits = start_logits.view(batch_size, args.num_para * args.max_seq_length)
            end_logits = end_logits.view(batch_size, args.num_para * args.max_seq_length)

            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            span_loss = (start_loss + end_loss) / 2

            select_loss_fct = torch.nn.BCEWithLogitsLoss()
            select_loss = select_loss_fct(select_logits, is_selected.view(-1))
            
            loss = args.gamma * span_loss + (1 - args.gamma) * select_loss

            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
def main():
    args = get_args()
    print('args: ', args)

    # Get ready...
    logging.basicConfig(
        filename=args.log_output_path,
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)
    logger = logging.getLogger(__name__)
    logger.info('\nargs: {}'.format(args))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    logger.info("device: {}, n_gpu {}".format(device, n_gpu))

    # Load tokenizer and model
    with open(args.slots_tokens_path, 'r') as slots_file:
        slots_dict = json.load(slots_file)
        slots_tokens = list(slots_dict.keys())
    rank_tokens = [
        '<' + str(i) + '>' for i in range(args.target_cluster_num_threshold)
    ]
    special_tokens = slots_tokens + rank_tokens
    tokenizer = GPT2Tokenizer.from_pretrained(args.model_name)
    tokenizer.add_tokens(special_tokens)
    special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens)
    slots_tokens_ids = tokenizer.convert_tokens_to_ids(slots_tokens)
    rank_tokens_ids = tokenizer.convert_tokens_to_ids(rank_tokens)
    logger.info("special_tokens: {}".format(special_tokens))
    logger.info("special_tokens_ids: {}".format(special_tokens_ids))
    logger.info("special_tokens: {}".format(tokenizer.special_tokens_map))
    bos_token, eos_token, unk_token = tokenizer.special_tokens_map['bos_token'], \
                                      tokenizer.special_tokens_map['eos_token'], \
                                      tokenizer.special_tokens_map['unk_token']
    bos_token_id, eos_token_id, unk_token_id = tokenizer.convert_tokens_to_ids(
        [bos_token, eos_token, unk_token])
    if args.wo_pretrained:
        config = GPT2Config.from_pretrained(args.model_name,
                                            n_layer=args.wo_pretrained_layer)
        model = GPT2LMHeadModel(config)
        logger.info("GPT2 model without pretrained: {}".format(model))
    else:
        model = GPT2LMHeadModel.from_pretrained(args.model_name)
    model.resize_token_embeddings(len(tokenizer))
    model.to(device)

    # Load and encode the datasets
    logger.info("Loading and Encoding dataset...")

    def tokenize_and_encode(data):
        """ Tokenize and encode a nested object """
        encoded_data = []
        for cluster in tqdm(data, desc='Tokenize and Encode'):
            encoded_data.append([[
                tokenizer.convert_tokens_to_ids(tokenizer.tokenize(s))
                for s in pair
            ] for pair in cluster])
        return encoded_data

    encoded_datasets = {}
    if args.do_train:
        encoded_path = args.train_dataset.split('.')
        encoded_path[-2] += '_encoded'
        encoded_path = '.'.join(encoded_path)
        if os.path.exists(encoded_path):
            encoded_datasets.update(pickle.load(open(encoded_path, "rb")))
        else:
            train_dataset = load_da_dataset(args,
                                            args.train_dataset,
                                            unk_token_old=args.unknown_token,
                                            unk_token_new=unk_token)
            encoded_datasets['train'] = tokenize_and_encode(train_dataset)
            with open(encoded_path, "wb") as f:
                pickle.dump(encoded_datasets, f)
    if args.do_gen:
        encoded_path = args.gen_dataset.split('.')
        encoded_path[-2] += '_gen_encoded'
        encoded_path = '.'.join(encoded_path)
        if os.path.exists(encoded_path):
            encoded_datasets.update(pickle.load(open(encoded_path, "rb")))
        else:
            gen_dataset = load_da_dataset(args,
                                          args.gen_dataset,
                                          trg_file=False,
                                          unk_token_old=args.unknown_token,
                                          unk_token_new=unk_token)
            encoded_datasets['gen'] = tokenize_and_encode(gen_dataset)
            with open(encoded_path, "wb") as f:
                pickle.dump(encoded_datasets, f)

    log_data_type = 'train' if args.do_train else 'gen'
    rank_idx = -1 if log_data_type == 'gen' else -2
    for i in range(2):
        for j in range(min(5, len(encoded_datasets[log_data_type][i]))):
            logger.info("\n*****Examples*****")
            logger.info('Source tokens: {}'.format([
                tokenizer.convert_ids_to_tokens(d)
                for d in encoded_datasets[log_data_type][i][j][:rank_idx]
            ]))
            logger.info('Source ids: {}'.format(
                encoded_datasets[log_data_type][i][j][:rank_idx]))
            logger.info('Rank tokens: {}'.format([
                tokenizer.convert_ids_to_tokens(d)
                for d in encoded_datasets[log_data_type][i][j][rank_idx]
            ]))
            logger.info('Rank ids: {}'.format(
                encoded_datasets[log_data_type][i][j][rank_idx]))
            if log_data_type != 'gen':
                logger.info('Target tokens: {}'.format([
                    tokenizer.convert_ids_to_tokens(d)
                    for d in encoded_datasets[log_data_type][i][j][-1]
                ]))
                logger.info('Target ids: {}'.format(
                    encoded_datasets[log_data_type][i][j][-1]))
            logger.info('')

    # Train
    if args.do_train:

        # Prepare inputs tensors and dataloaders
        encoded_datasets_train = encoded_datasets['train']
        train_tensor_dataset = pre_process_datasets(args,
                                                    encoded_datasets_train,
                                                    model.config.n_positions,
                                                    bos_token_id, eos_token_id)

        train_data = TensorDataset(*train_tensor_dataset)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        # Prepare optimizer
        step_num_per_batch = math.ceil(args.train_batch_size *
                                       args.target_cluster_num_threshold /
                                       args.train_target_size)
        if args.max_steps > 0:
            t_total = args.max_steps
            args.num_train_epochs = args.max_steps // (
                len(train_dataloader) * step_num_per_batch //
                args.gradient_accumulation_steps) + 1
        else:
            t_total = len(
                train_dataloader
            ) * step_num_per_batch // args.gradient_accumulation_steps * args.num_train_epochs

        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=t_total)

        if args.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.fp16_opt_level)

        def save_model(model, tokenizer, time):
            model_output_dir = os.path.join(args.model_output_dir, str(time))
            version = 1
            while os.path.exists(model_output_dir):
                model_output_dir = os.path.join(args.model_output_dir,
                                                str(time) + '.' + str(version))
                version += 1
            os.makedirs(model_output_dir)
            logger.info("\nSaving the model to {}\n".format(
                os.path.join(model_output_dir)))

            # Save a trained model, configuration and tokenizer
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self

            # If we save using the predefined names, we can load using `from_pretrained`
            output_model_file = os.path.join(model_output_dir, WEIGHTS_NAME)
            output_config_file = os.path.join(model_output_dir, CONFIG_NAME)

            torch.save(model_to_save.state_dict(), output_model_file)
            model_to_save.config.to_json_file(output_config_file)
            tokenizer.save_vocabulary(model_output_dir)

        # Let's train
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)
        model.train()

        update_step, total_step, exp_average_loss = 0, 0, None
        for idx_epochs in trange(int(args.num_train_epochs), desc="Epoch"):
            total_step = 0
            num_save_checkpoint = len(
                train_dataloader) // args.n_save_per_epoch
            tqdm_bar = tqdm(train_dataloader, desc="Training")
            optimizer.zero_grad()
            for idx_batch, batch in enumerate(tqdm_bar):
                batch = tuple(t.to(device) for t in batch)
                input_ids, lm_labels, len_sources = batch
                len_sources = len_sources.expand(-1, input_ids.shape[1]).view(
                    input_ids.shape[0] * input_ids.shape[1])
                input_ids = input_ids.view(
                    input_ids.shape[0] * input_ids.shape[1],
                    input_ids.shape[2])
                lm_labels = lm_labels.view(
                    lm_labels.shape[0] * lm_labels.shape[1],
                    lm_labels.shape[2])

                batch_size = input_ids.shape[0]
                random_indices = torch.tensor(
                    random.sample(range(batch_size), batch_size)).to(device)
                for idx_step in range(step_num_per_batch):
                    idx_begin = idx_step * args.train_target_size
                    idx_end = min(idx_begin + args.train_target_size,
                                  input_ids.shape[0])
                    step_size = idx_end - idx_begin
                    input_ids_step = input_ids.index_select(
                        0, random_indices[idx_begin:idx_end])
                    lm_labels_step = lm_labels.index_select(
                        0, random_indices[idx_begin:idx_end])
                    len_sources_step = len_sources.index_select(
                        0, random_indices[idx_begin:idx_end])

                    assert len(len_sources_step.unique()) == 1
                    logits = model(
                        input_ids_step,
                        intra_attention=args.intra_attention,
                        len_source=len_sources_step,
                        intra_attention_weight=args.intra_attention_weight)

                    # calculate the nll loss
                    shift_logits = logits[..., :-1, :].contiguous()
                    shift_labels = lm_labels_step[..., 1:].contiguous()
                    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-1)
                    loss = loss_fct(
                        shift_logits.view(-1, shift_logits.size(-1)),
                        shift_labels.view(-1))

                    # calculate the kl loss
                    if args.intra_kl_loss:
                        kl_loss_fun = torch.nn.KLDivLoss(reduce=False)
                        kl_loss = 0

                        # calculate loss within the current step
                        for idx_sample in range(1, step_size):
                            indices = torch.tensor(
                                list(range(idx_sample, step_size)) +
                                list(range(0, idx_sample))).to(device)
                            disper_logits = logits.detach().index_select(
                                0, indices)
                            log_q = F.log_softmax(logits, dim=-1)
                            p = F.softmax(disper_logits, dim=-1)
                            kl_loss_step = kl_loss_fun(log_q, p)
                            kl_weight = kl_anneal_function(
                                args.intra_kl_anneal_func, update_step,
                                args.intra_kl_k, args.intra_kl_x0,
                                args.intra_kl_loss_weight)
                            kl_loss += -kl_loss_step.sum(
                                dim=-1).mean() * kl_weight

                        # calculate loss outside the current step
                        indices_outside = torch.cat(
                            (random_indices[:idx_begin],
                             random_indices[idx_end:]))
                        input_ids_outside = input_ids.index_select(
                            0, indices_outside)
                        len_sources_outside = len_sources.index_select(
                            0, indices_outside)
                        with torch.no_grad():
                            assert len(len_sources_outside.unique()) == 1
                            logits_outside = model(
                                input_ids_outside,
                                intra_attention=args.intra_attention,
                                len_source=len_sources_outside,
                                intra_attention_weight=args.
                                intra_attention_weight)
                        for idx_sample in range(batch_size - step_size):
                            indices = torch.tensor([
                                idx_sample for _ in range(step_size)
                            ]).to(device)
                            disper_logits = logits_outside.detach(
                            ).index_select(0, indices)
                            log_q = F.log_softmax(logits, dim=-1)
                            p = F.softmax(disper_logits, dim=-1)
                            kl_loss_step = kl_loss_fun(log_q, p)
                            kl_weight = kl_anneal_function(
                                args.intra_kl_anneal_func, update_step,
                                args.intra_kl_k, args.intra_kl_x0,
                                args.intra_kl_loss_weight)
                            kl_loss += -kl_loss_step.sum(
                                dim=-1).mean() * kl_weight

                        kl_loss = kl_loss / (batch_size - 1)
                        loss += kl_loss
                    else:
                        kl_loss = torch.zeros_like(loss)
                    loss = loss / args.gradient_accumulation_steps

                    # backward, step and report
                    if args.fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()

                    if (total_step +
                            1) % args.gradient_accumulation_steps == 0:
                        if args.fp16:
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(optimizer),
                                args.max_grad_norm)
                        else:
                            torch.nn.utils.clip_grad_norm_(
                                model.parameters(), args.max_grad_norm)
                        scheduler.step()
                        optimizer.step()
                        optimizer.zero_grad()
                        update_step += 1

                    exp_average_loss = loss.item(
                    ) if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item(
                    )
                    total_step += 1
                    tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(
                        exp_average_loss,
                        scheduler.get_lr()[0])
                    kl_weight = kl_anneal_function(args.intra_kl_anneal_func,
                                                   update_step,
                                                   args.intra_kl_k,
                                                   args.intra_kl_x0,
                                                   args.intra_kl_loss_weight)
                    logger.info(
                        "Epoch: {}, Step: {}, Training loss: {:.2e} current loss: {:.2e} "
                        "nll loss: {:.2e} kl loss: {:.2e} lr: {:.2e} update_step: {:.2e} kl weight {:.2e}\n"
                        .format(idx_epochs, total_step, exp_average_loss,
                                loss.item(),
                                loss.item() - kl_loss.item(), kl_loss.item(),
                                scheduler.get_lr()[0], update_step, kl_weight))

                if (idx_epochs + 1) % args.n_save_epochs == 0 and (
                        idx_batch + 1) % num_save_checkpoint == 0:
                    save_model(
                        model, tokenizer,
                        'epoch_' + str(idx_epochs) + '_batch_' +
                        str(idx_batch) + '_step_' + str(total_step))

    # Generation
    if args.do_gen:
        encoded_datasets_gen = encoded_datasets['gen']

        if not os.path.exists(args.gen_output_dir):
            os.mkdir(args.gen_output_dir)
        gen_file_name = os.path.split(args.gen_dataset)[-1].split('.')
        gen_file_name[-2] += '_gen'
        gen_file_name = '.'.join(gen_file_name)
        gen_file_path = os.path.join(args.gen_output_dir, gen_file_name)
        gen_file = open(gen_file_path, 'w')

        for cluster in tqdm(encoded_datasets_gen,
                            desc='Generate target utterances',
                            total=len(encoded_datasets_gen)):
            pairs_with_st = []
            for pair in cluster:
                rank = pair[-1]
                source_sentences = pair[0]
                for source_sentence in pair[1:-1]:
                    source_sentences += [eos_token_id, bos_token_id
                                         ] + source_sentence
                pair_with_st = [bos_token_id] + source_sentences + [eos_token_id, bos_token_id] + rank + \
                               [eos_token_id, bos_token_id]
                pairs_with_st.append(pair_with_st)

            # Generation without intra_attention
            if not args.intra_attention:
                for pair_with_st in pairs_with_st:
                    out = sample_sequence(
                        model=model,
                        context=pair_with_st,
                        length=args.gen_length,
                        rank_tokens_ids=rank_tokens_ids,
                        slots_tokens_ids=slots_tokens_ids,
                        stop_early_id=eos_token_id
                        if args.gen_stop_early else None,
                        greed=True if args.gen_mode == 'greed' else False,
                        empty_accept=True if args.gen_accept_empty else False,
                        argmax_slots=args.gen_argmax_slots,
                        intra_attn=args.intra_attention,
                        intra_attention_weight=args.intra_attention_weight,
                        device=device)
                    if args.gen_stop_early:
                        out = out[0, len(pair_with_st):-2].tolist()
                    else:
                        out = out[0].tolist()
                    text = tokenizer.decode(out)
                    gen_file.write(
                        text.replace(unk_token, args.unknown_token).replace(
                            '\n', ' ') + '\n')
                    gen_file.flush()
            # Generation with intra_attention
            else:
                out = sample_sequence(
                    model=model,
                    context=pairs_with_st,
                    length=args.gen_length,
                    rank_tokens_ids=rank_tokens_ids,
                    slots_tokens_ids=slots_tokens_ids,
                    stop_early_id=eos_token_id
                    if args.gen_stop_early else None,
                    greed=True if args.gen_mode == 'greed' else False,
                    empty_accept=True if args.gen_accept_empty else False,
                    argmax_slots=args.gen_argmax_slots,
                    intra_attn=args.intra_attention,
                    intra_attention_weight=args.intra_attention_weight,
                    device=device)
                for idx_out in range(out.size(0)):
                    text = out[idx_out].tolist()
                    if args.gen_stop_early:
                        idx_end = -2
                        while text[idx_end] == eos_token_id:
                            idx_end -= 1
                        text = text[len(pairs_with_st[0]):idx_end + 1]
                    text = tokenizer.decode(text)
                    gen_file.write(
                        text.replace(unk_token, args.unknown_token).replace(
                            '\n', ' ') + '\n')
                    gen_file.flush()
        gen_file.close()

        # Combine the original utterances with the generated utterances to the whole augmented utterances
        augmented_file_name = os.path.split(
            args.gen_dataset)[-1].split('clustered')[:-1]
        augmented_file_name[-1] += 'augmented'
        augmented_file_name = 'clustered'.join(augmented_file_name)
        augmented_file_path = os.path.join(args.gen_output_dir,
                                           augmented_file_name)
        with open(gen_file_path, 'r') as gen_file, open(args.original_data_path, 'r') as origin_file, \
                open(augmented_file_path, 'w') as augmented_file:
            for line in origin_file:
                augmented_file.write(line)
            for line in gen_file:
                augmented_file.write(line)
class TransformerBase(TrainableModel):
    """
    Transformers base model (for working with pytorch-transformers models)
    """
    MODEL_CONFIGURATIONS = {
        'bert': (BertConfig, BertTokenizer),
        'quant_bert': (QuantizedBertConfig, BertTokenizer),
        'xlnet': (XLNetConfig, XLNetTokenizer),
        'xlm': (XLMConfig, XLMTokenizer),
        'roberta': (RobertaConfig, RobertaTokenizer)
    }

    def __init__(self,
                 model_type: str,
                 model_name_or_path: str,
                 labels: List[str] = None,
                 num_labels: int = None,
                 config_name=None,
                 tokenizer_name=None,
                 do_lower_case=False,
                 output_path=None,
                 device='cpu',
                 n_gpus=0):
        """
        Transformers base model (for working with pytorch-transformers models)

        Args:
            model_type (str): transformer model type
            model_name_or_path (str): model name or path to model
            labels (List[str], optional): list of labels. Defaults to None.
            num_labels (int, optional): number of labels. Defaults to None.
            config_name ([type], optional): configuration name. Defaults to None.
            tokenizer_name ([type], optional): tokenizer name. Defaults to None.
            do_lower_case (bool, optional): lower case input words. Defaults to False.
            output_path ([type], optional): model output path. Defaults to None.
            device (str, optional): backend device. Defaults to 'cpu'.
            n_gpus (int, optional): num of gpus. Defaults to 0.

        Raises:
            FileNotFoundError: [description]
        """
        assert model_type in self.MODEL_CONFIGURATIONS.keys(
        ), "unsupported model_type"
        self.model_type = model_type
        self.model_name_or_path = model_name_or_path
        self.labels = labels
        self.num_labels = num_labels
        self.do_lower_case = do_lower_case
        if output_path is not None and not os.path.exists(output_path):
            raise FileNotFoundError('output_path is not found')
        self.output_path = output_path

        self.model_class = None
        config_class, tokenizer_class = self.MODEL_CONFIGURATIONS[model_type]
        self.config_class = config_class
        self.tokenizer_class = tokenizer_class

        self.tokenizer_name = tokenizer_name
        self.tokenizer = self._load_tokenizer(self.tokenizer_name)
        self.config_name = config_name
        self.config = self._load_config(config_name)

        self.model = None
        self.device = device
        self.n_gpus = n_gpus

        self._optimizer = None
        self._scheduler = None

    def to(self, device='cpu', n_gpus=0):
        if self.model is not None:
            self.model.to(device)
            if n_gpus > 1:
                self.model = torch.nn.DataParallel(self.model)
        self.device = device
        self.n_gpus = n_gpus

    @property
    def optimizer(self):
        return self._optimizer

    @optimizer.setter
    def optimizer(self, opt):
        self._optimizer = opt

    @property
    def scheduler(self):
        return self._scheduler

    @scheduler.setter
    def scheduler(self, sch):
        self._scheduler = sch

    def setup_default_optimizer(self,
                                weight_decay: float = 0.0,
                                learning_rate: float = 5e-5,
                                adam_epsilon: float = 1e-8,
                                warmup_steps: int = 0,
                                total_steps: int = 0):
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in self.model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            weight_decay
        }, {
            'params': [
                p for n, p in self.model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=learning_rate,
                               eps=adam_epsilon)
        self.scheduler = WarmupLinearSchedule(self.optimizer,
                                              warmup_steps=warmup_steps,
                                              t_total=total_steps)

    def _load_config(self, config_name=None):
        config = self.config_class.from_pretrained(
            config_name if config_name else self.model_name_or_path,
            num_labels=self.num_labels)
        return config

    def _load_tokenizer(self, tokenizer_name=None):
        tokenizer = self.tokenizer_class.from_pretrained(
            tokenizer_name if tokenizer_name else self.model_name_or_path,
            do_lower_case=self.do_lower_case)
        return tokenizer

    def save_model(self,
                   output_dir: str,
                   save_checkpoint: bool = False,
                   args=None):
        """
        Save model/tokenizer/arguments to given output directory

        Args:
            output_dir (str): path to output directory
            save_checkpoint (bool, optional): save as checkpoint. Defaults to False.
            args ([type], optional): arguments object to save. Defaults to None.
        """
        # Create output directory if needed
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        logger.info("Saving model checkpoint to %s", output_dir)
        model_to_save = self.model.module if hasattr(self.model,
                                                     'module') else self.model
        model_to_save.save_pretrained(output_dir)
        if not save_checkpoint:
            if self.tokenizer is not None:
                self.tokenizer.save_pretrained(output_dir)
            with io.open(output_dir + os.sep + 'labels.txt',
                         'w',
                         encoding='utf-8') as fw:
                for l in self.labels:
                    fw.write('{}\n'.format(l))
            if args is not None:
                torch.save(args, os.path.join(output_dir, 'training_args.bin'))

    @classmethod
    def load_model(cls, model_path: str, model_type: str, *args, **kwargs):
        """
        Create a TranformerBase deom from given path

        Args:
            model_path (str): path to model
            model_type (str): model type

        Returns:
            TransformerBase: model
        """
        # Load a trained model and vocabulary from given path
        if not os.path.exists(model_path):
            raise FileNotFoundError
        with io.open(model_path + os.sep + 'labels.txt') as fp:
            labels = [l.strip() for l in fp.readlines()]
        return cls(model_type=model_type,
                   model_name_or_path=model_path,
                   labels=labels,
                   *args,
                   **kwargs)

    @staticmethod
    def get_train_steps_epochs(max_steps: int, num_train_epochs: int,
                               gradient_accumulation_steps: int,
                               num_samples: int):
        """
        get train steps and epochs

        Args:
            max_steps (int): max steps
            num_train_epochs (int): num epochs
            gradient_accumulation_steps (int): gradient accumulation steps
            num_samples (int): number of samples

        Returns:
            Tuple: total steps, number of epochs
        """
        if max_steps > 0:
            t_total = max_steps
            num_train_epochs = max_steps // (num_samples //
                                             gradient_accumulation_steps) + 1
        else:
            t_total = num_samples // gradient_accumulation_steps * num_train_epochs
        return t_total, num_train_epochs

    def get_logits(self, batch):
        self.model.eval()
        inputs = self._batch_mapper(batch)
        outputs = self.model(**inputs)
        return outputs[-1]

    def _train(self,
               data_set: DataLoader,
               dev_data_set: Union[DataLoader, List[DataLoader]] = None,
               test_data_set: Union[DataLoader, List[DataLoader]] = None,
               gradient_accumulation_steps: int = 1,
               per_gpu_train_batch_size: int = 8,
               max_steps: int = -1,
               num_train_epochs: int = 3,
               max_grad_norm: float = 1.0,
               logging_steps: int = 50,
               save_steps: int = 100):
        """Run model training
            batch_mapper: a function that maps a batch into parameters that the model
                          expects in the forward method (for use with custom heads and models).
                          If None it will default to the basic models input structure.
            logging_callback_fn: a function that is called in each evaluation step
                          with the model as a parameter.

        """
        t_total, num_train_epochs = self.get_train_steps_epochs(
            max_steps, num_train_epochs, gradient_accumulation_steps,
            len(data_set))
        if self.optimizer is None and self.scheduler is None:
            logger.info("Loading default optimizer and scheduler")
            self.setup_default_optimizer(total_steps=t_total)

        train_batch_size = per_gpu_train_batch_size * max(1, self.n_gpus)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(data_set.dataset))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per GPU/CPU = %d",
                    per_gpu_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            train_batch_size * gradient_accumulation_steps)
        logger.info("  Gradient Accumulation steps = %d",
                    gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        tr_loss, logging_loss = 0.0, 0.0
        self.model.zero_grad()
        train_iterator = trange(num_train_epochs, desc="Epoch")
        for _ in train_iterator:
            epoch_iterator = tqdm(data_set, desc="Train iteration")
            for step, batch in enumerate(epoch_iterator):
                self.model.train()
                batch = tuple(t.to(self.device) for t in batch)
                inputs = self._batch_mapper(batch)
                outputs = self.model(**inputs)
                loss = outputs[0]  # get loss

                if self.n_gpus > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if gradient_accumulation_steps > 1:
                    loss = loss / gradient_accumulation_steps

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               max_grad_norm)

                tr_loss += loss.item()
                if (step + 1) % gradient_accumulation_steps == 0:
                    self.optimizer.step()
                    self.scheduler.step()
                    self.model.zero_grad()
                    global_step += 1

                    if logging_steps > 0 and global_step % logging_steps == 0:
                        # Log metrics and run evaluation on dev/test
                        for ds in [dev_data_set, test_data_set]:
                            if ds is None:  # got no data loader
                                continue
                            if isinstance(ds, DataLoader):
                                ds = [ds]
                            for d in ds:
                                logits, label_ids = self._evaluate(d)
                                self.evaluate_predictions(logits, label_ids)
                        logger.info('lr = {}'.format(
                            self.scheduler.get_lr()[0]))
                        logger.info('loss = {}'.format(
                            (tr_loss - logging_loss) / logging_steps))
                        logging_loss = tr_loss

                    if save_steps > 0 and global_step % save_steps == 0:
                        # Save model checkpoint
                        self.save_model_checkpoint(
                            output_path=self.output_path,
                            name='checkpoint-{}'.format(global_step))

                if 0 < max_steps < global_step:
                    epoch_iterator.close()
                    break
            if 0 < max_steps < global_step:
                train_iterator.close()
                break

        logger.info(" global_step = %s, average loss = %s", global_step,
                    tr_loss)

    def _evaluate(self, data_set: DataLoader):
        logger.info("***** Running inference *****")
        logger.info(" Batch size: {}".format(data_set.batch_size))
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        for batch in tqdm(data_set, desc="Inference iteration"):
            self.model.eval()
            batch = tuple(t.to(self.device) for t in batch)

            with torch.no_grad():
                inputs = self._batch_mapper(batch)
                outputs = self.model(**inputs)
                if 'labels' in inputs:
                    tmp_eval_loss, logits = outputs[:2]
                    eval_loss += tmp_eval_loss.mean().item()
                else:
                    logits = outputs[0]
            nb_eval_steps += 1
            model_output = logits.detach().cpu()
            model_out_label_ids = inputs['labels'].detach().cpu(
            ) if 'labels' in inputs else None
            if preds is None:
                preds = model_output
                out_label_ids = model_out_label_ids
            else:
                preds = torch.cat((preds, model_output), dim=0)
                out_label_ids = torch.cat(
                    (out_label_ids, model_out_label_ids),
                    dim=0) if out_label_ids is not None else None
        if out_label_ids is None:
            return preds
        return preds, out_label_ids

    def _batch_mapper(self, batch):
        mapping = {
            'input_ids':
            batch[0],
            'attention_mask':
            batch[1],
            # XLM don't use segment_ids
            'token_type_ids':
            batch[2]
            if self.model_type in ['bert', 'quant_bert', 'xlnet'] else None
        }
        if len(batch) == 4:
            mapping.update({'labels': batch[3]})
        return mapping

    def evaluate_predictions(self, logits, label_ids):
        raise NotImplementedError(
            'evaluate_predictions method must be implemented in order to'
            'be used for dev/test set evaluation')

    def save_model_checkpoint(self, output_path: str, name: str):
        """
        save model checkpoint

        Args:
            output_path (str): output path
            name (str): name of checkpoint
        """
        output_dir_path = os.path.join(output_path, name)
        self.save_model(output_dir_path, save_checkpoint=True)
Exemple #7
0
    def train(self, train_dataset, output_dir, show_running_loss=True):
        """
        Trains the model on train_dataset.

        Utility function to be used by the train_model() method. Not intended to be used directly.
        """
        tokenizer = self.tokenizer
        device = self.device
        model = self.model
        args = self.args
        tb_writer = SummaryWriter()
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args['train_batch_size'])

        t_total = len(train_dataloader) // args[
            'gradient_accumulation_steps'] * args['num_train_epochs']

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args['weight_decay']
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        warmup_steps = math.ceil(t_total * args['warmup_ratio'])
        args['warmup_steps'] = warmup_steps if args[
            'warmup_steps'] == 0 else args['warmup_steps']

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args['learning_rate'],
                          eps=args['adam_epsilon'])
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args['warmup_steps'],
                                         t_total=t_total)

        if args['fp16']:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args['fp16_opt_level'])

        global_step = 0
        tr_loss, logging_loss = 0.0, 0.0
        model.zero_grad()
        train_iterator = trange(int(args['num_train_epochs']), desc="Epoch")

        for _ in train_iterator:
            # epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Current iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'labels': batch[3]
                }
                # XLM, DistilBERT and RoBERTa don't use segment_ids
                if args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2] if args.model_type in [
                        'bert', 'xlnet'
                    ] else None
                outputs = model(**inputs)
                # model outputs are always tuple in pytorch-transformers (see doc)
                loss = outputs[0]
                if show_running_loss:
                    print("\rRunning loss: %f" % loss, end='')

                if args['gradient_accumulation_steps'] > 1:
                    loss = loss / args['gradient_accumulation_steps']

                if args['fp16']:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args['max_grad_norm'])

                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args['max_grad_norm'])

                tr_loss += loss.item()
                if (step + 1) % args['gradient_accumulation_steps'] == 0:
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    if args['logging_steps'] > 0 and global_step % args[
                            'logging_steps'] == 0:
                        # Log metrics
                        # Only evaluate when single GPU otherwise metrics may not average well
                        tb_writer.add_scalar('lr',
                                             scheduler.get_lr()[0],
                                             global_step)
                        tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                             args['logging_steps'],
                                             global_step)
                        logging_loss = tr_loss

                    if args['save_steps'] > 0 and global_step % args[
                            'save_steps'] == 0:
                        # Save model checkpoint
                        output_dir = os.path.join(
                            output_dir, 'checkpoint-{}'.format(global_step))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # Take care of distributed/parallel training
                        model_to_save = model.module if hasattr(
                            model, 'module') else model
                        model_to_save.save_pretrained(output_dir)
        return global_step, tr_loss / global_step
Exemple #8
0
def train(args, train_dataset, model, tokenizer):
    summary_dir_name = construct_folder_name(args)
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter(f"{args.logging_dir}/{summary_dir_name}")

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    # add sampling
    # bp()
    if args.num_sample > -1:
        random_indices = np.arange(len(train_dataset))
        np.random.shuffle(random_indices)
        random_indices = random_indices[:args.num_sample]
        train_dataset = torch.utils.data.TensorDataset(
            *train_dataset[random_indices])

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    condition_fn = create_filter_conditions(args)
    optimizer_grouped_parameters = [{
        'params': [],
        'weight_decay': args.weight_decay
    }, {
        'params': [],
        'weight_decay': 0.0
    }]
    optimizer_parameters_name = []
    # 201 components
    for n, p in model.named_parameters():
        # print(n)
        if condition_fn(n):
            p.requires_grad = True
            optimizer_parameters_name.append(n)
            if any(nd in n for nd in no_decay):
                optimizer_grouped_parameters[1]['params'].append(p)
            else:
                optimizer_grouped_parameters[0]['params'].append(p)
        else:
            p.requires_grad = False
    print(f"Parameters to be optimized: {optimizer_parameters_name}")
    import sys
    open("bert-base-uncased-param2", "w").write(str(optimizer_parameters_name))
    sys.exit(0)
    # bp()
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    all_param = 0
    # only fine-tune the last layer
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for idx in train_iterator:
        print(f"idx: {idx}")
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        # bp()
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'start_positions': batch[3],
                'end_positions': batch[4]
            }
            if args.model_type != 'distilbert':
                inputs[
                    'token_type_ids'] = None if args.model_type == 'xlm' else batch[
                        2]
            if args.model_type in ['xlnet', 'xlm']:
                inputs.update({'cls_index': batch[5], 'p_mask': batch[6]})
            # bp()
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)
            # results = evaluate(args, model, tokenizer)
            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            from time import time
            start = time()
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            all_param += time() - start

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoi1nt
                    output_dir = os.path.join(
                        args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemple #9
0
def train(args, train_dataset, model, tokenizer, processor, transformer, feats, sample_names=None):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomBatchedSampler(train_dataset, batch_size=args.train_batch_size)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
    #optimizer = optim.Adam(model.parameters())
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    best_dev_acc, best_dev_loss = 0.0, 99999999999.0
    best_steps = 0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            if args.model_type == 'slu-protonet':
                model.set_support('train')

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            if sample_names:
                logger.info("--------[input samples: %s]---------", [sample_names[x.item()] for x in batch[2]])
            embeddings = [feats[i.item()] for i in batch[0]]

            if args.do_noise:
                embeddings += torch.randn_like(embeddings)

            if args.do_round > 0:
                embeddings = (embeddings * 10**args.do_round).round() / (10**args.do_round)

            if args.do_softmax:
                embeddings = torch.nn.functional.softmax(embeddings, dim=1)

            outputs = model(embeddings, batch[1])
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer, processor, transformer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                        if results["eval_acc"] > best_dev_acc:
                            best_dev_acc = results["eval_acc"]
                            best_dev_loss = results["eval_loss"]
                            best_steps = global_step
                            if args.do_test:
                                results_test = evaluate(args, model, tokenizer, processor, transformer, test=True)
                                for key, value in results_test.items():
                                    tb_writer.add_scalar('test_{}'.format(key), value, global_step)
                                logger.info("test acc: %s, loss: %s, global steps: %s", str(results_test['eval_acc']), str(results_test['eval_loss']), str(global_step))
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                    logger.info("Average loss: %s at global step: %s", str((tr_loss - logging_loss)/args.logging_steps), str(global_step))
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                    save_path = os.path.join(output_dir, WEIGHTS_NAME)
                    torch.save(model, save_path)
                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if (args.max_steps > 0 and global_step > args.max_steps) or (global_step - best_steps > 42000):
                logger.info('Early stopping')
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step, best_steps
Exemple #10
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # filter weights to train

    if args.only_classifier:
        logger.info("Only training last classifier")
        args.layers_to_fine_tune = []
    elif args.layers_to_fine_tune:
        logger.info(f"Finetuning layers: {str(args.layers_to_fine_tune)}")

    # parameters = model.named_parameters()

    # print("before")
    # for name, params in parameters:
    #     print(name, params.size(), params.mean())

    parameters = model.named_parameters()

    param_maps = {
        "roberta.encoder.layer":0,
        "roberta.embeddings":0,
        "roberta.pooler":0,
        "bert.encoder.layer":0,
        "bert.embeddings":0,
        "bert.pooler":0,
        "transformer.mask_emb":0,
        "transformer.word_embedding.weight":0,
        "transformer.layer":0,
        "sequence_summary":0
    }

    total_param = 0

    for name, params in parameters:
        param_size = params.size()
        param_aggregated = 1
        for size in param_size:
            param_aggregated = param_aggregated * size

        flag = True
        for key in param_maps.keys():
            if key in name:
                flag=False
                param_maps[key] = param_maps[key] + param_aggregated
                total_param = total_param + param_aggregated
                break
        if flag:
            print("failed to count layers", name)


    if args.model_type == 'bert':
        organized_param_maps = {
            "embedding": param_maps["bert.embeddings"],
            "total_encoder": param_maps["bert.encoder.layer"],
            "encoder/12": param_maps["bert.encoder.layer"]/12,
            "encoder/24": param_maps["bert.encoder.layer"]/24,
            "pooling": param_maps["bert.pooler"],
        }
    elif args.model_type == 'roberta':
        organized_param_maps = {
            "embedding": param_maps["roberta.embeddings"],
            "total_encoder": param_maps["roberta.encoder.layer"],
            "encoder/12": param_maps["roberta.encoder.layer"]/12,
            "encoder/24": param_maps["roberta.encoder.layer"]/24,
            "pooling": param_maps["roberta.pooler"],
        }
    elif args.model_type == 'xlnet':
        organized_param_maps = {
            "embedding": param_maps["transformer.mask_emb"] + param_maps["transformer.word_embedding.weight"],
            "total_encoder": param_maps["transformer.layer"],
            "encoder/12": param_maps["transformer.layer"]/12,
            "encoder/24": param_maps["transformer.layer"]/24,
            "pooling": param_maps["sequence_summary"]
        }

    for key, val in organized_param_maps.items():
        print(key, '\n\t', val, '\t', round(val/1000000,1), '\t', round(100*val/total_param,1))

    print("total", round(total_param/1000000,1))

    parameters = model.named_parameters()

    if args.only_classifier or args.layers_to_fine_tune:
        parameters = []
        for name, params in model.named_parameters():
            if args.model_type == 'bert' or args.model_type == 'roberta':
                if 'encoder' in name:
                    for layer in args.layers_to_fine_tune:
                        if f"layer.{layer}." in name:
                            parameters.append((name, params))
                            break
                elif 'embeddings' in name:
                    continue
                else:
                    parameters.append((name, params))

            elif args.model_type == 'xlnet':
                if 'transformer' in name:
                    for layer in args.layers_to_fine_tune:
                        if f"layer.{layer}." in name:
                            parameters.append((name, params))
                            break
                else:
                    parameters.append((name, params))

    # print("after")
    # for name, params in parameters:
    #     print(name, params.size())

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in parameters if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in parameters if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {'input_ids':      batch[0],
                      'attention_mask': batch[1],
                      'labels':         batch[3]}
            if args.model_type != 'distilbert':
                inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 and not args.tpu:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.tpu:
                args.xla_model.optimizer_step(optimizer, barrier=True)
                model.zero_grad()
                global_step += 1

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemple #11
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")
    parser.add_argument("--data_dir", default=None, type=str, required=True,
                        help="directory containing the data")
    parser.add_argument("--output_dir", default="BERT_output", type=str, required=True,
                        help="The model output save dir")
    parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
    parser.add_argument("--evaluate_during_training", action='store_true',
                        help="Run evaluation during training at each logging step.")

    parser.add_argument("--max_seq_length", default=100, type=int, required=False, 
                        help="maximum sequence length for BERT sequence classificatio")
    parser.add_argument("--max_steps", default=-1, type=int,
                        help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
    parser.add_argument("--warmup_steps", default=0, type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--num_train_epochs", default=3, type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--learning_rate", default=1e-5, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")

    parser.add_argument("--train_batch_size", default=64, type=int, required=False,
                        help="batch size for train and eval")
    parser.add_argument('--logging_steps', type=int, default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--log_path', default=None, type=str, required=False)

    args = parser.parse_args()
    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt = '%m/%d/%Y %H:%M:%S',
                        level = logging.INFO)
    set_seed(args)
    ## get train and dev data
    print('loading dataset...')
    processor = FAQProcessor()
    label_list = processor.get_labels()
    num_labels = len(label_list)
    config = BertConfig.from_pretrained('bert-base-chinese', cache_dir='./cache_down', num_labels=num_labels)
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', cache_dir='./cache_down')

    train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
    eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)

    ## 构建模型
    model =  BertForSequenceClassification.from_pretrained("./cache_down/pytorch_model.bin", config=config)
    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(args.device)
    # print(model)

    ## 损失函数
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    t_total = len(train_dataset) // args.gradient_accumulation_steps * args.num_train_epochs * args.train_batch_size
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)

    ## training 
    logger.info('*****Running training*******')
    logger.info(' Num examples = %d', len(train_dataset))
    logger.info(' Gradient Accumulation steps = %d', args.gradient_accumulation_steps)
    
    best_acc_f1 = 0
    if not os.path.exists(os.path.join(args.output_dir, args.log_path)):
        os.makedirs(os.path.join(args.output_dir, args.log_path))
    else:
        for file in os.listdir(os.path.join(args.output_dir, args.log_path)):
            os.remove(os.path.join(args.output_dir, args.log_path, file))
    train_loss_file = os.path.join(args.output_dir, args.log_path, 'train_loss_file.txt')
    train_acc_file = os.path.join(args.output_dir, args.log_path, 'train_acc_file.txt')
    eval_loss_file = os.path.join(args.output_dir, args.log_path, 'eval_loss_file.txt')
    for epoch in range(args.num_train_epochs):
        logger.info(' Num epochs = %d', epoch)
        train(args, train_dataset, model, optimizer, scheduler, args.device, tokenizer,train_loss_file,train_acc_file)
        results = evaluate(args, eval_dataset, model, args.device, tokenizer)
        with open(eval_loss_file, 'a+') as eval_writer:
            eval_writer.write('epoch:{}, lr: {}, eval_loss:{}, result: {}\n'.format(epoch, scheduler.get_lr()[0],results[0], results[1]))
        if results[1]['acc_and_f1'] > best_acc_f1:
            best_acc_f1 = results[1]['acc_and_f1']
            print('saving best model')
            model_to_save = model.module if hasattr(model, 'module') else model
            model_to_save.save_pretrained(os.path.join(args.output_dir, args.log_path))
            tokenizer.save_pretrained(os.path.join(args.output_dir, args.log_path))
            torch.save(args, os.path.join(args.output_dir, args.log_path, 'training_args_bert.bin'))
Exemple #12
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriterP(args.output_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    if args.lr_decay:
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=t_total)
    else:
        scheduler = WarmupConstantSchedule(optimizer,
                                           warmup_steps=args.warmup_steps)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    try:
        with open(os.path.join(args.model_name_or_path, 'step.txt'), 'r') as c:
            global_step = int(c.readline())
    except OSError as e:
        global_step = 0

    tr_loss, logging_loss = 0.0, 0.0
    moving_loss = MovingLoss(10000)
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproducibility (even between python 2 and 3)
    try:
        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=args.local_rank not in [-1, 0])
            for step, batch in enumerate(epoch_iterator):
                inputs, labels = mask_tokens(
                    batch, tokenizer, args) if args.mlm else (batch, batch)
                inputs = inputs.to(args.device)
                labels = labels.to(args.device)
                model.train()
                outputs = model(
                    inputs, masked_lm_labels=labels) if args.mlm else model(
                        inputs, labels=labels)
                loss = outputs[
                    0]  # model outputs are always tuple in pytorch-transformers (see doc)

                if args.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                tr_loss += loss.item()
                moving_loss.add(loss.item())
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training and global_step % args.eval_steps == 0:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer,
                                           f"step {global_step}")
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)

                    if args.local_rank in [
                            -1, 0
                    ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        tb_writer.add_scalar('lr',
                                             scheduler.get_lr()[0],
                                             global_step)
                        tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                             args.logging_steps, global_step)
                        logging_loss = tr_loss
                        logger.info(
                            f"Moving loss {moving_loss.loss:.2f}, perplexity {torch.exp(torch.tensor(moving_loss.loss)):.2f}"
                        )

                    if args.local_rank in [
                            -1, 0
                    ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                        # Save model checkpoint
                        save_state(args, model, tokenizer, global_step)

                if args.max_steps > 0 and global_step > args.max_steps:
                    epoch_iterator.close()
                    break
            print_sample(model, tokenizer, args.device)
            if args.max_steps > 0 and global_step > args.max_steps:
                train_iterator.close()
                break
    except (KeyboardInterrupt, SystemExit):
        save_state(args, model, tokenizer, global_step)
        raise

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemple #13
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter(join(args.output_dir, 'tensorboard'))

    model.resize_token_embeddings(len(tokenizer))
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = SequentialSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    if args.warmup_proportion:
        args.warmup_steps = int(t_total * args.warmup_proportion)

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)

    init_epoch = 0
    init_step, global_step = 0, 0
    tr_loss, logging_loss = 0.0, 0.0

    if args.resume_optimizer:
            if os.path.exists(os.path.join(args.model_name_or_path, 'optim_state.bin')):
                ckpt = torch.load(os.path.join(args.model_name_or_path, 'optim_state.bin'), map_location='cpu')
                optimizer.load_state_dict(ckpt['optimizer_state_dict'])
                scheduler.load_state_dict(ckpt['scheduler_state_dict'])
                init_epoch = ckpt['epoch']
                init_step, global_step = ckpt['step'], ckpt['global_step']
                tr_loss, logging_loss = ckpt['tr_loss'], ckpt['logging_loss']
                train_dataset = train_dataset[init_step:]
            else:
                logger.warning('optimizer state not found.')

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    model.zero_grad()

    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], initial=init_epoch)
    set_seed(args)  # Added here for reproducibility (even between python 2 and 3)
    
    for epoch_idx in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0], initial=step)
        for step, batch in enumerate(epoch_iterator):
            if step < 13312:
                pass
            else:
                inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
                inputs = inputs.to(args.device)
                labels = labels.to(args.device)
                model.train()
                outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
                loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

                if args.n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu parallel training
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                tr_loss += loss.item()
                if step % args.gradient_accumulation_steps == 0 and step != 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    # Log metrics
                    if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                        tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                        logging_loss = tr_loss
                    # Log evaluation

                    if args.evaluate_during_training and \
                            global_step % args.evaluation_steps == 0:
                        results = evaluate(args, model, tokenizer)
                        if args.local_rank in [-1, 0]:
                            for key, value in results.items():
                                tb_writer.add_scalar('eval_{}'.format(key), value, global_step)

                    if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                        # Save model checkpoint
                        output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(output_dir)
                        torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                        logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
def train(args, dataloaders, model, tokenizer, num_batches, observed_pairs,
          heldout_pairs, partition, lang_nns):
    """ Train the model """
    _, ex_counts = zip(*[x for x in sorted(num_batches.items())])
    if args.weight_by_size:
        sample_probs = ex_counts / np.linalg.norm(ex_counts, ord=1)
        ex_avg = sum(ex_counts) / len(ex_counts)
        lr_weights = {
            key: ex_avg / value
            for key, value in sorted(num_batches.items())
        }
    else:
        sample_probs = None

    num_batches = sum(ex_counts)

    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            num_batches // args.gradient_accumulation_steps) + 1
    else:
        t_total = num_batches // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight', 'mean',
                'logvar']  # No decay in Gaussian weights
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=int(args.warmup_proportion *
                                                      t_total),
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", num_batches * args.train_batch_size)
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    tr_kl, logging_kl = 0.0, 0.0
    best_valid = 0.0
    patience = 0
    model.zero_grad()
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for epoch in range(int(args.num_train_epochs)):
        step = 0
        while step < num_batches:
            if not args.largest_source:
                which = np.random.choice(len(observed_pairs), p=sample_probs)
                task, language = observed_pairs[which]
            else:
                task = random.choice(args.tasks)
                language = 'en_{}'.format(partition) if task == "ner" else 'en'
            bank = random.choice([
                k for k in dataloaders[task].keys() if k.startswith(language)
            ])
            batch = next(dataloaders[task][bank]['train'])
            model.train()
            batch = tuple(t.to(args.device) for t in batch)

            outputs = model(batch, task,
                            language if not args.largest_source else 'en')
            loss, _, kl_term = outputs

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training
                kl_term = kl_term.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
                kl_term = kl_term / args.gradient_accumulation_steps

            tr_loss += loss.item()

            if args.mode in ['lrcmeta', 'svimeta']:
                if args.scaling == "uniform":
                    scaling = 1. / t_total
                elif args.scaling == "linear_annealing":
                    scaling = ((t_total - step - 1) * 2. + 1.) / t_total**2
                elif args.scaling == "logistic_annealing":
                    steepness = 0.0025
                    scaling = 1. / (1 + np.exp(-steepness *
                                               (step - t_total / 2.)))
                loss = loss + scaling * kl_term
                tr_kl += kl_term.item()

            if args.local_rank in [
                    -1, 0
            ] and global_step % 1000 == 0 and global_step:
                logger.info("Epoch {} seen examples {} log-lik {}".format(
                    epoch, step * args.train_batch_size,
                    (tr_loss - logging_loss) / 1000.))
                tb_writer.add_scalar('lr_{}'.format(partition),
                                     scheduler.get_lr()[0], global_step)
                tb_writer.add_scalar('log-lik_{}'.format(partition),
                                     (tr_loss - logging_loss) / 1000.,
                                     global_step)
                if args.mode in ['lrcmeta', 'svimeta']:
                    logger.info("Epoch {} seen examples {} kl {}".format(
                        epoch, step * args.train_batch_size,
                        (tr_kl - logging_kl) / 1000.))
                    tb_writer.add_scalar('kl_{}'.format(partition),
                                         (tr_kl - logging_kl) / 1000.,
                                         global_step)
                logging_loss = tr_loss
                logging_kl = tr_kl

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            if args.debug:
                plot_grad_flow(model.named_parameters())

            if args.weight_by_size:
                lr2s = []
                for param_group in optimizer.param_groups:
                    lr2 = param_group['lr']
                    param_group['lr'] = lr2 * lr_weights[(task, language)]
                    lr2s.append(lr2)

            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args,
                                           dataloaders,
                                           model,
                                           tokenizer,
                                           'dev',
                                           heldout_pairs,
                                           lang_nns,
                                           partition,
                                           prefix=partition)
                        for task, value1 in results.items():
                            for language, value2 in value1.items():
                                for metric, value3 in value2.items():
                                    tb_writer.add_scalar(
                                        'eval_{}'.format("-".join(
                                            [task, language, metric])), value3,
                                        global_step)
                        overall_valid = np.mean([
                            results[task]['all']['f1'] for task in args.tasks
                        ])
                        if overall_valid > best_valid:
                            logger.info(
                                "New best validation! Average f1 {}".format(
                                    overall_valid))
                            checkpoint(model,
                                       args,
                                       affix='-best-{}'.format(partition),
                                       results=results)
                            best_valid = overall_valid
                            patience = 0
                        else:
                            patience += 1

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    checkpoint(model,
                               args,
                               affix='-latest-{}'.format(partition),
                               results=None)

                if args.weight_by_size:
                    for param_group, lr2 in zip(optimizer.param_groups, lr2s):
                        param_group['lr'] = lr2

            if (args.max_steps > 0 and global_step > args.max_steps
                ) or patience > args.max_patience:
                break
            step += 1
        if (args.max_steps > 0 and
                global_step > args.max_steps) or patience > args.max_patience:
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemple #15
0
class Distiller:
    def __init__(self, params: dict, dataset: LmSeqsDataset,
                 token_probs: torch.tensor, student: nn.Module,
                 teacher: nn.Module):
        logger.info('Initializing Distiller')
        self.params = params
        self.dump_path = params.dump_path
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

        self.student_config = student.config
        self.vocab_size = student.config.vocab_size

        if params.n_gpu <= 1:
            sampler = RandomSampler(dataset)
        else:
            sampler = DistributedSampler(dataset)

        if params.group_by_size:
            groups = create_lengths_groups(lengths=dataset.lengths,
                                           k=params.max_model_input_size)
            sampler = GroupedBatchSampler(sampler=sampler,
                                          group_ids=groups,
                                          batch_size=params.batch_size)
        else:
            sampler = BatchSampler(sampler=sampler,
                                   batch_size=params.batch_size,
                                   drop_last=False)

        self.dataloader = DataLoader(dataset=dataset,
                                     batch_sampler=sampler,
                                     collate_fn=dataset.batch_sequences)

        self.temperature = params.temperature
        assert self.temperature > 0.

        self.alpha_ce = params.alpha_ce
        self.alpha_mlm = params.alpha_mlm
        self.alpha_clm = params.alpha_clm
        self.alpha_mse = params.alpha_mse
        self.alpha_cos = params.alpha_cos

        self.mlm = params.mlm
        if self.mlm:
            logger.info(f'Using MLM loss for LM step.')
            self.mlm_mask_prop = params.mlm_mask_prop
            assert 0.0 <= self.mlm_mask_prop <= 1.0
            assert params.word_mask + params.word_keep + params.word_rand == 1.0
            self.pred_probs = torch.FloatTensor(
                [params.word_mask, params.word_keep, params.word_rand])
            self.pred_probs = self.pred_probs.to(
                f'cuda:{params.local_rank}'
            ) if params.n_gpu > 0 else self.pred_probs
            self.token_probs = token_probs.to(
                f'cuda:{params.local_rank}'
            ) if params.n_gpu > 0 else token_probs
            if self.fp16:
                self.pred_probs = self.pred_probs.half()
                self.token_probs = self.token_probs.half()
        else:
            logger.info(f'Using CLM loss for LM step.')

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0
        self.last_loss_clm = 0
        if self.alpha_mse > 0.: self.last_loss_mse = 0
        if self.alpha_cos > 0.: self.last_loss_cos = 0
        self.last_log = 0

        self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
        self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
        if self.alpha_mse > 0.:
            self.mse_loss_fct = nn.MSELoss(reduction='sum')
        if self.alpha_cos > 0.:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean')

        logger.info('--- Initializing model optimizer')
        assert params.gradient_accumulation_steps >= 1
        self.num_steps_epoch = len(self.dataloader)
        num_train_optimization_steps = int(
            self.num_steps_epoch / params.gradient_accumulation_steps *
            params.n_epoch) + 1

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in student.named_parameters()
                if not any(nd in n for nd in no_decay) and p.requires_grad
            ],
            'weight_decay':
            params.weight_decay
        }, {
            'params': [
                p for n, p in student.named_parameters()
                if any(nd in n for nd in no_decay) and p.requires_grad
            ],
            'weight_decay':
            0.0
        }]
        logger.info(
            "------ Number of trainable parameters (student): %i" % sum([
                p.numel() for p in self.student.parameters() if p.requires_grad
            ]))
        logger.info("------ Number of parameters (student): %i" %
                    sum([p.numel() for p in self.student.parameters()]))
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=params.learning_rate,
                               eps=params.adam_epsilon,
                               betas=(0.9, 0.98))

        warmup_steps = math.ceil(num_train_optimization_steps *
                                 params.warmup_prop)
        self.scheduler = WarmupLinearSchedule(
            self.optimizer,
            warmup_steps=warmup_steps,
            t_total=num_train_optimization_steps)

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            logger.info(
                f"Using fp16 training: {self.params.fp16_opt_level} level")
            self.student, self.optimizer = amp.initialize(
                self.student,
                self.optimizer,
                opt_level=self.params.fp16_opt_level)
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel
                logger.info(
                    "Using apex.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel
                logger.info(
                    "Using nn.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(
                    self.student,
                    device_ids=[params.local_rank],
                    output_device=params.local_rank,
                    find_unused_parameters=True)

        self.is_master = params.is_master
        if self.is_master:
            logger.info('--- Initializing Tensorboard')
            self.tensorboard = SummaryWriter(
                log_dir=os.path.join(self.dump_path, 'log', 'train'))
            self.tensorboard.add_text(tag='config/training',
                                      text_string=str(self.params),
                                      global_step=0)
            self.tensorboard.add_text(tag='config/student',
                                      text_string=str(self.student_config),
                                      global_step=0)

    def prepare_batch_mlm(self, batch):
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = (torch.arange(token_ids.size(1),
                                  dtype=torch.long,
                                  device=lengths.device) < lengths[:, None])

        bs, max_seq_len = token_ids.size()
        mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)

        x_prob = self.token_probs[token_ids.flatten()]
        n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
        tgt_ids = torch.multinomial(x_prob / x_prob.sum(),
                                    n_tgt,
                                    replacement=False)
        pred_mask = torch.zeros(
            bs * max_seq_len, dtype=torch.bool, device=token_ids.device
        )  # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
        pred_mask[tgt_ids] = 1
        pred_mask = pred_mask.view(bs, max_seq_len)

        pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0

        # mask a number of words == 0 [8] (faster with fp16)
        if self.fp16:
            n1 = pred_mask.sum().item()
            if n1 > 8:
                pred_mask = pred_mask.view(-1)
                n2 = max(n1 % 8, 8 * (n1 // 8))
                if n2 != n1:
                    pred_mask[torch.nonzero(pred_mask).view(-1)[:n1 - n2]] = 0
                pred_mask = pred_mask.view(bs, max_seq_len)
                assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()

        _token_ids_real = token_ids[pred_mask]
        _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
        _token_ids_mask = _token_ids_real.clone().fill_(
            self.params.special_tok_ids['mask_token'])
        probs = torch.multinomial(self.pred_probs,
                                  len(_token_ids_real),
                                  replacement=True)
        _token_ids = _token_ids_mask * (
            probs == 0).long() + _token_ids_real * (
                probs == 1).long() + _token_ids_rand * (probs == 2).long()
        token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

        mlm_labels[
            ~pred_mask] = -1  # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, mlm_labels

    def prepare_batch_clm(self, batch):
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = (torch.arange(token_ids.size(1),
                                  dtype=torch.long,
                                  device=lengths.device) < lengths[:, None])
        clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
        clm_labels[
            ~attn_mask] = -1  # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, clm_labels

    def round_batch(self, x: torch.tensor, lengths: torch.tensor):
        """
        For float16 only.
        Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.

        Input:
        ------
            x: `torch.tensor(bs, seq_length)` - The token ids.
            lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.

        Output:
        -------
            x:  `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
            lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
        """
        if not self.fp16 or len(lengths) < 8:
            return x, lengths

        # number of sentences == 0 [8]
        bs1 = len(lengths)
        bs2 = 8 * (bs1 // 8)
        assert bs2 > 0 and bs2 % 8 == 0
        if bs1 != bs2:
            idx = torch.randperm(bs1)[:bs2]
            lengths = lengths[idx]
            slen = lengths.max().item()
            x = x[idx, :slen]
        else:
            idx = None

        # sequence length == 0 [8]
        ml1 = x.size(1)
        if ml1 % 8 != 0:
            pad = 8 - (ml1 % 8)
            ml2 = ml1 + pad
            if self.mlm:
                pad_id = self.params.special_tok_ids['pad_token']
            else:
                pad_id = self.params.special_tok_ids['unk_token']
            padding_tensor = torch.zeros(bs2,
                                         pad,
                                         dtype=torch.long,
                                         device=x.device).fill_(pad_id)
            x = torch.cat([x, padding_tensor], 1)
            assert x.size() == (bs2, ml2)

        assert x.size(0) % 8 == 0
        assert x.size(1) % 8 == 0
        return x, lengths

    def train(self):
        """
        The real training loop.
        """
        if self.is_master: logger.info('Starting training')
        self.last_log = time.time()
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
            if self.is_master:
                logger.info(
                    f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')
            if self.multi_gpu:
                torch.distributed.barrier()

            iter_bar = tqdm(self.dataloader,
                            desc="-Iter",
                            disable=self.params.local_rank not in [-1, 0])
            for batch in iter_bar:
                if self.params.n_gpu > 0:
                    batch = tuple(
                        t.to(f'cuda:{self.params.local_rank}') for t in batch)

                if self.mlm:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(
                        batch=batch)
                else:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_clm(
                        batch=batch)
                self.step(input_ids=token_ids,
                          attention_mask=attn_mask,
                          lm_labels=lm_labels)

                iter_bar.update()
                iter_bar.set_postfix({
                    'Last_loss':
                    f'{self.last_loss:.2f}',
                    'Avg_cum_loss':
                    f'{self.total_loss_epoch/self.n_iter:.2f}'
                })
            iter_bar.close()

            if self.is_master:
                logger.info(
                    f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}')
            self.end_epoch()

        if self.is_master:
            logger.info(f'Save very last checkpoint as `pytorch_model.bin`.')
            self.save_checkpoint(checkpoint_name=f'pytorch_model.bin')
            logger.info('Training is finished')

    def step(self, input_ids: torch.tensor, attention_mask: torch.tensor,
             lm_labels: torch.tensor):
        """
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
        lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
        """
        if self.mlm:
            s_logits, s_hidden_states = self.student(
                input_ids=input_ids,
                attention_mask=attention_mask)  # (bs, seq_length, voc_size)
            with torch.no_grad():
                t_logits, t_hidden_states = self.teacher(
                    input_ids=input_ids, attention_mask=attention_mask
                )  # (bs, seq_length, voc_size)
        else:
            s_logits, _, s_hidden_states = self.student(
                input_ids=input_ids,
                attention_mask=None)  # (bs, seq_length, voc_size)
            with torch.no_grad():
                t_logits, _, t_hidden_states = self.teacher(
                    input_ids=input_ids,
                    attention_mask=None)  # (bs, seq_length, voc_size)
        assert s_logits.size() == t_logits.size()

        #https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
        #https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
        if self.params.restrict_ce_to_mask:
            mask = (lm_labels > -1).unsqueeze(-1).expand_as(
                s_logits)  # (bs, seq_lenth, voc_size)
        else:
            mask = attention_mask.unsqueeze(-1).expand_as(
                s_logits)  # (bs, seq_lenth, voc_size)
        s_logits_slct = torch.masked_select(
            s_logits,
            mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, s_logits.size(
            -1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(
            t_logits,
            mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, s_logits.size(
            -1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        assert t_logits_slct.size() == s_logits_slct.size()

        loss_ce = self.ce_loss_fct(
            F.log_softmax(s_logits_slct / self.temperature, dim=-1),
            F.softmax(t_logits_slct / self.temperature,
                      dim=-1)) * (self.temperature)**2
        loss = self.alpha_ce * loss_ce

        if self.alpha_mlm > 0.:
            loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)),
                                        lm_labels.view(-1))
            loss += self.alpha_mlm * loss_mlm
        if self.alpha_clm > 0.:
            shift_logits = s_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
            loss_clm = self.lm_loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1))
            loss += self.alpha_clm * loss_clm

        if self.alpha_mse > 0.:
            loss_mse = self.mse_loss_fct(
                s_logits_slct, t_logits_slct) / s_logits_slct.size(
                    0)  # Reproducing batchmean reduction
            loss += self.alpha_mse * loss_mse
        if self.alpha_cos > 0.:
            s_hidden_states = s_hidden_states[-1]  # (bs, seq_length, dim)
            t_hidden_states = t_hidden_states[-1]  # (bs, seq_length, dim)
            mask = attention_mask.unsqueeze(-1).expand_as(
                s_hidden_states)  # (bs, seq_length, dim)
            assert s_hidden_states.size() == t_hidden_states.size()
            dim = s_hidden_states.size(-1)

            s_hidden_states_slct = torch.masked_select(
                s_hidden_states, mask)  # (bs * seq_length * dim)
            s_hidden_states_slct = s_hidden_states_slct.view(
                -1, dim)  # (bs * seq_length, dim)
            t_hidden_states_slct = torch.masked_select(
                t_hidden_states, mask)  # (bs * seq_length * dim)
            t_hidden_states_slct = t_hidden_states_slct.view(
                -1, dim)  # (bs * seq_length, dim)

            target = s_hidden_states_slct.new(
                s_hidden_states_slct.size(0)).fill_(1)  # (bs * seq_length,)
            loss_cos = self.cosine_loss_fct(s_hidden_states_slct,
                                            t_hidden_states_slct, target)
            loss += self.alpha_cos * loss_cos

        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
        if self.alpha_mlm > 0.:
            self.last_loss_mlm = loss_mlm.item()
        if self.alpha_clm > 0.:
            self.last_loss_clm = loss_clm.item()
        if self.alpha_mse > 0.:
            self.last_loss_mse = loss_mse.item()
        if self.alpha_cos > 0.:
            self.last_loss_cos = loss_cos.item()

        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

    def optimize(self, loss):
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
            logger.error('NaN detected')
            exit()

        if self.multi_gpu:
            loss = loss.mean()
        if self.params.gradient_accumulation_steps > 1:
            loss = loss / self.params.gradient_accumulation_steps

        if self.fp16:
            from apex import amp
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.iter()
        if self.n_iter % self.params.gradient_accumulation_steps == 0:
            if self.fp16:
                torch.nn.utils.clip_grad_norm_(
                    amp.master_params(self.optimizer),
                    self.params.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(self.student.parameters(),
                                               self.params.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()

    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        self.n_iter += 1
        self.n_total_iter += 1

        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
            self.last_log = time.time()
        if self.n_total_iter % self.params.checkpoint_interval == 0:
            self.save_checkpoint()

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """
        if not self.is_master:
            return

        for param_name, param in self.student.named_parameters():
            self.tensorboard.add_scalar(tag='parameter_mean/' + param_name,
                                        scalar_value=param.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag='parameter_std/' + param_name,
                                        scalar_value=param.data.std(),
                                        global_step=self.n_total_iter)
            if param.grad is None:
                continue
            self.tensorboard.add_scalar(tag="grad_mean/" + param_name,
                                        scalar_value=param.grad.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag="grad_std/" + param_name,
                                        scalar_value=param.grad.data.std(),
                                        global_step=self.n_total_iter)

        self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch",
                                    scalar_value=self.total_loss_epoch /
                                    self.n_iter,
                                    global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss",
                                    scalar_value=self.last_loss,
                                    global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss_ce",
                                    scalar_value=self.last_loss_ce,
                                    global_step=self.n_total_iter)
        if self.alpha_mlm > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_mlm",
                                        scalar_value=self.last_loss_mlm,
                                        global_step=self.n_total_iter)
        if self.alpha_clm > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_clm",
                                        scalar_value=self.last_loss_clm,
                                        global_step=self.n_total_iter)
        if self.alpha_mse > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_mse",
                                        scalar_value=self.last_loss_mse,
                                        global_step=self.n_total_iter)
        if self.alpha_cos > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_cos",
                                        scalar_value=self.last_loss_cos,
                                        global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="learning_rate/lr",
                                    scalar_value=self.scheduler.get_lr()[0],
                                    global_step=self.n_total_iter)

        self.tensorboard.add_scalar(
            tag="global/memory_usage",
            scalar_value=psutil.virtual_memory()._asdict()['used'] / 1_000_000,
            global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="global/speed",
                                    scalar_value=time.time() - self.last_log,
                                    global_step=self.n_total_iter)

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
        logger.info(
            f'{self.n_sequences_epoch} sequences have been trained during this epoch.'
        )

        if self.is_master:
            self.save_checkpoint(
                checkpoint_name=f'model_epoch_{self.epoch}.pth')
            self.tensorboard.add_scalar(tag='epoch/loss',
                                        scalar_value=self.total_loss_epoch /
                                        self.n_iter,
                                        global_step=self.epoch)

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

    def save_checkpoint(self, checkpoint_name: str = 'checkpoint.pth'):
        """
        Save the current state. Only by the master process.
        """
        if not self.is_master:
            return
        mdl_to_save = self.student.module if hasattr(
            self.student, 'module') else self.student
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    # 验证结果
    output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            ascii=True)
    set_seed(args)  # 设定随机种子,便于复现
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", ascii=True)
        logger.info(f"Epoch {_+1}\n")
        with open(output_eval_file, "a") as writer:
            writer.write(f"Epoch {_+1}\n")
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'labels': batch[3]
            }
            if args.model_type != 'distilbert':
                inputs['token_type_ids'] = batch[2] if args.model_type in [
                    'bert', 'xlnet'
                ] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
            outputs = model(**inputs)
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        with open(output_eval_file, "a") as writer:
                            train_loss = (tr_loss -
                                          logging_loss) / args.logging_steps
                            logger.info(
                                f"\tglobal step:{global_step} - loss: {train_loss:.4f} - val_loss: {results['loss']:.4f} - val_acc: {results['acc']:.4f} - val_f1: {results['f1']:.4f}\n"
                            )
                            writer.write(
                                f"\tglobal step:{global_step} - loss: {train_loss:.4f} - val_loss: {results['loss']:.4f} - val_acc: {results['acc']:.4f} - val_f1: {results['f1']:.4f}\n"
                            )
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    # Take care of distributed/parallel training
                    model_to_save = model.module if hasattr(
                        model, 'module') else model
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    tb_writer.close()

    return global_step, tr_loss / global_step
Exemple #17
0
    def fit(
        self,
        token_ids,
        input_mask,
        labels,
        val_token_ids,
        val_input_mask,
        val_labels,
        token_type_ids=None,
        val_token_type_ids=None,
        verbose=True,
        logging_steps=0,
        save_steps=0,
        val_steps=0,
    ):
        """Fine-tunes the XLNet classifier using the given training data.

        Args:
            token_ids (list): List of training token id lists.
            input_mask (list): List of input mask lists.
            labels (list): List of training labels.
            token_type_ids (list, optional): List of lists. Each sublist
                contains segment ids indicating if the token belongs to
                the first sentence(0) or second sentence(1). Only needed
                for two-sentence tasks.
            verbose (bool, optional): If True, shows the training progress and
                loss values. Defaults to True.
        """

        device, num_gpus = get_device(self.num_gpus)
        self.model = move_to_device(self.model, device, self.num_gpus)

        token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
        input_mask_tensor = torch.tensor(input_mask, dtype=torch.long)
        labels_tensor = torch.tensor(labels, dtype=torch.long)

        val_token_ids_tensor = torch.tensor(val_token_ids, dtype=torch.long)
        val_input_mask_tensor = torch.tensor(val_input_mask, dtype=torch.long)
        val_labels_tensor = torch.tensor(val_labels, dtype=torch.long)

        if token_type_ids:
            token_type_ids_tensor = torch.tensor(token_type_ids, dtype=torch.long)
            val_token_type_ids_tensor = torch.tensor(val_token_type_ids, dtype=torch.long)

            train_dataset = TensorDataset(
                token_ids_tensor, input_mask_tensor, token_type_ids_tensor, labels_tensor
            )

            val_dataset = TensorDataset(
                val_token_ids_tensor,
                val_input_mask_tensor,
                val_token_type_ids_tensor,
                val_labels_tensor,
            )

        else:

            train_dataset = TensorDataset(token_ids_tensor, input_mask_tensor, labels_tensor)

            val_dataset = TensorDataset(
                val_token_ids_tensor, val_input_mask_tensor, val_labels_tensor
            )

        # define optimizer and model parameters
        param_optimizer = list(self.model.named_parameters())
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                "weight_decay": self.weight_decay,
            },
            {
                "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]

        val_sampler = RandomSampler(val_dataset)

        val_dataloader = DataLoader(val_dataset, sampler=val_sampler, batch_size=self.batch_size)

        num_examples = len(token_ids)
        num_batches = int(np.ceil(num_examples / self.batch_size))
        num_train_optimization_steps = num_batches * self.num_epochs

        optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr, eps=self.adam_eps)
        scheduler = WarmupLinearSchedule(
            optimizer, warmup_steps=self.warmup_steps, t_total=num_train_optimization_steps
        )

        global_step = 0
        self.model.train()
        optimizer.zero_grad()
        for epoch in range(self.num_epochs):

            train_sampler = RandomSampler(train_dataset)

            train_dataloader = DataLoader(
                train_dataset, sampler=train_sampler, batch_size=self.batch_size
            )

            tr_loss = 0.0
            logging_loss = 0.0
            val_loss = 0.0

            for i, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                if token_type_ids:
                    x_batch, mask_batch, token_type_ids_batch, y_batch = tuple(
                        t.to(device) for t in batch
                    )
                else:
                    token_type_ids_batch = None
                    x_batch, mask_batch, y_batch = tuple(t.to(device) for t in batch)

                outputs = self.model(
                    input_ids=x_batch,
                    token_type_ids=token_type_ids_batch,
                    attention_mask=mask_batch,
                    labels=y_batch,
                )

                loss = outputs[0]  # model outputs are always tuple in pytorch-transformers

                loss.sum().backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

                tr_loss += loss.sum().item()
                optimizer.step()
                # Update learning rate schedule
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1
                # logging of learning rate and loss
                if logging_steps > 0 and global_step % logging_steps == 0:
                    mlflow.log_metric("learning rate", scheduler.get_lr()[0], step=global_step)
                    mlflow.log_metric(
                        "training loss",
                        (tr_loss - logging_loss) / (logging_steps * self.batch_size),
                        step=global_step,
                    )
                    logging_loss = tr_loss
                # model checkpointing
                if save_steps > 0 and global_step % save_steps == 0:
                    checkpoint_dir = os.path.join(os.getcwd(), "checkpoints")
                    if not os.path.isdir(checkpoint_dir):
                        os.makedirs(checkpoint_dir)
                    checkpoint_path = checkpoint_dir + "/" + str(global_step) + ".pth"
                    torch.save(self.model.state_dict(), checkpoint_path)
                    mlflow.log_artifact(checkpoint_path)
                # model validation
                if val_steps > 0 and global_step % val_steps == 0:
                    # run model on validation set
                    self.model.eval()
                    val_loss = 0.0
                    for j, val_batch in enumerate(val_dataloader):
                        if token_type_ids:
                            val_x_batch, val_mask_batch, val_token_type_ids_batch, val_y_batch = tuple(
                                t.to(device) for t in val_batch
                            )
                        else:
                            token_type_ids_batch = None
                            val_x_batch, val_mask_batch, val_y_batch = tuple(
                                t.to(device) for t in val_batch
                            )
                        val_outputs = self.model(
                            input_ids=val_x_batch,
                            token_type_ids=val_token_type_ids_batch,
                            attention_mask=val_mask_batch,
                            labels=val_y_batch,
                        )
                        vloss = val_outputs[0]
                        val_loss += vloss.sum().item()
                    mlflow.log_metric(
                        "validation loss", val_loss / len(val_dataset), step=global_step
                    )
                    self.model.train()

                if verbose:
                    if i % ((num_batches // 10) + 1) == 0:
                        if val_loss > 0:
                            print(
                                "epoch:{}/{}; batch:{}->{}/{}; average training loss:{:.6f};\
                                 average val loss:{:.6f}".format(
                                    epoch + 1,
                                    self.num_epochs,
                                    i + 1,
                                    min(i + 1 + num_batches // 10, num_batches),
                                    num_batches,
                                    tr_loss / (i + 1),
                                    val_loss / (j + 1),
                                )
                            )
                        else:
                            print(
                                "epoch:{}/{}; batch:{}->{}/{}; average train loss:{:.6f}".format(
                                    epoch + 1,
                                    self.num_epochs,
                                    i + 1,
                                    min(i + 1 + num_batches // 10, num_batches),
                                    num_batches,
                                    tr_loss / (i + 1),
                                )
                            )
        checkpoint_dir = os.path.join(os.getcwd(), "checkpoints")
        if not os.path.isdir(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        checkpoint_path = checkpoint_dir + "/" + "final" + ".pth"
        torch.save(self.model.state_dict(), checkpoint_path)
        mlflow.log_artifact(checkpoint_path)
        # empty cache
        del [x_batch, y_batch, mask_batch, token_type_ids_batch]
        if val_steps > 0:
            del [val_x_batch, val_y_batch, val_mask_batch, val_token_type_ids_batch]
        torch.cuda.empty_cache()
Exemple #18
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='openai-gpt',
                        help='pretrained model name')
    parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument('--train_dataset', type=str, default='')
    parser.add_argument('--eval_dataset', type=str, default='')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--num_train_epochs', type=int, default=3)
    parser.add_argument('--train_batch_size', type=int, default=8)
    parser.add_argument('--eval_batch_size', type=int, default=16)
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument('--max_grad_norm', type=int, default=1)
    parser.add_argument("--max_steps", default=-1, type=int,
                        help="If > 0: set total number of training \
                        steps to perform. Override num_train_epochs.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before\
                        performing a backward/update pass.")
    parser.add_argument('--learning_rate', type=float, default=6.25e-5)
    parser.add_argument("--warmup_steps", default=0, type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--lm_coef', type=float, default=0.9)
    parser.add_argument('--n_valid', type=int, default=374)

    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    args = parser.parse_args()
    print(args)

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    logger.info("device: {}, n_gpu {}".format(device, n_gpu))

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Load tokenizer and model
    # This loading functions also add new tokens and embeddings called `special tokens`
    # These new embeddings will be fine-tuned on the RocStories dataset
    special_tokens = ['_start_', '_delimiter_', '_classify_']
    tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name)
    tokenizer.add_tokens(special_tokens)
    special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens)
    model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.model_name)
    model.resize_token_embeddings(len(tokenizer))
    model.to(device)

    # Load and encode the datasets
    if not args.train_dataset and not args.eval_dataset:
        roc_stories = cached_path(ROCSTORIES_URL)
    def tokenize_and_encode(obj):
        """ Tokenize and encode a nested object """
        if isinstance(obj, str):
            return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
        elif isinstance(obj, int):
            return obj
        return list(tokenize_and_encode(o) for o in obj)
    logger.info("Encoding dataset...")
    train_dataset = load_rocstories_dataset(args.train_dataset)
    eval_dataset = load_rocstories_dataset(args.eval_dataset)
    datasets = (train_dataset, eval_dataset)
    encoded_datasets = tokenize_and_encode(datasets)

    # Compute the max input length for the Transformer
    max_length = model.config.n_positions // 2 - 2
    input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3  \
                           for dataset in encoded_datasets for story, cont1, cont2, _ in dataset)
    input_length = min(input_length, model.config.n_positions)  # Max size of input for the pre-trained model

    # Prepare inputs tensors and dataloaders
    tensor_datasets = pre_process_datasets(encoded_datasets, input_length, max_length, *special_tokens_ids)
    train_tensor_dataset, eval_tensor_dataset = tensor_datasets[0], tensor_datasets[1]

    train_data = TensorDataset(*train_tensor_dataset)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

    eval_data = TensorDataset(*eval_tensor_dataset)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # Prepare optimizer
    if args.do_train:
        if args.max_steps > 0:
            t_total = args.max_steps
            args.num_train_epochs = args.max_steps //\
                (len(train_dataloader) // args.gradient_accumulation_steps) + 1
        else:
            t_total = len(train_dataloader)\
                // args.gradient_accumulation_steps * args.num_train_epochs

        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
        scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)

    if args.do_train:
        nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_steps = 0
            tqdm_bar = tqdm(train_dataloader, desc="Training")
            for step, batch in enumerate(tqdm_bar):
                batch = tuple(t.to(device) for t in batch)
                input_ids, mc_token_ids, lm_labels, mc_labels = batch
                losses = model(input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels)
                loss = args.lm_coef * losses[0] + losses[1]
                loss.backward()
                scheduler.step()
                optimizer.step()
                optimizer.zero_grad()
                tr_loss += loss.item()
                exp_average_loss = loss.item() if exp_average_loss is None else 0.7*exp_average_loss+0.3*loss.item()
                nb_tr_steps += 1
                tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, scheduler.get_lr()[0])

    # Save a trained model
    if args.do_train:
        # Save a trained model, configuration and tokenizer
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model itself

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        tokenizer.save_vocabulary(args.output_dir)

        # Load a trained model and vocabulary that you have fine-tuned
        model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.output_dir)
        tokenizer = OpenAIGPTTokenizer.from_pretrained(args.output_dir)
        model.to(device)

    if args.do_eval:
        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            batch = tuple(t.to(device) for t in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels = batch
            with torch.no_grad():
               _, mc_loss, _, mc_logits = model(input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels)

            mc_logits = mc_logits.detach().cpu().numpy()
            mc_labels = mc_labels.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(mc_logits, mc_labels)

            eval_loss += mc_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        train_loss = tr_loss/nb_tr_steps if args.do_train else None
        result = {'eval_loss': eval_loss,
                  'eval_accuracy': eval_accuracy,
                  'train_loss': train_loss}

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
 logits = model(inputs)[0]
 idx = batch['sum_idx'].item() # index of separator token
 # only consider loss on reference summary just like seq2seq models
 shift_logits = logits[..., idx:-1, :].contiguous()
 shift_labels = labels[..., idx+1:].contiguous()
 loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
 loss = loss/args.gradient_accumulation_steps
 loss.backward()
 torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
 tr_loss += loss.item()
 if (step + 1) % args.gradient_accumulation_steps == 0:
     optimizer.step()
     scheduler.step()  # Update learning rate schedule
     model.zero_grad()
     global_step += 1
     writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
     writer.add_scalar('loss', (tr_loss - logging_loss)/args.gradient_accumulation_steps, global_step)
     logging_loss = tr_loss
     print("loss:", loss.item(), end='\n\n')
     if (step + 1)/args.gradient_accumulation_steps == 1.0:
     	print('After 1st update: ', end='\n\n')
     	generate_sample(valid_dataset, tokenizer, num=2, eval_step=False,device=args.device)
     
     
 if (step + 1) % (10*args.gradient_accumulation_steps) == 0:
     results = evaluate(args, model, valid_dataset, ignore_index, global_step)
     for key, value in results.items():
         writer.add_scalar('eval_{}'.format(key), value, global_step)
     print('After', global_step+1,'updates: ', end='\n\n')
     generate_sample(valid_dataset, tokenizer, num=2, eval_step=True,device=args.device)
         
def train(train_dataset, model, device, eval_dataset=None):
    if args.use_tensorboard:
        tb_writer = SummaryWriter()

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.train_batch_size,
                                  sampler=train_sampler)
    t_total = len(
        train_dataloader) // args.gradient_accumulation_steps * args.epochs
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0
    }]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)

    logger.info("***** Running training *****")
    logger.info("example number: {}".format(len(train_dataset)))
    logger.info("batch size: {}".format(args.train_batch_size))
    logger.info("epoch size: {}".format(args.epochs))
    logger.info("gradient accumulation step number: {}".format(
        args.gradient_accumulation_steps))
    logger.info("total step number: {}".format(t_total))
    logger.info("warmup step number: {}".format(args.warmup_steps))
    global_step = 0
    tr_loss, logging_loss = 0, 0
    for epoch in range(1, args.epochs + 1):
        logger.info("##### Epoch {} #####".format(epoch))
        epoch_iterator = tqdm(train_dataloader, desc="Training")
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "labels": batch[3]
            }
            outputs = model(**inputs)
            loss = outputs[0]

            if torch.cuda.device_count() > 1:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss /= args.gradient_accumulation_steps
            loss.backward()
            tr_loss += loss.item()

            # Logging
            if global_step % args.logging_steps == 0:
                # Write some info to tensorboard.
                if args.use_tensorboard:
                    model_to_save = model.module if hasattr(
                        model, "module") else model
                    tb_writer.add_histogram("classifier.weight",
                                            model_to_save.classifier.weight,
                                            global_step)
                    tb_writer.add_histogram("classifier.bias",
                                            model_to_save.classifier.bias,
                                            global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("train_loss",
                                         (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                logging_loss = tr_loss
                # Evaluation.
                if eval_dataset is not None:
                    result = evaluate(eval_dataset, model, device)
                    logger.info("eval accuracy: {}, eval loss: {}".format(
                        result["acc"], result["loss"]))
                    for k, v in result.items():
                        tb_writer.add_scalar("eval_{}".format(k), v,
                                             global_step)

            # Update parameters.
            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1
    logger.info("***** Finish Training! *****")
Exemple #21
0
def train_squad(args, tokenizer, model):
    # open to new log file (need modify with logging but later)
    w_log_file = open(args.path_log_file, "a")

    if not args.load_data_from_pt:
        train_dataset, train_dataloader = load_squad_to_torch_dataset(
            args.path_input_train_data,
            tokenizer,
            args.max_seq_length,
            args.max_query_length,
            args.batch_size,
            is_training=True)
        torch.save(train_dataset, args.path_pt_train_dataset)

    else:
        train_dataset = torch.load(args.path_pt_train_dataset)
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.batch_size)

    t_total = len(train_dataloader
                  ) // args.gradient_accumulation_steps * args.num_train_epochs

    w_log_file.write("Number train sample: {}\n".format(len(train_dataset)))
    w_log_file.write("Load dataset done !!!\n")
    print("Load dataset done !!!")

    if not args.no_cuda:
        device = torch.device(args.device)
    else:
        device = torch.device("cpu")

    args.device = device

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Train!
    w_log_file.write("***** Running training *****\n")
    w_log_file.write("  Num examples = {}".format(len(train_dataset)))
    w_log_file.write("  Num Epochs = {}".format(args.num_train_epochs))
    w_log_file.write("  Gradient Accumulation steps = {}".format(
        args.gradient_accumulation_steps))
    w_log_file.write("  Total optimization steps = {}".format(t_total))

    n_epoch = 0
    global_step = 0

    model.zero_grad()
    set_seed(args)

    for _ in range(args.num_train_epochs):
        l_full_target = []
        l_full_predict = []
        tr_loss, logging_loss = 0.0, 0.0

        epoch_iterator = tqdm(train_dataloader,
                              desc="training ...",
                              leave=False)
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'token_type_ids': batch[2],
                'label': batch[3]
            }

            loss, l_predict, l_target = model.loss(inputs['input_ids'],
                                                   inputs['attention_mask'],
                                                   inputs['token_type_ids'],
                                                   inputs['label'])
            l_full_target.extend(l_target)
            l_full_predict.extend(l_predict)

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    step == len(train_dataloader) - 1):
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if (args.save_steps > 0 and global_step % args.save_steps == 0 or \
                        step >= int(len(train_dataloader) - 1)) and \
                        global_step > (2/3 * (len(train_dataloader) / (args.batch_size * args.gradient_accumulation_steps))):
                    line_start_logging = "Log write at epoch: {}, step: {} and lr: {}\n".format(
                        n_epoch, global_step, round(scheduler.get_lr()[0], 6))
                    print(line_start_logging)
                    w_log_file.write(line_start_logging)

                    f1_score_micro = f1_score(l_full_target, l_full_predict)
                    accuracy = accuracy_score(l_full_target, l_full_predict)

                    output_train = {
                        "loss": round(tr_loss / len(train_dataset), 3),
                        "accuracy": round(accuracy, 3),
                        "f1": round(f1_score_micro, 3)
                    }
                    line_log_train = "train result - loss: {}, acc: {}, f1: {}\n".format(
                        output_train['loss'], output_train['accuracy'],
                        output_train['f1'])
                    print(line_log_train)
                    w_log_file.write(line_log_train)

                    if args.path_input_test_data is not None:
                        w_log_file.write("Start evaluating test data !!\n")
                        output_test = evaluate(args,
                                               model,
                                               tokenizer,
                                               is_test=True)
                        line_log_test = "test result - loss: {}, acc: {}, f1: {}\n".format(
                            output_test['loss'], output_test['accuracy'],
                            output_test['f1'])
                        print(line_log_test)
                        w_log_file.write(line_log_test)

                    if args.path_input_validation_data is not None:
                        w_log_file.write(
                            "Start evaluating validation data !!\n")
                        output_validation = evaluate(args,
                                                     model,
                                                     tokenizer,
                                                     is_test=False)
                        line_log_val = "test result - loss: {}, acc: {}, f1: {}\n".format(
                            output_validation['loss'],
                            output_validation['accuracy'],
                            output_validation['f1'])
                        print(line_log_val)
                        w_log_file.write(line_log_val)
                    line_end_logging = "end for logging current step {} !!!".format(
                        global_step)
                    print(line_end_logging)

                    w_log_file.write(line_end_logging)

                    prefix_dir_save = "epoch{}_step{}".format(
                        n_epoch, global_step)
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, prefix_dir_save)
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module') else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    w_log_file.write(
                        "Saving model checkpoint to {}".format(output_dir))

        n_epoch += 1
    w_log_file.close()
Exemple #22
0
def train(model, dataset, args):
    tb_writer = SummaryWriter()

    train_num = int(args.eval_precent*len(dataset))
    eval_num = len(dataset) - train_num
    train_set, eval_set = random_split(dataset, [train_num, eval_num])

    train_sampler = RandomSampler(train_set)
    train_dataloader = DataLoader(train_set, sampler=train_sampler,
                                  batch_size=8)

    # 设置优化器和学习率衰减 (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']

    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0},
        {'params': [p for n, p in model.named_parameters() if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5, eps=1e-8)
    scheduler = WarmupLinearSchedule(
        optimizer, warmup_steps=0, t_total=len(dataset))

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    best_step, best_acc = 0, 0.0

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epoch}")
    logger.info(f"  Train batch size  = {args.batch_size}")
    logger.info(f"  Total optimization steps = {len(dataset)}")

    model.zero_grad()
    train_iterator = trange(args.num_train_epoch, desc="Epoch")
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for _, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            inputs = {'input_ids':      batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2],
                      'labels':         batch[3]}
            outputs = model(**inputs)
            loss = outputs[0]

            # 绘图
            tr_loss += loss.item()
            global_step += 1
            if global_step % args.logging_step == 0:

                tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                tb_writer.add_scalar(
                    'loss',
                    (tr_loss - logging_loss)/args.logging_step,
                    global_step)

                logging_loss = tr_loss

            # 保存模型
            if global_step % args.save_step == 0:
                output_dir = os.path.join(
                    args.output_dir, 'checkpoint-{}'.format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model.save_pretrained(output_dir)
                # tokenizer.save_vocabulary(output_dir)
                torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                logger.info("Saving model checkpoint to %s", output_dir)

            # 验证
            if global_step % args.eval_step == 0 and args.train_with_eval:
                pred, labels = predict(model, eval_set, args, eval=True)
                acc = sum(pred == labels)*1.0/len(eval_set)
                tb_writer.add_scalar('acc', float(acc), global_step)
                if acc > best_acc:
                    best_acc = acc
                    best_step = global_step
                logger.info(f"Global step is {global_step},"
                            f"acc is {float(acc):.4f},")
                logger.info(f"Best step is {best_step},"
                            f"best acc is {float(best_acc):.4f}")

            # 反向传播、优化器优化等常规步骤
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()  # 更新学习速率
            model.zero_grad()
Exemple #23
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    train_loss, test_loss, test_f1, test_acc = [], [], [], []
    model_name = "{}-{}-{}".format(args.optimizer.lower(), args.task_name,
                                   args.model_type)
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    if args.optimizer.lower() == "adamw":
        print("We use AdamW optimizer!")
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adam":
        print("We use Adam optimizer!")
        optimizer = Adam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "sgd":
        print("We use SGD optimizer!")
        optimizer = SGD(optimizer_grouped_parameters,
                        lr=args.learning_rate,
                        momentum=0.9)
    elif "acclip" in args.optimizer.lower():
        print("We use ACClip optimizer!")
        optimizer = ACClip(optimizer_grouped_parameters, lr=args.learning_rate)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'labels': batch[3]
            }
            if args.model_type != 'distilbert':
                inputs['token_type_ids'] = batch[2] if args.model_type in [
                    'bert', 'xlnet'
                ] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.optimizer.lower(
                ) != "acclip":  # make sure we don't clip for acclip
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)  #
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results, eval_loss = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                            print('eval_{}'.format(key), value, global_step)
                        print('eval_loss', eval_loss, global_step)
                        test_f1.append(results['f1'] * 100)
                        test_acc.append(results['acc'] * 100)
                        test_loss.append(eval_loss)
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    iter_train_loss = (tr_loss -
                                       logging_loss) / args.logging_steps
                    tb_writer.add_scalar('loss', iter_train_loss, global_step)
                    print("eval_Training_Loss_{}".format(iter_train_loss),
                          global_step)
                    logging_loss = tr_loss
                    train_loss.append(iter_train_loss)

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    if not os.path.exists("./curves"):
        os.mkdir("./curves")
    with open(os.path.join('./curves', model_name), "wb") as f:
        pickle.dump(
            {
                'train_loss': train_loss,
                'test_loss': test_loss,
                'test_f1': test_f1,
                'test_acc': test_acc
            }, f)
    return global_step, tr_loss / global_step
Exemple #24
0
def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()
        epoch_writer = open('epoch_result.txt', 'w', encoding='utf8')

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [{
        "params": [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay":
        args.weight_decay
    }, {
        "params": [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        "weight_decay":
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    i = 1
    dev_f1 = 0.0
    best_performance = 1
    print(args.device)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids":
                batch[0],
                "attention_mask":
                batch[1],
                "token_type_ids":
                batch[2] if args.model_type in ["bert", "xlnet"] else None,
                # XLM and RoBERTa don"t use segment_ids
                "labels":
                batch[3]
            }
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in pytorch-transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results, _ = evaluate(args, model, tokenizer, labels,
                                              pad_token_label_id)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)

                    logging_loss = tr_loss
                '''
                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model, "module") else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)
                 '''
            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

        # epoch_test
        epoch_writer.write("epoch: {} -------------------\n".format(i))

        res, _ = evaluate(args,
                          model,
                          tokenizer,
                          labels,
                          pad_token_label_id,
                          mode="dev")
        for key in sorted(res.keys()):
            epoch_writer.write("dev: {} = {}\n".format(key, str(res[key])))

        if float(res['f1']) > dev_f1:
            dev_f1 = float(res['f1'])
            best_performance = i
            output_dir = os.path.join(args.output_dir,
                                      "checkpoint-{}".format(i))
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            model.save_pretrained(output_dir)
            tokenizer.save_pretrained(output_dir)
            torch.save(args, os.path.join(output_dir, "training_args.bin"))

        res, _ = evaluate(args,
                          model,
                          tokenizer,
                          labels,
                          pad_token_label_id,
                          mode="test")
        for key in sorted(res.keys()):
            epoch_writer.write("test: {} = {}\n".format(key, str(res[key])))

        epoch_writer.write("\n")

        i += 1

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return best_performance, global_step, tr_loss / global_step
def train(args, model, tokenizer):
    with open("../../data/train_texts", "r") as fr:
        texts = fr.readlines()
    with open("../../data/train_labels", "r") as fr:
        labels = fr.readlines()
    examples = load_dataset(texts, labels)
    label_list = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
    features = convert_examples_to_features(examples, tokenizer,
                                 label_list=label_list,
                                 output_mode="classification",
                                 max_length=32)  #这个过程中添加了special token
    
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
    train_dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
    train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
    t_total = len(train_dataloader) * args.num_train_epochs
    args.t_total = t_total
    args.warmup_steps = 0.1 * t_total
    
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)  #!
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=args.t_total)  #!!!
    
    # Train!
    print("***** Running training *****")
    print("  Num examples = %d", len(train_dataset))
    print("  Num Epochs = %d", args.num_train_epochs)
    print("  Total optimization steps = %d", args.t_total)
    tr_loss = 0.0
    global_step = 0
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
    best_acc = 0.0
    for epoch, _ in enumerate(train_iterator):
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            if step % 100 ==0:
                print("lr: {}".format(scheduler.get_lr()))
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {'input_ids':      batch[0],
                    'attention_mask': batch[1],
                    'token_type_ids': batch[2],
                    'labels':         batch[3]}

            outputs = model(**inputs)
            
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
            loss.backward()
            tr_loss += loss.item()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1
        result = evaluate(args, model, tokenizer)
        print("result: {}, best: {}".format(result, best_acc))
        if result >= best_acc:
            print("model saved")
            best_acc = result
            output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(epoch))
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
            model_to_save.save_pretrained(output_dir)
            torch.save(args, os.path.join(output_dir, 'training_args.bin'))
Exemple #26
0
def train(args, model, task_dataset, src_probs, loss_ignore_index, src_predictions=None):
    """ Train the model using multi-task training and knowledge distillation """
    tb_writer = SummaryWriter(log_dir=args.log_dir)

    if not src_probs is None:
        # check the size of the two relative datasets, expand task_dataset with src_probs_list
        assert len(task_dataset) == src_probs.size(0)

        # build hard labels if needed
        if not args.hard_label_usage == 'none':
            num_labels = src_probs.size(-1)
            confident_labels = None
            confident_labels_mask = None
            for src_prediction in src_predictions:
                tmp_src_labels = torch.argmax(src_prediction, dim=-1)
                if confident_labels is None:
                    confident_labels = tmp_src_labels
                    confident_labels_mask = torch.ones_like(confident_labels)
                else:
                    confident_labels_mask[confident_labels != tmp_src_labels] = 0
                    confident_labels[confident_labels != tmp_src_labels] = num_labels

            embedding_matrix = torch.cat([torch.eye(num_labels), torch.zeros(1, num_labels)], dim=0).to(args.device)
            hard_labels = torch.nn.functional.embedding(confident_labels, embedding_matrix).detach()
            confident_labels_mask = confident_labels_mask.detach()

            # s_label0 = torch.argmax(src_predictions[0], dim=-1)
            # s_label1 = torch.argmax(src_predictions[1], dim=-1)
            # s_label2 = torch.argmax(src_predictions[2], dim=-1)
            # for ki in range(src_probs.size(0)):
            #     for kj in range(src_probs.size(1)):
            #         if confident_labels_mask[ki, kj] == 1:
            #             if not (confident_labels[ki, kj] == s_label0[ki, kj] and confident_labels[ki, kj] == s_label1[
            #                 ki, kj] and confident_labels[ki, kj] == s_label2[ki, kj]):
            #                 raise ValueError("Error 0")
            #             if not (hard_labels[ki, kj, confident_labels[ki, kj]].cpu().item() == 1 and torch.sum(
            #                     hard_labels[ki, kj]).cpu().item() == 1):
            #                 raise ValueError("Error 2")
            #         else:
            #             if (confident_labels[ki, kj] == s_label0[ki, kj] and confident_labels[ki, kj] == s_label1[
            #                 ki, kj] and confident_labels[ki, kj] == s_label2[ki, kj]):
            #                 raise ValueError("Error 1")
            #             if not torch.sum(hard_labels[ki, kj]).cpu().item() == 0:
            #                 raise ValueError("Error 3")

            if args.hard_label_usage == 'replace':
                src_probs[confident_labels_mask == 1, :] = hard_labels[confident_labels_mask == 1, :]

                # for ki in range(src_probs.size(0)):
                #     for kj in range(src_probs.size(1)):
                #         if confident_labels_mask[ki, kj] == 1:
                #             if not torch.sum(torch.abs(src_probs[ki, kj] - hard_labels[ki, kj])).cpu().item() == 0:
                #                 raise ValueError("Error 0")
                #         else:
                #             if torch.sum(torch.abs(src_probs[ki, kj] - hard_labels[ki, kj])).cpu().item() == 0:
                #                 raise ValueError("Error 1")

        task_dataset.tensors += (src_probs,)
        if args.hard_label_usage == 'weight':
            task_dataset.tensors += (hard_labels,)
            task_dataset.tensors += (confident_labels_mask,)

    # parepare dataloader
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    sampler = RandomSampler(task_dataset)
    dataloader = DataLoader(task_dataset, sampler=sampler, batch_size=args.train_batch_size)

    # compute total update steps
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // len(dataloader) + 1
    else:
        t_total = len(dataloader) * args.num_train_epochs

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num task examples = %d", len(task_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  GPU IDs for training: %s", " ".join([str(id) for id in args.gpu_ids]))
    logger.info("  Total task optimization steps = %d", t_total)
    logger.info("  Total language identifier optimization steps = %d", t_total)

    # Prepare optimizer and schedule (linear warmup and decay)
    no_grad = ["embeddings"] + ["layer." + str(layer_i) + "." for layer_i in range(12) if layer_i < args.freeze_bottom_layer]
    opt_params = get_optimizer_grouped_parameters(args, model, no_grad=no_grad)
    optimizer = AdamW(opt_params, lr=args.learning_rate, eps=args.adam_epsilon, weight_decay=args.weight_decay)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=int(t_total * args.warmup_ratio), t_total=t_total)

    if args.n_gpu > 1:
        base_model = torch.nn.DataParallel(model, device_ids=args.gpu_ids)

    global_step = 0
    loss_accum, loss_KD_accum = 0.0, 0.0
    logging_loss, logging_loss_KD = 0.0, 0.0

    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
    for epoch_i in range(args.num_train_epochs):
        for step, batch in enumerate(dataloader):
            model.train()
            model.zero_grad()

            inputs = {"input_ids": batch[0].to(args.device),
                      "attention_mask": batch[1].to(args.device),
                      "token_type_ids": batch[2].to(args.device),
                      "labels": batch[3].to(args.device),
                      "src_probs": batch[4] if not src_probs is None else None,
                      "loss_ignore_index": loss_ignore_index,
                      "hard_labels": batch[5] if args.hard_label_usage == 'weight' else None,
                      "hard_labels_mask": batch[6] if args.hard_label_usage == 'weight' else None,
                      "hard_label_loss_weight": args.hard_label_weight} # activate the KD loss

            outputs = model(**inputs)

            if src_probs is None:
                loss = outputs[0]
                loss_KD = loss
            else:
                loss_KD, loss = outputs[:2]

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
                loss_KD = loss_KD.mean()

            # loss.backward()
            loss_KD.backward()

            loss_accum += loss.item()
            loss_KD_accum += loss_KD.item()

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            scheduler.step()  # Update learning rate schedule
            optimizer.step()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                tb_writer.add_scalar("loss", (loss_accum - logging_loss) / args.logging_steps, global_step)
                tb_writer.add_scalar("loss_KD", (loss_KD_accum - logging_loss_KD) / args.logging_steps, global_step)
                logger.info("Epoch: {}\t global_step: {}\t lr: {:.8}\tloss: {:.8f}\tloss_KD: {:.8f}".format(epoch_i,
                        global_step, scheduler.get_lr()[0], (loss_accum - logging_loss) / args.logging_steps,
                                                        (loss_KD_accum - logging_loss_KD) / args.logging_steps))

                logging_loss = loss_accum
                logging_loss_KD = loss_KD_accum

            if args.save_steps > 0 and global_step % args.save_steps == 0:
                # Save model checkpoint
                output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                # base model
                model_to_save = model.module if hasattr(model, "module") else model
                model_to_save.save_pretrained(output_dir)
                torch.save(args, os.path.join(output_dir, "training_args.bin"))

                logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            break

    tb_writer.close()

    return global_step, loss_KD_accum / global_step, loss_accum / global_step
Exemple #27
0
    def train(self, train_dataset, output_dir, show_running_loss=True):
        """
        Trains the model on train_dataset.

        Utility function to be used by the train_model() method. Not intended to be used directly.
        """

        tokenizer = self.tokenizer
        device = self.device
        model = self.model
        args = self.args

        tb_writer = SummaryWriter()
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args["train_batch_size"])

        t_total = len(train_dataloader) // args["gradient_accumulation_steps"] * args["num_train_epochs"]

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {"params": [p for n, p in model.named_parameters() if not any(
                nd in n for nd in no_decay)], "weight_decay": args["weight_decay"]},
            {"params": [p for n, p in model.named_parameters() if any(
                nd in n for nd in no_decay)], "weight_decay": 0.0}
        ]

        warmup_steps = math.ceil(t_total * args["warmup_ratio"])
        args["warmup_steps"] = warmup_steps if args["warmup_steps"] == 0 else args["warmup_steps"]

        optimizer = AdamW(optimizer_grouped_parameters, lr=args["learning_rate"], eps=args["adam_epsilon"])
        scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args["warmup_steps"], t_total=t_total)

        if args["fp16"]:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")

            model, optimizer = amp.initialize(model, optimizer, opt_level=args["fp16_opt_level"])

        if args["n_gpu"] > 1:
            model = torch.nn.DataParallel(model)

        global_step = 0
        tr_loss, logging_loss = 0.0, 0.0
        model.zero_grad()
        train_iterator = trange(int(args["num_train_epochs"]), desc="Epoch", disable=args["silent"])

        model.train()
        for _ in train_iterator:
            # epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args["silent"])
            for step, batch in enumerate(tqdm(train_dataloader, desc="Current iteration", disable=args["silent"])):
                batch = tuple(t.to(device) for t in batch)

                inputs = self._get_inputs_dict(batch)
                outputs = model(**inputs)
                # model outputs are always tuple in pytorch-transformers (see doc)
                loss = outputs[0]
                if show_running_loss:
                    if not args["silent"]:
                        print("\rRunning loss: %f" % loss, end="")

                if args["gradient_accumulation_steps"] > 1:
                    loss = loss / args["gradient_accumulation_steps"]

                if args["fp16"]:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args["max_grad_norm"])
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args["max_grad_norm"])

                tr_loss += loss.item()
                if (step + 1) % args["gradient_accumulation_steps"] == 0:
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    if args["logging_steps"] > 0 and global_step % args["logging_steps"] == 0:
                        # Log metrics
                        # Only evaluate when single GPU otherwise metrics may not average well
                        tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                        tb_writer.add_scalar("loss", (tr_loss - logging_loss)/args["logging_steps"], global_step)
                        logging_loss = tr_loss

                    if args["save_steps"] > 0 and global_step % args["save_steps"] == 0:
                        # Save model checkpoint
                        output_dir_current = os.path.join(output_dir, "checkpoint-{}".format(global_step))

                        if not os.path.exists(output_dir_current):
                            os.makedirs(output_dir_current)

                        # Take care of distributed/parallel training
                        model_to_save = model.module if hasattr(model, "module") else model
                        model_to_save.save_pretrained(output_dir_current)
                        self.tokenizer.save_pretrained(output_dir_current)


        return global_step, tr_loss / global_step
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    tb_writer = SummaryWriter()

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=False)
    set_seed(
        args)  # Added here for reproducibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=False)
        for step, batch in enumerate(epoch_iterator):
            inputs, labels = batch, batch
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs, labels=labels)
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.evaluate_during_training:
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = 'checkpoint'
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        '{}-{}'.format(checkpoint_prefix, global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    # Take care of distributed/parallel training
                    model_to_save = model.module if hasattr(
                        model, 'module') else model
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    tb_writer.close()
    return global_step, tr_loss / global_step
Exemple #29
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'labels': batch[3]
            }
            if args.model_type != 'distilbert':
                inputs['token_type_ids'] = batch[2] if args.model_type in [
                    'bert', 'xlnet'
                ] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1
                ) % args.gradient_accumulation_steps == 0 and not args.tpu:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.tpu:
                args.xla_model.optimizer_step(optimizer, barrier=True)
                model.zero_grad()
                global_step += 1

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.model_type in ["bert", "xlnet"]:

        aggression_tensor = torch.tensor(tokenizer.encode("aggression"),
                                         dtype=torch.long).to(args.device)
        attack_tensor = torch.tensor(tokenizer.encode("attack"),
                                     dtype=torch.long).to(args.device)
        toxicity_tensor = torch.tensor(tokenizer.encode("toxicity"),
                                       dtype=torch.long).to(args.device)
    elif args.model_type == "roberta":
        aggression_tensor = torch.tensor([0], dtype=torch.long).to(args.device)
        attack_tensor = torch.tensor([1], dtype=torch.long).to(args.device)
        toxicity_tensor = torch.tensor([2], dtype=torch.long).to(args.device)

    char_vocab = get_char_vocab()
    aggression_char_ids = char2ids("aggression", char_vocab)
    attack_char_ids = char2ids("attack", char_vocab)
    toxicity_char_ids = char2ids("toxicity", char_vocab)

    aggression_char_tenor = torch.tensor(aggression_char_ids,
                                         dtype=torch.long).to(args.device)
    attack_char_tenor = torch.tensor(attack_char_ids,
                                     dtype=torch.long).to(args.device)
    toxicity_char_tenor = torch.tensor(toxicity_char_ids,
                                       dtype=torch.long).to(args.device)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    best_f1 = 0.0
    best_aggression_score = {}
    best_attack_score = {}
    best_toxicity_score = {}
    epoch_num = 0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'aggression_labels': batch[3],
                'attack_labels': batch[4],
                'toxicity_labels': batch[5],
                'aggression_tensor': aggression_tensor,
                'attack_tensor': attack_tensor,
                'toxicity_tensor': toxicity_tensor,
                'aggression_char_tensor': aggression_char_tenor,
                'attack_char_tensor': attack_char_tenor,
                'toxicity_char_tensor': toxicity_char_tenor
            }
            if args.model_type != 'distilbert':
                inputs['token_type_ids'] = batch[2] if args.model_type in [
                    'bert', 'xlnet'
                ] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
            if args.all_task:
                aggression_logits, attack_logits, toxicity_logits, loss, _, _, _ = model(
                    **inputs)
            elif args.aggression_attack_task:
                aggression_logits, attack_logits, loss = model(**inputs)
            elif args.aggression_toxicity_task:
                aggression_logits, toxicity_logits, loss = model(**inputs)
            elif args.attack_toxicity_task:
                attack_logits, toxicity_logits, loss = model(**inputs)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, epoch_num)
                    logging_loss = tr_loss

        if args.all_task:

            aggression_results, attack_results, toxicity_results = evaluate(
                args, model, tokenizer)

            aggression_f1 = aggression_results['score']['f1']
            attack_f1 = attack_results['score']['f1']
            toxicity_f1 = toxicity_results['score']['f1']

            tb_writer.add_scalar('lr', scheduler.get_lr()[0], epoch_num)

            tb_writer.add_scalars(
                'accuracy', {
                    'aggression': aggression_results['score']['acc'] / 1,
                    'attack': attack_results['score']['acc'] / 1,
                    'toxicity': toxicity_results['score']['acc'] / 1
                }, epoch_num)

            tb_writer.add_scalars(
                'f1', {
                    'aggression': aggression_results['score']['f1'] / 1,
                    'attack': attack_results['score']['f1'] / 1,
                    'toxicity': toxicity_results['score']['f1'] / 1
                }, epoch_num)
            tb_writer.add_scalars(
                'precision', {
                    'aggression': aggression_results['score']['precision'] / 1,
                    'attack': attack_results['score']['precision'] / 1,
                    'toxicity': toxicity_results['score']['precision'] / 1
                }, epoch_num)
            tb_writer.add_scalars(
                'recall', {
                    'aggression': aggression_results['score']['recall'] / 1,
                    'attack': attack_results['score']['recall'] / 1,
                    'toxicity': toxicity_results['score']['recall'] / 1
                }, epoch_num)
            tb_writer.add_scalars(
                'auc', {
                    'aggression': aggression_results['score']['auc'] / 1,
                    'attack': attack_results['score']['auc'] / 1,
                    'toxicity': toxicity_results['score']['auc'] / 1
                }, epoch_num)

            if (aggression_f1 + attack_f1 + toxicity_f1) / 3.0 > best_f1:
                best_f1 = (aggression_f1 + attack_f1 + toxicity_f1) / 3.0
                best_aggression_score.update(aggression_results)
                best_attack_score.update(attack_results)
                best_toxicity_score.update(toxicity_results)

                # Save model checkpoint
                output_dir = os.path.join(args.output_dir,
                                          'checkpoint-{}'.format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                output_model_file = os.path.join(output_dir, "model.pt")
                torch.save(model.state_dict(), output_model_file)
                torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                logger.info("Saving model checkpoint to %s", output_dir)

                aggression_output_eval_file = os.path.join(
                    args.output_dir, "aggression_eval_results.txt")
                attack_output_eval_file = os.path.join(
                    args.output_dir, "attack_eval_results.txt")
                toxicity_output_eval_file = os.path.join(
                    args.output_dir, "toxicity_eval_results.txt")
                with open(aggression_output_eval_file, "a") as writer:

                    for key in sorted(aggression_results.keys()):

                        writer.write("checkpoint%s-%s = %s\n" %
                                     (str(global_step), key,
                                      str(aggression_results[key])))

                with open(attack_output_eval_file, "a") as writer:

                    for key in sorted(attack_results.keys()):

                        writer.write(
                            "checkpoint%s-%s = %s\n" %
                            (str(global_step), key, str(attack_results[key])))

                with open(toxicity_output_eval_file, "a") as writer:

                    for key in sorted(toxicity_results.keys()):

                        writer.write("checkpoint%s-%s = %s\n" %
                                     (str(global_step), key,
                                      str(toxicity_results[key])))

            logger.info("************* best  results ***************")
            for key in sorted(best_aggression_score.keys()):
                logger.info("aggression-%s = %s", key,
                            str(best_aggression_score[key]))

            for key in sorted(best_attack_score.keys()):
                logger.info("attack-%s = %s", key, str(best_attack_score[key]))

            for key in sorted(best_toxicity_score.keys()):
                logger.info("toxicity-%s = %s", key,
                            str(best_toxicity_score[key]))

        elif args.aggression_attack_task:
            aggression_results, attack_results = evaluate(
                args, model, tokenizer)

            aggression_f1 = aggression_results['score']['f1']
            attack_f1 = attack_results['score']['f1']

            tb_writer.add_scalar('lr', scheduler.get_lr()[0], epoch_num)
            tb_writer.add_scalar('loss', tr_loss / args.logging_steps,
                                 epoch_num)

            if (aggression_f1 + attack_f1) / 2.0 > best_f1:
                best_f1 = (aggression_f1 + attack_f1) / 2.0
                best_aggression_score.update(aggression_results)
                best_attack_score.update(attack_results)

                # Save model checkpoint
                output_dir = os.path.join(args.output_dir,
                                          'checkpoint-{}'.format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                output_model_file = os.path.join(output_dir, "model.pt")
                torch.save(model.state_dict(), output_model_file)
                torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                logger.info("Saving model checkpoint to %s", output_dir)

                aggression_output_eval_file = os.path.join(
                    args.output_dir, "aggression_eval_results.txt")
                attack_output_eval_file = os.path.join(
                    args.output_dir, "attack_eval_results.txt")
                with open(aggression_output_eval_file, "a") as writer:

                    for key in sorted(aggression_results.keys()):
                        writer.write("checkpoint%s-%s = %s\n" %
                                     (str(global_step), key,
                                      str(aggression_results[key])))

                with open(attack_output_eval_file, "a") as writer:

                    for key in sorted(attack_results.keys()):
                        writer.write(
                            "checkpoint%s-%s = %s\n" %
                            (str(global_step), key, str(attack_results[key])))

            logger.info("************* best  results ***************")
            for key in sorted(best_aggression_score.keys()):
                logger.info("aggression-%s = %s", key,
                            str(best_aggression_score[key]))

            for key in sorted(best_attack_score.keys()):
                logger.info("attack-%s = %s", key, str(best_attack_score[key]))

        elif args.aggression_toxicity_task:
            aggression_results, toxicity_results = evaluate(
                args, model, tokenizer)

            aggression_f1 = aggression_results['score']['f1']

            toxicity_f1 = toxicity_results['score']['f1']

            tb_writer.add_scalar('lr', scheduler.get_lr()[0], epoch_num)
            tb_writer.add_scalar('loss', tr_loss / args.logging_steps,
                                 epoch_num)
            if (aggression_f1 + toxicity_f1) / 2.0 > best_f1:
                best_f1 = (aggression_f1 + toxicity_f1) / 2.0
                best_aggression_score.update(aggression_results)

                best_toxicity_score.update(toxicity_results)

                # Save model checkpoint
                output_dir = os.path.join(args.output_dir,
                                          'checkpoint-{}'.format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                output_model_file = os.path.join(output_dir, "model.pt")
                torch.save(model.state_dict(), output_model_file)
                torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                logger.info("Saving model checkpoint to %s", output_dir)

                aggression_output_eval_file = os.path.join(
                    args.output_dir, "aggression_eval_results.txt")

                toxicity_output_eval_file = os.path.join(
                    args.output_dir, "toxicity_eval_results.txt")
                with open(aggression_output_eval_file, "a") as writer:

                    for key in sorted(aggression_results.keys()):
                        writer.write("checkpoint%s-%s = %s\n" %
                                     (str(global_step), key,
                                      str(aggression_results[key])))

                with open(toxicity_output_eval_file, "a") as writer:

                    for key in sorted(toxicity_results.keys()):
                        writer.write("checkpoint%s-%s = %s\n" %
                                     (str(global_step), key,
                                      str(toxicity_results[key])))

            logger.info("************* best  results ***************")
            for key in sorted(best_aggression_score.keys()):
                logger.info("aggression-%s = %s", key,
                            str(best_aggression_score[key]))

            for key in sorted(best_toxicity_score.keys()):
                logger.info("toxicity-%s = %s", key,
                            str(best_toxicity_score[key]))

        elif args.attack_toxicity_task:
            attack_results, toxicity_results = evaluate(args, model, tokenizer)

            attack_f1 = attack_results['score']['f1']
            toxicity_f1 = toxicity_results['score']['f1']

            tb_writer.add_scalar('lr', scheduler.get_lr()[0], epoch_num)
            tb_writer.add_scalar('loss', tr_loss / args.logging_steps,
                                 epoch_num)

            if (attack_f1 + toxicity_f1) / 2.0 > best_f1:
                best_f1 = (attack_f1 + toxicity_f1) / 2.0

                best_attack_score.update(attack_results)
                best_toxicity_score.update(toxicity_results)

                # Save model checkpoint
                output_dir = os.path.join(args.output_dir,
                                          'checkpoint-{}'.format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                output_model_file = os.path.join(output_dir, "model.pt")
                torch.save(model.state_dict(), output_model_file)
                torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                logger.info("Saving model checkpoint to %s", output_dir)

                attack_output_eval_file = os.path.join(
                    args.output_dir, "attack_eval_results.txt")
                toxicity_output_eval_file = os.path.join(
                    args.output_dir, "toxicity_eval_results.txt")

                with open(attack_output_eval_file, "a") as writer:

                    for key in sorted(attack_results.keys()):
                        writer.write(
                            "checkpoint%s-%s = %s\n" %
                            (str(global_step), key, str(attack_results[key])))

                with open(toxicity_output_eval_file, "a") as writer:

                    for key in sorted(toxicity_results.keys()):
                        writer.write("checkpoint%s-%s = %s\n" %
                                     (str(global_step), key,
                                      str(toxicity_results[key])))

            logger.info("************* best  results ***************")
            for key in sorted(best_aggression_score.keys()):
                logger.info("aggression-%s = %s", key,
                            str(best_aggression_score[key]))

            for key in sorted(best_attack_score.keys()):
                logger.info("attack-%s = %s", key, str(best_attack_score[key]))

            for key in sorted(best_toxicity_score.keys()):
                logger.info("toxicity-%s = %s", key,
                            str(best_toxicity_score[key]))

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        epoch_num += 1
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step