def main(args):
    logger.info("Dataset: %s" % str(args.train).split("/")[3])
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config = AutoConfig.from_pretrained(args.model_name)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelWithLMHead.from_pretrained(args.model_name, config=config)
    if args.model_name == "bert-base-cased":
        model.embeds = model.bert.embeddings.word_embeddings
        eos_idx = 102
        if not args.finetune:
            for param in model.bert.parameters():
                param.requires_grad = False
    elif args.model_name == "roberta-base":
        model.embeds = model.roberta.embeddings.word_embeddings
        eos_idx = tokenizer.eos_token_id
        if not args.finetune:
            for param in model.roberta.parameters():
                param.requires_grad = False
    if not args.finetune:
        for param in model.parameters():
            param.requires_grad = False
    model.relation_embeds = torch.nn.Parameter(
        torch.rand(args.trigger_length,
                   model.embeds.weight.shape[1],
                   requires_grad=True))
    model.to(device)

    collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
    train_dataset = utils.load_continuous_trigger_dataset(args.train,
                                                          tokenizer,
                                                          args.field_a,
                                                          args.field_b,
                                                          args.label_field,
                                                          limit=args.limit)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.bsz,
                              shuffle=True,
                              collate_fn=collator)
    dev_dataset = utils.load_continuous_trigger_dataset(
        args.dev, tokenizer, args.field_a, args.field_b, args.label_field)
    dev_loader = DataLoader(dev_dataset,
                            batch_size=args.bsz,
                            shuffle=True,
                            collate_fn=collator)
    test_dataset = utils.load_continuous_trigger_dataset(
        args.test, tokenizer, args.field_a, args.field_b, args.label_field)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.bsz,
                             shuffle=True,
                             collate_fn=collator)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-6)

    best_accuracy = 0
    for epoch in range(args.epochs):
        logger.info('Training...')
        model.train()
        avg_loss = utils.ExponentialMovingAverage()
        pbar = tqdm(train_loader)
        for model_inputs, labels in pbar:
            model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
            mask_token_idxs = (model_inputs["input_ids"]
                               == eos_idx).nonzero()[:,
                                                     1] + args.trigger_length
            model_inputs = generate_inputs_embeds(model_inputs, model,
                                                  tokenizer, eos_idx)
            labels = labels.to(device)[:, 1]
            optimizer.zero_grad()
            logits, *_ = model(**model_inputs)
            mask_logits = logits[
                torch.arange(0, logits.shape[0], dtype=torch.long),
                mask_token_idxs]
            loss = F.cross_entropy(mask_logits, labels)
            loss.backward()
            optimizer.step()
            avg_loss.update(loss.item())
            pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}')

        logger.info('Evaluating...')
        model.eval()
        correct = 0
        total = 0
        for model_inputs, labels in dev_loader:
            model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
            mask_token_idxs = (model_inputs["input_ids"]
                               == eos_idx).nonzero()[:,
                                                     1] + args.trigger_length
            model_inputs = generate_inputs_embeds(model_inputs, model,
                                                  tokenizer, eos_idx)
            labels = labels.to(device)[:, 1]
            logits, *_ = model(**model_inputs)
            mask_logits = logits[
                torch.arange(0, logits.shape[0], dtype=torch.long),
                mask_token_idxs]
            preds = torch.topk(mask_logits, 1, dim=1).indices[:, 0]
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        accuracy = correct / (total + 1e-13)
        logger.info(f'Accuracy: {accuracy : 0.4f}')

        if accuracy > best_accuracy:
            logger.info('Best performance so far.')
            # torch.save(model.state_dict(), args.ckpt_dir / WEIGHTS_NAME)
            # model.config.to_json_file(args.ckpt_dir / CONFIG_NAME)
            # tokenizer.save_pretrained(args.ckpt_dir)
            best_accuracy = accuracy

    logger.info('Testing...')
    model.eval()
    correct = 0
    total = 0
    # TO DO: currently testing on last model, not best validation model
    for model_inputs, labels in test_loader:
        model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
        mask_token_idxs = (model_inputs["input_ids"]
                           == eos_idx).nonzero()[:, 1] + args.trigger_length
        model_inputs = generate_inputs_embeds(model_inputs, model, tokenizer,
                                              eos_idx)
        labels = labels.to(device)[:, 1]
        logits, *_ = model(**model_inputs)
        mask_logits = logits[
            torch.arange(0, logits.shape[0], dtype=torch.long),
            mask_token_idxs]
        preds = torch.topk(mask_logits, 1, dim=1).indices[:, 0]
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    accuracy = correct / (total + 1e-13)
    logger.info(f'Accuracy: {accuracy : 0.4f}')
