Beispiel #1
0
    def test_collator(self):
        template = '[T] [T] {arbitrary} [T] {fields} [P]'
        tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
        config = AutoConfig.from_pretrained('bert-base-cased')
        utils.add_task_specific_tokens(tokenizer)
        templatizer = utils.TriggerTemplatizer(
            template,
            config,
            tokenizer,
            add_special_tokens=False
        )
        collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)

        instances = [
            {'arbitrary': 'a', 'fields': 'the', 'label': 'hot'},
            {'arbitrary': 'a a', 'fields': 'the the', 'label': 'cold'}
        ]
        templatized_instances = [templatizer(x) for x in instances]
        loader = DataLoader(
            templatized_instances,
            batch_size=2,
            shuffle=False,
            collate_fn=collator
        )
        model_inputs, labels = next(iter(loader))

        # Check results match our expectations
        expected_labels = torch.tensor([
            tokenizer.encode('hot', add_special_tokens=False, add_prefix_space=True),
            tokenizer.encode('cold', add_special_tokens=False, add_prefix_space=True),
        ])
        assert torch.equal(expected_labels, labels)

        expected_trigger_mask = torch.tensor([
            [True, True, False, True, False, False, False, False],
            [True, True, False, False, True, False, False, False],
        ])
        assert torch.equal(expected_trigger_mask, model_inputs['trigger_mask'])

        expected_predict_mask = torch.tensor([
            [False, False, False, False, False, True, False, False],
            [False, False, False, False, False, False, False, True],
        ])
        assert torch.equal(expected_predict_mask, model_inputs['predict_mask'])
