if euc:
    param_file = '/EuclideanEncoder/' + expinf + '_params.pt'
else:
    param_file = '/GeodesicCoordinates/' + expinf + '_params.pt'
probe.load_state_dict(torch.load(SAVE_DIR + fold + param_file))

#%%
means = []
std = []
for line_idx in [2]:
    ############## load the sentence
    line_idx_in_file = idx[line_idx]
    if line_idx_in_file < 0:
        continue
    line = linecache.getline(SAVE_DIR + bracket_file, line_idx_in_file + 1)
    sentence = BracketedSentence(line, dep_tree=dep)

    # check that the sentence in swapped_data.pkl matches that in train.txt
    # sameword = [sentence.words[i]==dist[line_idx][0][i] for i in range(sentence.ntok)]
    # if ~np.all(sameword):
    #     print('Mismatch between lines!!')
    #     print('Line %d is '%line_idx)
    #     break

    if sentence.ntok < 10:
        print('Skipping line %d, too short!' % line_idx)
        continue

    try:
        with bz2.BZ2File(LOAD_DIR + '/%d/original_vectors.pkl' % line_idx,
                         'rb') as vfile:
Beispiel #2
0
# phrase_type = 'unreal'
# phrase_type = 'blocks'
# phrase_type = 'all'
phrase_window = None
all_window = 4

