Example #1
0
import numpy as np
from pathlib import Path
import cv2
from tqdm import tqdm

import src.sensor as sensor
import src.utils as utils
import src.visualize as vis
from src.dataset import load_data

# Path info
base_path = Path('data') / 'KITTI'
date = '2011_09_26'
drive = 5

# Setup
dataset = load_data(base_path, date, drive)
i = 22
h_fov = (-90, 90)
v_fov = (-24.9, 2.0)

# Show Image
pcl_uv, pcl_d, img, _ = dataset.get_projected_pts(i, h_fov, v_fov)
vis.show_projection(pcl_uv, pcl_d, img)

# Write Video
vis.write_to_video(dataset, h_fov, v_fov)
Example #2
0
def train(args):
    if args.device == 'cuda' and torch.cuda.is_available():
        device = torch.device('cuda')
        print("using gpu: ", torch.cuda.get_device_name(torch.cuda.current_device()))
        
    else:
        device = torch.device('cpu')
        print('using cpu')
    
    if args.dataset_name == 'pubmed':
        LABEL_LIST = PUBMED_LABEL_LIST
    elif args.dataset_name == 'nicta':
        LABEL_LIST = NICTA_LABEL_LIST
    elif args.dataset_name == 'csabstract':
        LABEL_LIST = CSABSTRACT_LABEL_LIST

    train_x,train_labels = load_data(args.train_data, args.max_par_len,LABEL_LIST)
    dev_x,dev_labels = load_data(args.dev_data, args.max_par_len,LABEL_LIST)
    test_x,test_labels = load_data(args.test_data, args.max_par_len,LABEL_LIST)

    tokenizer = AutoTokenizer.from_pretrained(args.bert_model)
    train_x = tokenize_and_pad(train_x,tokenizer,args.max_par_len,args.max_seq_len, LABEL_LIST)  ## N, par_len, seq_len
    dev_x = tokenize_and_pad(dev_x,tokenizer,args.max_par_len, args.max_seq_len, LABEL_LIST)
    test_x = tokenize_and_pad(test_x,tokenizer, args.max_par_len, args.max_seq_len, LABEL_LIST)

    training_params = {
        "batch_size": args.batch_size,
        "shuffle": True,
        "drop_last": False
        }
    dev_params = {
        "batch_size": args.batch_size,
        "shuffle": False,
        "drop_last": False
        }
    test_params = {
        "batch_size": args.batch_size,
        "shuffle": False,
        "drop_last": False
        }

    print('train.py train_x.shape:',train_x.shape,'train_labels.shape',train_labels.shape)
    training_generator = return_dataloader(inputs=train_x, labels=train_labels, params=training_params)
    dev_generator = return_dataloader(inputs=dev_x, labels=dev_labels, params=dev_params)
    test_generator = return_dataloader(inputs=test_x, labels=test_labels, params=test_params)   

    src_pad_idx = 0
    trg_pad_idx = 0
    model = Transformer(
        label_list=LABEL_LIST,
        src_pad_idx=src_pad_idx,
        trg_pad_idx=trg_pad_idx,
        embed_size=args.embed_size,
        num_layers=args.num_layers,   ## debug
        forward_expansion=args.forward_expansion,
        heads=len(LABEL_LIST),
        dropout=0.1,
        device=device,
        max_par_len=args.max_par_len,
        max_seq_len=args.max_seq_len,
        bert_model=args.bert_model
    )
    model = model.to(device).float()
    
    criterion = nn.CrossEntropyLoss(ignore_index=trg_pad_idx)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
    
    epoch_losses = []
    best_val_loss = float('inf')
    for epoch in range(args.num_epochs):
        model.train()
        print(f"----------------[Epoch {epoch} / {args.num_epochs}]-----------------------")

        losses = []
        for batch_idx,batch in tqdm(enumerate(training_generator)):
            inp_data,target = batch
            inp_data = inp_data.to(device)
            target = target.to(device)

            ## For CRF
            optimizer.zero_grad()

            loss = -model(inp_data.long(),target[:,1:], training=True)       ## directly gives loss when training = True


            losses.append(loss.item())

            loss.backward()

            optimizer.step()
            
        mean_loss = sum(losses)/len(losses)

        print(f"Mean loss for epoch {epoch} is {mean_loss}")
        # Validation
        model.eval()
        val_targets = []
        val_preds = []
        for batch_idx,batch in tqdm(enumerate(dev_generator)):
            inp_data,target = batch
            inp_data = inp_data.to(device)
            target = target.to(device)
            with torch.no_grad():
                output = model(inp_data,target[:,:-1], training=False)      ## directly we get the labels here, instead of logits

            flattened_target = target[:,1:].to('cpu').flatten()
            output = convert_crf_output_to_tensor(output,args.max_par_len)
            flattened_preds = output.to('cpu').flatten()
            for target_i,pred_i in zip(flattened_target,flattened_preds):
                if target_i != 0:
                    val_targets.append(target_i)
                    val_preds.append(pred_i)

        f1 = f1_score(val_targets,val_preds,average='micro')
        
        print(f'------Micro F1 score on dev set: {f1}------')

        if loss < best_val_loss:
            print(f"val loss less than previous best val loss of {best_val_loss}")
            best_val_loss = loss
            if args.save_model:
                dir_name = f"seed_{args.seed}_parlen_{args.max_par_len}_seqlen_{args.max_seq_len}_lr_{args.lr}.pt"
                output_path = os.path.join(args.save_path,dir_name)
                if not os.path.exists(args.save_path):
                    os.makedirs(args.save_path)
                print(f"Saving model to path {output_path}")
                torch.save(model,output_path)

        # Testing
        if epoch % args.test_interval == 0:
            model.eval()
            test_targets = []
            test_preds = []
            for batch_idx, batch in tqdm(enumerate(test_generator)):
                inp_data,target = batch
                inp_data = inp_data.to(device)
                target = target.to(device)
                with torch.no_grad():
                    output = model(inp_data,target[:,:-1],training=False)
                    
                flattened_target = target[:,1:].to('cpu').flatten()
                output = convert_crf_output_to_tensor(output,args.max_par_len)
                flattened_preds = output.to('cpu').flatten()
                for target_i,pred_i in zip(flattened_target,flattened_preds):
                    if target_i!=0:
                        test_targets.append(target_i)
                        test_preds.append(pred_i)
            
            f1 = f1_score(test_targets,test_preds,average='micro')
            print(f"------Micro F1 score on test set: {f1}------")
