def train(device, data, schedule, mi_type, args): model = MI_Estimator(device, D=d, ED=ed, HD=256) model.to(device) model.train() optimizer = optim.Adam(model.parameters(), lr=5e-4) xs, ys = data xs = xs.to(device) ys = ys.to(device) zxs = torch.cat([xs, zerot], dim=0) lsh = LSH(SimHash(ed, K, L), K, L) estimates = [] for batch_idx, MI in enumerate(schedule): optimizer.zero_grad() # randomly select data from data distribution sdx_iter = (batch_idx // mi_range) * mi_range sdx_offset = sdx_iter * batch_size sdx = torch.from_numpy( np.random.choice(mi_range * batch_size, batch_size, replace=False) + sdx_offset).to(device) t = 10 if batch_idx <= 1000 else 100 if batch_idx % t == 0: # Load first section of desired size into lsh hash tables lxs = xs[:desired_size, :] assert (lxs.size(0) == desired_size) build(lsh, model, lxs) #lsh.stats() # Full - Load All Data #build(lsh, model, xs) # embed data y = F.embedding(sdx, ys).detach() ey = model.embed_y(y) # for each data sample, query lsh data structure, remove accidental hit # find maximum number of samples # create matrix and pad appropriately np_indices = lsh.query_remove_matrix(ey, sdx, xs.size(0)) indices = torch.from_numpy(np_indices).to(device) # create mask distinguishing between samples and padding mask = 1.0 - torch.eq(indices, xs.size(0)).float() mask = torch.cat([bs_onet, mask], dim=1).detach() px = torch.unsqueeze(F.embedding(sdx, xs), dim=1) nx = F.embedding(indices, zxs, padding_idx=xs.size(0)) x = torch.cat([px, nx], dim=1).detach() mi = model(x, y, mask, args) loss = -mi loss.backward() optimizer.step() estimates.append(mi.item()) if (batch_idx + 1) % 100 == 0: print('{} {} MI:{}, E_MI: {:.6f}'.format(mi_type.name, batch_idx + 1, MI, mi.item())) sys.stdout.flush() lsh.stats() return estimates
def train(device, data, schedule, mi_type, args): model = MI_Estimator(device, D=d, ED=ed, HD=256) model.to(device) model.train() optimizer = optim.Adam(model.parameters(), lr=5e-4) xs, ys = data xs = xs.to(device) ys = ys.to(device) lsh = LSH(SimHash(ed, K, L), K, L) estimates = [] avg_estimate = [] id_set = set() n_iters = num_iterations * batch_size for batch_idx in range(n_iters): iteration = batch_idx // batch_size MI = schedule[iteration] t = 10 if batch_idx <= 1000 else 100 if batch_idx % t == 0: build(lsh, model, xs) optimizer.zero_grad() y = ys[batch_idx:batch_idx + 1] ey = model.embed_y(y) id_list = lsh.query(ey) id_set = id_set.union(set(id_list)) indices = torch.LongTensor(id_list).to(device) nx = F.embedding(indices, xs) px = xs[batch_idx:batch_idx + 1] x = torch.cat([px, nx], dim=0) x = torch.unsqueeze(x, dim=0) mi = model(x, y, args) loss = -mi loss.backward() optimizer.step() avg_estimate.append(mi.item()) if (batch_idx + 1) % 100 == 0: ''' asim = model.cosine_similarity(x, y) true = torch.mean(torch.diag(asim)) neye = 1. - torch.eye(batch_size).to(device) noise = torch.sum(torch.mul(asim, neye)).item() / (batch_size * (batch_size-1)) print("MI:{} true: {:.4f}, noise: {:.4f}".format(MI, true, noise)) ''' avg_mi = sum(avg_estimate) / float(len(avg_estimate)) print('{} {} MI:{}, E_MI: {:.6f}'.format(mi_type.name, batch_idx + 1, MI, avg_mi)) sys.stdout.flush() if (batch_idx + 1) % wsize == 0: print(len(id_set), len(id_set) // wsize) id_set.clear() avg_mi = sum(avg_estimate) / float(len(avg_estimate)) estimates.append(avg_mi) avg_estimate.clear() lsh.stats() return estimates