Ejemplo n.º 1
0
def train_toy(toy, load=True, nb_steps=20, nb_flow=1, folder=""):
    device = "cpu"
    logger = utils.get_logger(logpath=os.path.join(folder, toy, 'logs'),
                              filepath=os.path.abspath(__file__))

    logger.info("Creating model...")
    model = UMNNMAFFlow(nb_flow=nb_flow,
                        nb_in=2,
                        hidden_derivative=[50, 50, 50, 50],
                        hidden_embedding=[50, 50, 50, 50],
                        embedding_s=10,
                        nb_steps=nb_steps,
                        device=device).to(device)
    logger.info("Model created.")
    opt = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-5)

    if load:
        logger.info("Loading model...")
        model.load_state_dict(torch.load(folder + toy + '/model.pt'))
        model.train()
        opt.load_state_dict(torch.load(folder + toy + '/ADAM.pt'))
        logger.info("Model loaded.")

    nb_samp = 1000
    batch_size = 100

    x_test = torch.tensor(toy_data.inf_train_gen(toy,
                                                 batch_size=1000)).to(device)
    x = torch.tensor(toy_data.inf_train_gen(toy, batch_size=1000)).to(device)

    for epoch in range(10000):
        ll_tot = 0
        start = timer()
        for j in range(0, nb_samp, batch_size):
            cur_x = torch.tensor(
                toy_data.inf_train_gen(toy, batch_size=batch_size)).to(device)
            ll, z = model.compute_ll(cur_x)
            ll = -ll.mean()
            ll_tot += ll.detach() / (nb_samp / batch_size)
            loss = ll
            opt.zero_grad()
            loss.backward()
            opt.step()
        end = timer()
        ll_test, _ = model.compute_ll(x_test)
        ll_test = -ll_test.mean()
        logger.info(
            "epoch: {:d} - Train loss: {:4f} - Test loss: {:4f} - Elapsed time per epoch {:4f} (seconds)"
            .format(epoch, ll_tot.item(), ll_test.item(), end - start))

        if (epoch % 100) == 0:
            summary_plots(x, x_test, folder, epoch, model, ll_tot, ll_test)
            torch.save(model.state_dict(), folder + toy + '/model.pt')
            torch.save(opt.state_dict(), folder + toy + '/ADAM.pt')
def compute_loss(args, model, batch_size=None, beta=1.):
    if batch_size is None:
        batch_size = args.batch_size

    # load data
    x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
    x = torch.from_numpy(x).type(torch.float32).to(device)

    #print(x.shape)
    #print(x)
    #(500, 2)
    #tensor([[3.4526e+00, 1.4150e-01],
    #        [3.1749e+00, 7.9239e-01],
    #        [3.4892e+00, -3.4557e-01],
    #        (...)
    #        [2.9311e+00, 1.4599e-01],
    #        [2.7587e+00, 3.3307e-02],
    #        [2.4231e+00, -1.2663e-01]], device = 'cuda:0')

    zero = torch.zeros(x.shape[0], 1).to(x)

    # transform to z
    z, delta_logp = model(x, zero)

    # compute log p(z)
    logpz = standard_normal_logprob(z).sum(1, keepdim=True)

    logpx = logpz - beta * delta_logp
    loss = -torch.mean(logpx)

    return loss, torch.mean(logpz), torch.mean(-delta_logp)
def loss_fn2(first_term_loss, genFGen2, args, model):
    # diff = t1 - t2

    # .numel() number of elements
    # return torch.sum(diff * diff) / diff.numel()

    xData = toy_data.inf_train_gen(args.data, batch_size=args.batch_size)
    xData = torch.from_numpy(xData).type(torch.float32).to(device)

    # second_term_loss = 0.0
    # second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))

    # second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))
    second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))

    for i in genFGen2:
        for j in xData:
            # second_term_loss += np.linalg.norm(i.cpu().detach().numpy()-j.cpu().detach().numpy())

            # second_term_loss += torch.norm(i-j, 2)
            # second_term_loss += torch.norm(i-j)

            # second_term_loss += torch.dist(i.type(torch.float64), j.type(torch.float64), 2)
            second_term_loss += torch.dist(i, j, 2)

        second_term_loss /= args.batch_size
        second_term_loss *= 30

    second_term_loss /= args.batch_size
    # second_term_loss = torch.tensor(second_term_loss).to(device)

    #return first_term_loss + second_term_loss + third_term_loss
    #return first_term_loss

    # third_term_loss = 0.0
    # third_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))

    # third_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))
    third_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))

    for i in range(args.batch_size):
        for j in range(args.batch_size):
            if i != j:
                # third_term_loss += ((np.linalg.norm(genFGen3[i,:].cpu().detach().numpy()-genFGen3[j,:].cpu().detach().numpy())) / (np.linalg.norm(genFGen2[i,:].cpu().detach().numpy()-genFGen2[j,:].cpu().detach().numpy())))

                # third_term_loss += ((torch.norm(genFGen3[i,:]-genFGen3[j,:], 2)) / (torch.norm(genFGen2[i,:]-genFGen2[j,:], 2)))
                # third_term_loss += ((torch.norm(genFGen3[i,:]-genFGen3[j,:])) / (torch.norm(genFGen2[i,:]-genFGen2[j,:])))

                # third_term_loss += ((torch.norm(genFGen3[i,:] - genFGen3[j,:])) / (torch.norm(genFGen2[i,:] - genFGen2[j,:])))
                third_term_loss += (
                    (torch.dist(genFGen3[i, :], genFGen3[j, :], 2)) /
                    (torch.dist(genFGen2[i, :], genFGen2[j, :], 2)))

        third_term_loss /= (args.batch_size - 1)
        third_term_loss *= 10

    third_term_loss /= args.batch_size
    # third_term_loss = torch.tensor(third_term_loss).to(device)

    return first_term_loss + second_term_loss + third_term_loss
Ejemplo n.º 4
0
    def get_ckpt_model_and_data(args):
        # Load checkpoint.
        checkpt = torch.load(args.checkpt,
                             map_location=lambda storage, loc: storage)
        ckpt_args = checkpt['args']
        state_dict = checkpt['state_dict']

        # Construct model and restore checkpoint.
        regularization_fns, regularization_coeffs = create_regularization_fns(
            ckpt_args)
        model = build_model_tabular(ckpt_args, 2,
                                    regularization_fns).to(device)
        if ckpt_args.spectral_norm: add_spectral_norm(model)
        set_cnf_options(ckpt_args, model)

        model.load_state_dict(state_dict)
        model.to(device)

        print(model)
        print("Number of trainable parameters: {}".format(
            count_parameters(model)))

        # Load samples from dataset
        data_samples = toy_data.inf_train_gen(ckpt_args.data, batch_size=2000)

        return model, data_samples
Ejemplo n.º 5
0
def compute_loss(args, model, batch_size=args.batch_size):

    x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
    x = torch.from_numpy(x).type(torch.float32).to(device)
    zero = torch.zeros(x.shape[0], 1).to(x)
    z, change = model(x, zero)

    logpx = standard_normal_logprob(z).sum(1, keepdim=True) - change
    loss = -torch.mean(logpx)
    return loss
