Beispiel #1
0
def test(args):
    hyparam_list = [("model", args.model_name), ("cube", args.cube_len),
                    ("bs", args.batch_size), ("g_lr", args.g_lr),
                    ("d_lr", args.d_lr), ("z", args.z_dis),
                    ("bias", args.bias), ("sl", args.soft_label)]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    print(log_param)

    # model define
    D = _D(args)
    G = _G(args)

    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    pickle_path = args.pickle_dir
    read_pickle(pickle_path, G, G_solver, D, D_solver)  # load the models

    Z = generateZ(args)

    fake = G(Z)

    samples = fake.cpu().data[:8].squeeze().numpy()
    image_path = args.output_dir + args.image_dir + log_param
    if not os.path.exists(image_path):
        os.makedirs(image_path)
    SavePloat_Voxels(samples, image_path, iteration)
Beispiel #2
0
def test_3DGAN(args):
    # datset define
    dsets_path = args.input_dir + args.data_dir + "test/"
    print(dsets_path)
    dsets = ShapeNetDataset(dsets_path, args)
    dset_loaders = torch.utils.data.DataLoader(dsets,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)

    # model define
    D = _D(args)
    G = _G(args)

    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    pickle_path = "." + args.pickle_dir + '3DVAEGAN_MULTIVIEW_MAX'
    read_pickle(pickle_path, G, G_solver, D, D_solver)
    recon_loss_total = 0
    for i, X in enumerate(dset_loaders):
        #X = X.view(-1, 1, args.cube_len, args.cube_len, args.cube_len)
        X = var_or_cuda(X)
        print(X.size())
        Z = generateZ(args)
        print(Z.size())
        fake = G(Z).squeeze()
        print(fake.size())
        recon_loss = torch.sum(torch.pow((fake - X), 2), dim=(1, 2, 3))
        print(recon_loss.size())
        print("RECON LOSS ITER: ", i, " - ", torch.mean(recon_loss))
        recon_loss_total += (recon_loss)
        samples = fake.cpu().data[:8].squeeze().numpy()

        image_path = args.output_dir + args.image_dir + '3DVAEGAN_MULTIVIEW_MAX_test'
        if not os.path.exists(image_path):
            os.makedirs(image_path)

        SavePloat_Voxels(samples, image_path, i)
