def get_model(model, args): if args.mode == 'train': if args.no_recommend_setting: print('Use arg parsing settings for training ') else: print( f'Change recommendation train settings based on {model} model') args = change_train_settings(model, args) if model == 'vanillagan': gen = vanillaGAN.Generator(args.z_dim, args.ngf, args.img_size) dis = vanillaGAN.Discriminator(args.ndf, args.img_size) elif model == 'dcgan': gen = DCGAN.Generator(args.z_dim, args.ngf, args.gen_img_ch) dis = DCGAN.Discriminator(args.ndf, args.gen_img_ch) elif model == 'cgan': gen = CGAN.Generator(args.z_dim, args.ngf, args.gen_img_ch) dis = CGAN.Discriminator(args.ndf, args.img_size, args.gen_img_ch) elif model == 'wgan': gen = WGAN.Generator(args.z_dim, args.ngf, args.gen_img_ch) dis = WGAN.Discriminator(args.ndf, args.gen_img_ch) elif model == 'wgangp': gen = WGANGP.Generator(args.z_dim, args.ngf, args.gen_img_ch) dis = WGANGP.Discriminator(args.ndf, args.gen_img_ch) if args.mode == 'train': return gen.to(args.device), dis.to(args.device) else: return gen.to(args.device)
def build(self): generator_loss, dc_gan_loss, discriminator_loss = self.losses self.generator = load_model("generator_unet_%s" % (self.flag), self.img_dim, 64, True, True, self.batch_size) nb_patch, img_shape_disc = get_nb_patch(self.img_dim, self.patch_size) self.discriminator = load_model("DCGAN_discriminator", self.img_dim, nb_patch, -1, True, self.batch_size) opt_dcgan, opt_discriminator = Adam( lr=self.learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08),\ Adam(lr=self.learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08) self.generator.compile(loss=generator_loss, optimizer=opt_discriminator) self.discriminator.trainable = False self.DCGAN_model = DCGAN(self.generator, self.discriminator, self.img_dim, self.patch_size, "channels_last") loss = [dc_gan_loss, discriminator_loss] loss_weight = [1E1, 1] self.DCGAN_model.compile(loss=loss, loss_weights=loss_weight, optimizer=opt_dcgan) self.discriminator.trainable = True self.discriminator.compile(loss=discriminator_loss, optimizer=opt_discriminator)
def create_model(op, device): model_name = op.model.name if model_name == 'dcgan': return DCGAN(op, device) elif model_name == 'cyclegan': return CycleGAN(op, device)
def main(unused_args): mnist = DCGANFashionMNIST() inputs_real, inputs_z, learning_rate = model_inputs( IMAGE_SIZE, IMAGE_SIZE, 1, FLAGS.z_dim) dcgan = DCGAN(inputs_real, inputs_z, learning_rate, FLAGS) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for epoch_i in range(FLAGS.epoch): for batch_images in mnist.to_batches(batch_size=FLAGS.batch_size): batch_z = np.random.uniform( -1, 1, size=[FLAGS.batch_size, FLAGS.z_dim]) _ = sess.run(dcgan.d_optimizer, feed_dict={ dcgan.inputs_real: batch_images, dcgan.inputs_z: batch_z, dcgan.learning_rate: FLAGS.learning_rate }) _ = sess.run(dcgan.g_optimizer, feed_dict={ dcgan.inputs_real: batch_images, dcgan.inputs_z: batch_z, dcgan.learning_rate: FLAGS.learning_rate })
def train(real_imgs, cls): num_cls = self.D.num_classes batch_sz = real_imgs.shape[0] self.optimizer_D.zero_grad() label = one_hot_embedding(cls, num_cls) real_pred = self.D(real_imgs).view(batch_sz, -1) real_loss = DCGAN.criterion(real_pred, label) z = self.D.get_features(real_imgs) fake_imgs = self.G(z).detach() label = one_hot_embedding(cls + num_cls / 2, num_cls) fake_pred = self.D(fake_imgs).view(batch_sz, -1) fake_loss = DCGAN.criterion(fake_pred, label) loss_D = real_loss + fake_loss loss_D.backward() self.optimizer_D.step() return loss_D.item(), z
def train(z, real_imgs, cls): num_cls = self.D.num_classes batch_sz = real_imgs.shape[0] self.optimizer_G.zero_grad() fake_imgs = self.G(z) label = one_hot_embedding(cls, num_cls) pred = self.D(fake_imgs).view(batch_sz, -1) loss_G = DCGAN.criterion(pred, label) loss_G.backward() self.optimizer_G.step() return loss_G.item(), fake_imgs
def main(argv=None): dcgan = DCGAN() lsgan = LSGAN() print(dcgan, lsgan)
def create_model(args): print('model', args) if args.id.lower() == 'mlp': return MLP([H * W * C, 100, 10]) elif args.id.lower() == 'dcgan': return DCGAN(args)
class K_DCGAN(data_model): def __init__( self, flag="deconv", epoch=100000, img_shape=[256, 256, 3], learning_rate=1e-3, white_bk=True, name="gan_unet_model", loss=[ 'categorical_crossentropy', # generator l1_loss, # K_dcgan 'binary_crossentropy' #dis criminator ]): self.loss = loss assert (flag == "deconv") or ( flag == "upsampling"), "Only support flag for deconv or upsampling" self.flag = flag self.img_dim = img_shape bk = "bk" if white_bk: bk = "wh" data_model.__init__(self, name + "bk%s_lr_%s_img_dim%s_loss_%s" % (bk, learning_rate, img_shape[0], loss), "DCGAN", img_shape=img_shape, epochs=epoch) assert len(loss) == 3 # training params self.patch_size = [64, 64] self.n_batch_per_epoch = self.batch_size * 100 self.learning_rate = learning_rate self.losses = loss # init all need dir and model self.build() self.disc_weights_path = os.path.join(self.weight_path, "disc_weight_epoch.h5") self.gen_weights_path = os.path.join(self.weight_path, "gen_weight_epoch.h5") self.DCGAN_weights_path = os.path.join(self.weight_path, "DCGAN_weight_epoch.h5") check_folders(self.weight_path) def build(self): generator_loss, dc_gan_loss, discriminator_loss = self.losses self.generator = load_model("generator_unet_%s" % (self.flag), self.img_dim, 64, True, True, self.batch_size) nb_patch, img_shape_disc = get_nb_patch(self.img_dim, self.patch_size) self.discriminator = load_model("DCGAN_discriminator", self.img_dim, nb_patch, -1, True, self.batch_size) opt_dcgan, opt_discriminator = Adam( lr=self.learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08),\ Adam(lr=self.learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08) self.generator.compile(loss=generator_loss, optimizer=opt_discriminator) self.discriminator.trainable = False self.DCGAN_model = DCGAN(self.generator, self.discriminator, self.img_dim, self.patch_size, "channels_last") loss = [dc_gan_loss, discriminator_loss] loss_weight = [1E1, 1] self.DCGAN_model.compile(loss=loss, loss_weights=loss_weight, optimizer=opt_dcgan) self.discriminator.trainable = True self.discriminator.compile(loss=discriminator_loss, optimizer=opt_discriminator) def log_checkpoint(self, epoch, batch, loss): log_path = os.path.join(self.weight_path, "checkpoint") prev_epochs = 0 if os.path.isfile(log_path): with open(log_path, "w+") as f: line = f.readline() if "Epoch" in line: line = f.readline().split(" ") prev_epochs = int(line[4]) with open(log_path, "w+") as f: f.write("Model_Name {} ".format(self.title)) f.write("Epoch {} in batch {}".format(epoch + prev_epochs, batch)) f.write("\n") f.write("Losses: {}".format(loss)) def save(self): if not os.path.exists(self.gen_weights_path): h5py.File(self.gen_weights_path) h5py.File(self.disc_weights_path) h5py.File(self.DCGAN_weights_path) self.generator.save_weights(self.gen_weights_path, overwrite=True) self.discriminator.save_weights(self.disc_weights_path, overwrite=True) self.DCGAN_model.save_weights(self.DCGAN_weights_path, overwrite=True) def test_img(self): # pick a random index idx = rnd.choice([i for i in range(0, len(self.data['X']))]) X, y = self.get_data(idx) # normalized images self.load() X_pred = self.generator.predict(np.array([X])) X = image.array_to_img(inverse_normalization(X, self.max, self.min)) y = image.array_to_img(inverse_normalization(y, self.max, self.min)) X_pred = image.array_to_img( inverse_normalization(X_pred[0], self.max, self.min)) suffix = "End_test" result = np.hstack((X, y, X_pred)) # check_folders("../figures/%s" % (self.title)) # plt.imshow(result) # plt.savefig( # "../figures/%s/current_batch_%s.png" % # (self.title, suffix)) # plt.axis("off") # plt.show() def load(self): '''Load models weight from log/${model_name}''' if os.path.exists(self.gen_weights_path): self.generator.load_weights(self.gen_weights_path) self.discriminator.load_weights(self.disc_weights_path) self.DCGAN_model.load_weights(self.DCGAN_weights_path) else: raise FileNotFoundError("No Previous Model Found") print("Loading model from {}".format([ self.disc_weights_path, self.gen_weights_path, self.DCGAN_weights_path ])) def summary(self, name="DCGAN"): if name == "Generator": self.generator.summary() elif name == "Discriminator": self.discriminator.summary() else: self.DCGAN_model.summary() @timeit(log_info="Training pix2pix") def train(self, label_smoothing=False, retrain=False): gen_loss, disc_loss = 100, 100 n_batch_per_epoch = self.n_batch_per_epoch total_epoch = n_batch_per_epoch * self.batch_size try: if retrain: print("Found prev_trained models ...") self.load() print("Retrain the model ") except FileNotFoundError: print("No previous model found start train a new model") try: os.system("clear") for e in range(self.nb_epochs): batch_counter = 1 start = time() progbar = generic_utils.Progbar(total_epoch) for X, y in self.gen_batch(self.batch_size): X_disc, y_disc = self.get_disc_batch( X, y, self.generator, batch_counter, self.patch_size, label_smoothing=label_smoothing, label_flipping=0) disc_loss = self.discriminator.train_on_batch( X_disc, y_disc) X_gen_target, Y_gen = next(self.gen_batch(self.batch_size)) self.generator.train_on_batch(X_gen_target, Y_gen) y_gen = np.zeros((Y_gen.shape[0], 2), dtype=np.uint8) y_gen[:, 1] = 1 self.discriminator.trainable = False gen_loss = self.DCGAN_model.train_on_batch( X_gen_target, [Y_gen, y_gen]) self.DCGAN_model.trainable = True batch_counter += 1 progbar.add(self.batch_size, values=[("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]) if batch_counter % (n_batch_per_epoch / 2) == 0: # Get new images from validation plot_generated_batch(X, y, self.generator, self.batch_size, "training", self.title, self) # get next validation batches X_test, y_test = next( self.gen_batch(self.batch_size, validation=True)) plot_generated_batch(X_test, y_test, self.generator, self.batch_size, "validation", self.title, self) if batch_counter >= n_batch_per_epoch: break print("") t_time = time() - start print('Epoch %s/%s, Time: %s ms' % (e + 1, self.nb_epochs, round(t_time, 2)), end="\r") if e % 5 == 0: self.save() self.log_checkpoint(e, batch_counter, [("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]) except KeyboardInterrupt: print( "\nInterruption occured.... Saving the model Epochs:{}".format( e)) self.save() self.log_checkpoint(e, batch_counter, [("D logloss", disc_loss), ("G tot", gen_loss[0]), ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])])
N = len(fnames) print("%d images found" % N) epochs = 50 batch_size = 128 lr = 1e-4 image_set = make_dataset(fnames) tr_set = image_set.shuffle(N).repeat(epochs).batch(batch_size) tr_feed = tr_set.make_one_shot_iterator().get_next() noise_dim = 100 noise_shape = (batch_size, noise_dim) birdgan = DCGAN() generator = birdgan.build_generator(input_shape=(noise_dim, )) discriminator = birdgan.build_discriminator() z = tf.random_normal(noise_shape) losses = make_losses(z, tr_feed, generator, discriminator) sampling_noise_seed = np.random.normal(size=(num_examples, noise_dim)) noise_feed = tf.placeholder(dtype=tf.float32, shape=(num_examples, noise_dim)) #holds images sampled from the generator during training sampled_images = generator(noise_feed, training=False) gen_optim = tf.train.AdamOptimizer(lr) disc_optim = tf.train.AdamOptimizer(lr)
torch.cuda.manual_seed_all(args.manualSeed) cudnn.benchmark = True if torch.cuda.is_available() and not args.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) # data dataset = create_dataset(args.dataset, image_size=args.image_size) dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers)) # model net = DCGAN() if args.resume != "": net.load_state_dicts(torch.load(args.resume)) # loss net.set_criterion([nn.BCELoss(), nn.BCELoss()]) net.train() # optimizer params = net.parameters() optim_G = optim.Adam(params[0], lr=args.lr, betas=(args.beta1, 0.999)) optim_D = optim.Adam(params[1], lr=args.lr, betas=(args.beta1, 0.999)) net.set_optim([optim_G, optim_D]) if args.cuda: net.cuda()
# Load generator model generator_model = load_model("generator_unet_%s" % generator, img_dim, nb_patch, bn_mode, use_mbd, batch_size, do_plot) # Load discriminator model discriminator_model = load_model("DCGAN_discriminator", img_dim_disc, nb_patch, bn_mode, use_mbd, batch_size, do_plot) # Compile generator model generator_model.compile(loss='mae', optimizer=opt_discriminator) discriminator_model.trainable = False # Define DCGAN model DCGAN_model = DCGAN(generator_model, discriminator_model, img_dim, patch_size, image_data_format) # Define loss function and loss weights loss = [l1_loss, 'binary_crossentropy'] loss_weights = [3, 1] # Compile DCGAN model DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) # Compile discriminator model discriminator_model.trainable = True discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator)
shuffle=False) print(f""" Total training data: {len(trainset)} Total testing data: {len(testset)} Total data: {len(trainset) + len(testset)} """, flush=True) # 2. instantiate the model Z_DIM = 100 # create the generator G G = DCGAN.Generator(channels=[Z_DIM, 256, 128, 64, 1], kernels=[None, 7, 5, 4, 4], strides=[None, 1, 1, 2, 2], paddings=[None, 0, 2, 1, 1], batch_norm=True, internal_activation=nn.ReLU(), output_activation=nn.Tanh()) # create the discrimintor D D = DCGAN.Discriminator(channels=[1, 64, 128, 256, 1], kernels=[None, 4, 4, 5, 7], strides=[None, 2, 2, 1, 1], paddings=[None, 1, 1, 2, 0], batch_norm=True, internal_activation=nn.LeakyReLU(0.2), output_activation=nn.Sigmoid()) print(f""" Generator G: {G}
def get_models(fname_hparams,device,load_gen=True,load_discr=True,verbose=True): # Set default params defaults = { 'model': 'original', 'nz': 100, 'nc': 1, 'ndf': 64, 'ngf': 64, 'n_epochs': 500, 'batch_size': 100, 'lrD': 0.0001, 'lrG': 0.0001, 'beta1': 0.5, 'beta2': 0.999, 'nD': 1, 'nG': 2, 'image_interval': 20, 'save_interval': 20, 'score_interval': 20, 'dataroot': '/home/raynor/datasets/april/velocity/', 'modelroot': '/home/raynor/code/seismogan/saved/', 'load_name': 'None', 'load_step': -1, } # Load params from text file hparams = load_hparams(fname_hparams,defaults) for i,h in enumerate(hparams): if os.path.exists(os.path.join(h.modelroot,h.name)): print(f'{h.name} folder exists. Skipping.') continue if verbose: print('Loading models') if h.model == 'DCGAN': gen = DCGAN.Generator(h.nz, h.nc, h.ngf, device) if load_gen else None discr = DCGAN.Discriminator(h.nc, h.ndf, device) if load_discr else None elif h.model == 'DCGAN_SN': gen = DCGAN_SN.Generator(h.nz, h.nc, h.ngf, device) if load_gen else None discr = DCGAN_SN.Discriminator(h.nc, h.ndf, device) if load_discr else None elif h.model == 'original': gen = original.Generator(h.nz, h.nc, h.ngf, device) if load_gen else None discr = original.Discriminator(h.nc, h.ndf, device) if load_discr else None elif h.model == 'original_SN': gen = original_SN.Generator(h.nz, h.nc, h.ngf, device) if load_gen else None discr = original_SN.Discriminator(h.nc, h.ndf, device) if load_discr else None elif h.model == 'original_SN2': gen = original_SN2.Generator(h.nz, h.nc, h.ngf, device) if load_gen else None discr = original_SN2.Discriminator(h.nc, h.ndf, device) if load_discr else None else: raise NotImplementedError if h.load_name.lower() != 'none': fname = load_models(os.path.join(h.modelroot,h.load_name),gen,discr,h.load_step) if verbose: print (f'Loaded model: {fname}') h.has_next = i+1 < len(hparams) yield h,gen,discr