Ejemplo n.º 6
0
def compute_loss(args, model, batch_size=None, beta=1.):
    if batch_size is None: batch_size = args.batch_size

    # load data
    x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
    x = torch.from_numpy(x).type(torch.float32).to(device)
    zero = torch.zeros(x.shape[0], 1).to(x)

    # transform to z
    z, delta_logp = model(x, zero)

    # compute log p(z)
    logpz = standard_normal_logprob(z).sum(1, keepdim=True)

    logpx = logpz - beta * delta_logp
    loss = -torch.mean(logpx)
    return loss, torch.mean(logpz), torch.mean(-delta_logp)
Ejemplo n.º 7
0
def compute_loss(args, model, batch_size=None):
    if batch_size is None: batch_size = args.batch_size

    # load data
    x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
    x = torch.from_numpy(x).type(torch.float32).to(device)
    zero = torch.zeros(x.shape[0], 1).to(x)
    lec = None if (args.poly_coef is None or not model.training) else torch.tensor(0.0).to(x)

    # transform to z
    z, delta_logp, lec = model(x, zero, lec)

    # compute log q(z)
    logpz = standard_normal_logprob(z).sum(1, keepdim=True)

    logpx = logpz - delta_logp
    loss = -torch.mean(logpx)
    return loss, lec
Ejemplo n.º 8
0
def compute_loss(args, model, batch_size=None):
    if batch_size is None: batch_size = args.batch_size

    # load data
    x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
    x = torch.from_numpy(x).type(torch.float32).to(device)
    zero = torch.zeros(x.shape[0], 1).to(x)

    # transform to z
    std = (args.std_max - args.std_min) * torch.rand_like(x[:,0]).view(-1,1) + args.std_min
    eps = torch.randn_like(x) * std
    std_in = std / args.std_max * args.std_weight
    z, delta_logp = model(x+eps, std_in, zero)

    # compute log q(z)
    logpz = standard_normal_logprob(z).sum(1, keepdim=True)

    logpx = logpz - delta_logp
    loss = -torch.mean(logpx)
    return loss
def compute_loss(args, model, batch_size=None, beta=1.):
    if batch_size is None:
        batch_size = args.batch_size

    # load data

    # load data
    x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
    x = torch.from_numpy(x).type(torch.float32).to(device)

    #print('')
    #print(x.shape)

    #print('')
    #print(x)

    #print('')

    #plt.figure()
    #plt.plot(x[:, 0].cpu().squeeze().numpy(), x[:, 1].cpu().squeeze().numpy(), 'o')

    #plt.ion()
    #plt.show()
    #plt.pause(2)

    #sfadadfa

    zero = torch.zeros(x.shape[0], 1).to(x)

    # transform to z
    z, delta_logp = model(x, zero)

    # compute log p(z)

    # compute log p(z)
    logpz = standard_normal_logprob(z).sum(1, keepdim=True)

    logpx = logpz - beta * delta_logp
    loss = -torch.mean(logpx)

    return loss, torch.mean(logpz), torch.mean(-delta_logp)
def loss_fn2(genFGen2, args, model):
    genFGen3 = torch.randn((args.batch_size, 2)).to(device)

    first_term_loss = compute_loss2(genFGen2, args, model, beta=beta)

    xData = toy_data.inf_train_gen(args.data, batch_size=args.batch_size)
    xData = torch.from_numpy(xData).type(torch.float32).to(device)

    #second_term_loss = 0.0
    #second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))

    #second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))
    #second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))

    #second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))
    #second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32')).to(device)

    #second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32')).to(device)

    #second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32')).to(device)
    second_term_loss = torch.from_numpy(np.array(float('inf'),
                                                 dtype='float32')).to(device)

    for i in genFGen2:
        for j in xData:
            # second_term_loss += np.linalg.norm(i.cpu().detach().numpy()-j.cpu().detach().numpy())

            # second_term_loss += torch.norm(i-j, 2)
            # second_term_loss += torch.norm(i-j)

            # second_term_loss += torch.dist(i.type(torch.float64), j.type(torch.float64), 2)
            #second_term_loss += torch.dist(i, j, 2)

            store_second_term_loss = torch.dist(i, j, 2)
            if second_term_loss < store_second_term_loss:
                second_term_loss = store_second_term_loss

        #second_term_loss /= args.batch_size
        second_term_loss *= 0.1

    second_term_loss /= args.batch_size
    # second_term_loss = torch.tensor(second_term_loss).to(device)

    #return first_term_loss + second_term_loss + third_term_loss
    #return first_term_loss

    #third_term_loss = 0.0
    #third_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))

    #third_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))
    #third_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))

    #third_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))
    third_term_loss = torch.from_numpy(np.array(0.0,
                                                dtype='float32')).to(device)

    for i in range(args.batch_size):
        for j in range(args.batch_size):
            if i != j:
                # third_term_loss += ((np.linalg.norm(genFGen3[i,:].cpu().detach().numpy()-genFGen3[j,:].cpu().detach().numpy())) / (np.linalg.norm(genFGen2[i,:].cpu().detach().numpy()-genFGen2[j,:].cpu().detach().numpy())))

                # third_term_loss += ((torch.norm(genFGen3[i,:]-genFGen3[j,:], 2)) / (torch.norm(genFGen2[i,:]-genFGen2[j,:], 2)))
                # third_term_loss += ((torch.norm(genFGen3[i,:]-genFGen3[j,:])) / (torch.norm(genFGen2[i,:]-genFGen2[j,:])))

                # third_term_loss += ((torch.norm(genFGen3[i,:] - genFGen3[j,:])) / (torch.norm(genFGen2[i,:] - genFGen2[j,:])))
                third_term_loss += (
                    (torch.dist(genFGen3[i, :], genFGen3[j, :], 2)) /
                    (torch.dist(genFGen2[i, :], genFGen2[j, :], 2)))

        third_term_loss /= (args.batch_size - 1)
        third_term_loss *= 0.1

    third_term_loss /= args.batch_size
    # third_term_loss = torch.tensor(third_term_loss).to(device)

    return first_term_loss + second_term_loss + third_term_loss
Ejemplo n.º 11
0
    ax.set_facecolor(COLOR_BACK)
    fig_filename = os.path.join(save_path, file_name + '.png')
    utils.makedirs(os.path.dirname(fig_filename))
    plt.savefig(fig_filename, format='png', dpi=1200, bbox_inches='tight')
    plt.close()


if __name__ == '__main__':
    save_path = 'generate1/' + args.data
    n_samples = 2000
    COLOR = 1, 1, 1, 0.64

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    sample_real = toy_data.inf_train_gen(args.data, batch_size=n_samples)
    save_fig(sample_real, COLOR, 'D', 0.1, save_path, 'sample_data')

    softflow = build_model_tabular(args, 2).to(device)
    softflow_path = args.load_path
    ckpt_softflow = torch.load(softflow_path)
    softflow.load_state_dict(ckpt_softflow['state_dict'])
    softflow.eval()

    z = torch.randn(n_samples, 2).type(torch.float32).to(device)
    sample_s = []
    inds = torch.arange(0, z.shape[0]).to(torch.int64)
    with torch.no_grad():
        for ii in torch.split(inds, int(100**2)):
            zeros_std = torch.zeros(z[ii].shape[0], 1).to(z)
            sample_s.append(softflow(z[ii], zeros_std, reverse=True))
Ejemplo n.º 12
0
 def sample(self):
     return torch.tensor(inf_train_gen(self.data, batch_size=1)), torch.tensor([])