def main(args):
    # Load data
    dataset, num_labels = load_data(args)
    if args.dataset == 'mnli':
        text_key = None
        testset_key = 'validation_%s' % args.mnli_option
    else:
        text_key = 'text' if (args.dataset in ["ag_news", "imdb", "yelp"]) else 'sentence'
        testset_key = 'test' if (args.dataset in ["ag_news", "imdb", "yelp"]) else 'validation'
    
    # Load target model
    pretrained = args.target_model.startswith('textattack')
    pretrained_surrogate = args.surrogate_model.startswith('textattack')
    suffix = '_finetune' if args.finetune else ''
    tokenizer = AutoTokenizer.from_pretrained(args.target_model, use_fast=True)
    model = AutoModelForSequenceClassification.from_pretrained(args.target_model, num_labels=num_labels).cuda()
    if not pretrained:
        model_checkpoint = os.path.join(args.result_folder, '%s_%s%s.pth' % (args.target_model.replace('/', '-'), args.dataset, suffix))
        print('Loading checkpoint: %s' % model_checkpoint)
        model.load_state_dict(torch.load(model_checkpoint))
        tokenizer.model_max_length = 512
    if args.target_model == 'gpt2':
        tokenizer.padding_side = "right"
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = model.config.eos_token_id
        
    label_perm = lambda x: x
    if pretrained:
        if args.target_model == 'textattack/bert-base-uncased-MNLI' or args.target_model == 'textattack/xlnet-base-cased-MNLI':
            label_perm = lambda x: (x + 1) % 3
        elif args.target_model == 'textattack/roberta-base-MNLI':
            label_perm = lambda x: -(x - 1) + 1

    # Compute clean accuracy
    corr = evaluate(model, tokenizer, dataset[testset_key], text_key, pretrained=pretrained, label_perm=label_perm)
    print('Clean accuracy = %.4f' % corr.float().mean())
    
    surr_tokenizer = AutoTokenizer.from_pretrained(args.surrogate_model, use_fast=True)
    surr_tokenizer.model_max_length = 512
    if args.surrogate_model == 'gpt2':
        surr_tokenizer.padding_side = "right"
        surr_tokenizer.pad_token = tokenizer.eos_token
        
    clean_texts, adv_texts, clean_logits, adv_logits, adv_log_coeffs, labels, times = load_checkpoints(args)

    label_perm = lambda x: x
    if pretrained and args.surrogate_model != args.target_model:
        if args.target_model == 'textattack/bert-base-uncased-MNLI' or args.target_model == 'textattack/xlnet-base-cased-MNLI':
            label_perm = lambda x: (x + 1) % 3
        elif args.target_model == 'textattack/roberta-base-MNLI':
            label_perm = lambda x: -(x - 1) + 1
    
    attack_target = args.attack_target if args.dataset == 'mnli' else ''
    all_sentences, all_corr, cosine_sim = evaluate_adv_samples(
        model, tokenizer, surr_tokenizer, adv_log_coeffs, clean_texts, labels, attack_target=attack_target,
        gumbel_samples=args.gumbel_samples, batch_size=args.batch_size, print_every=args.print_every,
        pretrained=pretrained, pretrained_surrogate=pretrained_surrogate, label_perm=label_perm)
    
    print("__logs:" + json.dumps({
        "cosine_similarity": float(cosine_sim),
        "adv_acc2": all_corr.float().mean(1).eq(1).float().mean().item()
    }))
    output_file = get_output_file(args, args.surrogate_model, args.start_index, args.end_index)
    output_file = os.path.join(args.adv_samples_folder,
                               'transfer_%s_%s' % (args.target_model.replace('/', '-'), output_file))
    torch.save({
        'all_sentences': all_sentences, 
        'all_corr': all_corr, 
        'clean_texts': clean_texts,        
        'labels': labels, 
        'times': times
    }, output_file)
