Exemple #1
0
def train(args):

    hyparam_list = [
        ("model", args.model_name),
        ("cube", args.cube_len),
        ("bs", args.batch_size),
        ("g_lr", args.g_lr),
        ("d_lr", args.d_lr),
        ("z", args.z_dis),
        ("bias", args.bias),
        ("sl", args.soft_label),
    ]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    print(log_param)

    # for using tensorboard
    if args.use_tensorboard:
        import tensorflow as tf

        summary_writer = tf.summary.FileWriter(args.output_dir + args.log_dir +
                                               log_param)

        def inject_summary(summary_writer, tag, value, step):
            summary = tf.Summary(
                value=[tf.Summary.Value(tag=tag, simple_value=value)])
            summary_writer.add_summary(summary, global_step=step)

        inject_summary = inject_summary

    # datset define
    dsets_path = args.input_dir + args.data_dir + "train/"
    print(dsets_path)
    dsets = ShapeNetDataset(dsets_path, args)
    dset_loaders = torch.utils.data.DataLoader(dsets,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)

    # model define
    D = _D(args)
    G = _G(args)

    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if args.lrsh:
        D_scheduler = MultiStepLR(D_solver, milestones=[500, 1000])

    if torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    criterion = nn.BCELoss()

    pickle_path = "." + args.pickle_dir + log_param
    read_pickle(pickle_path, G, G_solver, D, D_solver)

    for epoch in range(args.n_epochs):
        for i, X in enumerate(dset_loaders):

            X = var_or_cuda(X)

            if X.size()[0] != int(args.batch_size):
                #print("batch_size != {} drop last incompatible batch".format(int(args.batch_size)))
                continue

            Z = generateZ(args)
            real_labels = var_or_cuda(torch.ones(args.batch_size))
            fake_labels = var_or_cuda(torch.zeros(args.batch_size))

            if args.soft_label:
                real_labels = var_or_cuda(
                    torch.Tensor(args.batch_size).uniform_(0.7, 1.2))
                fake_labels = var_or_cuda(
                    torch.Tensor(args.batch_size).uniform_(0, 0.3))

            # ============= Train the discriminator =============#
            d_real = D(X)
            d_real_loss = criterion(d_real, real_labels)

            fake = G(Z)
            d_fake = D(fake)
            d_fake_loss = criterion(d_fake, fake_labels)

            d_loss = d_real_loss + d_fake_loss

            d_real_acu = torch.ge(d_real.squeeze(), 0.5).float()
            d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float()
            d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0))

            if d_total_acu <= args.d_thresh:
                D.zero_grad()
                d_loss.backward()
                D_solver.step()

            # =============== Train the generator ===============#

            Z = generateZ(args)

            fake = G(Z)
            d_fake = D(fake)
            g_loss = criterion(d_fake, real_labels)

            D.zero_grad()
            G.zero_grad()
            g_loss.backward()
            G_solver.step()

        # =============== logging each iteration ===============#
        iteration = str(G_solver.state_dict()['state'][
            G_solver.state_dict()['param_groups'][0]['params'][0]]['step'])
        if args.use_tensorboard:
            log_save_path = args.output_dir + args.log_dir + log_param
            if not os.path.exists(log_save_path):
                os.makedirs(log_save_path)

            info = {
                'loss/loss_D_R': d_real_loss.data[0],
                'loss/loss_D_F': d_fake_loss.data[0],
                'loss/loss_D': d_loss.data[0],
                'loss/loss_G': g_loss.data[0],
                'loss/acc_D': d_total_acu.data[0]
            }

            for tag, value in info.items():
                inject_summary(summary_writer, tag, value, iteration)

            summary_writer.flush()

        # =============== each epoch save model or save image ===============#
        print(
            'Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, D_acu : {:.4}, D_lr : {:.4}'
            .format(iteration, d_loss.data[0], g_loss.data[0],
                    d_total_acu.data[0],
                    D_solver.state_dict()['param_groups'][0]["lr"]))

        if (epoch + 1) % args.image_save_step == 0:

            samples = fake.cpu().data[:8].squeeze().numpy()

            image_path = args.output_dir + args.image_dir + log_param
            if not os.path.exists(image_path):
                os.makedirs(image_path)

            SavePloat_Voxels(samples, image_path, iteration)

        if (epoch + 1) % args.pickle_step == 0:
            pickle_save_path = args.output_dir + args.pickle_dir + log_param
            save_new_pickle(pickle_save_path, iteration, G, G_solver, D,
                            D_solver)

        if args.lrsh:

            try:

                D_scheduler.step()

            except Exception as e:

                print("fail lr scheduling", e)