Beispiel #3
0
def train(args):

    hyparam_list = [
        ("model", args.model_name),
        ("cube", args.cube_len),
        ("bs", args.batch_size),
        ("g_lr", args.g_lr),
        ("d_lr", args.d_lr),
        ("z", args.z_dis),
        ("bias", args.bias),
        ("sl", args.soft_label),
    ]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    print(log_param)

    # for using tensorboard
    if args.use_tensorboard:
        import tensorflow as tf

        summary_writer = tf.summary.FileWriter(args.output_dir + args.log_dir +
                                               log_param)

        def inject_summary(summary_writer, tag, value, step):
            summary = tf.Summary(
                value=[tf.Summary.Value(tag=tag, simple_value=value)])
            summary_writer.add_summary(summary, global_step=step)

        inject_summary = inject_summary

    # datset define
    dsets_path = args.input_dir + args.data_dir + "train/"
    print(dsets_path)
    dsets = ShapeNetDataset(dsets_path, args)
    dset_loaders = torch.utils.data.DataLoader(dsets,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)

    # model define
    D = _D(args)
    G = _G(args)

    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if args.lrsh:
        D_scheduler = MultiStepLR(D_solver, milestones=[500, 1000])

    if torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    criterion = nn.BCELoss()

    pickle_path = "." + args.pickle_dir + log_param
    read_pickle(pickle_path, G, G_solver, D, D_solver)

    for epoch in range(args.n_epochs):
        for i, X in enumerate(dset_loaders):

            X = var_or_cuda(X)

            if X.size()[0] != int(args.batch_size):
                #print("batch_size != {} drop last incompatible batch".format(int(args.batch_size)))
                continue

            Z = generateZ(args)
            real_labels = var_or_cuda(torch.ones(args.batch_size))
            fake_labels = var_or_cuda(torch.zeros(args.batch_size))

            if args.soft_label:
                real_labels = var_or_cuda(
                    torch.Tensor(args.batch_size).uniform_(0.7, 1.2))
                fake_labels = var_or_cuda(
                    torch.Tensor(args.batch_size).uniform_(0, 0.3))

            # ============= Train the discriminator =============#
            d_real = D(X)
            d_real_loss = criterion(d_real, real_labels)

            fake = G(Z)
            d_fake = D(fake)
            d_fake_loss = criterion(d_fake, fake_labels)

            d_loss = d_real_loss + d_fake_loss

            d_real_acu = torch.ge(d_real.squeeze(), 0.5).float()
            d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float()
            d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0))

            if d_total_acu <= args.d_thresh:
                D.zero_grad()
                d_loss.backward()
                D_solver.step()

            # =============== Train the generator ===============#

            Z = generateZ(args)

            fake = G(Z)
            d_fake = D(fake)
            g_loss = criterion(d_fake, real_labels)

            D.zero_grad()
            G.zero_grad()
            g_loss.backward()
            G_solver.step()

        # =============== logging each iteration ===============#
        iteration = str(G_solver.state_dict()['state'][
            G_solver.state_dict()['param_groups'][0]['params'][0]]['step'])
        if args.use_tensorboard:
            log_save_path = args.output_dir + args.log_dir + log_param
            if not os.path.exists(log_save_path):
                os.makedirs(log_save_path)

            info = {
                'loss/loss_D_R': d_real_loss.data[0],
                'loss/loss_D_F': d_fake_loss.data[0],
                'loss/loss_D': d_loss.data[0],
                'loss/loss_G': g_loss.data[0],
                'loss/acc_D': d_total_acu.data[0]
            }

            for tag, value in info.items():
                inject_summary(summary_writer, tag, value, iteration)

            summary_writer.flush()

        # =============== each epoch save model or save image ===============#
        print(
            'Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, D_acu : {:.4}, D_lr : {:.4}'
            .format(iteration, d_loss.data[0], g_loss.data[0],
                    d_total_acu.data[0],
                    D_solver.state_dict()['param_groups'][0]["lr"]))

        if (epoch + 1) % args.image_save_step == 0:

            samples = fake.cpu().data[:8].squeeze().numpy()

            image_path = args.output_dir + args.image_dir + log_param
            if not os.path.exists(image_path):
                os.makedirs(image_path)

            SavePloat_Voxels(samples, image_path, iteration)

        if (epoch + 1) % args.pickle_step == 0:
            pickle_save_path = args.output_dir + args.pickle_dir + log_param
            save_new_pickle(pickle_save_path, iteration, G, G_solver, D,
                            D_solver)

        if args.lrsh:

            try:

                D_scheduler.step()

            except Exception as e:

                print("fail lr scheduling", e)