Example #4
0
def train(args):
    if args.device == 'cuda' and torch.cuda.is_available():
        device = torch.device('cuda')
        print("using gpu: ", torch.cuda.get_device_name(torch.cuda.current_device()))
        
    else:
        device = torch.device('cpu')
        print('using cpu')
    
    if args.dataset_name == 'pubmed':
        LABEL_LIST = PUBMED_LABEL_LIST
    elif args.dataset_name == 'nicta':
        LABEL_LIST = NICTA_LABEL_LIST
    elif args.dataset_name == 'csabstract':
        LABEL_LIST = CSABSTRACT_LABEL_LIST

    train_x,train_labels = load_data(args.train_data, args.max_par_len,LABEL_LIST)
    dev_x,dev_labels = load_data(args.dev_data, args.max_par_len,LABEL_LIST)
    test_x,test_labels = load_data(args.test_data, args.max_par_len,LABEL_LIST)

    tokenizer = AutoTokenizer.from_pretrained(args.bert_model)
    train_x = tokenize_and_pad(train_x,tokenizer,args.max_par_len,args.max_seq_len, LABEL_LIST)  ## N, par_len, seq_len
    dev_x = tokenize_and_pad(dev_x,tokenizer,args.max_par_len, args.max_seq_len, LABEL_LIST)
    test_x = tokenize_and_pad(test_x,tokenizer, args.max_par_len, args.max_seq_len, LABEL_LIST)

    # print('train_x[0]',train_x[0])
    # print('train_x[0].shape',train_x[0].shape)
    # quit()
    training_params = {
        "batch_size": args.batch_size,
        "shuffle": True,
        "drop_last": False
        }
    dev_params = {
        "batch_size": args.batch_size,
        "shuffle": False,
        "drop_last": False
        }
    test_params = {
        "batch_size": args.batch_size,
        "shuffle": False,
        "drop_last": False
        }

    print('train.py train_x.shape:',train_x.shape,'train_labels.shape',train_labels.shape)
    training_generator = return_dataloader(inputs=train_x, labels=train_labels, params=training_params)
    dev_generator = return_dataloader(inputs=dev_x, labels=dev_labels, params=dev_params)
    test_generator = return_dataloader(inputs=test_x, labels=test_labels, params=test_params)   

    src_pad_idx = 0
    trg_pad_idx = 0
    model = Transformer(
        label_list=LABEL_LIST,
        src_pad_idx=src_pad_idx,
        trg_pad_idx=trg_pad_idx,
        embed_size=args.embed_size,
        num_layers=args.num_layers,   ## debug
        forward_expansion=args.forward_expansion,
        heads=len(LABEL_LIST),
        dropout=0.1,
        device=device,
        max_par_len=args.max_par_len,
        max_seq_len=args.max_seq_len,
        bert_model=args.bert_model
    )
    model = model.to(device).float()
    # for param in model.parameters():
    #     try:
    #         torch.nn.init.xavier_uniform_(param)
    #     except:
    #         continue
    
    criterion = nn.CrossEntropyLoss(ignore_index=trg_pad_idx)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #     optimizer, factor=0.1, patience=10, verbose=True
    # )
    
    epoch_losses = []
    best_val_loss = float('inf')
    for epoch in range(args.num_epochs):
        model.train()
        print(f"----------------[Epoch {epoch} / {args.num_epochs}]-----------------------")

        losses = []
        for batch_idx,batch in tqdm(enumerate(training_generator)):
            # print('batch',batch)
            # print('type of batch',type(batch))
            inp_data,target = batch
            # print('inp_data',inp_data)
            # print('type(inp_data)',type(inp_data))
            # print('target',target)
            # print('type(target)',type(target))
            # print('target.shape',target.shape)
            inp_data = inp_data.to(device)
            # print('inp_data.shape',inp_data.shape)
            target = target.to(device)
            # assert False

            ## For generation
            # output = model(inp_data.long(),target[:,:-1], training=True)       ## N,par_len, label_size
            
            ## For CRF
            optimizer.zero_grad()

            # output = model(inp_data.long(),target[:,1:], training=True)       ## N,par_len, label_size
            loss = -model(inp_data.long(),target[:,1:], training=True)       ## directly gives loss when training = True


            # output = model(inp_data,target[:,:-1])

            # print('model net',make_dot(output))
            # print(make_dot(output))
            # make_arch = make_dot(output)
            # Source(make_arch).render('graph.png')
            # assert False
            ## output - N,par_len, num_labels --> N*par_len, num_labels
            # output = output.reshape(-1,output.shape[2])
            ## target -
            # target = target[:,1:].reshape(-1)

            # print('output.shape',output.shape)
            # print('target.shape',target.shape)
            # print(f'{epoch} model params', list(model.parameters())[-1])
            # print('len params',len(list(model.parameters())))
            # print('trainable params: ',len(list(filter(lambda p: p.requires_grad, model.parameters()))))

            # loss = criterion(output,target)
            # loss.retain_grad()
            losses.append(loss.item())

            # print(f'{epoch} loss grads before', list(loss.grad)[-1])
            loss.backward()
            # print(f'{epoch} loss grads after', loss.grad)
            # print('model params')
            # count = 0
            # for p in model.parameters():
            #     if p.grad is not None:
            #         print(p.grad,p.grad.norm())
            #         count +=1 
            # print(f'non none grads are {count}')
            # torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm=1)

            optimizer.step()
            # break #NOTE: break is there only for quick checking. Remove this for actual training.
            
        mean_loss = sum(losses)/len(losses)
        # scheduler.step(mean_loss)

        print(f"Mean loss for epoch {epoch} is {mean_loss}")
        # Validation
        model.eval()
        # val_losses = []
        val_targets = []
        val_preds = []
        for batch_idx,batch in tqdm(enumerate(dev_generator)):
            inp_data,target = batch
            inp_data = inp_data.to(device)
            target = target.to(device)
            with torch.no_grad():
                output = model(inp_data,target[:,:-1], training=False)      ## directly we get the labels here, instead of logits
                # reshaped_output = output.reshape(-1,output.shape[2])
                # reshaped_target = target[:,1:].reshape(-1)
                # loss = criterion(reshaped_output,reshaped_target).item()

            # val_losses.append(loss)
            flattened_target = target[:,1:].to('cpu').flatten()
            # print(output)
            output = convert_crf_output_to_tensor(output,args.max_par_len)
            # flattened_preds = torch.softmax(output,dim=-1).argmax(dim=-1).to('cpu').flatten()
            flattened_preds = output.to('cpu').flatten()
            for target_i,pred_i in zip(flattened_target,flattened_preds):
                if target_i != 0:
                    val_targets.append(target_i)
                    val_preds.append(pred_i)
            # val_targets.append(target[:,1:].to('cpu').flatten())
            # output = torch.softmax(output,dim=-1).argmax(dim=-1)
            # val_preds.append(output.to('cpu').flatten())
            # break #NOTE: break is there only for quick checking. Remove this for actual training.

        # loss = sum(val_losses) / len(val_losses)
        # print(f"Validation loss at epoch {epoch} is {loss}")
        # val_targets = torch.cat(val_targets,dim=0)
        # val_preds = torch.cat(val_preds,dim=0)
        f1 = f1_score(val_targets,val_preds,average='micro')
        
        print(f'------Micro F1 score on dev set: {f1}------')

        # if loss < best_val_loss:
        #     print(f"val loss less than previous best val loss of {best_val_loss}")
        #     best_val_loss = loss
        #     if args.save_model:
        #         dir_name = f"seed_{args.seed}_parlen_{args.max_par_len}_seqlen_{args.max_seq_len}_lr_{args.lr}.pt"
        #         output_path = os.path.join(args.save_path,dir_name)
        #         if not os.path.exists(args.save_path):
        #             os.makedirs(args.save_path)
        #         print(f"Saving model to path {output_path}")
        #         torch.save(model,output_path)

        # Testing
        if epoch % args.test_interval == 0:
            model.eval()
            test_targets = []
            test_preds = []
            for batch_idx, batch in tqdm(enumerate(test_generator)):
                inp_data,target = batch
                inp_data = inp_data.to(device)
                target = target.to(device)
                with torch.no_grad():
                    output = model(inp_data,target[:,:-1],training=False)
                    
                # output = torch.softmax(output,dim=-1).argmax(dim=-1)
                flattened_target = target[:,1:].to('cpu').flatten()
                output = convert_crf_output_to_tensor(output,args.max_par_len)
                flattened_preds = output.to('cpu').flatten()
                for target_i,pred_i in zip(flattened_target,flattened_preds):
                    if target_i!=0:
                        test_targets.append(target_i)
                        test_preds.append(pred_i)
                # test_targets.append(target[:,1:].to('cpu').flatten())
                # test_preds.append(output.to('cpu').flatten())
                # break  #NOTE: break is there only for quick checking. Remove this for actual training. 
            
            # test_targets = torch.cat(test_targets,dim=0)
            # test_preds = torch.cat(test_preds,dim=0)
            # f1 = f1_score(target[:,1:].to('cpu').flatten(),output.to('cpu').flatten(),average='macro')
            f1 = f1_score(test_targets,test_preds,average='micro')
            print(f"------Micro F1 score on test set: {f1}------")

    ## Uncomment for generating attention vectors. 
    # Look into src/word_level_labelatt.py for details of computing and storing these attention scores
    # Look into src/selfatt.py for sentence level attention scores
    att_x = train_x[:10,:,:].to(device)
    att_y = train_labels[:10,:].to(device)[:,:-1] 
    model(att_x,att_y,training=False,att_heat_map=True)    
