Ejemplo n.º 1
0
 def reparameterize(self, mu, var):
     if self.training:
         std = var.mul(0.5).exp_()
         eps = utils.var_or_cuda((std.data.new(std.size()).normal_()))
         z = eps.mul(std).add_(mu)
         return z
     else:
         return mu
Ejemplo n.º 2
0
    def forward(self, images):
        means = utils.var_or_cuda(
            torch.zeros(self.args.num_views, self.args.batch_size, 200))
        vars = utils.var_or_cuda(
            torch.zeros(self.args.num_views, self.args.batch_size, 200))
        zs = utils.var_or_cuda(
            torch.zeros(self.args.num_views, self.args.batch_size, 200))
        for i, image in enumerate(images):
            image = utils.var_or_cuda(image)
            z_mean, z_log_var = self.single_image_forward(image)
            zs[i:] = self.reparameterize(z_mean, z_log_var)
            means[i:] = z_mean
            vars[i:] = z_log_var

        #z_mu= self.combine(means)
        #z_var = self.combine(vars)
        return self.combine(zs), means, vars
def test_3DVAEGAN(args):
    # datset define
    dsets_path = args.input_dir + args.data_dir + "test/"
    print(dsets_path)
    dsets = ShapeNetPlusImageDataset(dsets_path, args)
    dset_loaders = torch.utils.data.DataLoader(dsets,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)

    # model define
    E = _E(args)
    G = _G(args)

    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)
    E_solver = optim.Adam(E.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.is_available():
        print("using cuda")
        G.cuda()
        E.cuda()

    pickle_path = "." + args.pickle_dir + '3DVAEGAN'
    read_pickle(pickle_path, G, G_solver, G, G_solver, E, E_solver)
    recon_loss_total = 0
    for i, (image, model_3d) in enumerate(dset_loaders):

        X = var_or_cuda(model_3d)
        image = var_or_cuda(image)

        z_mu, z_var = E(image)
        Z_vae = E.reparameterize(z_mu, z_var)
        G_vae = G(Z_vae)

        recon_loss = torch.sum(torch.pow((G_vae - X), 2), dim=(1, 2, 3))
        print(recon_loss.size())
        print("RECON LOSS ITER: ", i, " - ", torch.mean(recon_loss))
        recon_loss_total += (recon_loss)
        samples = G_vae.cpu().data[:8].squeeze().numpy()

        image_path = args.output_dir + args.image_dir + '3DVAEGAN_test'
        if not os.path.exists(image_path):
            os.makedirs(image_path)

        SavePloat_Voxels(samples, image_path, i)
Ejemplo n.º 4
0
def test_3DGAN(args):
    # datset define
    dsets_path = args.input_dir + args.data_dir + "test/"
    print(dsets_path)
    dsets = ShapeNetDataset(dsets_path, args)
    dset_loaders = torch.utils.data.DataLoader(dsets,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)

    # model define
    D = _D(args)
    G = _G(args)

    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    pickle_path = "." + args.pickle_dir + '3DVAEGAN_MULTIVIEW_MAX'
    read_pickle(pickle_path, G, G_solver, D, D_solver)
    recon_loss_total = 0
    for i, X in enumerate(dset_loaders):
        #X = X.view(-1, 1, args.cube_len, args.cube_len, args.cube_len)
        X = var_or_cuda(X)
        print(X.size())
        Z = generateZ(args)
        print(Z.size())
        fake = G(Z).squeeze()
        print(fake.size())
        recon_loss = torch.sum(torch.pow((fake - X), 2), dim=(1, 2, 3))
        print(recon_loss.size())
        print("RECON LOSS ITER: ", i, " - ", torch.mean(recon_loss))
        recon_loss_total += (recon_loss)
        samples = fake.cpu().data[:8].squeeze().numpy()

        image_path = args.output_dir + args.image_dir + '3DVAEGAN_MULTIVIEW_MAX_test'
        if not os.path.exists(image_path):
            os.makedirs(image_path)

        SavePloat_Voxels(samples, image_path, i)
Ejemplo n.º 5
0
def train(args):

    hyparam_list = [
        ("model", args.model_name),
        ("cube", args.cube_len),
        ("bs", args.batch_size),
        ("g_lr", args.g_lr),
        ("d_lr", args.d_lr),
        ("z", args.z_dis),
        ("bias", args.bias),
        ("sl", args.soft_label),
    ]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    print(log_param)

    # for using tensorboard
    if args.use_tensorboard:
        import tensorflow as tf

        summary_writer = tf.summary.FileWriter(args.output_dir + args.log_dir +
                                               log_param)

        def inject_summary(summary_writer, tag, value, step):
            summary = tf.Summary(
                value=[tf.Summary.Value(tag=tag, simple_value=value)])
            summary_writer.add_summary(summary, global_step=step)

        inject_summary = inject_summary

    # datset define
    dsets_path = args.input_dir + args.data_dir + "train/"
    print(dsets_path)
    dsets = ShapeNetDataset(dsets_path, args)
    dset_loaders = torch.utils.data.DataLoader(dsets,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)

    # model define
    D = _D(args)
    G = _G(args)

    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if args.lrsh:
        D_scheduler = MultiStepLR(D_solver, milestones=[500, 1000])

    if torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    criterion = nn.BCELoss()

    pickle_path = "." + args.pickle_dir + log_param
    read_pickle(pickle_path, G, G_solver, D, D_solver)

    for epoch in range(args.n_epochs):
        for i, X in enumerate(dset_loaders):

            X = var_or_cuda(X)

            if X.size()[0] != int(args.batch_size):
                #print("batch_size != {} drop last incompatible batch".format(int(args.batch_size)))
                continue

            Z = generateZ(args)
            real_labels = var_or_cuda(torch.ones(args.batch_size))
            fake_labels = var_or_cuda(torch.zeros(args.batch_size))

            if args.soft_label:
                real_labels = var_or_cuda(
                    torch.Tensor(args.batch_size).uniform_(0.7, 1.2))
                fake_labels = var_or_cuda(
                    torch.Tensor(args.batch_size).uniform_(0, 0.3))

            # ============= Train the discriminator =============#
            d_real = D(X)
            d_real_loss = criterion(d_real, real_labels)

            fake = G(Z)
            d_fake = D(fake)
            d_fake_loss = criterion(d_fake, fake_labels)

            d_loss = d_real_loss + d_fake_loss

            d_real_acu = torch.ge(d_real.squeeze(), 0.5).float()
            d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float()
            d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0))

            if d_total_acu <= args.d_thresh:
                D.zero_grad()
                d_loss.backward()
                D_solver.step()

            # =============== Train the generator ===============#

            Z = generateZ(args)

            fake = G(Z)
            d_fake = D(fake)
            g_loss = criterion(d_fake, real_labels)

            D.zero_grad()
            G.zero_grad()
            g_loss.backward()
            G_solver.step()

        # =============== logging each iteration ===============#
        iteration = str(G_solver.state_dict()['state'][
            G_solver.state_dict()['param_groups'][0]['params'][0]]['step'])
        if args.use_tensorboard:
            log_save_path = args.output_dir + args.log_dir + log_param
            if not os.path.exists(log_save_path):
                os.makedirs(log_save_path)

            info = {
                'loss/loss_D_R': d_real_loss.data[0],
                'loss/loss_D_F': d_fake_loss.data[0],
                'loss/loss_D': d_loss.data[0],
                'loss/loss_G': g_loss.data[0],
                'loss/acc_D': d_total_acu.data[0]
            }

            for tag, value in info.items():
                inject_summary(summary_writer, tag, value, iteration)

            summary_writer.flush()

        # =============== each epoch save model or save image ===============#
        print(
            'Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, D_acu : {:.4}, D_lr : {:.4}'
            .format(iteration, d_loss.data[0], g_loss.data[0],
                    d_total_acu.data[0],
                    D_solver.state_dict()['param_groups'][0]["lr"]))

        if (epoch + 1) % args.image_save_step == 0:

            samples = fake.cpu().data[:8].squeeze().numpy()

            image_path = args.output_dir + args.image_dir + log_param
            if not os.path.exists(image_path):
                os.makedirs(image_path)

            SavePloat_Voxels(samples, image_path, iteration)

        if (epoch + 1) % args.pickle_step == 0:
            pickle_save_path = args.output_dir + args.pickle_dir + log_param
            save_new_pickle(pickle_save_path, iteration, G, G_solver, D,
                            D_solver)

        if args.lrsh:

            try:

                D_scheduler.step()

            except Exception as e:

                print("fail lr scheduling", e)