def train(args):
    #WSGAN related params
    lambda_gp = 10
    n_critic = 5

    hyparam_list = [
        ("model", args.model_name),
        ("cube", args.cube_len),
        ("bs", args.batch_size),
        ("g_lr", args.g_lr),
        ("d_lr", args.d_lr),
        ("z", args.z_dis),
        ("bias", args.bias),
    ]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    print(log_param)

    #define different paths
    pickle_path = "." + args.pickle_dir + log_param
    image_path = args.output_dir + args.image_dir + log_param
    pickle_save_path = args.output_dir + args.pickle_dir + log_param

    N = None  # None for the whole dataset
    VOL_SIZE = 64
    train_path = pathlib.Path("../Vert_dataset")
    dataset = VertDataset(train_path,
                          n=N,
                          transform=transforms.Compose(
                              [ResizeTo(VOL_SIZE),
                               transforms.ToTensor()]))
    print('Number of samples: ', len(dataset))
    dset_loaders = torch.utils.data.DataLoader(dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=0)
    print('Number of batches: ', len(dset_loaders))

    #  Build the model
    D = _D(args)
    G = _G(args)

    #Create the solvers
    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.device_count() > 1:
        D = nn.DataParallel(D)
        G = nn.DataParallel(G)
        print("Using {} GPUs".format(torch.cuda.device_count()))
        D.cuda()
        G.cuda()

    elif torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    #Load checkpoint if available
    read_pickle(pickle_path, G, G_solver, D, D_solver)

    G_losses = []
    D_losses = []

    for epoch in range(args.n_epochs):
        epoch_start_time = time.time()
        print("epoch %d started" % (epoch))
        for i, X in enumerate(dset_loaders):
            #print(X.shape)
            X = X.view(-1, args.cube_len * args.cube_len * args.cube_len)
            X = var_or_cuda(X)
            X = X.type(torch.cuda.FloatTensor)
            Z = generateZ(num_samples=X.size(0), z_size=args.z_size)

            #Train the critic
            d_loss, Wasserstein_D, gp = train_critic(X, Z, D, G, D_solver,
                                                     G_solver)

            # Train the generator every n_critic steps
            if i % n_critic == 0:
                Z = generateZ(num_samples=X.size(0), z_size=args.z_size)
                g_loss = train_gen(Z, D, G, D_solver, G_solver)

            #Log each iteration
            iteration = str(G_solver.state_dict()['state'][
                G_solver.state_dict()['param_groups'][0]['params'][0]]['step'])
            print('Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, WSdistance : {:.4}, GP : {:.4}'.format(iteration, d_loss.item(), \
                                                                            g_loss.item(), Wasserstein_D.item(), gp.item() ))
        ## End of epoch
        epoch_end_time = time.time()

        #Plot the losses each epoch
        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item())
        plot_losess(G_losses, D_losses, epoch)

        if (epoch + 1) % args.image_save_step == 0:
            print("Saving voxels")
            Z = generateZ(num_samples=8, z_size=args.z_size)
            gen_output = G(Z)
            samples = gen_output.cpu().data[:8].squeeze().numpy()
            samples = samples.reshape(-1, args.cube_len, args.cube_len,
                                      args.cube_len)
            Save_Voxels(samples, image_path, iteration)

        if (epoch + 1) % args.pickle_step == 0:
            print("Pickeling the model")
            save_new_pickle(pickle_save_path, iteration, G, G_solver, D,
                            D_solver)

        print("epoch time", (epoch_end_time - epoch_start_time) / 60)
        print("epoch %d ended" % (epoch))
        print("################################################")