示例#2
0
def main(args):
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config = AutoConfig.from_pretrained(args.model_name,
                                        num_labels=args.num_labels)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelForSequenceClassification.from_pretrained(args.model_name,
                                                               config=config)
    model.to(device)

    collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
    train_dataset, label_map = utils.load_classification_dataset(
        args.train,
        tokenizer,
        args.field_a,
        args.field_b,
        args.label_field,
        limit=args.limit)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.bsz,
                              shuffle=True,
                              collate_fn=collator)
    dev_dataset, _ = utils.load_classification_dataset(args.dev, tokenizer,
                                                       args.field_a,
                                                       args.field_b,
                                                       args.label_field,
                                                       label_map)
    dev_loader = DataLoader(dev_dataset,
                            batch_size=args.bsz,
                            shuffle=True,
                            collate_fn=collator)
    test_dataset, _ = utils.load_classification_dataset(
        args.test, tokenizer, args.field_a, args.field_b, args.label_field,
        label_map)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.bsz,
                             shuffle=True,
                             collate_fn=collator)
    optimizer = torch.optim.Adam(model.classifier.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-6)

    # if not args.ckpt_dir.exists():
    #     logger.info(f'Making checkpoint directory: {args.ckpt_dir}')
    #     args.ckpt_dir.mkdir(parents=True)
    # elif not args.force_overwrite:
    #     raise RuntimeError('Checkpoint directory already exists.')

    best_accuracy = 0
    for epoch in range(args.epochs):
        logger.info('Training...')
        model.train()
        avg_loss = utils.ExponentialMovingAverage()
        pbar = tqdm(train_loader)
        for model_inputs, labels in pbar:
            model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
            labels = labels.to(device)
            optimizer.zero_grad()
            logits, *_ = model(**model_inputs)
            loss = F.cross_entropy(logits, labels.squeeze(-1))
            loss.backward()
            optimizer.step()
            avg_loss.update(loss.item())
            pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}')

        logger.info('Evaluating...')
        model.eval()
        correct = 0
        total = 0
        for model_inputs, labels in dev_loader:
            model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
            labels = labels.to(device)
            logits, *_ = model(**model_inputs)
            _, preds = logits.max(dim=-1)
            correct += (preds == labels.squeeze(-1)).sum().item()
            total += labels.size(0)
        accuracy = correct / (total + 1e-13)
        logger.info(f'Accuracy: {accuracy : 0.4f}')

        if accuracy > best_accuracy:
            logger.info('Best performance so far.')
            # torch.save(model.state_dict(), args.ckpt_dir / WEIGHTS_NAME)
            # model.config.to_json_file(args.ckpt_dir / CONFIG_NAME)
            # tokenizer.save_pretrained(args.ckpt_dir)
            best_accuracy = accuracy

    logger.info('Testing...')
    model.eval()
    correct = 0
    total = 0
    for model_inputs, labels in test_loader:
        model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
        labels = labels.to(device)
        logits, *_ = model(**model_inputs)
        _, preds = logits.max(dim=-1)
        correct += (preds == labels.squeeze(-1)).sum().item()
        total += labels.size(0)
    accuracy = correct / (total + 1e-13)
    logger.info(f'Accuracy: {accuracy : 0.4f}')
示例#3
0
def main(args):
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config = AutoConfig.from_pretrained(args.model_name, num_labels=args.num_labels)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelForSequenceClassification.from_pretrained(args.model_name, config=config)
    model.to(device)

    collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
    train_dataset, label_map = utils.load_classification_dataset(
        args.train,
        tokenizer,
        args.field_a,
        args.field_b,
        args.label_field,
        limit=args.limit
    )
    train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
    dev_dataset, _ = utils.load_classification_dataset(
        args.dev,
        tokenizer,
        args.field_a,
        args.field_b,
        args.label_field,
        label_map
    )
    dev_loader = DataLoader(dev_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)
    test_dataset, _ = utils.load_classification_dataset(
        args.test,
        tokenizer,
        args.field_a,
        args.field_b,
        args.label_field,
        label_map
    )
    test_loader = DataLoader(test_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)

    if args.bias_correction:
        betas = (0.9, 0.999)
    else:
        betas = (0.0, 0.000)

    optimizer = AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=1e-2,
        betas=betas
    )

    # Use suggested learning rate scheduler
    num_training_steps = len(train_dataset) * args.epochs // args.bsz
    num_warmup_steps = num_training_steps // 10
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps,
                                                num_training_steps)

    if not args.ckpt_dir.exists():
        logger.info(f'Making checkpoint directory: {args.ckpt_dir}')
        args.ckpt_dir.mkdir(parents=True)
    elif not args.force_overwrite:
        raise RuntimeError('Checkpoint directory already exists.')

    try:
        best_accuracy = 0
        for epoch in range(args.epochs):
            logger.info('Training...')
            model.train()
            avg_loss = utils.ExponentialMovingAverage()
            pbar = tqdm(train_loader)
            for model_inputs, labels in pbar:
                model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
                labels = labels.to(device)
                optimizer.zero_grad()
                logits, *_ = model(**model_inputs)
                loss = F.cross_entropy(logits, labels.squeeze(-1))
                loss.backward()
                optimizer.step()
                scheduler.step()
                avg_loss.update(loss.item())
                pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}, '
                                     f'lr: {optimizer.param_groups[0]["lr"]: .3e}')

            logger.info('Evaluating...')
            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for model_inputs, labels in dev_loader:
                    model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
                    labels = labels.to(device)
                    logits, *_ = model(**model_inputs)
                    _, preds = logits.max(dim=-1)
                    correct += (preds == labels.squeeze(-1)).sum().item()
                    total += labels.size(0)
                accuracy = correct / (total + 1e-13)
            logger.info(f'Accuracy: {accuracy : 0.4f}')

            if accuracy > best_accuracy:
                logger.info('Best performance so far.')
                model.save_pretrained(args.ckpt_dir)
                tokenizer.save_pretrained(args.ckpt_dir)
                best_accuracy = accuracy
    except KeyboardInterrupt:
        logger.info('Interrupted...')

    logger.info('Testing...')
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for model_inputs, labels in test_loader:
            model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
            labels = labels.to(device)
            logits, *_ = model(**model_inputs)
            _, preds = logits.max(dim=-1)
            correct += (preds == labels.squeeze(-1)).sum().item()
            total += labels.size(0)
        accuracy = correct / (total + 1e-13)
    logger.info(f'Accuracy: {accuracy : 0.4f}')