Example #5
0
def main(params):
    # Loading data
    dataset, num_labels = load_data(params)
    dataset = dataset["train"]
    text_key = 'text'
    if params.dataset == "dbpedia14":
        text_key = 'content'
    print(f"Loaded dataset {params.dataset}, that has {len(dataset)} rows")

    # Load model and tokenizer from HuggingFace
    model_class = transformers.AutoModelForSequenceClassification
    model = model_class.from_pretrained(params.model,
                                        num_labels=num_labels).cuda()

    if params.ckpt != None:
        state_dict = torch.load(params.ckpt)
        model.load_state_dict(state_dict)
    tokenizer = textattack.models.tokenizers.AutoTokenizer(params.model)
    model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(
        model, tokenizer, batch_size=params.batch_size)

    # Create radioactive directions and modify classification layer to use those
    if params.radioactive:
        torch.manual_seed(0)
        radioactive_directions = torch.randn(num_labels, 768)
        radioactive_directions /= torch.norm(radioactive_directions,
                                             dim=1,
                                             keepdim=True)
        print(radioactive_directions)
        model.classifier.weight.data = radioactive_directions.cuda()
        model.classifier.bias.data = torch.zeros(num_labels).cuda()

    start_index = params.chunk_id * params.chunk_size
    end_index = start_index + params.chunk_size

    if params.target_dir is not None:
        target_file = join(params.target_dir, f"{params.chunk_id}.csv")
        f = open(target_file, "w")
        f = csv.writer(f,
                       delimiter=',',
                       quotechar='"',
                       quoting=csv.QUOTE_NONNUMERIC)

    # Creating attack
    print(f"Building {params.attack} attack")
    if params.attack == "custom":
        current_label = -1
        if params.targeted:
            current_label = dataset[start_index]['label']
            assert all([
                dataset[i]['label'] == current_label
                for i in range(start_index, end_index)
            ])
        attack = build_attack(model_wrapper, current_label)
    elif params.attack == "bae":
        print(f"Building BAE method with threshold={params.bae_threshold:.2f}")
        attack = build_baegarg2019(model_wrapper,
                                   threshold_cosine=params.bae_threshold,
                                   query_budget=params.query_budget)
    elif params.attack == "bert-attack":
        assert params.query_budget is None
        attack = BERTAttackLi2020.build(model_wrapper)
    elif params.attack == "clare":
        assert params.query_budget is None
        attack = CLARE2020.build(model_wrapper)

    # Launching attack
    begin_time = time.time()
    samples = [
        (dataset[i][text_key],
         attack.goal_function.get_output(AttackedText(dataset[i][text_key])))
        for i in range(start_index, end_index)
    ]
    results = list(attack.attack_dataset(samples))

    # Storing attacked text
    bert_scorer = BERTScorer(model_type="bert-base-uncased", idf=False)

    n_success = 0
    similarities = []
    queries = []
    use = USE()

    for i_result, result in enumerate(results):
        print("")
        print(50 * "*")
        print("")
        text = dataset[start_index + i_result][text_key]
        ptext = result.perturbed_text()
        i_data = start_index + i_result
        if params.target_dir is not None:
            if params.dataset == 'dbpedia14':
                f.writerow([
                    dataset[i_data]['label'] + 1, dataset[i_data]['title'],
                    ptext
                ])
            else:
                f.writerow([dataset[i_data]['label'] + 1, ptext])

        print("True label ", dataset[i_data]['label'])
        print(f"CLEAN TEXT\n {text}")
        print(f"ADV TEXT\n {ptext}")

        if type(result) not in [SuccessfulAttackResult, FailedAttackResult]:
            print("WARNING: Attack neither succeeded nor failed...")
        print(result.goal_function_result_str())
        precision, recall, f1 = [
            r.item() for r in bert_scorer.score([ptext], [text])
        ]
        print(
            f"Bert scores: precision {precision:.2f}, recall: {recall:.2f}, f1: {f1:.2f}"
        )
        initial_logits = model_wrapper([text])
        final_logits = model_wrapper([ptext])
        print("Initial logits", initial_logits)
        print("Final logits", final_logits)
        print("Logits difference", final_logits - initial_logits)

        # Statistics
        n_success += 1 if type(result) is SuccessfulAttackResult else 0
        queries.append(result.num_queries)
        similarities.append(use.compute_sim([text], [ptext]))

    print("Processing all samples took %.2f" % (time.time() - begin_time))
    print(f"Total success: {n_success}/{len(results)}")
    logs = {
        "success_rate": n_success / len(results),
        "avg_queries": sum(queries) / len(queries),
        "queries": queries,
        "avg_similarity": sum(similarities) / len(similarities),
        "similarities": similarities,
    }
    print("__logs:" + json.dumps(logs))
    if params.target_dir is not None:
        f.close()