Ejemplo n.º 13
0
def loss_fn2(genFGen2, args, model):
    genFGen3 = torch.randn((args.batch_size, 2)).to(device)

    #first_term_loss = compute_loss2(genFGen2, args, model, beta=beta)

    #first_term_loss = compute_loss2(genFGen2, args, model, beta=beta)
    #first_term_loss = compute_loss2(genFGen2, args, model, beta=beta)

    #print('')
    #print(first_term_loss)

    import math
    #mu = torch.from_numpy(np.array([2.8099582e+00, 9.6440443e-04], dtype="float32")).to(device)
    #mu = torch.from_numpy(np.array([[2.805741, -0.00889241]], dtype="float32")).to(device)
    mu = torch.from_numpy(np.array([2.805741, -0.00889241],
                                   dtype="float32")).to(device)

    #S = torch.from_numpy(np.array([[0.35833913, 0.0], [0.0, 0.34720358]], dtype="float32")).to(device)
    #S = torch.from_numpy(np.array([[pow(0.35833913,2), 0.0], [0.0, pow(0.34720358,2)]], dtype="float32")).to(device)
    S = torch.from_numpy(
        np.array([[pow(0.3442525, 2), 0.0], [0.0, pow(0.35358343, 2)]],
                 dtype="float32")).to(device)

    #print('')
    #import timeit
    #import scipy.stats
    #start = timeit.default_timer()

    #storeAll = torch.from_numpy(np.array(0.0, dtype="float32")).to(device)
    #for loopIndex_i in range(genFGen2.size()[0]):
    #    #print((((torch.from_numpy(np.array(-np.log(2 * math.pi), dtype="float32")).to(device))) -
    #    #      ((0.5 * torch.log(torch.det(S)))) - 0.5 *
    #    #      ((((torch.matmul(torch.matmul(((genFGen2[loopIndex_i:1+loopIndex_i, :]) - mu), torch.inverse(S)),
    #    #                       torch.transpose((genFGen2[loopIndex_i:1+loopIndex_i, :]) - mu, 0, 1))))))))
    #
    #    storeAll += ((((torch.from_numpy(np.array(-np.log(2 * math.pi), dtype="float32")).to(device))) -
    #                  ((0.5 * torch.log(torch.det(S)))) - 0.5 * (
    #                  (((torch.matmul(torch.matmul(((genFGen2[loopIndex_i:1 + loopIndex_i, :]) - mu), torch.inverse(S)),
    #                                  torch.transpose((genFGen2[loopIndex_i:1 + loopIndex_i, :]) - mu, 0, 1)))))))).squeeze()
    #
    #storeAll /= genFGen2.size()[0]
    #print(storeAll)

    #stop = timeit.default_timer()
    #print('Time: ', stop - start)

    #print('')
    #print(torch.exp(((torch.from_numpy(np.array(-np.log(2 * math.pi), dtype="float32")).to(device))) -
    #       ((0.5 * torch.log(torch.det(S)))) - 0.5 *
    #       ((((torch.matmul(torch.matmul(((genFGen2[0:1, :]) - mu), torch.inverse(S)),
    #                        torch.transpose((genFGen2[0:1, :]) - mu, 0, 1))))))))

    #import scipy.stats
    #print(scipy.stats.multivariate_normal.pdf(genFGen2[0:1, :].cpu().detach().numpy(),
    #                                          mean=mu.cpu().detach().numpy(), cov=S.cpu().detach().numpy()))

    #first_term_loss = storeAll
    #asdfasdfadfa

    #print('')
    #import timeit
    #import scipy.stats
    #start = timeit.default_timer()

    storeAll = torch.from_numpy(np.array(0.0, dtype="float32")).to(device)
    for loopIndex_i in range(genFGen2.size()[0]):
        #storeAll += np.log(scipy.stats.multivariate_normal.pdf(genFGen2[loopIndex_i:1+loopIndex_i, :].cpu().detach().numpy(),
        #                                      mean=mu.cpu().detach().numpy(), cov=S.cpu().detach().numpy()))

        toUse_storeAll = torch.distributions.MultivariateNormal(
            loc=mu, covariance_matrix=S)
        storeAll += toUse_storeAll.log_prob(
            genFGen2[loopIndex_i:1 + loopIndex_i, :].squeeze(0))
    storeAll /= genFGen2.size()[0]

    #first_term_loss = storeAll
    #print(storeAll)
    #asdfasdfadfa

    #first_term_loss = storeAll

    #first_term_loss = storeAll
    #first_term_loss = 1.0e16 * torch.exp(storeAll)

    #first_term_loss = storeAll
    first_term_loss = storeAll

    #stop = timeit.default_timer()
    #print('Time: ', stop - start)

    #print('')
    #asdfsdfaa

    xData = toy_data.inf_train_gen(args.data, batch_size=args.batch_size)
    xData = torch.from_numpy(xData).type(torch.float32).to(device)

    #import timeit
    #start = timeit.default_timer()

    #second_term_loss = 0.0
    #second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))

    #second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))
    #second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))

    #second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32'))
    #second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32')).to(device)

    #second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32')).to(device)
    #for i in genFGen2:
    #    for j in xData:
    #        # second_term_loss += np.linalg.norm(i.cpu().detach().numpy()-j.cpu().detach().numpy())
    #
    #        # second_term_loss += torch.norm(i-j, 2)
    #        # second_term_loss += torch.norm(i-j)
    #
    #        # second_term_loss += torch.dist(i.type(torch.float64), j.type(torch.float64), 2)
    #        second_term_loss += torch.dist(i, j, 2)
    #
    #    second_term_loss /= args.batch_size
    #    #second_term_loss *= 0.1
    #
    #second_term_loss /= args.batch_size

    #print('')
    #print(first_term_loss)
    #print(torch.exp(first_term_loss))

    #print('')
    #print(second_term_loss)

    #stop = timeit.default_timer()
    #print('Time: ', stop - start)

    #print('')
    #asdfsdfa

    #second_term_loss = []
    #for i in genFGen2:
    #    second_term_lossTwo = []
    #    for j in xData:
    #        #second_term_lossTwo.append(torch.dist(i, j, 2).unsqueeze(0))
    #        second_term_lossTwo.append(torch.dist(i, j, 2))
    #    second_term_loss.append(second_term_lossTwo)
    #    #second_term_loss.append(.unsqueeze(0))
    #    #second_term_loss.append()
    #second_term_loss = torch.cat(second_term_loss)
    ##second_term_loss = torch.cat(second_term_loss, dim=1)
    ##second_term_loss *= 0.1
    #second_term_loss = torch.mean(second_term_loss)

    #import timeit
    #start = timeit.default_timer()

    var1 = []
    for i in genFGen2:
        for j in xData:
            new_stuff = torch.dist(i, j, 2)  # this is a tensor
            var1.append(new_stuff.unsqueeze(0))
            #var1.append(new_stuff)
    var1_tensor = torch.cat(var1)
    #var1_tensor = torch.cat(var1, dim=1)
    #second_term_loss = torch.mean(var1_tensor)
    #second_term_loss = torch.mean(var1_tensor) / args.batch_size
    #second_term_loss = 10.0 * torch.mean(var1_tensor) / args.batch_size
    #second_term_loss = torch.min(var1_tensor)
    second_term_loss = torch.min(var1_tensor) / args.batch_size
    #second_term_loss = 10.0 * torch.min(var1_tensor) / args.batch_size

    #second_term_loss = torch.from_numpy(np.array(None, dtype='float32')).to(device)
    #for i in genFGen2:
    #    second_term_lossTwo = torch.from_numpy(np.array(None, dtype='float32')).to(device)
    #    for j in xData:
    #        second_term_lossTwo = torch.cat((second_term_lossTwo, torch.dist(i, j, 2)), 0)
    #    second_term_loss = torch.cat((second_term_loss, torch.mean(second_term_lossTwo, 0)), 0)
    #    #second_term_loss *= 0.1
    #second_term_loss = torch.mean(second_term_loss)

    #print('')
    #print(second_term_loss)

    #stop = timeit.default_timer()
    #print('Time: ', stop - start)

    #print('')
    #print(second_term_loss)
    #print(torch.min(var1_tensor) / args.batch_size)

    #print('')
    #asdfdffa

    second_term_loss *= 10000.0
    #second_term_loss *= 1.0e-13
    #second_term_loss *= 1.0e3

    if torch.abs(first_term_loss) / second_term_loss > 200.0:
        second_term_loss /= 2.0

    print('')
    print(first_term_loss)
    print(second_term_loss)

    print('')

    #print('')
    #print(torch.exp(first_term_loss))
    #print(1.0e-13 * second_term_loss/10000.0)

    #print('')
    #asdfdffa

    #third_term_loss = torch.from_numpy(np.array(0.0, dtype='float32')).to(device)
    #for i in range(args.batch_size):
    #    for j in range(args.batch_size):
    #        if i != j:
    #            # third_term_loss += ((np.linalg.norm(genFGen3[i,:].cpu().detach().numpy()-genFGen3[j,:].cpu().detach().numpy())) / (np.linalg.norm(genFGen2[i,:].cpu().detach().numpy()-genFGen2[j,:].cpu().detach().numpy())))
    #
    #            # third_term_loss += ((torch.norm(genFGen3[i,:]-genFGen3[j,:], 2)) / (torch.norm(genFGen2[i,:]-genFGen2[j,:], 2)))
    #            # third_term_loss += ((torch.norm(genFGen3[i,:]-genFGen3[j,:])) / (torch.norm(genFGen2[i,:]-genFGen2[j,:])))
    #
    #            # third_term_loss += ((torch.norm(genFGen3[i,:] - genFGen3[j,:])) / (torch.norm(genFGen2[i,:] - genFGen2[j,:])))
    #            third_term_loss += ((torch.dist(genFGen3[i, :], genFGen3[j, :], 2)) / (torch.dist(genFGen2[i, :], genFGen2[j, :], 2)))
    #    third_term_loss /= (args.batch_size - 1)
    #third_term_loss /= args.batch_size
    #third_term_loss *= 100.0

    #print(third_term_loss)
    #print('')

    #asdfsfaa

    #return first_term_loss + second_term_loss + third_term_loss
    return first_term_loss + second_term_loss
                if test_loss.item() < best_loss:
                    best_loss = test_loss.item()
                    utils.makedirs(args.save)
                    torch.save({
                        'args': args,
                        'state_dict': model.state_dict(),
                    }, os.path.join(args.save, 'checkpt.pth'))
                model.train()

        #if itr == 1 or itr % args.viz_freq == 0:
        if itr % args.viz_freq == 0:
            with torch.no_grad():
                model.eval()

                p_samples = toy_data.inf_train_gen(args.data, batch_size=20000)
                sample_fn, density_fn = model.inverse, model.forward

                plt.figure(figsize=(9, 3))
                visualize_transform(p_samples, torch.randn, standard_normal_logprob, transform=sample_fn,
                                    inverse_transform=density_fn, samples=True, npts=400, device=device)

                fig_filename = os.path.join(args.save, 'figs', '{:04d}.jpg'.format(itr))
                print('')

                print(fig_filename)
                print('')

                utils.makedirs(os.path.dirname(fig_filename))
                plt.savefig(fig_filename)
                    log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe)
                    logger.info(log_message)

                    if test_loss.item() < best_loss:
                        best_loss = test_loss.item()
                        utils.makedirs(args.save)
                        torch.save({
                            'args': args,
                            'state_dict': model.state_dict(),
                        }, os.path.join(args.save, 'checkpt.pth'))
                    model.train()

            if itr % args.viz_freq == 0:
                with torch.no_grad():
                    model.eval()
                    p_samples = toy_data.inf_train_gen(args.data, rng_seed=420000, batch_size=2000)

                    sample_fn, density_fn = get_transforms(model)

                    plt.figure(figsize=(9, 3))
                    visualize_transform(
                        p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn,
                        samples=True, npts=800, device=device
                    ) 
                    fig_filename = os.path.join(args.save, 'figs', '{:04d}.pdf'.format(itr))
                    utils.makedirs(os.path.dirname(fig_filename))
                    plt.savefig(fig_filename)
                    plt.close()
                    model.train()

            end = time.time()