def run_model(args):

    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    logger.info('Loading model, tokenizer, etc.')
    config, model, tokenizer = load_pretrained(args.model_name)
    model.to(device)
    embeddings = get_embeddings(model, config)
    embedding_gradient = GradientStorage(embeddings)
    predictor = PredictWrapper(model)

    if args.label_map is not None:
        label_map = json.loads(args.label_map)
        logger.info(f"Label map: {label_map}")
    else:
        label_map = None

    templatizer = utils.TriggerTemplatizer(
        args.template,
        config,
        tokenizer,
        label_map=label_map,
        label_field=args.label_field,
        tokenize_labels=args.tokenize_labels,
        add_special_tokens=False,
        use_ctx=args.use_ctx
    )

    # Obtain the initial trigger tokens and label mapping
    if args.initial_trigger:
        trigger_ids = tokenizer.convert_tokens_to_ids(args.initial_trigger)
        logger.debug(f'Initial trigger: {args.initial_trigger}')
        logger.debug(f'Trigger ids: {trigger_ids}')
        assert len(trigger_ids) == templatizer.num_trigger_tokens
    else:
        trigger_ids = [tokenizer.mask_token_id] * templatizer.num_trigger_tokens
    trigger_ids = torch.tensor(trigger_ids, device=device).unsqueeze(0)
    best_trigger_ids = trigger_ids.clone()

    # NOTE: Accuracy can only be computed if a fixed pool of labels is given, which currently
    # requires the label map to be specified. Since producing a label map may be cumbersome (e.g.,
    # for link prediction tasks), we just use (negative) loss as the evaluation metric in these cases.
    if label_map:
        evaluation_fn = AccuracyFn(tokenizer, label_map, device)
    else:
        evaluation_fn = lambda x, y: -get_loss(x, y)

    logger.info('Loading datasets')
    collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)

    if args.perturbed:
        train_dataset = utils.load_augmented_trigger_dataset(args.train, templatizer, limit=args.limit)
    else:
        train_dataset = utils.load_trigger_dataset(args.train, templatizer, use_ctx=args.use_ctx, limit=args.limit)
    train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)

    if args.perturbed:
        dev_dataset = utils.load_augmented_trigger_dataset(args.dev, templatizer)
    else:
        dev_dataset = utils.load_trigger_dataset(args.dev, templatizer, use_ctx=args.use_ctx)
    dev_loader = DataLoader(dev_dataset, batch_size=args.eval_size, shuffle=False, collate_fn=collator)

    # To "filter" unwanted trigger tokens, we subtract a huge number from their logits.
    filter = torch.zeros(tokenizer.vocab_size, dtype=torch.float32, device=device)
    if args.filter:
        logger.info('Filtering label tokens.')
        if label_map:
            for label_tokens in label_map.values():
                label_ids = utils.encode_label(tokenizer, label_tokens).unsqueeze(0)
                filter[label_ids] = -1e32
        else:
            for _, label_ids in train_dataset:
                filter[label_ids] = -1e32
        logger.info('Filtering special tokens and capitalized words.')
        for word, idx in tokenizer.get_vocab().items():
            if len(word) == 1 or idx >= tokenizer.vocab_size:
                continue
            # Filter special tokens.
            if idx in tokenizer.all_special_ids:
                logger.debug('Filtered: %s', word)
                filter[idx] = -1e32
            # Filter capitalized words (lazy way to remove proper nouns).
            if isupper(idx, tokenizer):
                logger.debug('Filtered: %s', word)
                filter[idx] = -1e32

    logger.info('Evaluating')
    numerator = 0
    denominator = 0
    for model_inputs, labels in tqdm(dev_loader):
        model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
        labels = labels.to(device)
        with torch.no_grad():
            predict_logits = predictor(model_inputs, trigger_ids)
        numerator += evaluation_fn(predict_logits, labels).sum().item()
        denominator += labels.size(0)
    dev_metric = numerator / (denominator + 1e-13)
    logger.info(f'Dev metric: {dev_metric}')

    best_dev_metric = -float('inf')
    # Measure elapsed time of trigger search
    start = time.time()

    for i in range(args.iters):

        logger.info(f'Iteration: {i}')

        logger.info('Accumulating Gradient')
        model.zero_grad()

        pbar = tqdm(range(args.accumulation_steps))
        train_iter = iter(train_loader)
        averaged_grad = None

        # Accumulate
        for step in pbar:

            # Shuttle inputs to GPU
            try:
                model_inputs, labels = next(train_iter)
            except:
                logger.warning(
                    'Insufficient data for number of accumulation steps. '
                    'Effective batch size will be smaller than specified.'
                )
                break
            model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
            labels = labels.to(device)
            predict_logits = predictor(model_inputs, trigger_ids)
            loss = get_loss(predict_logits, labels).mean()
            loss.backward()

            grad = embedding_gradient.get()
            bsz, _, emb_dim = grad.size()
            selection_mask = model_inputs['trigger_mask'].unsqueeze(-1)
            grad = torch.masked_select(grad, selection_mask)
            grad = grad.view(bsz, templatizer.num_trigger_tokens, emb_dim)

            if averaged_grad is None:
                averaged_grad = grad.sum(dim=0) / args.accumulation_steps
            else:
                averaged_grad += grad.sum(dim=0) / args.accumulation_steps

        logger.info('Evaluating Candidates')
        pbar = tqdm(range(args.accumulation_steps))
        train_iter = iter(train_loader)

        token_to_flip = random.randrange(templatizer.num_trigger_tokens)
        candidates = hotflip_attack(averaged_grad[token_to_flip],
                                    embeddings.weight,
                                    increase_loss=False,
                                    num_candidates=args.num_cand,
                                    filter=filter)

        current_score = 0
        candidate_scores = torch.zeros(args.num_cand, device=device)
        denom = 0
        for step in pbar:

            try:
                model_inputs, labels = next(train_iter)
            except:
                logger.warning(
                    'Insufficient data for number of accumulation steps. '
                    'Effective batch size will be smaller than specified.'
                )
                break
            model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
            labels = labels.to(device)
            with torch.no_grad():
                predict_logits = predictor(model_inputs, trigger_ids)
                eval_metric = evaluation_fn(predict_logits, labels)

            # Update current score
            current_score += eval_metric.sum()
            denom += labels.size(0)

            # NOTE: Instead of iterating over tokens to flip we randomly change just one each
            # time so the gradients don't get stale.
            for i, candidate in enumerate(candidates):

                # if candidate.item() in filter_candidates:
                #     candidate_scores[i] = -1e32
                #     continue

                temp_trigger = trigger_ids.clone()
                temp_trigger[:, token_to_flip] = candidate
                with torch.no_grad():
                    predict_logits = predictor(model_inputs, temp_trigger)
                    eval_metric = evaluation_fn(predict_logits, labels)

                candidate_scores[i] += eval_metric.sum()

        # TODO: Something cleaner. LAMA templates can't have mask tokens, so if
        # there are still mask tokens in the trigger then set the current score
        # to -inf.
        if args.print_lama:
            if trigger_ids.eq(tokenizer.mask_token_id).any():
                current_score = float('-inf')

        if (candidate_scores > current_score).any():
            logger.info('Better trigger detected.')
            best_candidate_score = candidate_scores.max()
            best_candidate_idx = candidate_scores.argmax()
            trigger_ids[:, token_to_flip] = candidates[best_candidate_idx]
            logger.info(f'Train metric: {best_candidate_score / (denom + 1e-13): 0.4f}')
        else:
            logger.info('No improvement detected. Skipping evaluation.')
            continue

        logger.info('Evaluating')
        numerator = 0
        denominator = 0
        for model_inputs, labels in tqdm(dev_loader):
            model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
            labels = labels.to(device)
            with torch.no_grad():
                predict_logits = predictor(model_inputs, trigger_ids)
            numerator += evaluation_fn(predict_logits, labels).sum().item()
            denominator += labels.size(0)
        dev_metric = numerator / (denominator + 1e-13)

        logger.info(f'Trigger tokens: {tokenizer.convert_ids_to_tokens(trigger_ids.squeeze(0))}')
        logger.info(f'Dev metric: {dev_metric}')

        # TODO: Something cleaner. LAMA templates can't have mask tokens, so if
        # there are still mask tokens in the trigger then set the current score
        # to -inf.
        if args.print_lama:
            if best_trigger_ids.eq(tokenizer.mask_token_id).any():
                best_dev_metric = float('-inf')

        if dev_metric > best_dev_metric:
            logger.info('Best performance so far')
            best_trigger_ids = trigger_ids.clone()
            best_dev_metric = dev_metric

    best_trigger_tokens = tokenizer.convert_ids_to_tokens(best_trigger_ids.squeeze(0))
    logger.info(f'Best tokens: {best_trigger_tokens}')
    logger.info(f'Best dev metric: {best_dev_metric}')
    if args.print_lama:
        # Templatize with [X] and [Y]
        if args.use_ctx:
            model_inputs, label_ids = templatizer({
                'sub_label': '[X]',
                'obj_label': tokenizer.lama_y,
                'context': ''
            })
        else:
            model_inputs, label_ids = templatizer({
                'sub_label': '[X]',
                'obj_label': tokenizer.lama_y,
            })
        lama_template = model_inputs['input_ids']
        # Instantiate trigger tokens
        lama_template.masked_scatter_(
            mask=model_inputs['trigger_mask'],
            source=best_trigger_ids.cpu())
        # Instantiate label token
        lama_template.masked_scatter_(
            mask=model_inputs['predict_mask'],
            source=label_ids)
        # Print LAMA JSON template
        relation = args.train.parent.stem

        # The following block of code is a bit hacky but whatever, it gets the job done
        if args.use_ctx:
            template = tokenizer.decode(lama_template.squeeze(0)[1:-1]).replace('[SEP] ', '').replace('</s> ', '').replace('[ X ]', '[X]')
        else:
            template = tokenizer.decode(lama_template.squeeze(0)[1:-1]).replace('[ X ]', '[X]')

        out = {
            'relation': args.train.parent.stem,
            'template': template
        }
        print(json.dumps(out))
