class _3DGAN(object): def __init__(self, args, config=config): self.args = args self.attribute = args.attribute self.gpu = args.gpu self.mode = args.mode self.restore = args.restore # init dataset and networks self.config = config self.dataset = ShapeNet(self.attribute) self.G = Generator() self.D = Discriminator() self.adv_criterion = torch.nn.BCELoss() self.set_mode_and_gpu() self.restore_from_file() def set_mode_and_gpu(self): if self.mode == 'train': self.G.train() self.D.train() if self.gpu: with torch.cuda.device(self.gpu[0]): self.G.cuda() self.D.cuda() self.adv_criterion.cuda() if len(self.gpu) > 1: self.G = torch.nn.DataParallel(self.G, device_ids=self.gpu) self.D = torch.nn.DataParallel(self.D, device_ids=self.gpu) elif self.mode == 'test': self.G.eval() self.D.eval() if self.gpu: with torch.cuda.device(self.gpu[0]): self.G.cuda() self.D.cuda() if len(self.gpu) > 1: self.G = torch.nn.DataParallel(self.G, device_ids=self.gpu) self.D = torch.nn.DataParallel(self.D, device_ids=self.gpu) else: raise NotImplementationError() def restore_from_file(self): if self.restore is not None: ckpt_file_G = os.path.join( self.config.model_dir, 'G_iter_{:06d}.pth'.format(self.restore)) assert os.path.exists(ckpt_file_G) self.G.load_state_dict(torch.load(ckpt_file_G)) if self.mode == 'train': ckpt_file_D = os.path.join( self.config.model_dir, 'D_iter_{:06d}.pth'.format(self.restore)) assert os.path.exists(ckpt_file_D) self.D.load_state_dict(torch.load(ckpt_file_D)) self.start_step = self.restore + 1 else: self.start_step = 1 def save_log(self): scalar_info = { 'loss_D': self.loss_D, 'loss_G': self.loss_G, 'G_lr': self.G_lr_scheduler.get_lr()[0], 'D_lr': self.D_lr_scheduler.get_lr()[0], } for key, value in self.G_loss.items(): scalar_info['G_loss/' + key] = value for key, value in self.D_loss.items(): scalar_info['D_loss/' + key] = value for tag, value in scalar_info.items(): self.writer.add_scalar(tag, value, self.step) def save_img(self, save_num=5): for i in range(save_num): mdict = {'instance': self.fake_X[i, 0].data.cpu().numpy()} sio.savemat( os.path.join(self.config.img_dir, '{:06d}_{:02d}.mat'.format(self.step, i)), mdict) def save_model(self): torch.save( {key: val.cpu() for key, val in self.G.state_dict().items()}, os.path.join(self.config.model_dir, 'G_iter_{:06d}.pth'.format(self.step))) torch.save( {key: val.cpu() for key, val in self.D.state_dict().items()}, os.path.join(self.config.model_dir, 'D_iter_{:06d}.pth'.format(self.step))) def train(self): self.writer = SummaryWriter(self.config.log_dir) self.opt_G = torch.optim.Adam(self.G.parameters(), lr=self.config.G_lr, betas=(0.5, 0.999)) self.opt_D = torch.optim.Adam(self.D.parameters(), lr=self.config.D_lr, betas=(0.5, 0.999)) self.G_lr_scheduler = torch.optim.lr_scheduler.StepLR( self.opt_G, step_size=self.config.step_size, gamma=self.config.gamma) self.D_lr_scheduler = torch.optim.lr_scheduler.StepLR( self.opt_D, step_size=self.config.step_size, gamma=self.config.gamma) # start training for step in range(self.start_step, 1 + self.config.max_iter): self.step = step self.G_lr_scheduler.step() self.D_lr_scheduler.step() self.real_X = next(self.dataset.gen(True)) self.noise = torch.randn(self.config.nchw[0], 200) if len(self.gpu): with torch.cuda.device(self.gpu[0]): self.real_X = self.real_X.cuda() self.noise = self.noise.cuda() self.fake_X = self.G(self.noise) # update D self.D_real = self.D(self.real_X) self.D_fake = self.D(self.fake_X.detach()) self.D_loss = { 'adv_real': self.adv_criterion(self.D_real, torch.ones_like(self.D_real)), 'adv_fake': self.adv_criterion(self.D_fake, torch.zeros_like(self.D_fake)), } self.loss_D = sum(self.D_loss.values()) self.opt_D.zero_grad() self.loss_D.backward() self.opt_D.step() # update G self.D_fake = self.D(self.fake_X) self.G_loss = { 'adv_fake': self.adv_criterion(self.D_fake, torch.ones_like(self.D_fake)) } self.loss_G = sum(self.G_loss.values()) self.opt_G.zero_grad() self.loss_G.backward() self.opt_G.step() print('step: {:06d}, loss_D: {:.6f}, loss_G: {:.6f}'.format( self.step, self.loss_D.data.cpu().numpy(), self.loss_G.data.cpu().numpy())) if self.step % 100 == 0: self.save_log() if self.step % 1000 == 0: self.save_img() self.save_model() print('Finished training!') self.writer.close()
class GAN_CLS(object): def __init__(self, args, data_loader, SUPERVISED=True): """ Arguments : ---------- args : Arguments defined in Argument Parser data_loader = An instance of class DataLoader for loading our dataset in batches SUPERVISED : """ self.data_loader = data_loader self.num_epochs = args.num_epochs self.batch_size = args.batch_size self.log_step = config.log_step self.sample_step = config.sample_step self.log_dir = args.log_dir self.checkpoint_dir = args.checkpoint_dir self.sample_dir = config.sample_dir self.final_model = args.final_model self.dataset = args.dataset self.model_name = args.model_name self.img_size = args.img_size self.z_dim = args.z_dim self.text_embed_dim = args.text_embed_dim self.text_reduced_dim = args.text_reduced_dim self.learning_rate = args.learning_rate self.beta1 = args.beta1 self.beta2 = args.beta2 self.l1_coeff = args.l1_coeff self.resume_epoch = args.resume_epoch self.SUPERVISED = SUPERVISED # Logger setting self.logger = logging.getLogger('__name__') self.logger.setLevel(logging.INFO) self.formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') self.file_handler = logging.FileHandler(self.log_dir) self.file_handler.setFormatter(self.formatter) self.logger.addHandler(self.file_handler) self.build_model() def build_model(self): """ A function of defining following instances : ----- Generator ----- Discriminator ----- Optimizer for Generator ----- Optimizer for Discriminator ----- Defining Loss functions """ # --------------------------------------------------------------------- # 1. Network Initialization # --------------------------------------------------------------------- self.gen = Generator(batch_size=self.batch_size, img_size=self, img_size, z_dim=self.z_dim, text_embed_dim=self.text_embed_dim, text_reduced_dim=self.text_reduced_dim) self.disc = Discriminator(batch_size=self.batch_size, img_size=self, img_size, text_embed_dim=self.text_embed_dim, text_reduced_dim=self.text_reduced_dim) self.gen_optim = optim.Adam(self.gen.parameters(), lr=self.learning_rate, betas=(self.beta1, self.beta2)) self.disc_optim = optim.Adam(self.disc.parameters(), lr=self.learning_rate, betas=(self.beta1, self.beta2)) self.cls_gan_optim = optim.Adam(itertools.chain(self.gen.parameters(), self.disc.parameters()), lr=self.learning_rate, betas=(self.beta1, self.beta2)) print ('------------- Generator Model Info ---------------') self.print_network(self.gen, 'G') print ('------------------------------------------------') print ('------------- Discriminator Model Info ---------------') self.print_network(self.disc, 'D') print ('------------------------------------------------') self.gen.cuda() self.disc.cuda() self.criterion = nn.BCELoss().cuda() # self.CE_loss = nn.CrossEntropyLoss().cuda() # self.MSE_loss = nn.MSELoss().cuda() self.gen.train() self.disc.train() def print_network(self, model, name): """ A function for printing total number of model parameters """ num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("Total number of parameters: {}".format(num_params)) def load_checkpoints(self, resume_epoch): """Restore the trained generator and discriminator.""" print('Loading the trained models from step {}...'.format(resume_epoch)) G_path = os.path.join(self.checkpoint_dir, '{}-G.ckpt'.format(resume_epoch)) D_path = os.path.join(self.checkpoint_dir, '{}-D.ckpt'.format(resume_epoch)) self.gen.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) self.disc.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) def train_model(self): data_loader = self.data_loader start_epoch = 0 if self.resume_epoch: start_epoch = self.resume_epoch self.load_checkpoints(self.resume_epoch) print ('--------------- Model Training Started ---------------') start_time = time.time() for epoch in range(start_epoch, self.num_epochs): for idx, batch in enumerate(data_loader): true_imgs = batch['true_imgs'] true_embed = batch['true_embed'] false_imgs = batch['false_imgs'] real_labels = torch.ones(true_imgs.size(0)) fake_labels = torch.zeros(true_imgs.size(0)) smooth_real_labels = torch.FloatTensor(Utils.smooth_label(real_labels.numpy(), -0.1)) true_imgs = Variable(true_imgs.float()).cuda() true_embed = Variable(true_embed.float()).cuda() false_imgs = Variable(false_imgs.float()).cuda() real_labels = Variable(real_labels).cuda() smooth_real_labels = Variable(smooth_real_labels).cuda() fake_labels = Variable(fake_labels).cuda() # --------------------------------------------------------------- # 2. Training the generator # --------------------------------------------------------------- self.gen.zero_grad() z = Variable(torch.randn(true_imgs.size(0), self.z_dim)).cuda() fake_imgs = self.gen(true_embed, z) fake_out, fake_logit = self.disc(fake_imgs, true_embed) true_out, true_logit = self.disc(true_imgs, true_embed) gen_loss = self.criterion(fake_out, real_labels) + self.l1_coeff * nn.L1Loss(fake_imgs, true_imgs) gen_loss.backward() self.gen_optim.step() # --------------------------------------------------------------- # 3. Training the discriminator # --------------------------------------------------------------- self.disc.zero_grad() false_out, false_logit = self.disc(false_imgs, true_embed) disc_loss = self.criterion(true_out, smooth_real_labels) + self.criterion(fake_out, fake_labels) + self.criterion(false_out, fake_labels) disc_loss.backward() self.disc_optim.step() # self.cls_gan_optim.step() # Logging loss = {} loss['G_loss'] = gen_loss.item() loss['D_loss'] = disc_loss.item() # --------------------------------------------------------------- # 4. Logging INFO into log_dir # --------------------------------------------------------------- if (idx + 1) % self.log_step == 0: end_time = time.time() - start_time end_time = datetime.timedelta(seconds=end_time) log = "Elapsed [{}], Epoch [{}/{}], Idx [{}]".format(end_time, epoch + 1, self.num_epochs, idx) for net, loss_value in loss.items(): log += ", {}: {:.4f}".format(net, loss_value) self.logger.info(log) print (log) # --------------------------------------------------------------- # 5. Saving generated images # --------------------------------------------------------------- if (idx + 1) % self.sample_step == 0: concat_imgs = torch.cat((true_imgs, fake_imgs), 2) # ?????????? save_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(idx + 1)) cocat_imgs = (cocat_imgs + 1) / 2 # out.clamp_(0, 1) save_image(concat_imgs.data.cpu(), self.sample_dir, nrow=1, padding=0) print ('Saved real and fake images into {}...'.format(self.sample_dir)) # --------------------------------------------------------------- # 6. Saving the checkpoints & final model # --------------------------------------------------------------- if (idx + 1) % self.model_save_step == 0: G_path = os.path.join(self.checkpoint_dir, '{}-G.ckpt'.format(idx + 1)) D_path = os.path.join(self.checkpoint_dir, '{}-D.ckpt'.format(idx + 1)) torch.save(self.gen.state_dict(), G_path) torch.save(self.disc.state_dict(), D_path) print('Saved model checkpoints into {}...'.format(self.checkpoint_dir))
def train(args): device_str = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device_str) gen = Generator(args.nz, 800) gen = gen.to(device) gen.apply(weights_init) discriminator = Discriminator(800) discriminator = discriminator.to(device) discriminator.apply(weights_init) bce = nn.BCELoss() bce = bce.to(device) galaxy_dataset = GalaxySet(args.data_path, normalized=args.normalized, out=args.out) loader = DataLoader(galaxy_dataset, batch_size=args.bs, shuffle=True, num_workers=2, drop_last=True) loader_iter = iter(loader) d_optimizer = Adam(discriminator.parameters(), betas=(0.5, 0.999), lr=args.lr) g_optimizer = Adam(gen.parameters(), betas=(0.5, 0.999), lr=args.lr) real_labels = to_var(torch.ones(args.bs), device_str) fake_labels = to_var(torch.zeros(args.bs), device_str) fixed_noise = to_var(torch.randn(1, args.nz), device_str) for i in tqdm(range(args.iters)): try: batch_data = loader_iter.next() except StopIteration: loader_iter = iter(loader) batch_data = loader_iter.next() batch_data = to_var(batch_data, device).unsqueeze(1) batch_data = batch_data[:, :, :1600:2] batch_data = batch_data.view(-1, 800) ### Train Discriminator ### d_optimizer.zero_grad() # train Infer with real pred_real = discriminator(batch_data) d_loss = bce(pred_real, real_labels) # train infer with fakes z = to_var(torch.randn((args.bs, args.nz)), device) fakes = gen(z) pred_fake = discriminator(fakes.detach()) d_loss += bce(pred_fake, fake_labels) d_loss.backward() d_optimizer.step() ### Train Gen ### g_optimizer.zero_grad() z = to_var(torch.randn((args.bs, args.nz)), device) fakes = gen(z) pred_fake = discriminator(fakes) gen_loss = bce(pred_fake, real_labels) gen_loss.backward() g_optimizer.step() if i % 5000 == 0: print("Iteration %d >> g_loss: %.4f., d_loss: %.4f." % (i, gen_loss, d_loss)) torch.save(gen.state_dict(), os.path.join(args.out, 'gen_%d.pkl' % 0)) torch.save(discriminator.state_dict(), os.path.join(args.out, 'disc_%d.pkl' % 0)) gen.eval() fixed_fake = gen(fixed_noise).detach().cpu().numpy() real_data = batch_data[0].detach().cpu().numpy() gen.train() display_noise(fixed_fake.squeeze(), os.path.join(args.out, "gen_sample_%d.png" % i)) display_noise(real_data.squeeze(), os.path.join(args.out, "real_%d.png" % 0))
class GAN3DTrainer(object): def __init__(self, logDir, printEvery=1, resume=False, useTensorboard=True): super(GAN3DTrainer, self).__init__() self.logDir = logDir self.currentEpoch = 0 self.totalBatches = 0 self.trainStats = {'lossG': [], 'lossD': [], 'accG': [], 'accD': []} self.printEvery = printEvery self.G = Generator() self.D = Discriminator() self.device = torch.device('cpu') if torch.cuda.is_available(): self.device = torch.device('cuda:0') self.G = self.G.to(self.device) self.D = self.D.to(self.device) # parallelize models on both devices, splitting input on batch dimension self.G = torch.nn.DataParallel(self.G, device_ids=[0, 1]) self.D = torch.nn.DataParallel(self.D, device_ids=[0, 1]) # optim params direct from paper self.optimG = torch.optim.Adam(self.G.parameters(), lr=0.0025, betas=(0.5, 0.999)) self.optimD = torch.optim.Adam(self.D.parameters(), lr=0.00005, betas=(0.5, 0.999)) if resume: self.load() self.useTensorboard = useTensorboard self.tensorGraphInitialized = False self.writer = None if useTensorboard: self.writer = SummaryWriter( os.path.join(self.logDir, 'tensorboard')) def train(self, trainData: torch.utils.data.DataLoader): epochLoss = 0.0 numBatches = 0 self.G.train() self.D.train() for i, sample in enumerate(tqdm(trainData)): data = sample['data'] self.optimG.zero_grad() self.G.zero_grad() self.optimD.zero_grad() self.D.zero_grad() realVoxels = torch.zeros(data['62'].shape[0], 64, 64, 64).to(self.device) realVoxels[:, 1:-1, 1:-1, 1:-1] = data['62'].to(self.device) # discriminator train z = torch.normal(torch.zeros(data['62'].shape[0], 200), torch.ones(data['62'].shape[0], 200) * 0.33).to( self.device) fakeVoxels = self.G(z) fakeD = self.D(fakeVoxels) realD = self.D(realVoxels) lossD = -torch.mean(torch.log(realD) + torch.log(1. - fakeD)) accD = ((realD >= .5).float().mean() + (fakeD < .5).float().mean()) / 2. accG = (fakeD > .5).float().mean() # only train if Disc wrong enough :) if accD < .8: self.D.zero_grad() lossD.backward() self.optimD.step() # gen train z = torch.normal(torch.zeros(data['62'].shape[0], 200), torch.ones(data['62'].shape[0], 200) * 0.33).to( self.device) fakeVoxels = self.G(z) fakeD = self.D(fakeVoxels) # https://arxiv.org/pdf/1706.05170.pdf (IV. Methods, A. Training the gen model) lossG = -torch.mean(torch.log(fakeD)) self.D.zero_grad() self.G.zero_grad() lossG.backward() self.optimG.step() #log numBatches += 1 if i % self.printEvery == 0: tqdm.write( f'[TRAIN] Epoch {self.currentEpoch:03d}, Batch {i:03d}: ' f'gen: {float(accG.item()):2.3f}, dis = {float(accD.item()):2.3f}' ) if (self.useTensorboard): self.writer.add_scalar('GenLoss/train', lossG, numBatches + self.totalBatches) self.writer.add_scalar('DisLoss/train', lossD, numBatches + self.totalBatches) self.writer.add_scalar('GenAcc/train', accG, numBatches + self.totalBatches) self.writer.add_scalar('DisAcc/train', accD, numBatches + self.totalBatches) self.writer.flush() if not self.tensorGraphInitialized: #TODO: why can't I push graph? tempZ = torch.autograd.Variable( torch.rand(data['62'].shape[0], 200, 1, 1, 1)).cuda(1) self.writer.add_graph(self.G.module, tempZ) self.writer.flush() self.writer.add_graph(self.D.module, fakeVoxels) self.writer.flush() self.tensorGraphInitialized = True #self.trainLoss.append(epochLoss) self.currentEpoch += 1 self.totalBatches += numBatches def save(self): logTable = { 'epoch': self.currentEpoch, 'totalBatches': self.totalBatches } torch.save(self.G.state_dict(), os.path.join(self.logDir, 'generator.pth')) torch.save(self.D.state_dict(), os.path.join(self.logDir, 'discrim.pth')) torch.save(self.optimG.state_dict(), os.path.join(self.logDir, 'optimG.pth')) torch.save(self.optimD.state_dict(), os.path.join(self.logDir, 'optimD.pth')) with open(os.path.join(self.logDir, 'recent.log'), 'w') as f: f.write(json.dumps(logTable)) pickle.dump(self.trainStats, open(os.path.join(self.logDir, 'trainStats.pkl'), 'wb')) tqdm.write('======== SAVED RECENT MODEL ========') def load(self): self.G.load_state_dict( torch.load(os.path.join(self.logDir, 'generator.pth'))) self.D.load_state_dict( torch.load(os.path.join(self.logDir, 'discrim.pth'))) self.optimG.load_state_dict( torch.load(os.path.join(self.logDir, 'optimG.pth'))) self.optimD.load_state_dict( torch.load(os.path.join(self.logDir, 'optimD.pth'))) with open(os.path.join(self.logDir, 'recent.log'), 'r') as f: runData = json.load(f) self.trainStats = pickle.load( open(os.path.join(self.logDir, 'trainStats.pkl'), 'rb')) self.currentEpoch = runData['epoch'] self.totalBatches = runData['totalBatches']
def gan_augment(x, y, seed, n_samples=None): if n_samples is None: n_samples = len(x) lr = 3e-4 num_ep = 300 z_dim = 100 model_path = "./gan_checkpoint_%d.pth" % seed device = "cuda" if torch.cuda.is_available() else "cpu" G = Generator(z_dim).to(device) D = Discriminator(z_dim).to(device) bce_loss = nn.BCELoss() G_optim = optim.Adam(G.parameters(), lr=lr * 3, betas=(0.5, 0.999)) D_optim = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999)) batch = 64 train_x = torch.Tensor(x) train_labels = torch.LongTensor(y) if os.path.exists(model_path): print("load trained GAN...") state = torch.load(model_path) G.load_state_dict(state["G"]) else: print("training a new GAN...") for epoch in range(num_ep): for _ in range(len(train_x) // batch): idx = np.random.choice(range(len(train_x)), batch) batch_x = train_x[idx].to(device) batch_labels = train_labels[idx].to(device) y_real = torch.ones(batch).to(device) y_fake = torch.zeros(batch).to(device) # train D with real images D.zero_grad() D_real_out = D(batch_x, batch_labels).squeeze() D_real_loss = bce_loss(D_real_out, y_real) # train D with fake images z_ = torch.randn((batch, z_dim)).view(-1, z_dim, 1, 1).to(device) fake_labels = torch.randint(0, 10, (batch, )).to(device) G_out = G(z_, fake_labels) D_fake_out = D(G_out, fake_labels).squeeze() D_fake_loss = bce_loss(D_fake_out, y_fake) D_loss = D_real_loss + D_fake_loss D_loss.backward() D_optim.step() # train G G.zero_grad() z_ = torch.randn((batch, z_dim)).view(-1, z_dim, 1, 1).to(device) fake_labels = torch.randint(0, 10, (batch, )).to(device) G_out = G(z_, fake_labels) D_out = D(G_out, fake_labels).squeeze() G_loss = bce_loss(D_out, y_real) G_loss.backward() G_optim.step() plot2img(G_out[:50].cpu()) print("epoch: %d G_loss: %.2f D_loss: %.2f" % (epoch, G_loss, D_loss)) state = {"G": G.state_dict(), "D": D.state_dict()} torch.save(state, model_path) with torch.no_grad(): z_ = torch.randn((n_samples, z_dim)).view(-1, z_dim, 1, 1).to(device) fake_labels = torch.randint(0, 10, (n_samples, )).to(device) G_samples = G(z_, fake_labels) samples = G_samples.cpu().numpy().reshape((-1, 28, 28, 1)) return samples, fake_labels.cpu().numpy()
def main(): # Supervised GAN? options = [False, True] # Alternative: run over different pre-processing types, comment the above line and uncomment the one below # options = [None,'returns','logreturns','scale_S_ref'] results_path = META['results_path'] for i in range(len(options)): # Reset the seed at each iteration for equal initalisation of the nets torch.manual_seed(SEED) np.random.seed(seed=SEED) META['seed'] = SEED # Make folder for each run of the training loop if not pt.exists(pt.join(results_path, 'iter_%d' % i)): os.mkdir(pt.join(results_path + '/iter_%d' % i)) #--------------------------------------------------------- # Modify training conditions in loop #--------------------------------------------------------- META['supervised'] = options[i] # Alternative: run over different pre-processing types, comment the above line and uncomment the one below # META['proc_type'] = options[i] #--------------------------------------------------------- # Override the default n_D, the amount of training steps of D per G training step if vanilla GAN. META['n_D'] = 1 if META['supervised'] else META['n_D'] #--------------------------------------------------------- # Make the dataset and initialise the GAN # X.generate_CIR_data() X = load_preset(META['preset'], N_train=META['N_train'], N_test=META['N_test']) X.exact = preprocess(X.exact,torch.tensor(X.params['S0'],dtype=torch.float32).view(-1,1),proc_type=META['proc_type'],\ S_ref=torch.tensor(X.params['S_bar'],device=torch.device('cpu'),dtype=torch.float32),eps=META['eps']) c_dim = 0 if X.C is None else len(X.C) netG = Generator(c_dim=c_dim).to(DEVICE) netG.eps = META['eps'] netD = Discriminator(c_dim=c_dim+1,negative_slope=META['negative_slope'],hidden_dim=META['hidden_dim'],activation=META['activation']).to(DEVICE) if META['supervised']\ else Discriminator(c_dim=c_dim,negative_slope=META['negative_slope'],hidden_dim=META['hidden_dim'],activation=META['activation']).to(DEVICE) analysis = CGANalysis(X, netD, netG, SDE=X.SDE, save_all_figs=META['save_figs'], results_path=results_path, proc_type=META['proc_type'], eps=META['eps'], supervised=META['supervised']) # Traing the GAN output_dict, results_df = train_GAN(netD, netG, X, META) # Store results netG_dir = pt.join(results_path, 'iter_%d' % i, 'netG.pth') netD_dir = pt.join(results_path, 'iter_%d' % i, 'netD.pth') torch.save(netG.state_dict(), netG_dir) print('Saved Generator in %s' % netG_dir) torch.save(netD.state_dict(), netD_dir) print('Saved Discriminator in %s' % netD_dir) if META['report'] == True: results_df.to_csv(pt.join(results_path, 'iter_%d' % i, 'train_log.csv'), index=False, header=True) # Uncomment to save the entire output dict # log_path = pt.join(results_path,'iter_%d'%i,'train_log.pkl') # pickle_it(output_dict,log_path) meta_path = pt.join(results_path, 'iter_%d' % i, 'metadata.pkl') pickle_it(META, meta_path) print('Saved logs in ' + results_path + '/iter_%d/' % i) print('----- Experiment finished -----')