Esempio n. 1
0
    def __init__(self, args, config=config):
        self.args = args
        self.attribute = args.attribute
        self.gpu = args.gpu
        self.mode = args.mode
        self.restore = args.restore

        # init dataset and networks
        self.config = config
        batch_ip = []
	batch_gt = []

	for ind in indices[batch_num*batch_size:batch_num*batch_size+batch_size]:
		model_path = models[ind[0]]
		img_path = join(FLAGS.data_dir_imgs, model_path, 'rendering', PNG_FILES[ind[1]])
		pcl_path = join(FLAGS.data_dir_pcl, model_path, 'pointcloud_2048.npy')

		pcl_gt = np.load(pcl_path)

		ip_image = cv2.imread(img_path)[4:-5, 4:-5, :3]
		ip_image = cv2.cvtColor(ip_image, cv2.COLOR_BGR2RGB)

		batch_gt.append(pcl_gt)
		batch_ip.append(ip_image)        self.G = Generator()
        self.D = Discriminator()

        self.adv_criterion = torch.nn.BCELoss()

        self.set_mode_and_gpu()
        self.restore_from_file()
Esempio n. 2
0
    def build_model(self):
        """ A function of defining following instances :

        -----  Generator
        -----  Discriminator
        -----  Optimizer for Generator
        -----  Optimizer for Discriminator
        -----  Defining Loss functions

        """

        # ---------------------------------------------------------------------
        #						1. Network Initialization
        # ---------------------------------------------------------------------
        self.gen = Generator(batch_size=self.batch_size,
                             img_size=self, img_size,
                             z_dim=self.z_dim,
                             text_embed_dim=self.text_embed_dim,
                             text_reduced_dim=self.text_reduced_dim)

        self.disc = Discriminator(batch_size=self.batch_size,
                                  img_size=self, img_size,
                                  text_embed_dim=self.text_embed_dim,
                                  text_reduced_dim=self.text_reduced_dim)

        self.gen_optim = optim.Adam(self.gen.parameters(),
                                    lr=self.learning_rate,
                                    betas=(self.beta1, self.beta2))

        self.disc_optim = optim.Adam(self.disc.parameters(),
                                     lr=self.learning_rate,
                                     betas=(self.beta1, self.beta2))

        self.cls_gan_optim = optim.Adam(itertools.chain(self.gen.parameters(),
                                                        self.disc.parameters()),
                                        lr=self.learning_rate,
                                        betas=(self.beta1, self.beta2))

        print ('-------------  Generator Model Info  ---------------')
        self.print_network(self.gen, 'G')
        print ('------------------------------------------------')

        print ('-------------  Discriminator Model Info  ---------------')
        self.print_network(self.disc, 'D')
        print ('------------------------------------------------')

        self.gen.cuda()
        self.disc.cuda()
        self.criterion = nn.BCELoss().cuda()
        # self.CE_loss = nn.CrossEntropyLoss().cuda()
        # self.MSE_loss = nn.MSELoss().cuda()
        self.gen.train()
        self.disc.train()