Beispiel #3
0
def run_autoprompt(args, dataset, cache_test):
    if cache_test.is_test:
        raise CacheMiss()

    ct.set_seed(args.seed)
    global_data = GlobalData.from_pretrained(args.model_name)

    templatizer = utils.TriggerTemplatizer(
        args.template,
        global_data.config,
        global_data.tokenizer,
        label_field=args.label_field,
        label_map=dataset.label_map,
        tokenize_labels=args.tokenize_labels,
        add_special_tokens=True,
    )
    evaluation_fn = ct.AccuracyFn(global_data.tokenizer,
                                  dataset.label_map,
                                  global_data.device,
                                  tokenize_labels=args.tokenize_labels)

    # Do not allow for initial trigger specification.
    trigger_ids = [global_data.tokenizer.mask_token_id
                   ] * templatizer.num_trigger_tokens
    trigger_ids = torch.tensor(trigger_ids,
                               device=global_data.device).unsqueeze(0)
    best_trigger_ids = trigger_ids.clone()

    # Load datasets
    logger.info('Loading datasets')
    collator = utils.Collator(pad_token_id=global_data.tokenizer.pad_token_id)
    try:
        train_dataset = load_trigger_dataset(dataset.train, templatizer)
    except KeyError as e:
        raise RuntimeError(
            'A field in your template is not present in the uploaded dataset. '
            f'Check that there is a column with the name: {e}')

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.bsz,
                                               shuffle=True,
                                               collate_fn=collator)

    progress = st.progress(0.0)
    trigger_placeholder = st.empty()
    best_dev_metric = -float('inf')
    for i in range(args.iters):
        logger.info(f'Iteration: {i}')
        progress.progress(float(i) / args.iters)

        current_trigger = ','.join(
            global_data.tokenizer.convert_ids_to_tokens(
                best_trigger_ids.squeeze(0)))
        trigger_placeholder.markdown(f'**Current trigger**: {current_trigger}')

        global_data.model.zero_grad()
        train_iter = iter(train_loader)
        averaged_grad = None

        # Compute gradient of loss
        for step in range(args.accumulation_steps):
            try:
                model_inputs, labels = next(train_iter)
            except:
                logger.warning(
                    'Insufficient data for number of accumulation steps. '
                    'Effective batch size will be smaller than specified.')
                break
            model_inputs = {
                k: v.to(global_data.device)
                for k, v in model_inputs.items()
            }
            labels = labels.to(global_data.device)
            predict_logits = global_data.predictor(model_inputs, trigger_ids)
            loss = ct.get_loss(predict_logits, labels).mean()
            loss.backward()

            grad = global_data.embedding_gradient.get()
            bsz, _, emb_dim = grad.size()
            selection_mask = model_inputs['trigger_mask'].unsqueeze(-1)
            grad = torch.masked_select(grad, selection_mask)
            grad = grad.view(bsz, templatizer.num_trigger_tokens, emb_dim)

            if averaged_grad is None:
                averaged_grad = grad.sum(dim=0) / args.accumulation_steps
            else:
                averaged_grad += grad.sum(dim=0) / args.accumulation_steps

        logger.info('Evaluating Candidates')
        pbar = tqdm(range(args.accumulation_steps))
        train_iter = iter(train_loader)

        token_to_flip = i % templatizer.num_trigger_tokens
        candidates = ct.hotflip_attack(averaged_grad[token_to_flip],
                                       global_data.embeddings.weight,
                                       increase_loss=False,
                                       num_candidates=args.num_cand)
        current_score = 0
        candidate_scores = torch.zeros(args.num_cand,
                                       device=global_data.device)
        denom = 0
        for step in pbar:
            try:
                model_inputs, labels = next(train_iter)
            except:
                logger.warning(
                    'Insufficient data for number of accumulation steps. '
                    'Effective batch size will be smaller than specified.')
                break
            model_inputs = {
                k: v.to(global_data.device)
                for k, v in model_inputs.items()
            }
            labels = labels.to(global_data.device)
            with torch.no_grad():
                predict_logits = global_data.predictor(model_inputs,
                                                       trigger_ids)
                eval_metric = evaluation_fn(predict_logits, labels)

            # Update current score
            current_score += eval_metric.sum()
            denom += labels.size(0)

            # NOTE: Instead of iterating over tokens to flip we randomly change just one each
            # time so the gradients don't get stale.
            for i, candidate in enumerate(candidates):

                # if candidate.item() in filter_candidates:
                #     candidate_scores[i] = -1e32
                #     continue

                temp_trigger = trigger_ids.clone()
                temp_trigger[:, token_to_flip] = candidate
                with torch.no_grad():
                    predict_logits = global_data.predictor(
                        model_inputs, temp_trigger)
                    eval_metric = evaluation_fn(predict_logits, labels)

                candidate_scores[i] += eval_metric.sum()

        if (candidate_scores >= current_score).any():
            logger.info('Better trigger detected.')
            best_candidate_score = candidate_scores.max()
            best_candidate_idx = candidate_scores.argmax()
            trigger_ids[:, token_to_flip] = candidates[best_candidate_idx]
            logger.info(
                f'Train metric: {best_candidate_score / (denom + 1e-13): 0.4f}'
            )

        # Skip eval
        best_trigger_ids = trigger_ids.clone()

    progress.progress(1.0)
    current_trigger = ','.join(
        global_data.tokenizer.convert_ids_to_tokens(
            best_trigger_ids.squeeze(0)))
    trigger_placeholder.markdown(f'**Current trigger**: {current_trigger}')

    best_trigger_tokens = global_data.tokenizer.convert_ids_to_tokens(
        best_trigger_ids.squeeze(0))

    train_output = predict_test(map(lambda x: x['sentence'], dataset.train),
                                dataset.label_map, templatizer,
                                best_trigger_ids, global_data.tokenizer,
                                global_data.predictor, args)

    # Streamlit does not like accessing widgets across functions, which is
    # problematic for this "live updating" widget which we want to still
    # display even if the train output is cached. To get around this, we're
    # going to delete the widget and replace it with a very similar looking
    # widget outside the function...no one will ever notice ;)
    trigger_placeholder.empty()

    return (best_trigger_tokens, current_score / denom, dataset.label_map,
            templatizer, best_trigger_ids, global_data.tokenizer,
            global_data.predictor, args, train_output)