Exemple #3
0
def train(args):
    #for creating the visdom object
    DEFAULT_PORT = 8097
    DEFAULT_HOSTNAME = "http://localhost"
    viz = Visdom(DEFAULT_HOSTNAME, DEFAULT_PORT, ipv6=False)

    hyparam_list = [
        ("model", args.model_name),
        ("cube", args.cube_len),
        ("bs", args.batch_size),
        ("g_lr", args.g_lr),
        ("d_lr", args.d_lr),
        ("z", args.z_dis),
        ("bias", args.bias),
        ("sl", args.soft_label),
    ]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    print(log_param)

    # for using tensorboard
    if args.use_tensorboard:
        import tensorflow as tf

        summary_writer = tf.summary.FileWriter(args.output_dir + args.log_dir +
                                               log_param)

        def inject_summary(summary_writer, tag, value, step):
            summary = tf.Summary(
                value=[tf.Summary.Value(tag=tag, simple_value=value)])
            summary_writer.add_summary(summary, global_step=step)

        inject_summary = inject_summary

    # datset define
    dsets_path = args.input_dir + args.data_dir + "train/"
    print(dsets_path)

    x_train = np.load("voxels_3DMNIST_16.npy")
    dataset = x_train.reshape(-1,
                              args.cube_len * args.cube_len * args.cube_len)
    print(dataset.shape)
    dset_loaders = torch.utils.data.DataLoader(dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)

    # model define
    D = _D(args)
    G = _G(args)

    D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta)
    G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta)

    if torch.cuda.is_available():
        print("using cuda")
        D.cuda()
        G.cuda()

    criterion = nn.BCELoss()

    pickle_path = "." + args.pickle_dir + log_param
    read_pickle(pickle_path, G, G_solver, D, D_solver)

    for epoch in range(args.n_epochs):
        epoch_start_time = time.time()
        print("epoch %d started" % (epoch))
        for i, X in enumerate(dset_loaders):

            X = var_or_cuda(X)
            X = X.type(torch.cuda.FloatTensor)
            if X.size()[0] != int(args.batch_size):
                #print("batch_size != {} drop last incompatible batch".format(int(args.batch_size)))
                continue

            Z = generateZ(args)
            real_labels = var_or_cuda(torch.ones(args.batch_size)).view(
                -1, 1, 1, 1, 1)
            fake_labels = var_or_cuda(torch.zeros(args.batch_size)).view(
                -1, 1, 1, 1, 1)

            if args.soft_label:
                real_labels = var_or_cuda(
                    torch.Tensor(args.batch_size).uniform_(0.9, 1.1)).view(
                        -1, 1, 1, 1, 1)  ####
                #fake_labels = var_or_cuda(torch.Tensor(args.batch_size).uniform_(0, 0.3)).view(-1,1,1,1,1)
                fake_labels = var_or_cuda(torch.zeros(args.batch_size)).view(
                    -1, 1, 1, 1, 1)  #####
            # ============= Train the discriminator =============#
            d_real = D(X)
            d_real_loss = criterion(d_real, real_labels)

            fake = G(Z)
            d_fake = D(fake)
            d_fake_loss = criterion(d_fake, fake_labels)

            d_loss = d_real_loss + d_fake_loss

            d_real_acu = torch.ge(d_real.squeeze(), 0.5).float()
            d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float()
            d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0))

            #if 1:
            if d_total_acu <= args.d_thresh:
                D.zero_grad()
                d_loss.backward()
                D_solver.step()

            # =============== Train the generator ===============#

            Z = generateZ(args)

            fake = G(Z)
            d_fake = D(fake)
            g_loss = criterion(d_fake, real_labels)

            D.zero_grad()
            G.zero_grad()
            g_loss.backward()
            G_solver.step()
            #######
            #print(fake.shape)
            #print(fake.cpu().data[:8].squeeze().numpy().shape)

            # =============== logging each iteration ===============#
            iteration = str(G_solver.state_dict()['state'][
                G_solver.state_dict()['param_groups'][0]['params'][0]]['step'])
            #print(type(iteration))
            #iteration = str(i)
            #saving the model and a image each 100 iteration
            if int(iteration) % 300 == 0:
                #pickle_save_path = args.output_dir + args.pickle_dir + log_param
                #save_new_pickle(pickle_save_path, iteration, G, G_solver, D, D_solver)
                samples = fake.cpu().data[:8].squeeze().numpy()

                #print(samples.shape)
                for s in range(8):
                    plotVoxelVisdom(samples[s, ...], viz,
                                    "Iteration:{:.4}".format(iteration))

#                 image_path = args.output_dir + args.image_dir + log_param
#                 if not os.path.exists(image_path):
#                     os.makedirs(image_path)