Esempio n. 3
0
def main(training,
         train_data_filename=None,
         pretrained_G_dir=None,
         save_ckpt_dir=None):
    # 0) Initialise
    src_dir = os.path.dirname(os.path.abspath(__file__))
    model_pdir = os.path.join(src_dir, '..', 'models')

    print(tf.config.experimental.list_physical_devices())
    strategy = tf.distribute.MirroredStrategy(
        cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())

    with strategy.scope():
        print(f'#in sync: {strategy.num_replicas_in_sync}')
        batch_size = 1  # * strategy.num_replicas_in_sync

        # 1) Get model
        G = ReconNet(1)
        ckpt_dir = os.path.join(model_pdir, pretrained_G_dir)
        latest_ckpt = tf.train.latest_checkpoint(ckpt_dir)
        G.load_weights(latest_ckpt)
        print(f'Generator loaded from {pretrained_G_dir}.')
        #G_lr_schedule = 1e-4
        G_lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=[1e5], values=[1e-4, 1e-5])
        G_optimizer = tf.keras.optimizers.Adam(learning_rate=G_lr_schedule)

        D = Discriminator()
        #D_lr_schedule = 1e-4
        D_lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=[1e5], values=[1e-4, 1e-5])
        D_optimizer = tf.keras.optimizers.Adam(learning_rate=D_lr_schedule)

    if training:
        # 3) Load training data
        data_pdir = os.path.join(src_dir, '..', 'data')

        train_data_path = os.path.join(data_pdir, train_data_filename)
        train_data = np.load(
            train_data_path)  # Assume data are normalised to [-1,1]
        train_dataset = tf.data.Dataset.from_tensor_slices(
            (train_data, train_data)).batch(batch_size)
        print(f'Training data {train_data_filename} loaded.')

        # 4) Train model
        epochs = 10
        ckpt_path = os.path.join(model_pdir, save_ckpt_dir, 'ckpt')
        train([G, D], [G_optimizer, D_optimizer], train_dataset, epochs,
              ckpt_path)
    return G, D
Esempio n. 4
0
    def __init__(self,
                 logDir,
                 printEvery=1,
                 resume=False,
                 useTensorboard=True):
        super(GAN3DTrainer, self).__init__()

        self.logDir = logDir

        self.currentEpoch = 0
        self.totalBatches = 0

        self.trainStats = {'lossG': [], 'lossD': [], 'accG': [], 'accD': []}

        self.printEvery = printEvery

        self.G = Generator()
        self.D = Discriminator()

        self.device = torch.device('cpu')
        if torch.cuda.is_available():
            self.device = torch.device('cuda:0')

            self.G = self.G.to(self.device)
            self.D = self.D.to(self.device)

            # parallelize models on both devices, splitting input on batch dimension
            self.G = torch.nn.DataParallel(self.G, device_ids=[0, 1])
            self.D = torch.nn.DataParallel(self.D, device_ids=[0, 1])

        # optim params direct from paper
        self.optimG = torch.optim.Adam(self.G.parameters(),
                                       lr=0.0025,
                                       betas=(0.5, 0.999))

        self.optimD = torch.optim.Adam(self.D.parameters(),
                                       lr=0.00005,
                                       betas=(0.5, 0.999))

        if resume:
            self.load()

        self.useTensorboard = useTensorboard
        self.tensorGraphInitialized = False
        self.writer = None
        if useTensorboard:
            self.writer = SummaryWriter(
                os.path.join(self.logDir, 'tensorboard'))
Esempio n. 5
0
    def __init__(self,
                 logDir,
                 printEvery=1,
                 resume=False,
                 lossRatio=0.0,
                 useTensorboard=True):
        super().__init__()

        self.printEvery = printEvery
        self.logDir = logDir
        self.lossRatio = lossRatio  # (1-a)*dissimLoss + a*realismLoss
        self.currentEpoch = 0
        self.totalBatches = 0

        self.P = Projector()
        # pre-trained G and D !
        self.G = Generator()
        self.D = Discriminator()

        # once hook is attached, activations will be pushed to self.activations
        self.D.attachLayerHook(self.D.layer3)

        self.device = torch.device('cpu')
        if torch.cuda.is_available():
            self.device = torch.device('cuda:0')

            self.G = self.G.to(self.device)
            self.D = self.D.to(self.device)
            self.P = self.P.to(self.device)

            # parallelize models on both devices, splitting input on batch dimension
            self.G = torch.nn.DataParallel(self.G, device_ids=[0, 1])
            self.D = torch.nn.DataParallel(self.D, device_ids=[0, 1])
            self.P = torch.nn.DataParallel(self.P, device_ids=[0, 1])

        self.optim = torch.optim.Adam(self.P.parameters(),
                                      lr=0.0005,
                                      betas=(0.5, 0.999))

        self.load(resume=resume)

        self.useTensorboard = useTensorboard
        self.tensorGraphInitialized = False
        self.writer = None
        if useTensorboard:
            self.writer = SummaryWriter(
                os.path.join(self.logDir, 'tensorboard'))