Beispiel #4
0
def main(args):
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config = AutoConfig.from_pretrained(args.model_name,
                                        num_labels=args.num_labels)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelForSequenceClassification.from_pretrained(args.model_name,
                                                               config=config)
    model.to(device)

    collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
    train_dataset, label_map = utils.load_classification_dataset(
        args.train,
        tokenizer,
        args.field_a,
        args.field_b,
        args.label_field,
        limit=args.limit)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.bsz,
                              shuffle=True,
                              collate_fn=collator)
    dev_dataset, _ = utils.load_classification_dataset(args.dev, tokenizer,
                                                       args.field_a,
                                                       args.field_b,
                                                       args.label_field,
                                                       label_map)
    dev_loader = DataLoader(dev_dataset,
                            batch_size=args.bsz,
                            shuffle=True,
                            collate_fn=collator)
    test_dataset, _ = utils.load_classification_dataset(
        args.test, tokenizer, args.field_a, args.field_b, args.label_field,
        label_map)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.bsz,
                             shuffle=True,
                             collate_fn=collator)
    optimizer = torch.optim.Adam(model.classifier.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-6)

    # if not args.ckpt_dir.exists():
    #     logger.info(f'Making checkpoint directory: {args.ckpt_dir}')
    #     args.ckpt_dir.mkdir(parents=True)
    # elif not args.force_overwrite:
    #     raise RuntimeError('Checkpoint directory already exists.')

    best_accuracy = 0
    for epoch in range(args.epochs):
        logger.info('Training...')
        model.train()
        avg_loss = utils.ExponentialMovingAverage()
        pbar = tqdm(train_loader)
        for model_inputs, labels in pbar:
            model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
            labels = labels.to(device)
            optimizer.zero_grad()
            logits, *_ = model(**model_inputs)
            loss = F.cross_entropy(logits, labels.squeeze(-1))
            loss.backward()
            optimizer.step()
            avg_loss.update(loss.item())
            pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}')

        logger.info('Evaluating...')
        model.eval()
        correct = 0
        total = 0
        for model_inputs, labels in dev_loader:
            model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
            labels = labels.to(device)
            logits, *_ = model(**model_inputs)
            _, preds = logits.max(dim=-1)
            correct += (preds == labels.squeeze(-1)).sum().item()
            total += labels.size(0)
        accuracy = correct / (total + 1e-13)
        logger.info(f'Accuracy: {accuracy : 0.4f}')

        if accuracy > best_accuracy:
            logger.info('Best performance so far.')
            # torch.save(model.state_dict(), args.ckpt_dir / WEIGHTS_NAME)
            # model.config.to_json_file(args.ckpt_dir / CONFIG_NAME)
            # tokenizer.save_pretrained(args.ckpt_dir)
            best_accuracy = accuracy

    logger.info('Testing...')
    model.eval()
    correct = 0
    total = 0
    for model_inputs, labels in test_loader:
        model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
        labels = labels.to(device)
        logits, *_ = model(**model_inputs)
        _, preds = logits.max(dim=-1)
        correct += (preds == labels.squeeze(-1)).sum().item()
        total += labels.size(0)
    accuracy = correct / (total + 1e-13)
    logger.info(f'Accuracy: {accuracy : 0.4f}')