Ejemplo n.º 6
0
def train(args):
    #for creating the visdom object
    DEFAULT_PORT = 8097
    DEFAULT_HOSTNAME = "http://localhost"
    viz = Visdom(DEFAULT_HOSTNAME, DEFAULT_PORT, ipv6=False)

    hyparam_list = [
        ("model", args.model_name),
        ("cube", args.cube_len),
        ("bs", args.batch_size),
        ("g_lr", args.g_lr),
        ("d_lr", args.d_lr),
        ("z", args.z_dis),
        ("bias", args.bias),
        ("sl", args.soft_label),
    ]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    print(log_param)

    # for using tensorboard
    if args.use_tensorboard:
        import tensorflow as tf

        summary_writer = tf.summary.FileWriter(args.output_dir + args.log_dir +
                                               log_param)

        def inject_summary(summary_writer, tag, value, step):
            summary = tf.Summary(
                value=[tf.Summary.Value(tag=tag, simple_value=value)])
            summary_writer.add_summary(summary, global_step=step)

        inject_summary = inject_summary

    # datset define
    dsets_path = args.input_dir + args.data_dir + "train/"
    print(dsets_path)

    x_train = np.load("voxels_3DMNIST_16.npy")
    dataset = x_train.reshape(-1,
                              args.cube_len * args.cube_len * args.cube_len)
    print(dataset.shape)
    dset_loaders = torch.utils.data.DataLoader(dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)

    # model define
    D = _D(args)
    G = _G(args)

    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    criterion = nn.BCELoss()

    pickle_path = "." + args.pickle_dir + log_param
    read_pickle(pickle_path, G, G_solver, D, D_solver)

    for epoch in range(args.n_epochs):
        epoch_start_time = time.time()
        print("epoch %d started" % (epoch))
        for i, X in enumerate(dset_loaders):

            X = var_or_cuda(X)
            X = X.type(torch.cuda.FloatTensor)
            if X.size()[0] != int(args.batch_size):
                #print("batch_size != {} drop last incompatible batch".format(int(args.batch_size)))
                continue

            Z = generateZ(args)
            real_labels = var_or_cuda(torch.ones(args.batch_size)).view(
                -1, 1, 1, 1, 1)
            fake_labels = var_or_cuda(torch.zeros(args.batch_size)).view(
                -1, 1, 1, 1, 1)

            if args.soft_label:
                real_labels = var_or_cuda(
                    torch.Tensor(args.batch_size).uniform_(0.9, 1.1)).view(
                        -1, 1, 1, 1, 1)  ####
                #fake_labels = var_or_cuda(torch.Tensor(args.batch_size).uniform_(0, 0.3)).view(-1,1,1,1,1)
                fake_labels = var_or_cuda(torch.zeros(args.batch_size)).view(
                    -1, 1, 1, 1, 1)  #####
            # ============= Train the discriminator =============#
            d_real = D(X)
            d_real_loss = criterion(d_real, real_labels)

            fake = G(Z)
            d_fake = D(fake)
            d_fake_loss = criterion(d_fake, fake_labels)

            d_loss = d_real_loss + d_fake_loss

            d_real_acu = torch.ge(d_real.squeeze(), 0.5).float()
            d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float()
            d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0))

            #if 1:
            if d_total_acu <= args.d_thresh:
                D.zero_grad()
                d_loss.backward()
                D_solver.step()

            # =============== Train the generator ===============#

            Z = generateZ(args)

            fake = G(Z)
            d_fake = D(fake)
            g_loss = criterion(d_fake, real_labels)

            D.zero_grad()
            G.zero_grad()
            g_loss.backward()
            G_solver.step()
            #######
            #print(fake.shape)
            #print(fake.cpu().data[:8].squeeze().numpy().shape)

            # =============== logging each iteration ===============#
            iteration = str(G_solver.state_dict()['state'][
                G_solver.state_dict()['param_groups'][0]['params'][0]]['step'])
            #print(type(iteration))
            #iteration = str(i)
            #saving the model and a image each 100 iteration
            if int(iteration) % 300 == 0:
                #pickle_save_path = args.output_dir + args.pickle_dir + log_param
                #save_new_pickle(pickle_save_path, iteration, G, G_solver, D, D_solver)
                samples = fake.cpu().data[:8].squeeze().numpy()

                #print(samples.shape)
                for s in range(8):
                    plotVoxelVisdom(samples[s, ...], viz,
                                    "Iteration:{:.4}".format(iteration))