Ejemplo n.º 16
0
def loss_fn2(genFGen2, args, model):
    first_term_loss = compute_loss2(genFGen2, args, model)
    #first_term_loss2 = compute_loss2(genFGen2, args, model)
    #first_term_loss = torch.log(first_term_loss2 / (1.0 - first_term_loss2))

    #print('')
    #print(first_term_loss)

    #mu = torch.from_numpy(np.array([2.805741, -0.00889241], dtype="float32")).to(device)
    #S = torch.from_numpy(np.array([[pow(0.3442525,2), 0.0], [0.0, pow(0.35358343,2)]], dtype="float32")).to(device)

    #storeAll = torch.from_numpy(np.array(0.0, dtype="float32")).to(device)
    #toUse_storeAll = torch.distributions.MultivariateNormal(loc=mu, covariance_matrix=S)
    #for loopIndex_i in range(genFGen2.size()[0]):
    #    storeAll += torch.exp(toUse_storeAll.log_prob(genFGen2[loopIndex_i:1 + loopIndex_i, :].squeeze(0)))
    #storeAll /= genFGen2.size()[0]

    #print(storeAll)
    #print('')

    #print('')
    #print(compute_loss2(mu.unsqueeze(0), args, model))

    #print(torch.exp(toUse_storeAll.log_prob(mu)))
    #print('')

    #first_term_loss = storeAll

    xData = toy_data.inf_train_gen(args.data, batch_size=args.batch_size)
    xData = torch.from_numpy(xData).type(torch.float32).to(device)

    #var2 = []
    #for i in genFGen2:
    #    var1 = []
    #    for j in xData:
    #        new_stuff = torch.dist(i, j, 2)  # this is a tensor
    #        var1.append(new_stuff.unsqueeze(0))
    #    var1_tensor = torch.cat(var1)
    #    second_term_loss2 = torch.min(var1_tensor) / args.batch_size
    #    var2.append(second_term_loss2.unsqueeze(0))
    #var2_tensor = torch.cat(var2)
    #second_term_loss = torch.mean(var2_tensor) / args.batch_size
    #second_term_loss *= 100.0

    #print('')
    #print(second_term_loss)

    # If you know in advance the size of the final tensor, you can allocate
    # an empty tensor beforehand and fill it in the for loop.

    #x = torch.empty(size=(len(items), 768))
    #for i in range(len(items)):
    #    x[i] = calc_result

    #print(len(genFGen2))
    #print(genFGen2.shape[0])
    # len(.) and not .shape[0]

    #print(len(xData))
    #print(xData.shape[0])
    # Use len(.) and not .shape[0]

    #second_term_loss = torch.empty(size=(len(genFGen2), len(xData))).to(device)
    #second_term_loss = torch.empty(size=(len(genFGen2), len(xData)), device=device, requires_grad=True)
    #second_term_loss3 = torch.empty(size=(len(genFGen2), len(xData)), device=device, requires_grad=True)
    second_term_loss3 = torch.empty(size=(len(genFGen2), len(xData)),
                                    device=device,
                                    requires_grad=False)
    for i in range(len(genFGen2)):
        for j in range(len(xData)):
            #second_term_loss[i, j] = torch.dist(genFGen2[i,:], xData[j,:], 2)
            #second_term_loss[i, j] = torch.dist(genFGen2[i, :], xData[j, :], 1)
            #second_term_loss3[i, j] = torch.dist(genFGen2[i, :], xData[j, :], 1)

            #second_term_loss3[i, j] = torch.dist(genFGen2[i, :], xData[j, :], 1)
            #second_term_loss3[i, j] = torch.dist(genFGen2[i, :], xData[j, :], 1)

            #second_term_loss3[i, j] = torch.dist(genFGen2[i, :], xData[j, :], 1)
            #second_term_loss3[i, j] = torch.tensor(0.1, requires_grad=True)

            #second_term_loss3[i, j] = torch.dist(genFGen2[i, :], xData[j, :], 1)
            #second_term_loss3[i, j] = torch.dist(genFGen2[i, :], xData[j, :], 1).requires_grad_()

            #second_term_loss3[i, j] = torch.dist(genFGen2[i, :], xData[j, :], 1).requires_grad_()
            second_term_loss3[i, j] = torch.dist(genFGen2[i, :], xData[j, :],
                                                 2).requires_grad_()**2

            #second_term_loss[i, j] = torch.dist(genFGen2[i, :], xData[j, :], 2)**2
    #second_term_loss2, _ = torch.min(second_term_loss, 1)
    second_term_loss2, _ = torch.min(second_term_loss3, 1)
    second_term_loss = 1000.0 * torch.mean(second_term_loss2) / (
        args.batch_size**2)

    #print(second_term_loss)
    #print('')

    print('')
    print(first_term_loss)
    print(second_term_loss)

    #third_term_loss = torch.from_numpy(np.array(0.0, dtype='float32')).to(device)
    #for i in range(args.batch_size):
    #    for j in range(args.batch_size):
    #        if i != j:
    #            # third_term_loss += ((np.linalg.norm(genFGen3[i,:].cpu().detach().numpy()-genFGen3[j,:].cpu().detach().numpy())) / (np.linalg.norm(genFGen2[i,:].cpu().detach().numpy()-genFGen2[j,:].cpu().detach().numpy())))
    #
    #            # third_term_loss += ((torch.norm(genFGen3[i,:]-genFGen3[j,:], 2)) / (torch.norm(genFGen2[i,:]-genFGen2[j,:], 2)))
    #            # third_term_loss += ((torch.norm(genFGen3[i,:]-genFGen3[j,:])) / (torch.norm(genFGen2[i,:]-genFGen2[j,:])))
    #
    #            # third_term_loss += ((torch.norm(genFGen3[i,:] - genFGen3[j,:])) / (torch.norm(genFGen2[i,:] - genFGen2[j,:])))
    #            third_term_loss += ((torch.dist(genFGen3[i, :], genFGen3[j, :], 2)) / (torch.dist(genFGen2[i, :], genFGen2[j, :], 2)))
    #    third_term_loss /= (args.batch_size - 1)
    #third_term_loss /= args.batch_size
    ##third_term_loss *= 1000.0

    #print(third_term_loss)
    #print('')

    print('')
    #asdfsfa

    #return first_term_loss + second_term_loss + third_term_loss
    #return first_term_loss + second_term_loss

    #return second_term_loss
    return first_term_loss + second_term_loss