def main(args):
    logger.info("Dataset: %s" % str(args.train).split("/")[3])
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config = AutoConfig.from_pretrained(args.model_name)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelWithLMHead.from_pretrained(args.model_name, config=config)
    if args.model_name == "bert-base-cased":
        model.embeds = model.bert.embeddings.word_embeddings
        eos_idx = 102
        if not args.finetune:
            for param in model.bert.parameters():
                param.requires_grad = False
    elif args.model_name == "roberta-base":
        model.embeds = model.roberta.embeddings.word_embeddings
        eos_idx = tokenizer.eos_token_id
        if not args.finetune:
            for param in model.roberta.parameters():
                param.requires_grad = False
    if not args.finetune:
        for param in model.parameters():
            param.requires_grad = False
    model.relation_embeds = torch.nn.Parameter(
        torch.rand(args.trigger_length,
                   model.embeds.weight.shape[1],
                   requires_grad=True))
    model.to(device)

    collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
    train_dataset = utils.load_continuous_trigger_dataset(args.train,
                                                          tokenizer,
                                                          args.field_a,
                                                          args.field_b,
                                                          args.label_field,
                                                          limit=args.limit)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.bsz,
                              shuffle=True,
                              collate_fn=collator)
    dev_dataset = utils.load_continuous_trigger_dataset(
        args.dev, tokenizer, args.field_a, args.field_b, args.label_field)
    dev_loader = DataLoader(dev_dataset,
                            batch_size=args.bsz,
                            shuffle=True,
                            collate_fn=collator)
    test_dataset = utils.load_continuous_trigger_dataset(
        args.test, tokenizer, args.field_a, args.field_b, args.label_field)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.bsz,
                             shuffle=True,
                             collate_fn=collator)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-6)

    best_accuracy = 0
    for epoch in range(args.epochs):
        logger.info('Training...')
        model.train()
        avg_loss = utils.ExponentialMovingAverage()
        pbar = tqdm(train_loader)
        for model_inputs, labels in pbar:
            model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
            mask_token_idxs = (model_inputs["input_ids"]
                               == eos_idx).nonzero()[:,
                                                     1] + args.trigger_length
            model_inputs = generate_inputs_embeds(model_inputs, model,
                                                  tokenizer, eos_idx)
            labels = labels.to(device)[:, 1]
            optimizer.zero_grad()
            logits, *_ = model(**model_inputs)
            mask_logits = logits[
                torch.arange(0, logits.shape[0], dtype=torch.long),
                mask_token_idxs]
            loss = F.cross_entropy(mask_logits, labels)
            loss.backward()
            optimizer.step()
            avg_loss.update(loss.item())
            pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}')

        logger.info('Evaluating...')
        model.eval()
        correct = 0
        total = 0
        for model_inputs, labels in dev_loader:
            model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
            mask_token_idxs = (model_inputs["input_ids"]
                               == eos_idx).nonzero()[:,
                                                     1] + args.trigger_length
            model_inputs = generate_inputs_embeds(model_inputs, model,
                                                  tokenizer, eos_idx)
            labels = labels.to(device)[:, 1]
            logits, *_ = model(**model_inputs)
            mask_logits = logits[
                torch.arange(0, logits.shape[0], dtype=torch.long),
                mask_token_idxs]
            preds = torch.topk(mask_logits, 1, dim=1).indices[:, 0]
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        accuracy = correct / (total + 1e-13)
        logger.info(f'Accuracy: {accuracy : 0.4f}')

        if accuracy > best_accuracy:
            logger.info('Best performance so far.')
            # torch.save(model.state_dict(), args.ckpt_dir / WEIGHTS_NAME)
            # model.config.to_json_file(args.ckpt_dir / CONFIG_NAME)
            # tokenizer.save_pretrained(args.ckpt_dir)
            best_accuracy = accuracy

    logger.info('Testing...')
    model.eval()
    correct = 0
    total = 0
    # TO DO: currently testing on last model, not best validation model
    for model_inputs, labels in test_loader:
        model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
        mask_token_idxs = (model_inputs["input_ids"]
                           == eos_idx).nonzero()[:, 1] + args.trigger_length
        model_inputs = generate_inputs_embeds(model_inputs, model, tokenizer,
                                              eos_idx)
        labels = labels.to(device)[:, 1]
        logits, *_ = model(**model_inputs)
        mask_logits = logits[
            torch.arange(0, logits.shape[0], dtype=torch.long),
            mask_token_idxs]
        preds = torch.topk(mask_logits, 1, dim=1).indices[:, 0]
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    accuracy = correct / (total + 1e-13)
    logger.info(f'Accuracy: {accuracy : 0.4f}')