Example #6
0
def main(args):

    dataset, num_labels = load_data(args)

    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
    model = AutoModelForSequenceClassification.from_pretrained(
        args.model, num_labels=num_labels)
    if args.model == 'gpt2':
        tokenizer.padding_side = "right"
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = model.config.eos_token_id
    if args.dataset == "mnli":
        # only evaluate on matched validation set
        testset_key = "validation_matched"
        preprocess_function = lambda examples: tokenizer(examples["premise"],
                                                         examples["hypothesis"
                                                                  ],
                                                         max_length=256,
                                                         truncation=True)
    else:
        text_key = 'text' if (args.dataset in ["ag_news", "imdb", "yelp"
                                               ]) else 'sentence'
        testset_key = 'test' if (args.dataset in ["ag_news", "imdb", "yelp"
                                                  ]) else 'validation'
        preprocess_function = lambda examples: tokenizer(
            examples[text_key], max_length=256, truncation=True)
    encoded_dataset = dataset.map(preprocess_function, batched=True)

    train_args = TrainingArguments(
        args.checkpoint_folder,
        disable_tqdm=not args.tqdm,
        evaluation_strategy="epoch",
        learning_rate=args.lr,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        num_train_epochs=args.epochs,
        weight_decay=args.weight_decay,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
    )

    trainer = Trainer(
        model,
        train_args,
        train_dataset=encoded_dataset["train"],
        eval_dataset=encoded_dataset[testset_key],
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    if not args.finetune:
        # freeze parameters of transformer
        transformer = list(model.children())[0]
        for param in transformer.parameters():
            param.requires_grad = False

    trainer.train()
    trainer.evaluate()
    suffix = ''
    if args.finetune:
        suffix += '_finetune'
    torch.save(
        model.state_dict(),
        os.path.join(
            args.result_folder, "%s_%s%s.pth" %
            (args.model.replace('/', '-'), args.dataset, suffix)))
Example #7
0
def main(args):
    pretrained = args.model.startswith('textattack')
    output_file = get_output_file(args, args.model, args.start_index, args.start_index + args.num_samples)
    output_file = os.path.join(args.adv_samples_folder, output_file)
    print(f"Outputting files to {output_file}")
    if os.path.exists(output_file):
        print('Skipping batch as it has already been completed.')
        exit()
    
    # Load dataset
    dataset, num_labels = load_data(args)
    label_perm = lambda x: x
    if pretrained and args.model == 'textattack/bert-base-uncased-MNLI':
        label_perm = lambda x: (x + 1) % 3

    # Load tokenizer, model, and reference model
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
    tokenizer.model_max_length = 512
    model = AutoModelForSequenceClassification.from_pretrained(args.model, num_labels=num_labels).cuda()
    if not pretrained:
        # Load model to attack
        suffix = '_finetune' if args.finetune else ''
        model_checkpoint = os.path.join(args.result_folder, '%s_%s%s.pth' % (args.model.replace('/', '-'), args.dataset, suffix))
        print('Loading checkpoint: %s' % model_checkpoint)
        model.load_state_dict(torch.load(model_checkpoint))
        tokenizer.model_max_length = 512
    if args.model == 'gpt2':
        tokenizer.padding_side = "right"
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = model.config.eos_token_id
        
    if 'bert-base-uncase' in args.model:
        # for BERT, load GPT-2 trained on BERT tokenizer
        ref_model = load_gpt2_from_dict("%s/transformer_wikitext-103.pth" % args.gpt2_checkpoint_folder, output_hidden_states=True).cuda()
    else:
        ref_model = AutoModelForCausalLM.from_pretrained(args.model, output_hidden_states=True).cuda()
    with torch.no_grad():
        embeddings = model.get_input_embeddings()(torch.arange(0, tokenizer.vocab_size).long().cuda())
        ref_embeddings = ref_model.get_input_embeddings()(torch.arange(0, tokenizer.vocab_size).long().cuda())
        
    # encode dataset using tokenizer
    if args.dataset == "mnli":
        testset_key = "validation_%s" % args.mnli_option
        preprocess_function = lambda examples: tokenizer(
            examples['premise'], examples['hypothesis'], max_length=256, truncation=True)
    else:
        text_key = 'text' if (args.dataset in ["ag_news", "imdb", "yelp"]) else 'sentence'
        testset_key = 'test' if (args.dataset in ["ag_news", "imdb", "yelp"]) else 'validation'
        preprocess_function = lambda examples: tokenizer(examples[text_key], max_length=256, truncation=True)
    encoded_dataset = dataset.map(preprocess_function, batched=True)
        
    # Compute idf dictionary for BERTScore
    if args.constraint == "bertscore_idf":
        if args.dataset == 'mnli':
            idf_dict = get_idf_dict(dataset['train']['premise'] + dataset['train']['hypothesis'], tokenizer, nthreads=20)
        else:
            idf_dict = get_idf_dict(dataset['train'][text_key], tokenizer, nthreads=20)
        
    if args.dataset == 'mnli':
        adv_log_coeffs = {'premise': [], 'hypothesis': []}
        clean_texts = {'premise': [], 'hypothesis': []}
        adv_texts = {'premise': [], 'hypothesis': []}
    else:
        adv_log_coeffs, clean_texts, adv_texts = [], [], []
    clean_logits = []
    adv_logits = []
    token_errors = []
    times = []
    
    assert args.start_index < len(encoded_dataset[testset_key]), 'Starting index %d is larger than dataset length %d' % (args.start_index, len(encoded_dataset[testset_key]))
    end_index = min(args.start_index + args.num_samples, len(encoded_dataset[testset_key]))
    adv_losses, ref_losses, perp_losses, entropies = torch.zeros(end_index - args.start_index, args.num_iters), torch.zeros(end_index - args.start_index, args.num_iters), torch.zeros(end_index - args.start_index, args.num_iters), torch.zeros(end_index - args.start_index, args.num_iters)
    for idx in range(args.start_index, end_index):
        input_ids = encoded_dataset[testset_key]['input_ids'][idx]
        if args.model == 'gpt2':
            token_type_ids = None
        else:
            token_type_ids = encoded_dataset[testset_key]['token_type_ids'][idx]
        label = label_perm(encoded_dataset[testset_key]['label'][idx])
        clean_logit = model(input_ids=torch.LongTensor(input_ids).unsqueeze(0).cuda(),
                             token_type_ids=(None if token_type_ids is None else torch.LongTensor(token_type_ids).unsqueeze(0).cuda())).logits.data.cpu()
        print('LABEL')
        print(label)
        print('TEXT')
        print(tokenizer.decode(input_ids))
        print('LOGITS')
        print(clean_logit)
        
        forbidden = np.zeros(len(input_ids)).astype('bool')
        # set [CLS] and [SEP] tokens to forbidden
        forbidden[0] = True
        forbidden[-1] = True
        offset = 0 if args.model == 'gpt2' else 1
        if args.dataset == 'mnli':
            # set either premise or hypothesis to forbidden
            premise_length = len(tokenizer.encode(encoded_dataset[testset_key]['premise'][idx]))
            input_ids_premise = input_ids[offset:(premise_length-offset)]
            input_ids_hypothesis = input_ids[premise_length:len(input_ids)-offset]
            if args.attack_target == "hypothesis":
                forbidden[:premise_length] = True
            else:
                forbidden[(premise_length-offset):] = True
        forbidden_indices = np.arange(0, len(input_ids))[forbidden]
        forbidden_indices = torch.from_numpy(forbidden_indices).cuda()
        token_type_ids_batch = (None if token_type_ids is None else torch.LongTensor(token_type_ids).unsqueeze(0).repeat(args.batch_size, 1).cuda())

        start_time = time.time()
        with torch.no_grad():
            orig_output = ref_model(torch.LongTensor(input_ids).cuda().unsqueeze(0)).hidden_states[args.embed_layer]
            if args.constraint.startswith('bertscore'):
                if args.constraint == "bertscore_idf":
                    ref_weights = torch.FloatTensor([idf_dict[idx] for idx in input_ids]).cuda()
                    ref_weights /= ref_weights.sum()
                else:
                    ref_weights = None
            elif args.constraint == 'cosine':
                # GPT-2 reference model uses last token embedding instead of pooling
                if args.model == 'gpt2' or 'bert-base-uncased' in args.model:
                    orig_output = orig_output[:, -1]
                else:
                    orig_output = orig_output.mean(1)
            log_coeffs = torch.zeros(len(input_ids), embeddings.size(0))
            indices = torch.arange(log_coeffs.size(0)).long()
            log_coeffs[indices, torch.LongTensor(input_ids)] = args.initial_coeff
            log_coeffs = log_coeffs.cuda()
            log_coeffs.requires_grad = True

        optimizer = torch.optim.Adam([log_coeffs], lr=args.lr)
        start = time.time()
        for i in range(args.num_iters):
            optimizer.zero_grad()
            coeffs = F.gumbel_softmax(log_coeffs.unsqueeze(0).repeat(args.batch_size, 1, 1), hard=False) # B x T x V
            inputs_embeds = (coeffs @ embeddings[None, :, :]) # B x T x D
            pred = model(inputs_embeds=inputs_embeds, token_type_ids=token_type_ids_batch).logits
            if args.adv_loss == 'ce':
                adv_loss = -F.cross_entropy(pred, label * torch.ones(args.batch_size).long().cuda())
            elif args.adv_loss == 'cw':
                top_preds = pred.sort(descending=True)[1]
                correct = (top_preds[:, 0] == label).long()
                indices = top_preds.gather(1, correct.view(-1, 1))
                adv_loss = (pred[:, label] - pred.gather(1, indices) + args.kappa).clamp(min=0).mean()
            
            # Similarity constraint
            ref_embeds = (coeffs @ ref_embeddings[None, :, :])
            pred = ref_model(inputs_embeds=ref_embeds)
            if args.lam_sim > 0:
                output = pred.hidden_states[args.embed_layer]
                if args.constraint.startswith('bertscore'):
                    ref_loss = -args.lam_sim * bert_score(orig_output, output, weights=ref_weights).mean()
                else:
                    if args.model == 'gpt2' or 'bert-base-uncased' in args.model:
                        output = output[:, -1]
                    else:
                        output = output.mean(1)
                    cosine = (output * orig_output).sum(1) / output.norm(2, 1) / orig_output.norm(2, 1)
                    ref_loss = -args.lam_sim * cosine.mean()
            else:
                ref_loss = torch.Tensor([0]).cuda()
               
            # (log) perplexity constraint
            if args.lam_perp > 0:
                perp_loss = args.lam_perp * log_perplexity(pred.logits, coeffs)
            else:
                perp_loss = torch.Tensor([0]).cuda()
                
            # Compute loss and backward
            total_loss = adv_loss + ref_loss + perp_loss
            total_loss.backward()
            
            entropy = torch.sum(-F.log_softmax(log_coeffs, dim=1) * F.softmax(log_coeffs, dim=1))
            if i % args.print_every == 0:
                print('Iteration %d: loss = %.4f, adv_loss = %.4f, ref_loss = %.4f, perp_loss = %.4f, entropy=%.4f, time=%.2f' % (
                    i+1, total_loss.item(), adv_loss.item(), ref_loss.item(), perp_loss.item(), entropy.item(), time.time() - start))

            # Gradient step
            log_coeffs.grad.index_fill_(0, forbidden_indices, 0)
            optimizer.step()

            # Log statistics
            adv_losses[idx - args.start_index, i] = adv_loss.detach().item()
            ref_losses[idx - args.start_index, i] = ref_loss.detach().item()
            perp_losses[idx - args.start_index, i] = perp_loss.detach().item()
            entropies[idx - args.start_index, i] = entropy.detach().item()
        times.append(time.time() - start_time)
        
        print('CLEAN TEXT')
        if args.dataset == 'mnli':
            clean_premise = tokenizer.decode(input_ids_premise)
            clean_hypothesis = tokenizer.decode(input_ids_hypothesis)
            clean_texts['premise'].append(clean_premise)
            clean_texts['hypothesis'].append(clean_hypothesis)
            print('%s %s' % (clean_premise, clean_hypothesis))
        else:
            clean_text = tokenizer.decode(input_ids[offset:(len(input_ids)-offset)])
            clean_texts.append(clean_text)
            print(clean_text)
        clean_logits.append(clean_logit)
        
        print('ADVERSARIAL TEXT')
        with torch.no_grad():
            for j in range(args.gumbel_samples):
                adv_ids = F.gumbel_softmax(log_coeffs, hard=True).argmax(1)
                if args.dataset == 'mnli':
                    if args.attack_target == 'premise':
                        adv_ids_premise = adv_ids[offset:(premise_length-offset)].cpu().tolist()
                        adv_ids_hypothesis = input_ids_hypothesis
                    else:
                        adv_ids_premise = input_ids_premise
                        adv_ids_hypothesis = adv_ids[premise_length:len(adv_ids)-offset].cpu().tolist()
                    adv_premise = tokenizer.decode(adv_ids_premise)
                    adv_hypothesis = tokenizer.decode(adv_ids_hypothesis)
                    x = tokenizer(adv_premise, adv_hypothesis, max_length=256, truncation=True, return_tensors='pt')
                    token_errors.append(wer(input_ids_premise + input_ids_hypothesis, x['input_ids'][0]))
                else:
                    adv_ids = adv_ids[offset:len(adv_ids)-offset].cpu().tolist()
                    adv_text = tokenizer.decode(adv_ids)
                    x = tokenizer(adv_text, max_length=256, truncation=True, return_tensors='pt')
                    token_errors.append(wer(adv_ids, x['input_ids'][0]))
                adv_logit = model(input_ids=x['input_ids'].cuda(), attention_mask=x['attention_mask'].cuda(),
                                  token_type_ids=(x['token_type_ids'].cuda() if 'token_type_ids' in x else None)).logits.data.cpu()
                if adv_logit.argmax() != label or j == args.gumbel_samples - 1:
                    if args.dataset == 'mnli':
                        adv_texts['premise'].append(adv_premise)
                        adv_texts['hypothesis'].append(adv_hypothesis)
                        print('%s %s' % (adv_premise, adv_hypothesis))
                    else:
                        adv_texts.append(adv_text)
                        print(adv_text)
                    adv_logits.append(adv_logit)
                    break
                
        # remove special tokens from adv_log_coeffs
        if args.dataset == 'mnli':
            adv_log_coeffs['premise'].append(log_coeffs[offset:(premise_length-offset), :].cpu())
            adv_log_coeffs['hypothesis'].append(log_coeffs[premise_length:(log_coeffs.size(0)-offset), :].cpu())
        else:
            adv_log_coeffs.append(log_coeffs[offset:(log_coeffs.size(0)-offset), :].cpu()) # size T x V
        
        print('')
        print('CLEAN LOGITS')
        print(clean_logit) # size 1 x C
        print('ADVERSARIAL LOGITS')
        print(adv_logit)   # size 1 x C
            
    print("Token Error Rate: %.4f (over %d tokens)" % (sum(token_errors) / len(token_errors), len(token_errors)))
    torch.save({
        'adv_log_coeffs': adv_log_coeffs, 
        'adv_logits': torch.cat(adv_logits, 0), # size N x C
        'adv_losses': adv_losses,
        'adv_texts': adv_texts,
        'clean_logits': torch.cat(clean_logits, 0), 
        'clean_texts': clean_texts, 
        'entropies': entropies,
        'labels': list(map(label_perm, encoded_dataset[testset_key]['label'][args.start_index:end_index])), 
        'perp_losses': perp_losses,
        'ref_losses': ref_losses,
        'times': times,
        'token_error': token_errors,
    }, output_file)