Ejemplo n.º 17
0
    x = torch.from_numpy(x).type(torch.float32).to(device)
    zero = torch.zeros(x.shape[0], 1).to(x)

    # transform to z
    z, delta_logp = model(x, zero)

    # compute log q(z)
    logpz = standard_normal_logprob(z).sum(1, keepdim=True)

    logpx = logpz - delta_logp
    loss = -torch.mean(logpx)
    return loss


if __name__ == '__main__':
    x = toy_data.inf_train_gen(args.data, batch_size=args.batch_size)
    print(x.shape)
    plt.hist(x, bins=500)
    plt.savefig('testwill.png')

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    # model = build_model_tabular(args, 200, regularization_fns).to(device)
    model = build_model_tabular(args, 1, regularization_fns).to(device)
    if args.spectral_norm: add_spectral_norm(model)
    set_cnf_options(args, model)

    # logger.info(model)
    logger.info("Number of trainable parameters: {}".format(
        count_parameters(model)))

    optimizer = optim.Adam(model.parameters(),
Ejemplo n.º 18
0
    alph    = checkpt['args'].alph
    nTh     = checkpt['args'].nTh
    d       = checkpt['state_dict']['A'].size(1) - 1
    net     = Phi(nTh=nTh, m=m, d=d, alph=alph)
    prec    = checkpt['state_dict']['A'].dtype
    net     = net.to(prec)
    net.load_state_dict(checkpt['state_dict'])
    net     = net.to(device)

    args.data = checkpt['args'].data

    torch.set_default_dtype(prec)
    cvt = lambda x: x.type(prec).to(device, non_blocking=True)

    nSamples = args.batch_size
    p_samples = cvt(torch.Tensor(toy_data.inf_train_gen(args.data, batch_size=nSamples)))
    y         = cvt(torch.randn(nSamples,d))

    net.eval()
    with torch.no_grad():

        test_loss, test_cs = compute_loss(net, p_samples, args.nt)

        # sample_fn, density_fn = get_transforms(model)
        modelFx     = integrate(p_samples[:, 0:d], net, [0.0, 1.0], args.nt, stepper="rk4", alph=net.alph)
        modelFinvfx = integrate(modelFx[:, 0:d]  , net, [1.0, 0.0], args.nt, stepper="rk4", alph=net.alph)
        modelGen    = integrate(y[:, 0:d]        , net, [1.0, 0.0], args.nt, stepper="rk4", alph=net.alph)

        print("          {:9s}  {:9s}  {:11s}  {:9s}".format( "loss", "L (L_2)", "C (loss)", "R (HJB)"))
        print("[TEST]:   {:9.3e}  {:9.3e}  {:11.5e}  {:9.3e}".format(test_loss, test_cs[0], test_cs[1], test_cs[2]))
Ejemplo n.º 19
0
def summary_plots(x, x_test, folder, epoch, model, ll_tot, ll_test):
    fig = plt.figure()
    ax = plt.subplot(1, 3, 1, aspect="equal")
    vf.plt_flow(model.compute_ll, ax)
    ax = plt.subplot(1, 3, 2, aspect="equal")
    vf.plt_samples(toy_data.inf_train_gen(toy, batch_size=1000000),
                   ax,
                   npts=200)
    ax = plt.subplot(1, 3, 3, aspect="equal")
    samples = model.invert(
        torch.distributions.Normal(0., 1.).sample([1000, 2]), 8, "Binary")
    vf.plt_samples(samples.detach().numpy(), ax, title="$x\sim q(x)$")
    plt.savefig("%s/flow_%d.png" % (folder + toy, epoch))
    plt.close(fig)
    fig = plt.figure()

    z = torch.distributions.Normal(0., 1.).sample(x_test.shape)
    plt.subplot(222)
    plt.title("z")
    plt.scatter(z[:, 0], z[:, 1], alpha=.2)
    x_min = z.min(0)[0] - .5
    x_max = z.max(0)[0] + .5
    ticks = [1, 1]
    plt.xticks(np.arange(int(x_min[0]), int(x_max[0]), ticks[0]),
               np.arange(int(x_min[0]), int(x_max[0]), ticks[0]))
    plt.yticks(np.arange(int(x_min[1]), int(x_max[1]), ticks[1]),
               np.arange(int(x_min[1]), int(x_max[1]), ticks[1]))

    z_pred = model.forward(x_test)
    z_pred = z_pred.detach().cpu().numpy()
    plt.subplot(221)
    plt.title("z pred")
    plt.scatter(z_pred[:, 0], z_pred[:, 1], alpha=.2)
    plt.xticks(np.arange(int(x_min[0]), int(x_max[0]), ticks[0]),
               np.arange(int(x_min[0]), int(x_max[0]), ticks[0]))
    plt.yticks(np.arange(int(x_min[1]), int(x_max[1]), ticks[1]),
               np.arange(int(x_min[1]), int(x_max[1]), ticks[1]))

    start = timer()
    z = torch.distributions.Normal(0., 1.).sample((1000, 2))
    x_pred = model.invert(z, 5, "ParallelSimpler")
    end = timer()
    print("Inversion time: {:4f}s".format(end - start))
    plt.subplot(223)
    plt.title("x pred")
    x_pred = x_pred.detach().cpu().numpy()
    plt.scatter(x_pred[:, 0], x_pred[:, 1], alpha=.2)
    x_min = x.min(0)[0] - .5
    x_max = x.max(0)[0] + .5
    ticks = [1, 1]
    plt.xticks(np.arange(int(x_min[0]), int(x_max[0]), ticks[0]),
               np.arange(int(x_min[0]), int(x_max[0]), ticks[0]))
    plt.yticks(np.arange(int(x_min[1]), int(x_max[1]), ticks[1]),
               np.arange(int(x_min[1]), int(x_max[1]), ticks[1]))

    plt.subplot(224)
    plt.title("x")
    plt.scatter(x[:, 0], x[:, 1], alpha=.2)
    plt.xticks(np.arange(int(x_min[0]), int(x_max[0]), ticks[0]),
               np.arange(int(x_min[0]), int(x_max[0]), ticks[0]))
    plt.yticks(np.arange(int(x_min[1]), int(x_max[1]), ticks[1]),
               np.arange(int(x_min[1]), int(x_max[1]), ticks[1]))

    plt.suptitle(
        str(("epoch: ", epoch, "Train loss: ", ll_tot.item(), "Test loss: ",
             ll_test.item())))
    plt.savefig("%s/%d.png" % (folder + toy, epoch))
    plt.close(fig)
Ejemplo n.º 20
0
                log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe)
                logger.info(log_message)

                if test_loss.item() < best_loss:
                    best_loss = test_loss.item()
                    utils.makedirs(args.save)
                    torch.save({
                        'args': args,
                        'state_dict': model.state_dict(),
                    }, os.path.join(args.save, 'checkpt.pth'))
                model.train()

        if itr % args.viz_freq == 0:
            with torch.no_grad():
                model.eval()
                p_samples = toy_data.inf_train_gen(args.data, batch_size=2000)

                sample_fn, density_fn = get_transforms(model)

                plt.figure(figsize=(9, 3))
                visualize_transform(
                    p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn,
                    samples=True, npts=800, device=device
                )
                fig_filename = os.path.join(args.save, 'figs', '{:04d}.pdf'.format(itr))
                utils.makedirs(os.path.dirname(fig_filename))
                plt.savefig(fig_filename)
                plt.close()
                model.train()

        end = time.time()