Beispiel #6
0
def run_autoprompt(args, dataset):
    ct.set_seed(args.seed)
    global_data = GlobalData.from_pretrained(args.model_name)

    templatizer = utils.TriggerTemplatizer(
        args.template,
        global_data.config,
        global_data.tokenizer,
        label_field=args.label_field,
        label_map=dataset.label_map,
        tokenize_labels=args.tokenize_labels,
        add_special_tokens=False,
    )
    evaluation_fn = ct.AccuracyFn(global_data.tokenizer, dataset.label_map,
                                  global_data.device)

    # Do not allow for initial trigger specification.
    trigger_ids = [global_data.tokenizer.mask_token_id
                   ] * templatizer.num_trigger_tokens
    trigger_ids = torch.tensor(trigger_ids,
                               device=global_data.device).unsqueeze(0)
    best_trigger_ids = trigger_ids.clone()

    # Load datasets
    logger.info('Loading datasets')
    collator = utils.Collator(pad_token_id=global_data.tokenizer.pad_token_id)
    train_dataset = load_trigger_dataset(dataset.train, templatizer)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.bsz,
                                               shuffle=True,
                                               collate_fn=collator)
    dev_dataset = load_trigger_dataset(dataset.dev, templatizer)
    dev_loader = torch.utils.data.DataLoader(dev_dataset,
                                             batch_size=args.eval_size,
                                             shuffle=False,
                                             collate_fn=collator)

    progress = st.progress(0.0)
    trigger_placeholder = st.empty()
    best_dev_metric = -float('inf')
    for i in range(args.iters):
        logger.info(f'Iteration: {i}')
        progress.progress(float(i) / args.iters)

        current_trigger = ','.join(
            global_data.tokenizer.convert_ids_to_tokens(
                best_trigger_ids.squeeze(0)))
        trigger_placeholder.markdown(f'**Current trigger**: {current_trigger}')

        global_data.model.zero_grad()
        train_iter = iter(train_loader)
        averaged_grad = None

        # Compute gradient of loss
        for step in range(args.accumulation_steps):
            try:
                model_inputs, labels = next(train_iter)
            except:
                logger.warning(
                    'Insufficient data for number of accumulation steps. '
                    'Effective batch size will be smaller than specified.')
                break
            model_inputs = {
                k: v.to(global_data.device)
                for k, v in model_inputs.items()
            }
            labels = labels.to(global_data.device)
            predict_logits = global_data.predictor(model_inputs, trigger_ids)
            loss = ct.get_loss(predict_logits, labels).mean()
            loss.backward()

            grad = global_data.embedding_gradient.get()
            bsz, _, emb_dim = grad.size()
            selection_mask = model_inputs['trigger_mask'].unsqueeze(-1)
            grad = torch.masked_select(grad, selection_mask)
            grad = grad.view(bsz, templatizer.num_trigger_tokens, emb_dim)

            if averaged_grad is None:
                averaged_grad = grad.sum(dim=0) / args.accumulation_steps
            else:
                averaged_grad += grad.sum(dim=0) / args.accumulation_steps

        logger.info('Evaluating Candidates')
        pbar = tqdm(range(args.accumulation_steps))
        train_iter = iter(train_loader)

        token_to_flip = i % templatizer.num_trigger_tokens
        candidates = ct.hotflip_attack(averaged_grad[token_to_flip],
                                       global_data.embeddings.weight,
                                       increase_loss=False,
                                       num_candidates=args.num_cand)
        current_score = 0
        candidate_scores = torch.zeros(args.num_cand,
                                       device=global_data.device)
        denom = 0
        for step in pbar:
            try:
                model_inputs, labels = next(train_iter)
            except:
                logger.warning(
                    'Insufficient data for number of accumulation steps. '
                    'Effective batch size will be smaller than specified.')
                break
            model_inputs = {
                k: v.to(global_data.device)
                for k, v in model_inputs.items()
            }
            labels = labels.to(global_data.device)
            with torch.no_grad():
                predict_logits = global_data.predictor(model_inputs,
                                                       trigger_ids)
                eval_metric = evaluation_fn(predict_logits, labels)

            # Update current score
            current_score += eval_metric.sum()
            denom += labels.size(0)

            # NOTE: Instead of iterating over tokens to flip we randomly change just one each
            # time so the gradients don't get stale.
            for i, candidate in enumerate(candidates):

                # if candidate.item() in filter_candidates:
                #     candidate_scores[i] = -1e32
                #     continue

                temp_trigger = trigger_ids.clone()
                temp_trigger[:, token_to_flip] = candidate
                with torch.no_grad():
                    predict_logits = global_data.predictor(
                        model_inputs, temp_trigger)
                    eval_metric = evaluation_fn(predict_logits, labels)

                candidate_scores[i] += eval_metric.sum()

        if (candidate_scores > current_score).any():
            logger.info('Better trigger detected.')
            best_candidate_score = candidate_scores.max()
            best_candidate_idx = candidate_scores.argmax()
            trigger_ids[:, token_to_flip] = candidates[best_candidate_idx]
            logger.info(
                f'Train metric: {best_candidate_score / (denom + 1e-13): 0.4f}'
            )

        logger.info('Evaluating')
        numerator = 0
        denominator = 0
        for model_inputs, labels in tqdm(dev_loader):
            model_inputs = {
                k: v.to(global_data.device)
                for k, v in model_inputs.items()
            }
            labels = labels.to(global_data.device)
            with torch.no_grad():
                predict_logits = global_data.predictor(model_inputs,
                                                       trigger_ids)
            numerator += evaluation_fn(predict_logits, labels).sum().item()
            denominator += labels.size(0)
        dev_metric = numerator / (denominator + 1e-13)

        if dev_metric > best_dev_metric:
            logger.info('Best performance so far')
            best_trigger_ids = trigger_ids.clone()
            best_dev_metric = dev_metric

    progress.progress(1.0)
    current_trigger = ','.join(
        global_data.tokenizer.convert_ids_to_tokens(
            best_trigger_ids.squeeze(0)))
    trigger_placeholder.markdown(f'**Current trigger**: {current_trigger}')

    best_trigger_tokens = global_data.tokenizer.convert_ids_to_tokens(
        best_trigger_ids.squeeze(0))
    dev_output = predict_test(map(lambda x: x['sentence'],
                                  dataset.dev), dataset.label_map, templatizer,
                              best_trigger_ids, global_data.tokenizer,
                              global_data.predictor, args)
    st.dataframe(pd.DataFrame(dev_output).style.highlight_min(axis=1))
    return best_trigger_tokens, best_dev_metric, dataset.label_map, templatizer, best_trigger_ids, global_data.tokenizer, global_data.predictor, args
