def train(device, data, schedule, mi_type, args):
    model = MI_Estimator(device, D=d, ED=ed, HD=256)
    model.to(device)
    model.train()

    optimizer = optim.Adam(model.parameters(), lr=5e-4)

    xs, ys = data
    xs = xs.to(device)
    ys = ys.to(device)
    zxs = torch.cat([xs, zerot], dim=0)

    lsh = LSH(SimHash(ed, K, L), K, L)

    estimates = []
    for batch_idx, MI in enumerate(schedule):
        optimizer.zero_grad()

        # randomly select data from data distribution
        sdx_iter = (batch_idx // mi_range) * mi_range
        sdx_offset = sdx_iter * batch_size
        sdx = torch.from_numpy(
            np.random.choice(mi_range *
                             batch_size, batch_size, replace=False) +
            sdx_offset).to(device)

        t = 10 if batch_idx <= 1000 else 100
        if batch_idx % t == 0:
            # Load first section of desired size into lsh hash tables
            lxs = xs[:desired_size, :]
            assert (lxs.size(0) == desired_size)
            build(lsh, model, lxs)

            #lsh.stats()
            # Full - Load All Data
            #build(lsh, model, xs)

        # embed data
        y = F.embedding(sdx, ys).detach()
        ey = model.embed_y(y)

        # for each data sample, query lsh data structure, remove accidental hit
        # find maximum number of samples
        # create matrix and pad appropriately
        np_indices = lsh.query_remove_matrix(ey, sdx, xs.size(0))
        indices = torch.from_numpy(np_indices).to(device)

        # create mask distinguishing between samples and padding
        mask = 1.0 - torch.eq(indices, xs.size(0)).float()
        mask = torch.cat([bs_onet, mask], dim=1).detach()

        px = torch.unsqueeze(F.embedding(sdx, xs), dim=1)
        nx = F.embedding(indices, zxs, padding_idx=xs.size(0))
        x = torch.cat([px, nx], dim=1).detach()

        mi = model(x, y, mask, args)
        loss = -mi
        loss.backward()
        optimizer.step()

        estimates.append(mi.item())
        if (batch_idx + 1) % 100 == 0:
            print('{} {} MI:{}, E_MI: {:.6f}'.format(mi_type.name,
                                                     batch_idx + 1, MI,
                                                     mi.item()))
            sys.stdout.flush()
    lsh.stats()
    return estimates
示例#2
0
def train(device, data, schedule, mi_type, args):
    model = MI_Estimator(device, D=d, ED=ed, HD=256)
    model.to(device)
    model.train()

    optimizer = optim.Adam(model.parameters(), lr=5e-4)

    xs, ys = data
    xs = xs.to(device)
    ys = ys.to(device)

    lsh = LSH(SimHash(ed, K, L), K, L)

    estimates = []
    avg_estimate = []

    id_set = set()
    n_iters = num_iterations * batch_size
    for batch_idx in range(n_iters):
        iteration = batch_idx // batch_size
        MI = schedule[iteration]

        t = 10 if batch_idx <= 1000 else 100
        if batch_idx % t == 0:
            build(lsh, model, xs)

        optimizer.zero_grad()

        y = ys[batch_idx:batch_idx + 1]
        ey = model.embed_y(y)

        id_list = lsh.query(ey)
        id_set = id_set.union(set(id_list))
        indices = torch.LongTensor(id_list).to(device)

        nx = F.embedding(indices, xs)
        px = xs[batch_idx:batch_idx + 1]
        x = torch.cat([px, nx], dim=0)
        x = torch.unsqueeze(x, dim=0)

        mi = model(x, y, args)
        loss = -mi
        loss.backward()
        optimizer.step()

        avg_estimate.append(mi.item())
        if (batch_idx + 1) % 100 == 0:
            '''
            asim = model.cosine_similarity(x, y)
            true = torch.mean(torch.diag(asim))
            neye = 1. - torch.eye(batch_size).to(device)
            noise = torch.sum(torch.mul(asim, neye)).item() / (batch_size * (batch_size-1))
            print("MI:{} true: {:.4f}, noise: {:.4f}".format(MI, true, noise))
            '''
            avg_mi = sum(avg_estimate) / float(len(avg_estimate))
            print('{} {} MI:{}, E_MI: {:.6f}'.format(mi_type.name,
                                                     batch_idx + 1, MI,
                                                     avg_mi))
            sys.stdout.flush()

        if (batch_idx + 1) % wsize == 0:
            print(len(id_set), len(id_set) // wsize)
            id_set.clear()
            avg_mi = sum(avg_estimate) / float(len(avg_estimate))
            estimates.append(avg_mi)
            avg_estimate.clear()
    lsh.stats()
    return estimates