Esempio n. 6
0
    def __init__(self, args, config=config):
        self.args = args
        self.attribute = args.attribute
        self.gpu = args.gpu
        self.mode = args.mode
        self.restore = args.restore

        # init dataset and networks
        self.config = config
        self.dataset = ShapeNet(self.attribute)
        self.G = Generator()
        self.D = Discriminator()

        self.adv_criterion = torch.nn.BCELoss()

        self.set_mode_and_gpu()
        self.restore_from_file()
Esempio n. 7
0
def test_discriminator(input, net):
    # torch.cuda.is_available = lambda : False

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    discriminator = Discriminator(net).to(device)

    # optimizer_generator = Adam(generator.parameters())
    # criterion = nn.BCELoss()

    discriminator.apply(weights_init)

    # Print the model
    print(discriminator)

    vector = discriminator(input)
    print('output of discriminator is:', vector.item())
Esempio n. 8
0
    trn_dataset = HelenDataset(mode='train')
    val_dataset = HelenDataset(mode='test')
else:
    print('not implemented')
    exit()

trn_dloader = torch.utils.data.DataLoader(dataset=trn_dataset, batch_size=14, shuffle=True)
val_dloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=1, shuffle=False)
hmaps_ch, pmaps_ch = trn_dataset.num_channels()

# load networks
G = FSRNet(hmaps_ch, pmaps_ch)
G = nn.DataParallel(G)
G = G.cuda()

D = Discriminator(input_shape=(3, 128, 128))
D = nn.DataParallel(D)
D = D.cuda()

F = FeatureExtractor().cuda()
F.eval()

