def evaluate(args,
             model,
             tokenizer,
             prefix="",
             output_predictions=False,
             sample_percentage=1.0):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (
        args.task_name, )
    eval_outputs_dirs = (args.output_dir, args.output_dir +
                         '-MM') if args.task_name == "mnli" else (
                             args.output_dir, )

    results = {}
    for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
        eval_dataset = load_and_cache_examples(args,
                                               eval_task,
                                               tokenizer,
                                               used_set=args.evaluate_on)

        if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir)

        args.eval_batch_size = args.per_gpu_eval_batch_size * max(
            1, args.n_gpu)
        # Note that DistributedSampler samples randomly
        eval_sampler = SequentialSampler(
            eval_dataset) if args.local_rank == -1 else DistributedSampler(
                eval_dataset)
        eval_dataloader = DataLoader(eval_dataset,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        # Eval!
        logger.info("***** Running evaluation {} *****".format(prefix))
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  sample_percentage = %f", sample_percentage)
        logger.info("  Batch size = %d", args.eval_batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None

        def get_cls_rep(m, i, output):
            if (len(representations) < 1000):  # avoiding memory issues
                representations.append(
                    output[:, 0, :].cpu())  # get only CLS token rep

        if args.model_type == 'bert-slice-aware' or args.model_type == 'bert-slice-aware-random-slices':
            sfs = slicing_functions[args.task_name]
            processor = slicing_processors[args.task_name]()
            if args.evaluate_on == 'dev' or args.task_name == 'antique':
                examples_dev = processor.get_dev_examples(args.data_dir)
            else:
                examples_dev = processor.get_test_examples(args.data_dir)

            snorkel_sf_applier = SFApplier(sfs)

            if os.path.isfile(args.data_dir + "/snorkel_slices_" +
                              args.evaluate_on + ".pickle"):
                with open(
                        args.data_dir + "/snorkel_slices_" + args.evaluate_on +
                        ".pickle", "rb") as f:
                    logger.info("loaded cached pickle for sliced " +
                                args.evaluate_on)
                    snorkel_slices_dev = pickle.load(f)
            else:
                snorkel_slices_dev = snorkel_sf_applier.apply(examples_dev)
                with open(
                        args.data_dir + "/snorkel_slices_" + args.evaluate_on +
                        ".pickle", "wb") as f:
                    pickle.dump(snorkel_slices_dev, f)
                    logger.info("dumped pickle with sliced " +
                                args.evaluate_on)

            snorkel_slices_with_ns = []
            for i, example in enumerate(examples_dev):
                for _ in range(len(example.documents)):
                    snorkel_slices_with_ns.append(snorkel_slices_dev[i])

            snorkel_slices_with_ns_np = np.array(
                snorkel_slices_with_ns, dtype=snorkel_slices_dev.dtype)

            X_dict = {
                'input_ids': eval_dataset.tensors[0],
                'attention_mask': eval_dataset.tensors[1],
                'token_type_ids': eval_dataset.tensors[2]
            }
            Y_dict = {'labels': eval_dataset.tensors[3]}

            ds = DictDataset(name='labels',
                             split=args.evaluate_on,
                             X_dict=X_dict,
                             Y_dict=Y_dict)

            dev_dl_slice = model.make_slice_dataloader(
                ds,
                snorkel_slices_with_ns_np,
                shuffle=False,
                batch_size=args.eval_batch_size)

            if not args.debug_mode:
                slice_membership_scores = model.score([dev_dl_slice])

            if output_predictions:
                model.base_task.module_pool['base_architecture'].\
                    module.module.bert.encoder.layer[11].output.dense.\
                    register_forward_hook(get_cls_rep)

            pred_dict = model.predict(dev_dl_slice,
                                      debug_mode=args.debug_mode,
                                      return_preds=True)
            preds = pred_dict['probs']['labels']
            out_label_ids = pred_dict['golds']['labels']

        else:
            if output_predictions:
                model.bert.encoder.layer[11].output.dense. \
                    register_forward_hook(get_cls_rep)
            for batch in eval_dataloader:
                model.eval()
                batch = tuple(t.to(args.device) for t in batch)

                with torch.no_grad():
                    inputs = {
                        'input_ids': batch[0],
                        'attention_mask': batch[1],
                        'labels': batch[3]
                    }
                    if args.model_type != 'distilbert':
                        inputs['token_type_ids'] = batch[
                            2] if args.model_type in ['bert', 'xlnet'
                                                      ] else None
                        # XLM, DistilBERT and RoBERTa don't use segment_ids
                    if args.model_type == 'bert-mtl':
                        inputs["clf_head"] = 0

                    outputs = model(**inputs)
                    tmp_eval_loss, logits = outputs[:2]

                    eval_loss += tmp_eval_loss.mean().item()
                nb_eval_steps += 1
                if preds is None:
                    preds = logits.detach().cpu().numpy()
                    out_label_ids = inputs['labels'].detach().cpu().numpy()
                else:
                    preds = np.append(preds,
                                      logits.detach().cpu().numpy(),
                                      axis=0)
                    out_label_ids = np.append(
                        out_label_ids,
                        inputs['labels'].detach().cpu().numpy(),
                        axis=0)

                if nb_eval_steps > int(sample_percentage * len(eval_dataset)):
                    break

                if args.debug_mode:
                    break

        if args.output_mode == "ranking":
            preds = softmax(preds, axis=1)
            preds = preds[:, 1]
        elif args.output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif args.output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(eval_task, preds, out_label_ids)
        results.update(result)

        if output_predictions:
            aps = compute_aps(preds, out_label_ids)
            output_aps_file = os.path.join(eval_output_dir, prefix,
                                           args.run_id + "/eval_aps.txt")
            with open(output_aps_file, "w") as f:
                for ap in aps:
                    f.write(str(ap) + "\n")
            output_preds_file = os.path.join(
                eval_output_dir, prefix, args.run_id + "/eval_predictions.txt")
            with open(output_preds_file, "w") as writer:
                for pred in preds:
                    writer.write(str(pred) + "\n")
            concat_rep = reduce(
                lambda left, right: torch.cat((left, right), 0),
                representations)
            output_rep_file = os.path.join(
                eval_output_dir, prefix,
                args.run_id + "/eval_representations.pt")
            torch.save(concat_rep, output_rep_file)
            if args.model_type == 'bert-slice-aware' and not args.debug_mode:
                output_slice_membership_f = os.path.join(
                    eval_output_dir, prefix,
                    args.run_id + "/slices_membership_scores.pickle")
                with open(output_slice_membership_f, "wb") as f:
                    pickle.dump(slice_membership_scores, f)

    return results
Example #2
0
def create_dict_dataloader(X, Y, split, **kwargs):
    """Create a DictDataLoader for bag-of-words features."""
    ds = DictDataset.from_tensors(torch.FloatTensor(X), torch.LongTensor(Y), split)
    return DictDataLoader(ds, **kwargs)
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = range(int(args.num_train_epochs))
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)

    if args.model_type == 'bert-slice-aware' or args.model_type == 'bert-slice-aware-random-slices':
        if args.model_type == 'bert-slice-aware':
            sfs = slicing_functions[args.task_name]
        elif args.model_type == 'bert-slice-aware-random-slices':
            if args.number_random_slices is None or args.size_random_slices is None:
                sfs = random_slicing_functions[args.task_name]
            else:
                sfs = args.sfs
        processor = slicing_processors[args.task_name]()
        examples_train = processor.get_train_examples(args.data_dir)

        snorkel_sf_applier = SFApplier(sfs)

        if os.path.isfile(args.data_dir + "/snorkel_slices_train.pickle"):
            with open(args.data_dir + "/snorkel_slices_train.pickle",
                      "rb") as f:
                logger.info("loaded cached pickle for sliced train.")
                snorkel_slices_train = pickle.load(f)
        else:
            snorkel_slices_train = snorkel_sf_applier.apply(examples_train)
            with open(args.data_dir + "/snorkel_slices_train.pickle",
                      "wb") as f:
                pickle.dump(snorkel_slices_train, f)
                logger.info("dumped pickle with sliced train.")

        snorkel_slices_with_ns = []
        for i, example in enumerate(examples_train):
            for _ in range(len(example.documents)):
                snorkel_slices_with_ns.append(snorkel_slices_train[i])

        snorkel_slices_with_ns_np = np.array(snorkel_slices_with_ns,
                                             dtype=snorkel_slices_train.dtype)

        slice_model = SliceAwareClassifier(
            task_name='labels',
            input_data_key='input_ids',
            base_architecture=model,
            head_dim=768,  #* args.max_seq_length,
            slice_names=[sf.name for sf in sfs])

        X_dict = {
            'input_ids': train_dataset.tensors[0],
            'attention_mask': train_dataset.tensors[1],
            'token_type_ids': train_dataset.tensors[2]
        }
        Y_dict = {'labels': train_dataset.tensors[3]}

        ds = DictDataset(name='labels',
                         split='train',
                         X_dict=X_dict,
                         Y_dict=Y_dict)
        train_dl_slice = slice_model.make_slice_dataloader(
            ds,
            snorkel_slices_with_ns_np,
            shuffle=True,
            batch_size=args.train_batch_size)

        trainer = Trainer(lr=args.learning_rate,
                          n_epochs=int(args.num_train_epochs),
                          l2=args.weight_decay,
                          optimizer="adamax",
                          max_steps=args.max_steps,
                          seed=args.seed)

        trainer.fit(slice_model, [train_dl_slice])
        model = slice_model
    else:
        for _ in train_iterator:
            epoch_iterator = train_dataloader
            for step, batch in enumerate(epoch_iterator):
                model.train()
                batch = tuple(t.to(args.device) for t in batch)
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'labels': batch[3]
                }
                if args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2] if args.model_type in [
                        'bert', 'xlnet'
                    ] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
                if args.model_type == 'bert-mtl':
                    inputs["clf_head"] = 0
                outputs = model(**inputs)
                loss = outputs[
                    0]  # model outputs are always tuple in transformers (see doc)

                if args.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                tr_loss += loss.item()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    if args.local_rank in [
                            -1, 0
                    ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        # Log metrics
                        if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                            results = evaluate(args,
                                               model,
                                               tokenizer,
                                               sample_percentage=0.01)
                            for key, value in results.items():
                                tb_writer.add_scalar('eval_{}'.format(key),
                                                     value, global_step)
                                ex.log_scalar('eval_{}'.format(key), value,
                                              global_step)
                                logger.info('eval_{}'.format(key) + ": " +
                                            str(value) + ", step: " +
                                            str(global_step))
                        tb_writer.add_scalar('lr',
                                             scheduler.get_lr()[0],
                                             global_step)
                        tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                             args.logging_steps, global_step)
                        ex.log_scalar("lr", scheduler.get_lr()[0], global_step)
                        ex.log_scalar("loss", (tr_loss - logging_loss) /
                                      args.logging_steps, global_step)
                        logging_loss = tr_loss

                    if args.local_rank in [
                            -1, 0
                    ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                        # Save model checkpoint
                        output_dir = os.path.join(
                            args.output_dir,
                            'checkpoint-{}'.format(global_step))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        model_to_save = model.module if hasattr(
                            model, 'module'
                        ) else model  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(output_dir)
                        torch.save(
                            args, os.path.join(output_dir,
                                               'training_args.bin'))
                        logger.info("Saving model checkpoint to %s",
                                    output_dir)

                if args.max_steps > 0 and global_step > args.max_steps:
                    break
                    # epoch_iterator.close()
            if args.max_steps > 0 and global_step > args.max_steps:
                break
                # train_iterator.close()

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return model