Ejemplo n.º 21
0
def loss_fn2(genFGen2, args, model):
    first_term_loss = compute_loss2(genFGen2, args, model)

    #print('')
    #print(first_term_loss)

    #import math
    #mu = torch.from_numpy(np.array([2.805741, -0.00889241], dtype="float32")).to(device)
    #S = torch.from_numpy(np.array([[pow(0.3442525,2), 0.0], [0.0, pow(0.35358343,2)]], dtype="float32")).to(device)

    #storeAll = torch.from_numpy(np.array(0.0, dtype="float32")).to(device)
    #toUse_storeAll = torch.distributions.MultivariateNormal(loc=mu, covariance_matrix=S)
    #for loopIndex_i in range(genFGen2.size()[0]):
    #    storeAll += torch.exp(toUse_storeAll.log_prob(genFGen2[loopIndex_i:1 + loopIndex_i, :].squeeze(0)))
    #storeAll /= genFGen2.size()[0]

    #print(storeAll)
    #print('')

    #print('')
    #print(compute_loss2(mu.unsqueeze(0), args, model))

    #print(torch.exp(toUse_storeAll.log_prob(mu)))
    #print('')

    #first_term_loss = storeAll

    xData = toy_data.inf_train_gen(args.data, batch_size=args.batch_size)
    xData = torch.from_numpy(xData).type(torch.float32).to(device)

    var2 = []
    for i in genFGen2:
        var1 = []
        for j in xData:
            new_stuff = torch.dist(i, j, 2)  # this is a tensor
            var1.append(new_stuff.unsqueeze(0))
        var1_tensor = torch.cat(var1)
        second_term_loss2 = torch.min(var1_tensor) / args.batch_size
        var2.append(second_term_loss2.unsqueeze(0))
    var2_tensor = torch.cat(var2)
    second_term_loss = torch.mean(var2_tensor) / args.batch_size
    second_term_loss *= 100.0

    print('')
    print(first_term_loss)
    print(second_term_loss)

    #third_term_loss = torch.from_numpy(np.array(0.0, dtype='float32')).to(device)
    #for i in range(args.batch_size):
    #    for j in range(args.batch_size):
    #        if i != j:
    #            # third_term_loss += ((np.linalg.norm(genFGen3[i,:].cpu().detach().numpy()-genFGen3[j,:].cpu().detach().numpy())) / (np.linalg.norm(genFGen2[i,:].cpu().detach().numpy()-genFGen2[j,:].cpu().detach().numpy())))
    #
    #            # third_term_loss += ((torch.norm(genFGen3[i,:]-genFGen3[j,:], 2)) / (torch.norm(genFGen2[i,:]-genFGen2[j,:], 2)))
    #            # third_term_loss += ((torch.norm(genFGen3[i,:]-genFGen3[j,:])) / (torch.norm(genFGen2[i,:]-genFGen2[j,:])))
    #
    #            # third_term_loss += ((torch.norm(genFGen3[i,:] - genFGen3[j,:])) / (torch.norm(genFGen2[i,:] - genFGen2[j,:])))
    #            third_term_loss += ((torch.dist(genFGen3[i, :], genFGen3[j, :], 2)) / (torch.dist(genFGen2[i, :], genFGen2[j, :], 2)))
    #    third_term_loss /= (args.batch_size - 1)
    #third_term_loss /= args.batch_size
    ##third_term_loss *= 1000.0

    #print(third_term_loss)
    #print('')

    print('')
    #asdfsfa

    #return first_term_loss + second_term_loss + third_term_loss
    return first_term_loss + second_term_loss