# settings
a = 1
b = 1
r_c = 1e-3
r_p = 1e-1
learning_rate = 2.5e-4
criterion_MSE = nn.MSELoss()
criterion_BCE = nn.BCELoss()
optimizer_G = optim.RMSprop(G.parameters(), lr=learning_rate)
Esempio n. 9
0
def train(dataloader,
          num_epochs,
          net,
          run_settings,
          learning_rate=0.0002,
          optimizerD='Adam'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create the nets
    generator = Generator(net).to(device)
    discriminator = Discriminator(net).to(device)

    # Apply the weights_init function to randomly initialize all weights
    generator.apply(weights_init)
    discriminator.apply(weights_init)

    # Initialize BCELoss function
    criterion = nn.BCELoss()

    # Create batch of latent vectors that we will use to visualize
    #  the progression of the generator
    fixed_noise = torch.randn(64, nz, 1, 1, device=device)

    # Establish convention for real and fake labels during training
    real_label = 1.
    fake_label = 0.

    beta1 = 0.5

    # Setup Adam optimizers for both G and D
    if optimizerD == 'SGD':
        optimizerD = optim.SGD(discriminator.parameters(), lr=learning_rate)
    else:
        optimizerD = optim.Adam(discriminator.parameters(),
                                lr=learning_rate,
                                betas=(beta1, 0.999))
    optimizerG = optim.Adam(generator.parameters(),
                            lr=learning_rate,
                            betas=(beta1, 0.999))

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    print("Starting Training Loop...")
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            ## Train with all-real batch
            discriminator.zero_grad()
            # Format batch
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size, ),
                               real_label,
                               dtype=torch.float,
                               device=device)
            # Forward pass real batch through D
            output = discriminator(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            # Generate fake image batch with G
            fake = generator(noise)
            label.fill_(fake_label)
            # Classify all fake batch with D
            output = discriminator(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            # Calculate the gradients for this batch
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            generator.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = discriminator(fake).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, label)
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output.mean().item()
            # Update G
            optimizerG.step()

            # Output training stats
            if i % 3 == 0:
                print(
                    '[%d/%d][%d/%d]\t\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                    % (epoch + 1, num_epochs, i + 1, len(dataloader),
                       errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Check how the generator is doing by saving its output on fixed_noise
            if (iters %
                (len(dataloader) * 50) == 0) or ((epoch == num_epochs - 1) and
                                                 (i == len(dataloader) - 1)):
                with torch.no_grad():
                    fake = generator(fixed_noise).detach().cpu()
                img_list.append(
                    vutils.make_grid(fake, padding=2, normalize=True))

            iters += 1

    print("finished")

    for i in range(len(img_list)):
        plt.imshow(np.transpose(img_list[i], (1, 2, 0)))
        plt.savefig('generated_images_' + str(i) + '.png')

    plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
    plt.savefig('generated_images_' + run_settings + '.png')

    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig('loss_graph_' + run_settings + '.png')
Esempio n. 10
0
def train(args):
    device_str = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device_str)

    gen = Generator(args.nz, 800)
    gen = gen.to(device)
    gen.apply(weights_init)

    discriminator = Discriminator(800)
    discriminator = discriminator.to(device)
    discriminator.apply(weights_init)

    bce = nn.BCELoss()
    bce = bce.to(device)

    galaxy_dataset = GalaxySet(args.data_path,
                               normalized=args.normalized,
                               out=args.out)
    loader = DataLoader(galaxy_dataset,
                        batch_size=args.bs,
                        shuffle=True,
                        num_workers=2,
                        drop_last=True)
    loader_iter = iter(loader)

    d_optimizer = Adam(discriminator.parameters(),
                       betas=(0.5, 0.999),
                       lr=args.lr)
    g_optimizer = Adam(gen.parameters(), betas=(0.5, 0.999), lr=args.lr)

    real_labels = to_var(torch.ones(args.bs), device_str)
    fake_labels = to_var(torch.zeros(args.bs), device_str)
    fixed_noise = to_var(torch.randn(1, args.nz), device_str)

    for i in tqdm(range(args.iters)):
        try:
            batch_data = loader_iter.next()
        except StopIteration:
            loader_iter = iter(loader)
            batch_data = loader_iter.next()

        batch_data = to_var(batch_data, device).unsqueeze(1)

        batch_data = batch_data[:, :, :1600:2]
        batch_data = batch_data.view(-1, 800)

        ### Train Discriminator ###

        d_optimizer.zero_grad()

        # train Infer with real
        pred_real = discriminator(batch_data)
        d_loss = bce(pred_real, real_labels)

        # train infer with fakes
        z = to_var(torch.randn((args.bs, args.nz)), device)
        fakes = gen(z)
        pred_fake = discriminator(fakes.detach())
        d_loss += bce(pred_fake, fake_labels)

        d_loss.backward()

        d_optimizer.step()

        ### Train Gen ###

        g_optimizer.zero_grad()

        z = to_var(torch.randn((args.bs, args.nz)), device)
        fakes = gen(z)
        pred_fake = discriminator(fakes)
        gen_loss = bce(pred_fake, real_labels)

        gen_loss.backward()
        g_optimizer.step()

        if i % 5000 == 0:
            print("Iteration %d >> g_loss: %.4f., d_loss: %.4f." %
                  (i, gen_loss, d_loss))
            torch.save(gen.state_dict(),
                       os.path.join(args.out, 'gen_%d.pkl' % 0))
            torch.save(discriminator.state_dict(),
                       os.path.join(args.out, 'disc_%d.pkl' % 0))
            gen.eval()
            fixed_fake = gen(fixed_noise).detach().cpu().numpy()
            real_data = batch_data[0].detach().cpu().numpy()
            gen.train()
            display_noise(fixed_fake.squeeze(),
                          os.path.join(args.out, "gen_sample_%d.png" % i))
            display_noise(real_data.squeeze(),
                          os.path.join(args.out, "real_%d.png" % 0))
Esempio n. 11
0
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Linear(512, 1024),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Linear(1024, 784),
#             nn.Tanh()
#         )

#     def forward(self, x):
#         x = x.view(x.size(0), 100)
#         out = self.model(x)
#         return out

from nets import Generator, Discriminator

G = Generator((100, 500, 28 * 28), 'relu')
D = Discriminator((28 * 28, 500, 1), 'relu')

discriminator = D.cuda()
generator = G.cuda()

criterion = nn.BCELoss()
lr = 0.0002
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)