#                 image_path = args.output_dir + args.image_dir + log_param
#                 if not os.path.exists(image_path):
#                     os.makedirs(image_path)

#                 SavePloat_Voxels(samples, image_path, iteration)
# =============== each epoch save model or save image ===============#
            print(
                'Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, D_acu : {:.4}, D_lr : {:.4}'
                .format(iteration, d_loss.item(), g_loss.item(),
                        d_total_acu.item(),
                        D_solver.state_dict()['param_groups'][0]["lr"]))

        epoch_end_time = time.time()

        if (epoch + 1) % args.image_save_step == 0:

            samples = fake.cpu().data[:8].squeeze().numpy()

            image_path = args.output_dir + args.image_dir + log_param
            if not os.path.exists(image_path):
                os.makedirs(image_path)

            SavePloat_Voxels(samples, image_path, iteration)

        if (epoch + 1) % args.pickle_step == 0:
            pickle_save_path = args.output_dir + args.pickle_dir + log_param
            save_new_pickle(pickle_save_path, iteration, G, G_solver, D,
                            D_solver)

        print("epoch time", (epoch_end_time - epoch_start_time) / 60)
        print("epoch %d ended" % (epoch))
        print("################################################")
def train(args):
    #WSGAN related params
    lambda_gp = 10
    n_critic = 5

    hyparam_list = [
        ("model", args.model_name),
        ("cube", args.cube_len),
        ("bs", args.batch_size),
        ("g_lr", args.g_lr),
        ("d_lr", args.d_lr),
        ("z", args.z_dis),
        ("bias", args.bias),
    ]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    print(log_param)

    #define different paths
    pickle_path = "." + args.pickle_dir + log_param
    image_path = args.output_dir + args.image_dir + log_param
    pickle_save_path = args.output_dir + args.pickle_dir + log_param

    N = None  # None for the whole dataset
    VOL_SIZE = 64
    train_path = pathlib.Path("../Vert_dataset")
    dataset = VertDataset(train_path,
                          n=N,
                          transform=transforms.Compose(
                              [ResizeTo(VOL_SIZE),
                               transforms.ToTensor()]))
    print('Number of samples: ', len(dataset))
    dset_loaders = torch.utils.data.DataLoader(dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=0)
    print('Number of batches: ', len(dset_loaders))

    #  Build the model
    D = _D(args)
    G = _G(args)

    #Create the solvers
    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.device_count() > 1:
        D = nn.DataParallel(D)
        G = nn.DataParallel(G)
        print("Using {} GPUs".format(torch.cuda.device_count()))
        D.cuda()
        G.cuda()

    elif torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    #Load checkpoint if available
    read_pickle(pickle_path, G, G_solver, D, D_solver)

    G_losses = []
    D_losses = []

    for epoch in range(args.n_epochs):
        epoch_start_time = time.time()
        print("epoch %d started" % (epoch))
        for i, X in enumerate(dset_loaders):
            #print(X.shape)
            X = X.view(-1, args.cube_len * args.cube_len * args.cube_len)
            X = var_or_cuda(X)
            X = X.type(torch.cuda.FloatTensor)
            Z = generateZ(num_samples=X.size(0), z_size=args.z_size)

            #Train the critic
            d_loss, Wasserstein_D, gp = train_critic(X, Z, D, G, D_solver,
                                                     G_solver)

            # Train the generator every n_critic steps
            if i % n_critic == 0:
                Z = generateZ(num_samples=X.size(0), z_size=args.z_size)
                g_loss = train_gen(Z, D, G, D_solver, G_solver)

            #Log each iteration
            iteration = str(G_solver.state_dict()['state'][
                G_solver.state_dict()['param_groups'][0]['params'][0]]['step'])
            print('Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, WSdistance : {:.4}, GP : {:.4}'.format(iteration, d_loss.item(), \
                                                                            g_loss.item(), Wasserstein_D.item(), gp.item() ))
        ## End of epoch
        epoch_end_time = time.time()

        #Plot the losses each epoch
        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item())
        plot_losess(G_losses, D_losses, epoch)

        if (epoch + 1) % args.image_save_step == 0:
            print("Saving voxels")
            Z = generateZ(num_samples=8, z_size=args.z_size)
            gen_output = G(Z)
            samples = gen_output.cpu().data[:8].squeeze().numpy()
            samples = samples.reshape(-1, args.cube_len, args.cube_len,
                                      args.cube_len)
            Save_Voxels(samples, image_path, iteration)

        if (epoch + 1) % args.pickle_step == 0:
            print("Pickeling the model")
            save_new_pickle(pickle_save_path, iteration, G, G_solver, D,
                            D_solver)

        print("epoch time", (epoch_end_time - epoch_start_time) / 60)
        print("epoch %d ended" % (epoch))
        print("################################################")