#                 SavePloat_Voxels(samples, image_path, iteration)
# =============== each epoch save model or save image ===============#
            print(
                'Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, D_acu : {:.4}, D_lr : {:.4}'
                .format(iteration, d_loss.item(), g_loss.item(),
                        d_total_acu.item(),
                        D_solver.state_dict()['param_groups'][0]["lr"]))

        epoch_end_time = time.time()

        if (epoch + 1) % args.image_save_step == 0:

            samples = fake.cpu().data[:8].squeeze().numpy()

            image_path = args.output_dir + args.image_dir + log_param
            if not os.path.exists(image_path):
                os.makedirs(image_path)

            SavePloat_Voxels(samples, image_path, iteration)

        if (epoch + 1) % args.pickle_step == 0:
            pickle_save_path = args.output_dir + args.pickle_dir + log_param
            save_new_pickle(pickle_save_path, iteration, G, G_solver, D,
                            D_solver)

        print("epoch time", (epoch_end_time - epoch_start_time) / 60)
        print("epoch %d ended" % (epoch))
        print("################################################")
Exemple #4
0
def train(args):
    print(args)

    ########################## tensorboard ##############################

    hyparam_list = [("z_dim", args.z_dim), ("lr", args.lr),
                    ("nm", args.normalization)]

    hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list))
    log_param = make_hyparam_string(hyparam_dict)
    log_param = args.model_name + "_" + log_param
    print(log_param)

    if args.use_tensorboard:
        import tensorflow as tf

        summary_writer = tf.summary.FileWriter(args.output_dir + args.log_dir +
                                               log_param)

        def inject_summary(summary_writer, tag, value, step):
            summary = tf.Summary(
                value=[tf.Summary.Value(tag=tag, simple_value=value)])
            summary_writer.add_summary(summary, global_step=step)

        inject_summary = inject_summary

    ########################## data loading ###############################

    data_transforms = {
        'train':
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Scale(args.image_scale),
            transforms.ToTensor(),
            Grayscale(),
            #Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], args.normalization),
        ])
    }

    dsets_path = args.input_dir + args.data_dir
    dsets = {
        x: datasets.ImageFolder(os.path.join(dsets_path, x),
                                data_transforms[x])
        for x in ['train']
    }
    dset_loaders = {
        x: torch.utils.data.DataLoader(dsets[x],
                                       batch_size=args.mb_size,
                                       shuffle=True,
                                       num_workers=4,
                                       drop_last=True)
        for x in ['train']
    }

    model = Vae_deconv(args)

    if torch.cuda.is_available():
        model.cuda()

    solver = optim.Adam(model.parameters(), lr=args.lr)

    solver.param_groups[0]['epoch'] = 0
    solver.param_groups[0]['iter'] = 0

    #criterion = torch.nn.BCELoss(size_average=False)
    criterion = torch.nn.MSELoss(size_average=False)

    pickle_path = "." + args.pickle_dir + log_param
    read_pickle(pickle_path, model, solver)

    # Training
    for epoch in range(0, args.epoch_iter):

        solver.param_groups[0]['epoch'] += 1
        epoch = solver.param_groups[0]['epoch']

        for img, label in dset_loaders['train']:
            # Sample data
            solver.param_groups[0]['iter'] += 1
            iteration = solver.param_groups[0]['iter']

            img = to_var(img)

            recon, mu, log_var = model(img)

            recon_loss = criterion(recon, img)  #/ img.size(0)
            kld_loss = torch.sum(0.5 *
                                 (mu**2 + torch.exp(log_var) - log_var - 1))

            ELBO = recon_loss + kld_loss
            solver.zero_grad()
            ELBO.backward()
            solver.step()

            print(
                'Iter-{}; recon_loss: {:.4} , kld_loss : {:.4}, ELBO  : {:.4}'.
                format(str(iteration), recon_loss.data[0], kld_loss.data[0],
                       ELBO.data[0]))

            if (iteration) % args.image_save_step == 0:

                samples = recon.cpu().data[:16]

                image_path = args.output_dir + args.image_dir + log_param
                if not os.path.exists(image_path):
                    os.makedirs(image_path)

                torchvision.utils.save_image(
                    samples,
                    image_path + '/{}.png'.format(str(iteration).zfill(3)),
                    normalize=True)

            if args.use_tensorboard and (iteration) % args.log_step == 0:

                log_save_path = args.output_dir + args.log_dir + log_param
                if not os.path.exists(log_save_path):
                    os.makedirs(log_save_path)

                info = {
                    'loss/loss_D_R': recon_loss.data[0],
                    'loss/loss_D': kld_loss.data[0],
                    'loss/loss_G': ELBO.data[0],
                }

                for tag, value in info.items():
                    inject_summary(summary_writer, tag, value, iteration)

                summary_writer.flush()

        if (epoch) % args.pickle_step == 0:
            pickle_save_path = args.output_dir + args.pickle_dir + log_param
            save_new_pickle(pickle_save_path, iteration, model, solver)