def train_discriminator(discriminator, images, real_labels, fake_images,
                        fake_labels):
    discriminator.zero_grad()

    # real_outputs = discriminator(images.reshape(-1, 28*28))
Esempio n. 12
0
            list_B.append(path)

iterator_A = data_iterator(
    DataSource(list_A,
               os.path.join(opt.root, 'Img/img_align_celeba_png/'),
               shuffle=True), opt.batch_size)
iterator_B = data_iterator(
    DataSource(list_B,
               os.path.join(opt.root, 'Img/img_align_celeba_png/'),
               shuffle=True), opt.batch_size)

# define networks
gen_AB = Generator('gen_AB', opt.hidden_channel, opt.out_channel)
gen_BA = Generator('gen_BA', opt.hidden_channel, opt.out_channel)

dis_A = Discriminator('dis_A', opt.hidden_channel)
dis_B = Discriminator('dis_B', opt.hidden_channel)

# define solvers
solver_gen_AB = S.Adam(opt.learning_rate, beta1=0.5)
solver_gen_BA = S.Adam(opt.learning_rate, beta1=0.5)

solver_dis_A = S.Adam(opt.learning_rate, beta1=0.5)
solver_dis_B = S.Adam(opt.learning_rate, beta1=0.5)

# define updater
updater = Updater(opt.batch_size, opt.lmd, opt.input_shape, iterator_A,
                  iterator_B, gen_AB, gen_BA, dis_A, dis_B, solver_gen_AB,
                  solver_gen_BA, solver_dis_A, solver_dis_B)

