def save_ckpt(generator, critic, adversarial): """Save the training information into a file. This includes but not limited to the information on the wieghts and the biases of the given network. The GANs model is a combination of three different neural networks (generator, critic/discriminator, adversarial) and the information on each one of them are saved. For more information on the constructor `Checkpoint` from the module `tensorflow.train`, refer to https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint Parameters ---------- generator : ganpdfs.model.WassersteinGanModel.generator generator neural network critic : ganpdfs.model.WassersteinGanModel.critic critic/discriminator neural network adversarial : ganpdfs.model.WassersteinGanModel.adversarial adversarial neural network Returns ------- A load status object, which can be used to make assertions about the status of a checkpoint restoration """ checkpoint = Checkpoint( critic=critic, generator=generator, adversarial=adversarial ) return checkpoint
def set_checkpoint(opt, G_YtoX, G_XtoY, D_X, D_Y, G_YtoX_optimizer, G_XtoY_optimizer, D_X_optimizer, D_Y_optimizer): if opt["use_cycle_consistency_loss"]: os.makedirs("./checkpoints/{}/train".format(opt["dataset_name"]), exist_ok=True) checkpoint_path = os.path.join("checkpoints", opt["dataset_name"], "train") else: os.makedirs("./no_cycle/checkpoints/{}/train".format( opt["dataset_name"]), exist_ok=True) checkpoint_path = os.path.join("no_cycle", "checkpoints", opt["dataset_name"], "train") ckpt = Checkpoint(G_YtoX_optimizer=G_YtoX_optimizer, G_XtoY_optimizer=G_XtoY_optimizer, D_X_optimizer=D_X_optimizer, D_Y_optimizer=D_Y_optimizer, G_YtoX=G_YtoX, G_XtoY=G_XtoY, D_X=D_X, D_Y=D_Y) ckpt_manager = CheckpointManager(ckpt, checkpoint_path, max_to_keep=5) # if a checkpoint exists, restore the latest checkpoint. if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint) print('Latest checkpoint restored!!') return ckpt, ckpt_manager
def train_model(model, model_dir, x_train, y_train, batch_size, epochs, learning_rate, decay): history = [] # Compile the model optimizer, loss_fn, train_acc_metric, val_acc_metric = compile_model( model, learning_rate, decay) # Create the checkpoint object if args.checkpoint_enabled.lower() == 'true': checkpoint = Checkpoint(model) # Prepare the batch datasets x_val, y_val, train_dataset, val_dataset = prepare_batch_datasets( x_train, y_train, batch_size) # Perform training logger.info('Training the model...') training_start_time = time.time() logger.debug('Iterating over epochs...') # Iterate over epochs for epoch in range(epochs): logger.debug('Starting epoch {}...'.format(int(epoch) + 1)) epoch_start_time = time.time() # Iterate over the batches of the dataset for step, (x_batch_train, y_batch_train) in enumerate(train_dataset): logger.debug('Running training step {}...'.format(int(step) + 1)) loss_value = training_step(model, x_batch_train, y_batch_train, optimizer, loss_fn, train_acc_metric) logger.debug('Training loss in step = {}'.format(loss_value)) logger.debug( 'Completed running training step {}.'.format(int(step) + 1)) # Perform validation and save metrics at the end of each epoch history.append([ int(epoch) + 1, train_acc_metric.result(), perform_validation(model, val_dataset, val_acc_metric) ]) # Reset metrics train_acc_metric.reset_states() val_acc_metric.reset_states() # Save the model as a checkpoint if args.checkpoint_enabled.lower() == 'true': save_checkpoint(checkpoint) epoch_end_time = time.time() logger.debug("Epoch duration = %.2f second(s)" % (epoch_end_time - epoch_start_time)) logger.debug('Completed epoch {}.'.format(int(epoch) + 1)) logger.debug('Completed iterating over epochs.') training_end_time = time.time() logger.info('Training duration = %.2f second(s)' % (training_end_time - training_start_time)) print_training_result(history) logger.info('Completed training the model.')
def __init__(self, network, generative_model, loss, summary_stats=None, optimizer=None, learning_rate=0.0005, checkpoint_path=None, max_to_keep=5, clip_method='global_norm', clip_value=None): """ Creates a trainer instance for performing single-model forward inference and training an amortized neural estimator for parameter estimation (BayesFlow). If a checkpoint_path is provided, the network's weights will be stored after each training epoch. If the folder contains a checkpoint, the trainer will try to load the weights and continue training with a pre-trained net. ---------- Arguments: network : bayesflow.Amortizer instance -- the neural architecture to be optimized generative_model: callable -- a function or an object with n_sim and n_obs mandatory arguments returning randomly sampled parameter vectors and datasets from a process model loss : callable with three arguments: (network, m_indices, x) -- the loss function ---------- Keyword arguments: summary_stats : callable -- optional summary statistics function optimizer : None or tf.keras.optimizer.Optimizer -- default Adam optimizer (equiv. to None) or a custom one learning_rate : float -- the learning rate used for the optimizer checkpoint_path : string -- optional folder name for storing the trained network max_to_keep : int -- optional number of checkpoints to keep clip_method : string in ('norm', 'value', 'global_norm') -- optional gradient clipping method clip_value : float -- the value used for gradient clipping when clip_method is set to 'value' or 'norm' """ # Basic attributes self.network = network self.generative_model = generative_model self.loss = loss self.summary_stats = summary_stats self.clip_method = clip_method self.clip_value = clip_value self.n_obs = None # Optimizer settings if optimizer is None: if tf.__version__.startswith('1'): self.optimizer = tf.train.AdamOptimizer(learning_rate) else: self.optimizer = Adam(learning_rate) else: self.optimizer = optimizer(learning_rate) # Checkpoint settings if checkpoint_path is not None: self.checkpoint = Checkpoint(optimizer=self.optimizer, model=self.network) self.manager = CheckpointManager(self.checkpoint, checkpoint_path, max_to_keep=max_to_keep) self.checkpoint.restore(self.manager.latest_checkpoint) if self.manager.latest_checkpoint: print("Networks loaded from {}".format(self.manager.latest_checkpoint)) else: print("Initializing networks from scratch.") else: self.checkpoint = None self.manager = None self.checkpoint_path = checkpoint_path
def __init__(self, optimizer, encoder, decoder, targ_lang, batch_size): self.directory = Path(f"data/ckpt") self.directory.mkdir(parents=True, exist_ok=True) self.optimizer = optimizer self.encoder = encoder self.decoder = decoder self.targ_lang = targ_lang self.batch_size = batch_size self.checkpoint = Checkpoint(optimizer=optimizer, encoder=encoder, decoder=decoder)
def train(epochs, batch_size, ckpt_path, imgs_path, lr, out_path): tf.keras.backend.clear_session() train_data = get_data(imgs_path, batch_size) gen = generator() disc = discriminator() print(gen.summary()) print(disc.summary()) gen_opt = Adam(learning_rate=lr, beta_1=0.5) disc_opt = Adam(learning_rate=lr, beta_1=0.5) ckpt = Checkpoint(disc=disc, gen=gen, disc_opt=disc_opt, gen_opt=gen_opt) manager = CheckpointManager(ckpt, ckpt_path, max_to_keep=3) if manager.latest_checkpoint: print("Restored from {}".format(manager.latest_checkpoint)) ckpt.restore(manager.latest_checkpoint) else: print("Initializing from scratch.") seed = tf.random.normal([16, ENCODING_SIZE], seed=1234) generate_and_save_images(gen, 0, seed, out_path) for ep in range(epochs): gen_loss = [] disc_loss_real = [] disc_loss_fake = [] print('Epoch: %d of %d' % (ep + 1, epochs)) start = time.time() for images in train_data: g_loss, d_loss_r, d_loss_f = train_step(images, gen, disc, gen_opt, disc_opt, batch_size) gen_loss.append(g_loss) disc_loss_real.append(d_loss_r) disc_loss_fake.append(d_loss_f) gen_loss = np.mean(np.asarray(gen_loss)) disc_loss_real = np.mean(np.asarray(disc_loss_real)) disc_loss_fake = np.mean(np.asarray(disc_loss_fake)) if (np.isnan(gen_loss) or np.isnan(disc_loss_real) or np.isnan(disc_loss_fake)): print("Something broke.") break manager.save() generate_and_save_images(gen, ep + 1, seed, out_path) print("Time for epoch:", time.time() - start) print("Gen loss=", gen_loss) print("Disc loss real=", disc_loss_real) print("Disc loss fake=", disc_loss_fake)
def train(args): train_ds, test_ds = get_data(args.img_path, args.batch) gen = generator() disc = discriminator() gen_opt = Adam(args.learning_rate, beta_1=0.5, beta_2=0.999) disc_opt = Adam(args.learning_rate, beta_1=0.5, beta_2=0.999) print(gen.summary()) print(disc.summary()) ckpt = Checkpoint(disc=disc, gen=gen, disc_opt=disc_opt, gen_opt=gen_opt) manager = CheckpointManager(ckpt, args.ckpt_path, max_to_keep=3) if args.continue_training: latest = manager.latest_checkpoint if latest: print("Restored from {}".format(latest)) ckpt.restore(latest) off = int(re.split('-', latest)[-1]) else: off = 0 print("Initializing from scratch.") for ep in range(args.epochs): for x, y in test_ds.take(1): generate_and_save_imgs(gen, ep + off, x, y, args.out_path) gen_loss = [] disc_loss = [] print('Epoch: %d of %d' % (ep + 1 + off, args.epochs + off)) start = time.time() for x, y in train_ds: g_loss, d_loss = train_step(x, y, gen, disc, gen_opt, disc_opt, args.batch) gen_loss.append(g_loss) disc_loss.append(d_loss) gen_loss = np.mean(np.asarray(gen_loss)) disc_loss = np.mean(np.asarray(disc_loss)) manager.save() print("Time for epoch:", time.time() - start) print("Gen loss=", gen_loss) print("Disc loss=", disc_loss) # Storing three different outputs after final epoch for x, y in test_ds.take(3): generate_and_save_imgs(gen, args.epochs + off, x, y, args.out_path) off += 1
def get_checkpoint_manager(model, optimizer, checkpoints_dir, max_checkpoints=None): """Obtains a checkpoint manager to manage model saving and restoring. Arguments: model (mode.ImageCaptionModel): object containing encoder, decoder and tokenizer optimizer (tf.optimizers.Optimizer): the optimizer used during the backpropagation step config (config.Config): Values for various configuration options Returns: tf.train.CheckpointManager, tf.train.Ckeckpoint """ ckpt = Checkpoint(encoder=model.encoder, decoder=model.decoder, optimizer=optimizer) ckpt_manager = CheckpointManager(ckpt, checkpoints_dir, max_to_keep=max_checkpoints) return ckpt_manager, ckpt
parse_np_array_image, ) def predict(model: CAE, image: Image) -> Image: image = tf.expand_dims(image, 0) pred_image = model(image) pred_image = pred_image.numpy() pred_image = pred_image[0, :, :, :] return parse_np_array_image(pred_image) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('--model-path', required=True) parser.add_argument('--image-path', required=True) parser.add_argument( '--image-output-path', required=False, default='pred_image.jpg', type=str, ) args = parser.parse_args() model = CAE() ckpt = Checkpoint(transformer=model) ckpt.restore(latest_checkpoint(args.model_path)).expect_partial() pred_image = predict(model, load_image(args.image_path)) pred_image.save(args.image_output_path)