sum_align = []
sum_attn = []
num_in_phrase = []
num_pairs = []  # number of pairs we compare, to take averages
dt_inphrase = []  # distance of words to other words in the same phrase
t0 = time()
for line_idx in tqdm(range(1000)):

    line = linecache.getline(dfile, line_idx + 1)
    sentence = BracketedSentence(line)
    if sentence.ntok < 10:
        continue
    # orig = d[0]
    orig = sentence.words
    ntok = sentence.ntok

    phrs = sentence.phrases(order=order, min_length=1)
    all_window = np.ceil(max([len(p) for p in phrs]) / 2)
    phrase_window = np.ceil(max([len(p) for p in phrs]) / 2)

    # whether to use true phrases
    if phrase_type == 'real':
        const = np.array([[sentence.is_relative(i,j, order=order) \
                           for i in range(ntok)] \
                              for j in range(ntok)])
    model = BertModel.from_pretrained('bert-base-cased', output_hidden_states=True, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

dfile = SAVE_DIR+'train_bracketed.txt'

# with open(SAVE_DIR+'permuted_data.pkl', 'rb') as dfile:
#     dist = pkl.load(dfile)

#%%
print('Computing mean and variance ...\n')
# foo = []
all_vecs = []
for line_idx in tqdm.tqdm(np.random.choice(5000,500,replace=False)):
    
    line = linecache.getline(dfile, line_idx+1)
    sentence = BracketedSentence(line)
    orig = sentence.words
    ntok = sentence.ntok
    if ntok < 10:
        continue
    # orig = d[0]
        
    orig_idx = np.array(range(ntok))
    
    n = np.random.choice(range(1,6))
    # n=1
    swap_idx = ngram_shuffling(np.arange(ntok), n)
    swapped = [orig[i] for i in swap_idx]
    
    # assert(swapped == line[1+phrase_type][swap_type][2])
    assert([swapped[i] for i in np.argsort(swap_idx)] == orig)
Beispiel #4
0
t0 = time()
for epoch in range(nepoch):
    nbatch = 0  # we'll only take gradient steps every bsz datapoints
    cumloss_pos = 0
    cumloss_syn = 0
    optimizer_pos.zero_grad()
    optimizer_syn.zero_grad()

    for line_idx in np.random.permutation(range(len(dist))):  # range(13):

        line_idx_in_file = idx[line_idx]
        if line_idx_in_file < 0:
            continue
        line = linecache.getline(SAVE_DIR + '/data/' + bracket_file,
                                 line_idx_in_file + 1)
        sentence = BracketedSentence(line, dep_tree=dep)

        # check that the sentence in swapped_data.pkl matches that in train.txt
        sameword = [
            sentence.words[i] == dist[line_idx][0][i]
            for i in range(sentence.ntok)
        ]
        if ~np.all(sameword):
            if verbose:
                print('Mismatch between lines!!')
                print('Line %d is ' % line_idx)
            break

        if sentence.ntok < 10:
            if verbose:
                print('Skipping line %d, too short!' % line_idx)
Beispiel #5
0
#     # swap_idx_all = [[] for _ in range(500)]

#     for layer in range(13):

#         num_pos = np.zeros(5)
#         num_syn = np.zeros(5)

#         with open(SAVE_DIR+fold+'pos_layer%d_init%d_params.pt'%(layer,init), 'rb') as f:
#             glm_pos.load_state_dict(torch.load(f))
#         with open(SAVE_DIR+fold+'syn_layer%d_init%d_params.pt'%(layer,init), 'rb') as f:
#             glm_syn.load_state_dict(torch.load(f))

for line_idx in tqdm.tqdm(np.random.choice(5000, num_lines, replace=False)):

    line = linecache.getline(dfile, line_idx + 1)
    sentence = BracketedSentence(line)

    if sentence.ntok < 10:
        continue

    # get the POS and Grandparent tag labels
    labels_pos = np.array([
        pos_tags.index(w) if w in pos_tags else np.nan
        for w in sentence.pos_tags
    ])
    valid_pos = ~np.isnan(labels_pos)
    y_pos = torch.tensor(labels_pos[valid_pos]).long()

    anc_tags = [sentence.ancestor_tags(i, 2) for i in range(sentence.ntok)]
    labels_syn = np.array([
        phrase_tags.index(w) if w in phrase_tags else np.nan for w in anc_tags
Beispiel #6
0
num_cond = np.zeros(len(these_bounds))
corrs = [[] for _ in these_bounds]
cca = [[] for _ in these_bounds]
# corrs = np.zeros((max_num, len(these_bounds)))
t0 = time()
pbar = tqdm.tqdm(total=max_num*len(these_bounds), desc='Tree dist 0/%d'%len(these_bounds))
for dt in these_bounds:
    pbar.desc = 'Tree dist %d/%d'%(dt,len(these_bounds))
    vecs = []
    whichline = []
    whichswap = []
    for line_idx in np.random.permutation(range(5000)):
        
        line = linecache.getline(dfile, line_idx+1)
        sentence = BracketedSentence(line)
        if sentence.ntok<10:
            continue
        # orig = d[0]
        orig = sentence.words
        ntok = sentence.ntok
        
        crossings = np.diff(np.abs(sentence.brackets).cumsum()[sentence.term2brak])
        these_pairs = np.isin(crossings,dt)
        if not np.any(these_pairs):
            continue
        
        orig_idx = np.array(range(ntok))
        
        i = np.random.choice(np.where(these_pairs)[0])
        c = crossings[i]
 expinf = 'layer%d_rank%d_init%d_%s_linear'%(layer, N, init, criterion.__class__.__name__)
 
 num_pos = np.zeros(5)
 num_syn = np.zeros(5)
 
 with open(SAVE_DIR+fold+expinf+'_params.pt', 'rb') as f:
     probe.load_state_dict(torch.load(f))
     
 # dbs = [[] for _ in range(5)]
 # dts = []
 scorr = [[] for _ in range(5)]
 # logli = [[] for _ in range(5)]
 for line_idx in tqdm.tqdm(np.random.choice(5000,num_lines), desc='Layer %d, model %d'%(layer, init)):
     
     line = linecache.getline(dfile, line_idx+1)
     sentence = BracketedSentence(line, dep_tree=dep_tree)
     
     if sentence.ntok<10:
         continue
     if sentence.ntok>110:
         continue
     
     toks = np.arange(sentence.ntok)
     ntok = sentence.ntok
     w1, w2 = np.nonzero(np.triu(np.ones((ntok,ntok)),k=1))
     
     # tree distance
     dT = np.array([sentence.tree_dist(w1[i],w2[i],term=(not dep_tree)) for i in range(len(w1))])
     
     # num_pos[0] += 1
     # num_syn[0] += 1
# full_sentence = False

encoder = nn.Linear(768, N, bias=False)
probe = EuclideanEncoder(encoder)
# criterion = nn.MSELoss(reduction='mean')
# criterion = nn.L1Loss(reduction='mean')
criterion = nn.PoissonNLLLoss(reduction='mean', log_input=False)

distortion = []
logli = [[] for _ in range(13)]  # loglihood on training context
swp_logli = [[] for _ in range(13)]  # cross-context
dtree = []
for line_idx in tqdm.tqdm(np.random.choice(5000, num_line)):

    line = linecache.getline(dfile, line_idx + 1)
    sentence = BracketedSentence(line, dep_tree=dep_tree)

    if sentence.ntok < 10:
        continue
    if sentence.ntok > 110:
        continue

    toks = np.arange(sentence.ntok)
    ntok = sentence.ntok

    if full_sentence:
        w1, w2 = np.nonzero(np.triu(np.ones((ntok, ntok)), k=1))
        dT = np.array([
            sentence.tree_dist(w1[i], w2[i], term=(not dep_tree))
            for i in range(len(w1))
        ])
# tree_dists = pkl.load(open(SAVE_DIR+'data/line%d_treedist.pkl'%line_idx,'rb'))
# weights = list((-np.array(tree_dists)[:,0]))
# idx = np.array(list(np.array(tree_dists)[:,1]))

# og = pkl.load(open(SAVE_DIR+'extracted/bert-base-cased/'+str(line_idx)+'/original_vectors.pkl','rb'))

# X = og[layer,:,:].T.dot(og[layer,:,:])/la.norm(og[layer,:,:], axis=0)**2
# X = -la.norm(og[-1,:,:,None]-og[-1,:,None,:],axis=0)
# seq = dist[line_idx][0]'

with open(SAVE_DIR + 'data/train_bracketed.txt', 'r') as dfile:
    for i in range(line_idx):
        line = dfile.readline()

brak = BracketedSentence(line)

# weights = []
# idx = []
# for i in range(len(brak.terminals)):
#     for j in range(i+1,len(brak.terminals)):
#         weights.append(-brak.tree_dist(i,j))
#         idx.append((i,j))
# idx = np.array(idx)
# seq = brak.words

weights = []
idx = []
for i in range(len(brak.nodes)):
    for j in range(i + 1, len(brak.nodes)):
        weights.append(-brak.tree_dist(i, j, False))
Beispiel #10
0
    optimizer_swap_syn.zero_grad()

    num_batch = int(5000 / bsz)
    with tqdm.tqdm(range(num_batch),
                   total=num_batch,
                   desc='Epoch %d' % epoch,
                   postfix=[
                       dict(loss_orig_pos_=0,
                            loss_swap_pos_=0,
                            loss_orig_syn_=0,
                            loss_swap_syn_=0)
                   ]) as pbar:
        for line_idx in np.random.permutation(range(5000)):

            line = linecache.getline(dfile, line_idx + 1)
            sentence = BracketedSentence(line)

            orig = sentence.words
            ntok = sentence.ntok

            orig_idx = np.array(range(ntok))
            if len(swap_idx_all[line_idx]) == 0:
                swap_idx_all[line_idx] = np.random.permutation(orig_idx)
            swap_idx = swap_idx_all[line_idx]

            orig_vecs = extract_tensor(orig, indices=orig_idx)
            swap_vecs = extract_tensor([orig[s] for s in swap_idx],
                                       indices=swap_idx)

            labels_pos = np.array([
                pos_tags.index(w) if w in pos_tags else np.nan
Beispiel #11
0
if not os.path.isdir(SAVE_DIR + svfolder):
    os.makedirs(SAVE_DIR + svfolder)

#%%

dep = list(open(SAVE_DIR + '/dependency_train_bracketed.txt', 'r'))

# tok_num = dep[0].values
# line_num = np.cumsum(np.diff(tok_num, prepend=0)<0)
# tokens = dep[1].values
# sent_length = np.unique(line_num, return_counts=True)[1]
# sent_length = [BracketedSentence(d,True).ntok for d in dep]
sent_length = []
tokens = []
for d in dep:
    bs = BracketedSentence(d, True)
    sent_length.append(bs.ntok)
    tokens.append(bs.words)

#%%
const = open(SAVE_DIR + '/train_bracketed.txt', 'r')

unmatched_lines = list(range(len(dep)))
idx_in_train = np.zeros(len(unmatched_lines)) * np.nan
for line_idx, line in enumerate(const):
    if line_idx < start:
        continue
    if line_idx > stop:
        break

    words = BracketedSentence(line).words
Beispiel #12
0
from transformers import BertTokenizer, BertModel, BertConfig, AutoConfig, AutoModel, AutoTokenizer
import pickle as pkl
import numpy as np
import scipy.linalg as la
import linecache
from time import time
import matplotlib.pyplot as plt

#%%
jon_folder = 'C:/Users/mmall/Documents/github/bertembeddings/data/jonathans/'

lines = os.listdir(jon_folder+model)
dist = pkl.load(open(jon_folder+'/permuted_data.pkl','rb'))

dfile = SAVE_DIR+'train_bracketed.txt'

#%%

ptb_in_perm = []
for line_idx in range(2416):
    
    line_ptb = linecache.getline(dfile, line_idx+1)
    sent = BracketedSentence(line_ptb).words
    
    found = False
    for pd_line in np.random.permutation(range(len(dist))):
        if sent == dist[pd_line][0]:
            ptb_in_perm.append(pd_line)
            print('Found one')
            break
    
    cumloss_swap_pos = 0
    cumloss_orig_syn = 0
    cumloss_swap_syn = 0
    optimizer_orig_pos.zero_grad()
    optimizer_swap_pos.zero_grad()
    optimizer_orig_syn.zero_grad()
    optimizer_swap_syn.zero_grad()

    for line_idx in np.random.permutation(range(len(dist))):  # range(13):

        line_idx_in_file = idx[line_idx]
        if line_idx_in_file < 0:
            continue
        line = linecache.getline(SAVE_DIR + '/data/' + bracket_file,
                                 line_idx_in_file + 1)
        sentence = BracketedSentence(line, dep_tree=dep)

        # check that the sentence in swapped_data.pkl matches that in train.txt
        sameword = [
            sentence.words[i] == dist[line_idx][0][i]
            for i in range(sentence.ntok)
        ]
        if ~np.all(sameword):
            if verbose:
                print('Mismatch between lines!!')
                print('Line %d is ' % line_idx)
            break

        if sentence.ntok < 10:
            if verbose:
                print('Skipping line %d, too short!' % line_idx)
# s_dist = [] # distance in sequence (i.e. number of words between)
which_line = []  # outside index of swapped_data
which_swap = []  # inside index of swapped_data

num = 0
num_vecs = 0
print('Beginning extraction ... ')
t0 = time()
for line_idx in np.random.permutation(range(len(dist))):  # range(13):

    line_idx_in_file = idx[line_idx]
    if line_idx_in_file < 0:
        continue
    line = linecache.getline(SAVE_DIR + '/data/' + bracket_file,
                             line_idx_in_file + 1)
    sentence = BracketedSentence(line, dep_tree=dep)

    # check that the sentence in swapped_data.pkl matches that in train.txt
    sameword = [
        sentence.words[i] == dist[line_idx][0][i] for i in range(sentence.ntok)
    ]
    if ~np.all(sameword):
        if verbose:
            print('Mismatch between lines!!')
            print('Line %d is ' % line_idx)
        break

    if sentence.ntok < 10:
        if verbose:
            print('Skipping line %d, too short!' % line_idx)
        continue
line_idx = 4

# tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

# if random_model:
#     model = BertModel(BertConfig(output_hidden_states=True))
#     base_directory = 'vectors/permuted_depth/bert_untrained/'
# else:
#     model = BertModel.from_pretrained('bert-base-cased', output_hidden_states=True)
#     base_directory = 'extracted/bert-base-cased/'


with open(SAVE_DIR+'data/train_bracketed.txt', 'r') as dfile:
    for i in range(line_idx+1):
        line = dfile.readline()
brak = BracketedSentence(line)



cors = np.zeros((13,8,5,2,100)) # (layer, tree_dist, seq_dist, line)
print('Beginning extraction ... ')
with open(SAVE_DIR+'data/extracted/swapped_data.pkl', 'rb+') as dfile:
    dist = pkl.load(dfile)
# with open(SAVE_DIR+'data/train_bracketed.txt', 'r') as dfile:
    # dist = pkl.load(dfile)
    # t0 = time()
    for i in range(stop):
        if i<start:
            break
        t0 = time()
        
Beispiel #16
0
# lines = os.listdir(jon_folder+'/ngram/bert-base-cased/')

#%%
order = 1
# num_phrases = None
num_phrases = None

# swap_type = 'within'
swap_type = 'among'

# print('Computing mean and variance ...')
all_vecs = []
for line in tqdm.tqdm(np.random.permutation(dist)[:500]):

    line = linecache.getline(dfile, line_idx + 1)
    sentence = BracketedSentence(line)
    if sentence.ntok < 10:
        continue
    # orig = d[0]
    orig = sentence.words
    ntok = sentence.ntok

    phrase_idx = sentence.phrases(order=order, strict=True)
    if num_phrases is not None:
        if len(phrase_idx) < num_phrases:
            continue
    else:
        if len(phrase_idx) < 2 or (np.sum(
                np.isin(range(ntok), np.concatenate(phrase_idx))) > ntok / 2):
            continue
Beispiel #17
0
        layer_att).squeeze()[:, :, 1:-1, 1:-1], np.array(split_word_idx)


#%%
model = BertModel.from_pretrained('bert-base-cased',
                                  output_hidden_states=True,
                                  output_attentions=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

dfile = SAVE_DIR + 'train_bracketed.txt'

#%%
line_idx = 4

line = linecache.getline(dfile, line_idx + 1)
sentence = BracketedSentence(line)
orig = sentence.words
ntok = sentence.ntok
orig_idx = np.arange(ntok)

# output_att: layer, head, nTok, nTok
orig_vecs, orig_attn, word_idx = extract_tensor(orig,
                                                indices=orig_idx,
                                                all_attn=True)

#%%
phrase_order = 1
# layers = list(range(12))
layers = [9]

phr = sentence.phrases(phrase_order)