Beispiel #4
0
def train(args):
    #for creating the visdom object
    DEFAULT_PORT = 8097
    DEFAULT_HOSTNAME = "http://localhost"
    viz = Visdom(DEFAULT_HOSTNAME, DEFAULT_PORT, ipv6=False)

    hyparam_list = [
        ("model", args.model_name),
        ("cube", args.cube_len),
        ("bs", args.batch_size),
        ("g_lr", args.g_lr),
        ("d_lr", args.d_lr),
        ("z", args.z_dis),
        ("bias", args.bias),
        ("sl", args.soft_label),
    ]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    print(log_param)

    # for using tensorboard
    if args.use_tensorboard:
        import tensorflow as tf

        summary_writer = tf.summary.FileWriter(args.output_dir + args.log_dir +
                                               log_param)

        def inject_summary(summary_writer, tag, value, step):
            summary = tf.Summary(
                value=[tf.Summary.Value(tag=tag, simple_value=value)])
            summary_writer.add_summary(summary, global_step=step)

        inject_summary = inject_summary

    # datset define
    dsets_path = args.input_dir + args.data_dir + "train/"
    print(dsets_path)

    x_train = np.load("voxels_3DMNIST_16.npy")
    dataset = x_train.reshape(-1,
                              args.cube_len * args.cube_len * args.cube_len)
    print(dataset.shape)
    dset_loaders = torch.utils.data.DataLoader(dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)

    # model define
    D = _D(args)
    G = _G(args)

    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    criterion = nn.BCELoss()

    pickle_path = "." + args.pickle_dir + log_param
    read_pickle(pickle_path, G, G_solver, D, D_solver)

    for epoch in range(args.n_epochs):
        epoch_start_time = time.time()
        print("epoch %d started" % (epoch))
        for i, X in enumerate(dset_loaders):

            X = var_or_cuda(X)
            X = X.type(torch.cuda.FloatTensor)
            if X.size()[0] != int(args.batch_size):
                #print("batch_size != {} drop last incompatible batch".format(int(args.batch_size)))
                continue

            Z = generateZ(args)
            real_labels = var_or_cuda(torch.ones(args.batch_size)).view(
                -1, 1, 1, 1, 1)
            fake_labels = var_or_cuda(torch.zeros(args.batch_size)).view(
                -1, 1, 1, 1, 1)

            if args.soft_label:
                real_labels = var_or_cuda(
                    torch.Tensor(args.batch_size).uniform_(0.9, 1.1)).view(
                        -1, 1, 1, 1, 1)  ####
                #fake_labels = var_or_cuda(torch.Tensor(args.batch_size).uniform_(0, 0.3)).view(-1,1,1,1,1)
                fake_labels = var_or_cuda(torch.zeros(args.batch_size)).view(
                    -1, 1, 1, 1, 1)  #####
            # ============= Train the discriminator =============#
            d_real = D(X)
            d_real_loss = criterion(d_real, real_labels)

            fake = G(Z)
            d_fake = D(fake)
            d_fake_loss = criterion(d_fake, fake_labels)

            d_loss = d_real_loss + d_fake_loss

            d_real_acu = torch.ge(d_real.squeeze(), 0.5).float()
            d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float()
            d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0))

            #if 1:
            if d_total_acu <= args.d_thresh:
                D.zero_grad()
                d_loss.backward()
                D_solver.step()

            # =============== Train the generator ===============#

            Z = generateZ(args)

            fake = G(Z)
            d_fake = D(fake)
            g_loss = criterion(d_fake, real_labels)

            D.zero_grad()
            G.zero_grad()
            g_loss.backward()
            G_solver.step()
            #######
            #print(fake.shape)
            #print(fake.cpu().data[:8].squeeze().numpy().shape)

            # =============== logging each iteration ===============#
            iteration = str(G_solver.state_dict()['state'][
                G_solver.state_dict()['param_groups'][0]['params'][0]]['step'])
            #print(type(iteration))
            #iteration = str(i)
            #saving the model and a image each 100 iteration
            if int(iteration) % 300 == 0:
                #pickle_save_path = args.output_dir + args.pickle_dir + log_param
                #save_new_pickle(pickle_save_path, iteration, G, G_solver, D, D_solver)
                samples = fake.cpu().data[:8].squeeze().numpy()

                #print(samples.shape)
                for s in range(8):
                    plotVoxelVisdom(samples[s, ...], viz,
                                    "Iteration:{:.4}".format(iteration))

#                 image_path = args.output_dir + args.image_dir + log_param
#                 if not os.path.exists(image_path):
#                     os.makedirs(image_path)