Ejemplo n.º 22
0
        model = construct_discrete_model().to(device)
        model.load_state_dict(torch.load(args.checkpt)['state_dict'])
    else:
        model = build_model_tabular(args, 2).to(device)

        sd = torch.load(args.checkpt)['state_dict']
        fixed_sd = {}
        for k, v in sd.items():
            fixed_sd[k.replace('odefunc.odefunc', 'odefunc')] = v
        model.load_state_dict(fixed_sd)

    print(model)
    print("Number of trainable parameters: {}".format(count_parameters(model)))

    model.eval()
    p_samples = toy_data.inf_train_gen(args.data, batch_size=800**2)

    with torch.no_grad():
        sample_fn, density_fn = get_transforms(model)

        plt.figure(figsize=(10, 10))
        ax = ax = plt.gca()
        viz_flow.plt_samples(p_samples, ax, npts=800)
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
        fig_filename = os.path.join(args.save, 'figs', 'true_samples.jpg')
        utils.makedirs(os.path.dirname(fig_filename))
        plt.savefig(fig_filename)
        plt.close()

        plt.figure(figsize=(10, 10))
        ax = ax = plt.gca()
Ejemplo n.º 23
0
def train_toy(toy,
              load=True,
              nb_step_dual=300,
              nb_steps=15,
              folder="",
              l1=1.,
              nb_epoch=20000,
              pre_heating_epochs=10,
              nb_flow=3,
              cond_type="Coupling",
              emb_net=[150, 150, 150]):
    logger = utils.get_logger(logpath=os.path.join(folder, toy, 'logs'),
                              filepath=os.path.abspath(__file__))

    logger.info("Creating model...")

    device = "cpu" if not (torch.cuda.is_available()) else "cuda:0"

    nb_samp = 100
    batch_size = 100

    x_test = torch.tensor(toy_data.inf_train_gen(toy,
                                                 batch_size=1000)).to(device)
    x = torch.tensor(toy_data.inf_train_gen(toy, batch_size=1000)).to(device)

    dim = x.shape[1]

    norm_type = "Affine"
    save_name = norm_type + str(emb_net) + str(nb_flow)
    solver = "CCParallel"
    int_net = [150, 150, 150]

    conditioner_type = cond_types[cond_type]
    conditioner_args = {
        "in_size": dim,
        "hidden": emb_net[:-1],
        "out_size": emb_net[-1]
    }
    if conditioner_type is DAGConditioner:
        conditioner_args['l1'] = l1
        conditioner_args['gumble_T'] = .5
        conditioner_args['nb_epoch_update'] = nb_step_dual
        conditioner_args["hot_encoding"] = True
    normalizer_type = norm_types[norm_type]
    if normalizer_type is MonotonicNormalizer:
        normalizer_args = {
            "integrand_net": int_net,
            "cond_size": emb_net[-1],
            "nb_steps": nb_steps,
            "solver": solver
        }
    else:
        normalizer_args = {}

    model = buildFCNormalizingFlow(nb_flow, conditioner_type, conditioner_args,
                                   normalizer_type, normalizer_args)

    opt = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)

    if load:
        logger.info("Loading model...")
        model.load_state_dict(
            torch.load(folder + toy + '/' + save_name + 'model.pt'))
        model.train()
        opt.load_state_dict(
            torch.load(folder + toy + '/' + save_name + 'ADAM.pt'))
        logger.info("Model loaded.")

    if True:
        for step in model.steps:
            step.conditioner.stoch_gate = True
            step.conditioner.noise_gate = False
            step.conditioner.gumble_T = .5
    torch.autograd.set_detect_anomaly(True)
    for epoch in range(nb_epoch):
        loss_tot = 0
        start = timer()
        for j in range(0, nb_samp, batch_size):
            cur_x = torch.tensor(
                toy_data.inf_train_gen(toy, batch_size=batch_size)).to(device)
            z, jac = model(cur_x)
            loss = model.loss(z, jac)
            loss_tot += loss.detach()
            if math.isnan(loss.item()):
                ll, z = model.compute_ll(cur_x)
                print(ll)
                print(z)
                print(ll.max(), z.max())
                exit()
            opt.zero_grad()
            loss.backward(retain_graph=True)
            opt.step()
        model.step(epoch, loss_tot)

        end = timer()
        z, jac = model(x_test)
        ll = (model.z_log_density(z) + jac)
        ll_test = -ll.mean()
        dagness = max(model.DAGness())
        logger.info(
            "epoch: {:d} - Train loss: {:4f} - Test loss: {:4f} - <<DAGness>>: {:4f} - Elapsed time per epoch {:4f} (seconds)"
            .format(epoch, loss_tot.item(), ll_test.item(), dagness,
                    end - start))

        if epoch % 100 == 0 and False:
            with torch.no_grad():
                stoch_gate = model.getDag().stoch_gate
                noise_gate = model.getDag().noise_gate
                s_thresh = model.getDag().s_thresh
                model.getDag().stoch_gate = False
                model.getDag().noise_gate = False
                model.getDag().s_thresh = True
                for threshold in [.95, .1, .01, .0001, 1e-8]:
                    model.set_h_threshold(threshold)
                    # Valid loop
                    z, jac = model(x_test)
                    ll = (model.z_log_density(z) + jac)
                    ll_test = -ll.mean().item()
                    dagness = max(model.DAGness()).item()
                    logger.info(
                        "epoch: {:d} - Threshold: {:4f} - Valid log-likelihood: {:4f} - <<DAGness>>: {:4f}"
                        .format(epoch, threshold, ll_test, dagness))
                model.getDag().stoch_gate = stoch_gate
                model.getDag().noise_gate = noise_gate
                model.getDag().s_thresh = s_thresh
                model.set_h_threshold(0.)

        if epoch % 500 == 0:
            font = {'family': 'normal', 'weight': 'normal', 'size': 25}

            matplotlib.rc('font', **font)
            if toy in [
                    "2spirals-8gaussians", "4-2spirals-8gaussians",
                    "8-2spirals-8gaussians", "2gaussians", "4gaussians",
                    "2igaussians", "8gaussians"
            ] or True:

                def compute_ll(x):
                    z, jac = model(x)
                    ll = (model.z_log_density(z) + jac)
                    return ll, z

                with torch.no_grad():
                    npts = 100
                    plt.figure(figsize=(12, 12))
                    gs = gridspec.GridSpec(2,
                                           2,
                                           width_ratios=[3, 1],
                                           height_ratios=[3, 1])
                    ax = plt.subplot(gs[0])
                    qz_1, qz_2 = vf.plt_flow(compute_ll,
                                             ax,
                                             npts=npts,
                                             device=device)
                    plt.subplot(gs[1])
                    plt.plot(qz_1, np.linspace(-4, 4, npts))
                    plt.ylabel('$x_2$', fontsize=25, rotation=-90, labelpad=20)

                    plt.xticks([])
                    plt.subplot(gs[2])
                    plt.plot(np.linspace(-4, 4, npts), qz_2)
                    plt.xlabel('$x_1$', fontsize=25)
                    plt.yticks([])
                    plt.savefig("%s%s/flow_%s_%d.pdf" %
                                (folder, toy, save_name, epoch))
                    torch.save(model.state_dict(),
                               folder + toy + '/' + save_name + 'model.pt')
                    torch.save(opt.state_dict(),
                               folder + toy + '/' + save_name + 'ADAM.pt')