# define monitor
Esempio n. 13
0
def gan_augment(x, y, seed, n_samples=None):
    if n_samples is None:
        n_samples = len(x)

    lr = 3e-4
    num_ep = 300
    z_dim = 100
    model_path = "./gan_checkpoint_%d.pth" % seed

    device = "cuda" if torch.cuda.is_available() else "cpu"
    G = Generator(z_dim).to(device)
    D = Discriminator(z_dim).to(device)
    bce_loss = nn.BCELoss()
    G_optim = optim.Adam(G.parameters(), lr=lr * 3, betas=(0.5, 0.999))
    D_optim = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

    batch = 64
    train_x = torch.Tensor(x)
    train_labels = torch.LongTensor(y)

    if os.path.exists(model_path):
        print("load trained GAN...")
        state = torch.load(model_path)
        G.load_state_dict(state["G"])
    else:
        print("training a new GAN...")
        for epoch in range(num_ep):
            for _ in range(len(train_x) // batch):
                idx = np.random.choice(range(len(train_x)), batch)
                batch_x = train_x[idx].to(device)
                batch_labels = train_labels[idx].to(device)

                y_real = torch.ones(batch).to(device)
                y_fake = torch.zeros(batch).to(device)

                # train D with real images
                D.zero_grad()
                D_real_out = D(batch_x, batch_labels).squeeze()
                D_real_loss = bce_loss(D_real_out, y_real)

                # train D with fake images
                z_ = torch.randn((batch, z_dim)).view(-1, z_dim, 1,
                                                      1).to(device)
                fake_labels = torch.randint(0, 10, (batch, )).to(device)
                G_out = G(z_, fake_labels)

                D_fake_out = D(G_out, fake_labels).squeeze()
                D_fake_loss = bce_loss(D_fake_out, y_fake)
                D_loss = D_real_loss + D_fake_loss
                D_loss.backward()
                D_optim.step()

                # train G
                G.zero_grad()
                z_ = torch.randn((batch, z_dim)).view(-1, z_dim, 1,
                                                      1).to(device)
                fake_labels = torch.randint(0, 10, (batch, )).to(device)
                G_out = G(z_, fake_labels)
                D_out = D(G_out, fake_labels).squeeze()
                G_loss = bce_loss(D_out, y_real)
                G_loss.backward()
                G_optim.step()

            plot2img(G_out[:50].cpu())
            print("epoch: %d G_loss: %.2f D_loss: %.2f" %
                  (epoch, G_loss, D_loss))
        state = {"G": G.state_dict(), "D": D.state_dict()}
        torch.save(state, model_path)

    with torch.no_grad():
        z_ = torch.randn((n_samples, z_dim)).view(-1, z_dim, 1, 1).to(device)
        fake_labels = torch.randint(0, 10, (n_samples, )).to(device)
        G_samples = G(z_, fake_labels)
        samples = G_samples.cpu().numpy().reshape((-1, 28, 28, 1))
    return samples, fake_labels.cpu().numpy()
Esempio n. 14
0
def main():
    # Supervised GAN?
    options = [False, True]
    # Alternative: run over different pre-processing types, comment the above line and uncomment the one below
    # options = [None,'returns','logreturns','scale_S_ref']

    results_path = META['results_path']

    for i in range(len(options)):
        # Reset the seed at each iteration for equal initalisation of the nets
        torch.manual_seed(SEED)
        np.random.seed(seed=SEED)
        META['seed'] = SEED

        # Make folder for each run of the training loop
        if not pt.exists(pt.join(results_path, 'iter_%d' % i)):
            os.mkdir(pt.join(results_path + '/iter_%d' % i))

        #---------------------------------------------------------
        # Modify training conditions in loop
        #---------------------------------------------------------

        META['supervised'] = options[i]
        # Alternative: run over different pre-processing types, comment the above line and uncomment the one below
        # META['proc_type'] = options[i]

        #---------------------------------------------------------

        # Override the default n_D, the amount of training steps of D per G training step if vanilla GAN.
        META['n_D'] = 1 if META['supervised'] else META['n_D']

        #---------------------------------------------------------

        # Make the dataset and initialise the GAN
        # X.generate_CIR_data()
        X = load_preset(META['preset'],
                        N_train=META['N_train'],
                        N_test=META['N_test'])
        X.exact = preprocess(X.exact,torch.tensor(X.params['S0'],dtype=torch.float32).view(-1,1),proc_type=META['proc_type'],\
         S_ref=torch.tensor(X.params['S_bar'],device=torch.device('cpu'),dtype=torch.float32),eps=META['eps'])

        c_dim = 0 if X.C is None else len(X.C)
        netG = Generator(c_dim=c_dim).to(DEVICE)
        netG.eps = META['eps']
        netD = Discriminator(c_dim=c_dim+1,negative_slope=META['negative_slope'],hidden_dim=META['hidden_dim'],activation=META['activation']).to(DEVICE) if META['supervised']\
         else Discriminator(c_dim=c_dim,negative_slope=META['negative_slope'],hidden_dim=META['hidden_dim'],activation=META['activation']).to(DEVICE)
        analysis = CGANalysis(X,
                              netD,
                              netG,
                              SDE=X.SDE,
                              save_all_figs=META['save_figs'],
                              results_path=results_path,
                              proc_type=META['proc_type'],
                              eps=META['eps'],
                              supervised=META['supervised'])

        # Traing the GAN
        output_dict, results_df = train_GAN(netD, netG, X, META)

        # Store results
        netG_dir = pt.join(results_path, 'iter_%d' % i, 'netG.pth')
        netD_dir = pt.join(results_path, 'iter_%d' % i, 'netD.pth')
        torch.save(netG.state_dict(), netG_dir)
        print('Saved Generator in %s' % netG_dir)
        torch.save(netD.state_dict(), netD_dir)
        print('Saved Discriminator in %s' % netD_dir)

        if META['report'] == True:
            results_df.to_csv(pt.join(results_path, 'iter_%d' % i,
                                      'train_log.csv'),
                              index=False,
                              header=True)
            # Uncomment to save the entire output dict
            # log_path = pt.join(results_path,'iter_%d'%i,'train_log.pkl')
            # pickle_it(output_dict,log_path)
            meta_path = pt.join(results_path, 'iter_%d' % i, 'metadata.pkl')
            pickle_it(META, meta_path)
            print('Saved logs in ' + results_path + '/iter_%d/' % i)

    print('----- Experiment finished -----')