#                 SavePloat_Voxels(samples, image_path, iteration)
# =============== each epoch save model or save image ===============#
            print(
                'Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, D_acu : {:.4}, D_lr : {:.4}'
                .format(iteration, d_loss.item(), g_loss.item(),
                        d_total_acu.item(),
                        D_solver.state_dict()['param_groups'][0]["lr"]))

        epoch_end_time = time.time()

        if (epoch + 1) % args.image_save_step == 0:

            samples = fake.cpu().data[:8].squeeze().numpy()

            image_path = args.output_dir + args.image_dir + log_param
            if not os.path.exists(image_path):
                os.makedirs(image_path)

            SavePloat_Voxels(samples, image_path, iteration)

        if (epoch + 1) % args.pickle_step == 0:
            pickle_save_path = args.output_dir + args.pickle_dir + log_param
            save_new_pickle(pickle_save_path, iteration, G, G_solver, D,
                            D_solver)

        print("epoch time", (epoch_end_time - epoch_start_time) / 60)
        print("epoch %d ended" % (epoch))
        print("################################################")
def train(args):
    #WSGAN related params
    lambda_gp = 10
    n_critic = 5

    hyparam_list = [
        ("model", args.model_name),
        ("cube", args.cube_len),
        ("bs", args.batch_size),
        ("g_lr", args.g_lr),
        ("d_lr", args.d_lr),
        ("z", args.z_dis),
        ("bias", args.bias),
    ]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    print(log_param)

    #define different paths
    pickle_path = "." + args.pickle_dir + log_param
    image_path = args.output_dir + args.image_dir + log_param
    pickle_save_path = args.output_dir + args.pickle_dir + log_param

    N = None  # None for the whole dataset
    VOL_SIZE = 64
    train_path = pathlib.Path("../Vert_dataset")
    dataset = VertDataset(train_path,
                          n=N,
                          transform=transforms.Compose(
                              [ResizeTo(VOL_SIZE),
                               transforms.ToTensor()]))
    print('Number of samples: ', len(dataset))
    dset_loaders = torch.utils.data.DataLoader(dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=0)
    print('Number of batches: ', len(dset_loaders))

    #  Build the model
    D = _D(args)
    G = _G(args)

    #Create the solvers
    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.device_count() > 1:
        D = nn.DataParallel(D)
        G = nn.DataParallel(G)
        print("Using {} GPUs".format(torch.cuda.device_count()))
        D.cuda()
        G.cuda()

    elif torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    #Load checkpoint if available
    read_pickle(pickle_path, G, G_solver, D, D_solver)

    G_losses = []
    D_losses = []

    for epoch in range(args.n_epochs):
        epoch_start_time = time.time()
        print("epoch %d started" % (epoch))
        for i, X in enumerate(dset_loaders):
            #print(X.shape)
            X = X.view(-1, args.cube_len * args.cube_len * args.cube_len)
            X = var_or_cuda(X)
            X = X.type(torch.cuda.FloatTensor)
            Z = generateZ(num_samples=X.size(0), z_size=args.z_size)

            #Train the critic
            d_loss, Wasserstein_D, gp = train_critic(X, Z, D, G, D_solver,
                                                     G_solver)

            # Train the generator every n_critic steps
            if i % n_critic == 0:
                Z = generateZ(num_samples=X.size(0), z_size=args.z_size)
                g_loss = train_gen(Z, D, G, D_solver, G_solver)

            #Log each iteration
            iteration = str(G_solver.state_dict()['state'][
                G_solver.state_dict()['param_groups'][0]['params'][0]]['step'])
            print('Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, WSdistance : {:.4}, GP : {:.4}'.format(iteration, d_loss.item(), \
                                                                            g_loss.item(), Wasserstein_D.item(), gp.item() ))
        ## End of epoch
        epoch_end_time = time.time()

        #Plot the losses each epoch
        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item())
        plot_losess(G_losses, D_losses, epoch)

        if (epoch + 1) % args.image_save_step == 0:
            print("Saving voxels")
            Z = generateZ(num_samples=8, z_size=args.z_size)
            gen_output = G(Z)
            samples = gen_output.cpu().data[:8].squeeze().numpy()
            samples = samples.reshape(-1, args.cube_len, args.cube_len,
                                      args.cube_len)
            Save_Voxels(samples, image_path, iteration)

        if (epoch + 1) % args.pickle_step == 0:
            print("Pickeling the model")
            save_new_pickle(pickle_save_path, iteration, G, G_solver, D,
                            D_solver)

        print("epoch time", (epoch_end_time - epoch_start_time) / 60)
        print("epoch %d ended" % (epoch))
        print("################################################")