def loss_fn2(genFGen2, args, model):
    #print(list(model.parameters()))

    #logger.info(model)
    #logger.info("Number of trainable parameters: {}".format(count_parameters(model)))

    #with torch.no_grad():
    #    model.eval()
    #
    #    p_samples = toy_data.inf_train_gen(args.data, batch_size=20000)
    #    sample_fn, density_fn = model.inverse, model.forward
    #
    #    plt.figure(figsize=(9, 3))
    #    visualize_transform(p_samples, torch.randn, standard_normal_logprob, transform=sample_fn,
    #                        inverse_transform=density_fn, samples=True, npts=400, device=device)
    #    plt.show()
    #    plt.close()
    #
    #    plt.figure(figsize=(4, 2))
    #    visualize_transform2(p_samples, torch.randn, standard_normal_logprob, transform=sample_fn,
    #                        inverse_transform=density_fn, samples=True, npts=400, device=device)
    #    plt.show()
    #    plt.close()

    #xData = toy_data.inf_train_gen(args.data, batch_size=args.batch_size)
    #xData = torch.from_numpy(xData).type(torch.float32).to(device)

    #second_term_loss = torch.from_numpy(np.array(0.0, dtype='float32')).to(device)
    #for i in genFGen2:
    #    for j in xData:
    #        # second_term_loss += np.linalg.norm(i.cpu().detach().numpy()-j.cpu().detach().numpy())
    #
    #        # second_term_loss += torch.norm(i-j, 2)
    #        # second_term_loss += torch.norm(i-j)
    #
    #        # second_term_loss += torch.dist(i.type(torch.float64), j.type(torch.float64), 2)
    #        second_term_loss += torch.dist(i, j, 2)
    #    second_term_loss /= args.batch_size
    #    #second_term_loss *= 0.1
    #second_term_loss /= args.batch_size

    #print('')
    #print(second_term_loss)

    #second_term_loss = torch.from_numpy(np.array([], dtype='float32')).to(device)
    #for i in genFGen2:
    #    second_term_lossTwo = torch.from_numpy(np.array([], dtype='float32')).to(device)
    #    for j in xData:
    #        second_term_lossTwo = torch.cat((second_term_lossTwo, torch.dist(i, j, 2)), 0)
    #    second_term_loss = torch.cat((second_term_loss, torch.mean(second_term_lossTwo, 0)), 0)
    #    #second_term_loss *= 0.1
    #second_term_loss = torch.mean(second_term_loss)

    #print('')
    #print(second_term_loss)

    #genFGen3 = torch.randn((args.batch_size, 2)).to(device)
    #first_term_loss = compute_loss2(genFGen2, args, model, beta=beta)

    #print('')
    #print(first_term_loss)

    import math
    mu = torch.from_numpy(np.array([2.805741, -0.00889241], dtype="float32")).to(device)
    S = torch.from_numpy(np.array([[pow(0.3442525,2), 0.0], [0.0, pow(0.35358343,2)]], dtype="float32")).to(device)

    storeAll = torch.from_numpy(np.array(0.0, dtype="float32")).to(device)
    for loopIndex_i in range(genFGen2.size()[0]):
        toUse_storeAll = torch.distributions.MultivariateNormal(loc=mu, covariance_matrix=S)
        storeAll += toUse_storeAll.log_prob(genFGen2[loopIndex_i:1+loopIndex_i, :].squeeze(0))
    storeAll /= genFGen2.size()[0]

    #print(storeAll)
    #print('')

    first_term_loss = storeAll

    xData = toy_data.inf_train_gen(args.data, batch_size=args.batch_size)
    xData = torch.from_numpy(xData).type(torch.float32).to(device)

    var2 = []
    for i in genFGen2:
        var1 = []
        for j in xData:
            new_stuff = torch.dist(i, j, 2)  # this is a tensor
            var1.append(new_stuff.unsqueeze(0))
        var1_tensor = torch.cat(var1)
        second_term_loss2 = torch.min(var1_tensor) / args.batch_size
        var2.append(second_term_loss2.unsqueeze(0))
    var2_tensor = torch.cat(var2)
    second_term_loss = torch.mean(var2_tensor) / args.batch_size

    first_term_loss = torch.exp(first_term_loss)

    print('')
    print(first_term_loss)
    print(second_term_loss)

    #third_term_loss = torch.from_numpy(np.array(0.0, dtype='float32')).to(device)
    #for i in range(args.batch_size):
    #    for j in range(args.batch_size):
    #        if i != j:
    #            # third_term_loss += ((np.linalg.norm(genFGen3[i,:].cpu().detach().numpy()-genFGen3[j,:].cpu().detach().numpy())) / (np.linalg.norm(genFGen2[i,:].cpu().detach().numpy()-genFGen2[j,:].cpu().detach().numpy())))
    #
    #            # third_term_loss += ((torch.norm(genFGen3[i,:]-genFGen3[j,:], 2)) / (torch.norm(genFGen2[i,:]-genFGen2[j,:], 2)))
    #            # third_term_loss += ((torch.norm(genFGen3[i,:]-genFGen3[j,:])) / (torch.norm(genFGen2[i,:]-genFGen2[j,:])))
    #
    #            # third_term_loss += ((torch.norm(genFGen3[i,:] - genFGen3[j,:])) / (torch.norm(genFGen2[i,:] - genFGen2[j,:])))
    #            third_term_loss += ((torch.dist(genFGen3[i, :], genFGen3[j, :], 2)) / (torch.dist(genFGen2[i, :], genFGen2[j, :], 2)))
    #    third_term_loss /= (args.batch_size - 1)
    #third_term_loss /= args.batch_size
    ##third_term_loss *= 1000.0

    #print(third_term_loss)
    #print('')

    print('')
    #asdfsfa

    #return first_term_loss + second_term_loss + third_term_loss
    return first_term_loss + second_term_loss
Ejemplo n.º 25
0
    logger.info("-------------------------")
    logger.info(str(optim))  # optimizer info
    logger.info("data={:} batch_size={:} gpu={:}".format(
        args.data, args.batch_size, args.gpu))
    logger.info("maxIters={:} val_freq={:} viz_freq={:}".format(
        args.niters, args.val_freq, args.viz_freq))
    logger.info("saveLocation = {:}".format(args.save))
    logger.info("-------------------------\n")

    end = time.time()
    best_loss = float('inf')
    bestParams = None

    # setup data [nSamples, d]
    # use one batch as the entire data set
    x0 = toy_data.inf_train_gen(args.data, batch_size=args.batch_size)
    x0 = cvt(torch.from_numpy(x0))

    x0val = toy_data.inf_train_gen(args.data, batch_size=args.val_batch_size)
    x0val = cvt(torch.from_numpy(x0val))

    log_msg = (
        '{:5s}  {:6s}   {:9s}  {:9s}  {:9s}  {:9s}      {:9s}  {:9s}  {:9s}  {:9s}  '
        .format('iter', ' time', 'loss', 'L (L_2)', 'C (loss)', 'R (HJB)',
                'valLoss', 'valL', 'valC', 'valR'))
    logger.info(log_msg)

    time_meter = utils.AverageMeter()

    net.train()
    for itr in range(1, args.niters + 1):