Beispiel #7
0
def main(args):
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config = AutoConfig.from_pretrained(args.model_name, num_labels=args.num_labels)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelForSequenceClassification.from_pretrained(args.model_name, config=config)
    model.to(device)

    collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
    train_dataset, label_map = utils.load_classification_dataset(
        args.train,
        tokenizer,
        args.field_a,
        args.field_b,
        args.label_field,
        limit=args.limit
    )
    train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
    dev_dataset, _ = utils.load_classification_dataset(
        args.dev,
        tokenizer,
        args.field_a,
        args.field_b,
        args.label_field,
        label_map
    )
    dev_loader = DataLoader(dev_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)
    test_dataset, _ = utils.load_classification_dataset(
        args.test,
        tokenizer,
        args.field_a,
        args.field_b,
        args.label_field,
        label_map
    )
    test_loader = DataLoader(test_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)

    if args.bias_correction:
        betas = (0.9, 0.999)
    else:
        betas = (0.0, 0.000)

    optimizer = AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=1e-2,
        betas=betas
    )

    # Use suggested learning rate scheduler
    num_training_steps = len(train_dataset) * args.epochs // args.bsz
    num_warmup_steps = num_training_steps // 10
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps,
                                                num_training_steps)

    if not args.ckpt_dir.exists():
        logger.info(f'Making checkpoint directory: {args.ckpt_dir}')
        args.ckpt_dir.mkdir(parents=True)
    elif not args.force_overwrite:
        raise RuntimeError('Checkpoint directory already exists.')

    try:
        best_accuracy = 0
        for epoch in range(args.epochs):
            logger.info('Training...')
            model.train()
            avg_loss = utils.ExponentialMovingAverage()
            pbar = tqdm(train_loader)
            for model_inputs, labels in pbar:
                model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
                labels = labels.to(device)
                optimizer.zero_grad()
                logits, *_ = model(**model_inputs)
                loss = F.cross_entropy(logits, labels.squeeze(-1))
                loss.backward()
                optimizer.step()
                scheduler.step()
                avg_loss.update(loss.item())
                pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}, '
                                     f'lr: {optimizer.param_groups[0]["lr"]: .3e}')

            logger.info('Evaluating...')
            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for model_inputs, labels in dev_loader:
                    model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
                    labels = labels.to(device)
                    logits, *_ = model(**model_inputs)
                    _, preds = logits.max(dim=-1)
                    correct += (preds == labels.squeeze(-1)).sum().item()
                    total += labels.size(0)
                accuracy = correct / (total + 1e-13)
            logger.info(f'Accuracy: {accuracy : 0.4f}')

            if accuracy > best_accuracy:
                logger.info('Best performance so far.')
                model.save_pretrained(args.ckpt_dir)
                tokenizer.save_pretrained(args.ckpt_dir)
                best_accuracy = accuracy
    except KeyboardInterrupt:
        logger.info('Interrupted...')

    logger.info('Testing...')
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for model_inputs, labels in test_loader:
            model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
            labels = labels.to(device)
            logits, *_ = model(**model_inputs)
            _, preds = logits.max(dim=-1)
            correct += (preds == labels.squeeze(-1)).sum().item()
            total += labels.size(0)
        accuracy = correct / (total + 1e-13)
    logger.info(f'Accuracy: {accuracy : 0.4f}')