Beispiel #6
0
def train(args):
    if args.use_tensorboard:
        import tensorflow as tf

        summary_writer = tf.summary.FileWriter(args.output_dir + args.log_dir +
                                               log_param)

        def inject_summary(summary_writer, tag, value, step):
            summary = tf.Summary(
                value=[tf.Summary.Value(tag=tag, simple_value=value)])
            summary_writer.add_summary(summary, global_step=step)

        inject_summary = inject_summary

    data_path = args.input_dir + args.data_dir + "train/"
    dataset = ShapeNetDataset(data_path, args)
    shape_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=1)

    Gen = Generator(args)
    Disc = Discriminator(args)
    print('[*] Number of parameters in Generator: {:,}'.format(
        sum([p.data.nelement() for p in Gen.parameters()])))
    print('[*] Number of parameters in Discriminator: {:,}'.format(
        sum([p.data.nelement() for p in Disc.parameters()])))

    criterion = nn.BCELoss()

    Gen_optim = optim.Adam(Gen.parameters(), lr=args.lr_gen, betas=args.betas)
    Disc_optim = optim.Adam(Disc.parameters(),
                            lr=args.lr_disc,
                            betas=args.betas)

    if args.sched:
        Disc_scheduler = optim.MultiStepLR(Disc_optim,
                                           milestones=[500, 1000, 1500, 2000])

    Gen.cuda()
    Disc.cuda()
    criterion.cuda()

    for epoch in trange(args.epochs):
        for i, x in enumerate(shape_loader):

            real_labels = torch.ones(args.batch_size).cuda()
            fake_labels = torch.ones(args.batch_size).cuda()

            if args.noisy_label:
                real_labels = torch.Tensor(args.batch_size).uniform_(0.7, 1.2)
                fake_labels = torch.Tensor(args.batch_size).uniform_(0, 0.3)

            x = x.cuda()
            z = generateZ(args)

            Disc_real = Disc(x)
            Disc_real_loss = criterion(Disc_real, real_labels)

            fake = Gen(z)
            Disc_fake = Disc(fake)
            Disc_fake_loss = criterion(Disc_fake, fake_labels)

            Disc_loss_total = Disc_real_loss + Disc_fake_loss

            real_acc = torch.ge(Disc_real.squeeze(), 0.5).float()
            fake_acc = torch.ge(Disc_fake.squeeze(), 0.5).float()
            total_acc = torch.mean(torch.cat((real_acc, fake_acc), 0))

            if total_acc <= args.d_thresh:
                Disc.zero_grad()
                Disc_loss_total.backward()
                Disc_optim.step()

            z = generateZ(args)

            fake = Gen(z)
            Disc_fake = Disc(fake)
            Gen_loss = criterion(Disc_fake, real_labels)

            Gen.zero_grad()
            Disc.zero_grad()
            Gen_loss.backward()
            Gen_optim.step()

        iteration = str(Gen_optim.state_dict())['state'][
            Gen_optim.state_dict()['param_groups'][0]['params'][0]['step']]
        print(
            'Iter-{}; , Dis_loss : {:.4}, Gen_loss : {:.4}, Dis_acc : {:.4}, Disc_lr : {:.4}'
            .format(iteration, Disc_loss_total.data[0], Gen_loss.data[0],
                    total_acc.data[0],
                    Disc_optim.state_dict()['param_groups'][0]["lr"]))
        save_checkpoint()