Beispiel #8
0
def main(args):
    ct.set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    logger.info('Loading model, tokenizer, etc.')
    config, model, tokenizer = load_pretrained(args.model_name)
    model.to(device)
    final_embeddings = get_final_embeddings(model)
    embedding_storage = utils.OutputStorage(final_embeddings)
    word_embeddings = get_word_embeddings(model)

    label_map = json.loads(args.label_map)
    reverse_label_map = {y: x for x, y in label_map.items()}
    templatizer = utils.TriggerTemplatizer(
        args.template,
        tokenizer,
        label_map=label_map,
        label_field=args.label_field,
        add_special_tokens=False
    )

    # The weights of this projection will help identify the best label words.
    projection = torch.nn.Linear(config.hidden_size, len(label_map))
    projection.to(device)

    # Obtain the initial trigger tokens and label mapping
    if args.initial_trigger:
        trigger_ids = tokenizer.encode(
            args.initial_trigger,
            add_special_tokens=False,
            add_prefix_space=True
        )
        assert len(trigger_ids) == templatizer.num_trigger_tokens
    else:
        trigger_ids = [tokenizer.mask_token_id] * templatizer.num_trigger_tokens
    trigger_ids = torch.tensor(trigger_ids, device=device).unsqueeze(0)

    logger.info('Loading datasets')
    collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
    train_dataset = utils.load_trigger_dataset(args.train, templatizer)
    train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)

    optimizer = torch.optim.Adam(projection.parameters(), lr=args.lr)

    scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1))
    scores = F.softmax(scores, dim=0)
    for i, row in enumerate(scores):
        _, top = row.topk(args.k)
        decoded = tokenizer.convert_ids_to_tokens(top)
        logger.info(f"Top k for class {reverse_label_map[i]}: {', '.join(decoded)}")

    logger.info('Training')
    for i in range(args.iters):
        pbar = tqdm(train_loader)
        for model_inputs, labels in pbar:
            optimizer.zero_grad()
            model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
            labels = labels.to(device)
            trigger_mask = model_inputs.pop('trigger_mask')
            predict_mask = model_inputs.pop('predict_mask')
            model_inputs = ct.replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask)
            with torch.no_grad():
                model(**model_inputs)
            embeddings = embedding_storage.get()
            predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1)
            logits = projection(predict_embeddings)
            loss = F.cross_entropy(logits, labels.squeeze(-1))
            loss.backward()
            optimizer.step()
            pbar.set_description(f'loss: {loss : 0.4f}')

        scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1))
        scores = F.softmax(scores, dim=0)
        for i, row in enumerate(scores):
            _, top = row.topk(args.k)
            decoded = tokenizer.convert_ids_to_tokens(top)
            logger.info(f"Top k for class {reverse_label_map[i]}: {', '.join(decoded)}")