Example #1
0
def set_checkpoint(opt, G_YtoX, G_XtoY, D_X, D_Y, G_YtoX_optimizer,
                   G_XtoY_optimizer, D_X_optimizer, D_Y_optimizer):
    if opt["use_cycle_consistency_loss"]:
        os.makedirs("./checkpoints/{}/train".format(opt["dataset_name"]),
                    exist_ok=True)
        checkpoint_path = os.path.join("checkpoints", opt["dataset_name"],
                                       "train")
    else:
        os.makedirs("./no_cycle/checkpoints/{}/train".format(
            opt["dataset_name"]),
                    exist_ok=True)
        checkpoint_path = os.path.join("no_cycle", "checkpoints",
                                       opt["dataset_name"], "train")

    ckpt = Checkpoint(G_YtoX_optimizer=G_YtoX_optimizer,
                      G_XtoY_optimizer=G_XtoY_optimizer,
                      D_X_optimizer=D_X_optimizer,
                      D_Y_optimizer=D_Y_optimizer,
                      G_YtoX=G_YtoX,
                      G_XtoY=G_XtoY,
                      D_X=D_X,
                      D_Y=D_Y)

    ckpt_manager = CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print('Latest checkpoint restored!!')

    return ckpt, ckpt_manager
Example #2
0
    def __init__(self, network, generative_model, loss, summary_stats=None, optimizer=None,
                learning_rate=0.0005, checkpoint_path=None, max_to_keep=5, clip_method='global_norm', clip_value=None):
        """
        Creates a trainer instance for performing single-model forward inference and training an
        amortized neural estimator for parameter estimation (BayesFlow). If a checkpoint_path is provided, the
        network's weights will be stored after each training epoch. If the folder contains a 
        checkpoint, the trainer will try to load the weights and continue training with a pre-trained net.
        ----------
        
        Arguments:
        network     : bayesflow.Amortizer instance -- the neural architecture to be optimized
        generative_model: callable -- a function or an object with n_sim and n_obs mandatory arguments 
                          returning randomly sampled parameter vectors and datasets from a process model
        loss        : callable with three arguments: (network, m_indices, x) -- the loss function
        ----------
        
        Keyword arguments:
        summary_stats   : callable -- optional summary statistics function
        optimizer       : None or tf.keras.optimizer.Optimizer -- default Adam optimizer (equiv. to None) or a custom one
        learning_rate   : float -- the learning rate used for the optimizer
        checkpoint_path : string -- optional folder name for storing the trained network
        max_to_keep     : int -- optional number of checkpoints to keep
        clip_method     : string in ('norm', 'value', 'global_norm') -- optional gradient clipping method
        clip_value      : float -- the value used for gradient clipping when clip_method is set to 'value' or 'norm'
        """
        
        # Basic attributes
        self.network = network
        self.generative_model = generative_model
        self.loss = loss
        self.summary_stats = summary_stats
        self.clip_method = clip_method
        self.clip_value = clip_value
        self.n_obs = None
        
        # Optimizer settings
        if optimizer is None:
            if tf.__version__.startswith('1'):
                self.optimizer = tf.train.AdamOptimizer(learning_rate)
            else:
                self.optimizer = Adam(learning_rate)
        else:
            self.optimizer = optimizer(learning_rate)

        # Checkpoint settings
        if checkpoint_path is not None:
            self.checkpoint = Checkpoint(optimizer=self.optimizer, model=self.network)
            self.manager = CheckpointManager(self.checkpoint, checkpoint_path, max_to_keep=max_to_keep)
            self.checkpoint.restore(self.manager.latest_checkpoint)
            if self.manager.latest_checkpoint:
                print("Networks loaded from {}".format(self.manager.latest_checkpoint))
            else:
                print("Initializing networks from scratch.")
        else:
            self.checkpoint = None
            self.manager = None
        self.checkpoint_path = checkpoint_path
Example #3
0
 def __init__(self, optimizer, encoder, decoder, targ_lang, batch_size):
     self.directory = Path(f"data/ckpt")
     self.directory.mkdir(parents=True, exist_ok=True)
     self.optimizer = optimizer
     self.encoder = encoder
     self.decoder = decoder
     self.targ_lang = targ_lang
     self.batch_size = batch_size
     self.checkpoint = Checkpoint(optimizer=optimizer,
                                  encoder=encoder,
                                  decoder=decoder)
Example #4
0
def train(epochs, batch_size, ckpt_path, imgs_path, lr, out_path):
    tf.keras.backend.clear_session()
    train_data = get_data(imgs_path, batch_size)
    gen = generator()
    disc = discriminator()

    print(gen.summary())
    print(disc.summary())
    gen_opt = Adam(learning_rate=lr, beta_1=0.5)
    disc_opt = Adam(learning_rate=lr, beta_1=0.5)

    ckpt = Checkpoint(disc=disc, gen=gen, disc_opt=disc_opt, gen_opt=gen_opt)
    manager = CheckpointManager(ckpt, ckpt_path, max_to_keep=3)

    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
        ckpt.restore(manager.latest_checkpoint)
    else:
        print("Initializing from scratch.")

    seed = tf.random.normal([16, ENCODING_SIZE], seed=1234)

    generate_and_save_images(gen, 0, seed, out_path)
    for ep in range(epochs):
        gen_loss = []
        disc_loss_real = []
        disc_loss_fake = []
        print('Epoch: %d of %d' % (ep + 1, epochs))
        start = time.time()

        for images in train_data:
            g_loss, d_loss_r, d_loss_f = train_step(images, gen, disc, gen_opt,
                                                    disc_opt, batch_size)
            gen_loss.append(g_loss)
            disc_loss_real.append(d_loss_r)
            disc_loss_fake.append(d_loss_f)
        gen_loss = np.mean(np.asarray(gen_loss))
        disc_loss_real = np.mean(np.asarray(disc_loss_real))
        disc_loss_fake = np.mean(np.asarray(disc_loss_fake))

        if (np.isnan(gen_loss) or np.isnan(disc_loss_real)
                or np.isnan(disc_loss_fake)):
            print("Something broke.")
            break

        manager.save()
        generate_and_save_images(gen, ep + 1, seed, out_path)

        print("Time for epoch:", time.time() - start)
        print("Gen loss=", gen_loss)
        print("Disc loss real=", disc_loss_real)
        print("Disc loss fake=", disc_loss_fake)
Example #5
0
class Training:
    def __init__(self, optimizer, encoder, decoder, targ_lang, batch_size):
        self.directory = Path(f"data/ckpt")
        self.directory.mkdir(parents=True, exist_ok=True)
        self.optimizer = optimizer
        self.encoder = encoder
        self.decoder = decoder
        self.targ_lang = targ_lang
        self.batch_size = batch_size
        self.checkpoint = Checkpoint(optimizer=optimizer,
                                     encoder=encoder,
                                     decoder=decoder)

    @function
    def train_step(self, inp, targ, enc_hidden):
        loss = 0

        with GradientTape() as tape:
            enc_output, enc_hidden = self.encoder(inp, enc_hidden)
            dec_hidden = enc_hidden
            dec_input = expand_dims([self.targ_lang.word_index['<start>']] *
                                    self.batch_size, 1)
            for t in range(1, targ.shape[1]):
                predictions, dec_hidden, _ = self.decoder(
                    dec_input, dec_hidden, enc_output)
                loss += self.optimizer.loss_function(targ[:, t], predictions)
                dec_input = expand_dims(targ.shape[1])
        batch_loss = (loss / int(targ.shape[1]))
        variables = self.encoder.trainable_variables + self.decoder.trainable_variables
        gradients = tape.gradient(loss, variables)
        self.optimizer.apply_gradients(zip(gradients, variables))
        return batch_loss

    def __call__(self, dataset, steps_per_epoch, epochs=10):
        for epoch in range(epochs):
            start = time.time()
            enc_hidden = self.encoder.initialize_hidden_state()
            total_loss = 0
            for (batch, (inp,
                         targ)) in enumerate(dataset.take(steps_per_epoch)):
                batch_loss = self.train_step(inp, targ, enc_hidden)
                total_loss += batch_loss
                if batch % 100 == 0:
                    print('Epoch {} Batch {} Loss {:.4f}'.format(
                        epoch + 1, batch, batch_loss.numpy()))
            if (epoch + 1) % 2 == 0:
                self.checkpoint.save(file_prefix=self.directory)
            print("Epoch {} Loss {:.4f}".format(epoch + 1,
                                                total_loss / steps_per_epoch))
            print("Time take for 1 epoch {} sec \n".format(time.time() -
                                                           start))
Example #6
0
def train(args):
    train_ds, test_ds = get_data(args.img_path, args.batch)

    gen = generator()
    disc = discriminator()
    gen_opt = Adam(args.learning_rate, beta_1=0.5, beta_2=0.999)
    disc_opt = Adam(args.learning_rate, beta_1=0.5, beta_2=0.999)
    print(gen.summary())
    print(disc.summary())

    ckpt = Checkpoint(disc=disc, gen=gen, disc_opt=disc_opt, gen_opt=gen_opt)
    manager = CheckpointManager(ckpt, args.ckpt_path, max_to_keep=3)

    if args.continue_training:
        latest = manager.latest_checkpoint
        if latest:
            print("Restored from {}".format(latest))
            ckpt.restore(latest)
            off = int(re.split('-', latest)[-1])
        else:
            off = 0
            print("Initializing from scratch.")

    for ep in range(args.epochs):
        for x, y in test_ds.take(1):
            generate_and_save_imgs(gen, ep + off, x, y, args.out_path)
        gen_loss = []
        disc_loss = []
        print('Epoch: %d of %d' % (ep + 1 + off, args.epochs + off))
        start = time.time()

        for x, y in train_ds:
            g_loss, d_loss = train_step(x, y, gen, disc, gen_opt, disc_opt,
                                        args.batch)
            gen_loss.append(g_loss)
            disc_loss.append(d_loss)
        gen_loss = np.mean(np.asarray(gen_loss))
        disc_loss = np.mean(np.asarray(disc_loss))

        manager.save()
        print("Time for epoch:", time.time() - start)
        print("Gen loss=", gen_loss)
        print("Disc loss=", disc_loss)

    # Storing three different outputs after final epoch
    for x, y in test_ds.take(3):
        generate_and_save_imgs(gen, args.epochs + off, x, y, args.out_path)
        off += 1
Example #7
0
def save_ckpt(generator, critic, adversarial):
    """Save the training information into a file. This includes but
    not limited to the information on the wieghts and the biases of
    the given network. The GANs model is a combination of three
    different neural networks (generator, critic/discriminator,
    adversarial) and the information on each one of them are saved.

    For more information on the constructor `Checkpoint` from
    the module `tensorflow.train`, refer to
    https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint

    Parameters
    ----------
    generator : ganpdfs.model.WassersteinGanModel.generator
        generator neural network
    critic : ganpdfs.model.WassersteinGanModel.critic
        critic/discriminator neural network
    adversarial : ganpdfs.model.WassersteinGanModel.adversarial
        adversarial neural network

    Returns
    -------
    A load status object, which can be used to make assertions about
    the status of a checkpoint restoration
    """

    checkpoint = Checkpoint(
            critic=critic,
            generator=generator,
            adversarial=adversarial
    )
    return checkpoint
def train_model(model, model_dir, x_train, y_train, batch_size, epochs,
                learning_rate, decay):
    history = []

    # Compile the model
    optimizer, loss_fn, train_acc_metric, val_acc_metric = compile_model(
        model, learning_rate, decay)

    # Create the checkpoint object
    if args.checkpoint_enabled.lower() == 'true':
        checkpoint = Checkpoint(model)

    # Prepare the batch datasets
    x_val, y_val, train_dataset, val_dataset = prepare_batch_datasets(
        x_train, y_train, batch_size)

    # Perform training
    logger.info('Training the model...')
    training_start_time = time.time()
    logger.debug('Iterating over epochs...')
    # Iterate over epochs
    for epoch in range(epochs):
        logger.debug('Starting epoch {}...'.format(int(epoch) + 1))
        epoch_start_time = time.time()

        # Iterate over the batches of the dataset
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            logger.debug('Running training step {}...'.format(int(step) + 1))
            loss_value = training_step(model, x_batch_train, y_batch_train,
                                       optimizer, loss_fn, train_acc_metric)
            logger.debug('Training loss in step = {}'.format(loss_value))
            logger.debug(
                'Completed running training step {}.'.format(int(step) + 1))

        # Perform validation and save metrics at the end of each epoch
        history.append([
            int(epoch) + 1,
            train_acc_metric.result(),
            perform_validation(model, val_dataset, val_acc_metric)
        ])
        # Reset metrics
        train_acc_metric.reset_states()
        val_acc_metric.reset_states()

        # Save the model as a checkpoint
        if args.checkpoint_enabled.lower() == 'true':
            save_checkpoint(checkpoint)

        epoch_end_time = time.time()
        logger.debug("Epoch duration = %.2f second(s)" %
                     (epoch_end_time - epoch_start_time))
        logger.debug('Completed epoch {}.'.format(int(epoch) + 1))

    logger.debug('Completed iterating over epochs.')
    training_end_time = time.time()
    logger.info('Training duration = %.2f second(s)' %
                (training_end_time - training_start_time))
    print_training_result(history)
    logger.info('Completed training the model.')
Example #9
0
def get_checkpoint_manager(model,
                           optimizer,
                           checkpoints_dir,
                           max_checkpoints=None):
    """Obtains a checkpoint manager to manage model saving and restoring.
    
    Arguments:
        model (mode.ImageCaptionModel): object containing encoder, decoder and tokenizer
        optimizer (tf.optimizers.Optimizer): the optimizer used during the backpropagation step
        config (config.Config): Values for various configuration options
    
    Returns:
        tf.train.CheckpointManager, tf.train.Ckeckpoint
    """

    ckpt = Checkpoint(encoder=model.encoder,
                      decoder=model.decoder,
                      optimizer=optimizer)
    ckpt_manager = CheckpointManager(ckpt,
                                     checkpoints_dir,
                                     max_to_keep=max_checkpoints)

    return ckpt_manager, ckpt
Example #10
0
class MetaTrainer:
    
    def __init__(self, network, generative_model, loss, summary_stats=None, optimizer=None,
                learning_rate=0.0005, checkpoint_path=None, max_to_keep=5, clip_method='global_norm', clip_value=None):
        """
        Creates a trainer instance for performing single-model forward inference and training an
        amortized neural estimator for parameter estimation (BayesFlow). If a checkpoint_path is provided, the
        network's weights will be stored after each training epoch. If the folder contains a 
        checkpoint, the trainer will try to load the weights and continue training with a pre-trained net.
        ----------
        
        Arguments:
        network     : bayesflow.Amortizer instance -- the neural architecture to be optimized
        generative_model: callable -- a function or an object with n_sim and n_obs mandatory arguments 
                          returning randomly sampled parameter vectors and datasets from a process model
        loss        : callable with three arguments: (network, m_indices, x) -- the loss function
        ----------
        
        Keyword arguments:
        summary_stats   : callable -- optional summary statistics function
        optimizer       : None or tf.keras.optimizer.Optimizer -- default Adam optimizer (equiv. to None) or a custom one
        learning_rate   : float -- the learning rate used for the optimizer
        checkpoint_path : string -- optional folder name for storing the trained network
        max_to_keep     : int -- optional number of checkpoints to keep
        clip_method     : string in ('norm', 'value', 'global_norm') -- optional gradient clipping method
        clip_value      : float -- the value used for gradient clipping when clip_method is set to 'value' or 'norm'
        """
        
        # Basic attributes
        self.network = network
        self.generative_model = generative_model
        self.loss = loss
        self.summary_stats = summary_stats
        self.clip_method = clip_method
        self.clip_value = clip_value
        self.n_obs = None
        
        # Optimizer settings
        if optimizer is None:
            if tf.__version__.startswith('1'):
                self.optimizer = tf.train.AdamOptimizer(learning_rate)
            else:
                self.optimizer = Adam(learning_rate)
        else:
            self.optimizer = optimizer(learning_rate)

        # Checkpoint settings
        if checkpoint_path is not None:
            self.checkpoint = Checkpoint(optimizer=self.optimizer, model=self.network)
            self.manager = CheckpointManager(self.checkpoint, checkpoint_path, max_to_keep=max_to_keep)
            self.checkpoint.restore(self.manager.latest_checkpoint)
            if self.manager.latest_checkpoint:
                print("Networks loaded from {}".format(self.manager.latest_checkpoint))
            else:
                print("Initializing networks from scratch.")
        else:
            self.checkpoint = None
            self.manager = None
        self.checkpoint_path = checkpoint_path

    def train_online(self, epochs, iterations_per_epoch, batch_size, n_obs, **kwargs):
        """
        Trains the inference network(s) via online learning. Additional keyword arguments
        are passed to the simulators.
        ----------
        
        Arguments:
        epochs               : int -- number of epochs (and number of times a checkpoint is stored)
        iterations_per_epoch : int -- number of batch simulations to perform per epoch
        batch_size           : int -- number of simulations to perform at each backprop step
        n_obs                : int or callable -- if int, then treated as a fixed number of observations, if callable, then
                               treated as a function for sampling N, i.e., N ~ p(N)
        ----------

        Returns:
        losses : dict (ep_num : list_of_losses) -- a dictionary storing the losses across epochs and iterations
        """
        
        losses = dict()
        for ep in range(1, epochs+1):
            losses[ep] = []
            with tqdm(total=iterations_per_epoch, desc='Training epoch {}'.format(ep)) as p_bar:
                for it in range(1, iterations_per_epoch+1):

                    # Determine n_obs and generate data on-the-fly
                    if type(n_obs) is int:
                        n_obs_it = n_obs
                    else:
                        n_obs_it = n_obs()
                    model_indices, params, sim_data = self._forward_inference(batch_size, n_obs_it, **kwargs)

                    # One step backprop
                    loss = self._train_step(model_indices, params, sim_data)

                    # Store loss into dictionary
                    losses[ep].append(loss)
                    
                    # Update progress bar
                    p_bar.set_postfix_str("Epoch {0},Iteration {1},Loss: {2:.3f},Running Loss: {3:.3f}"
                    .format(ep, it, loss, np.mean(losses[ep])))
                    p_bar.update(1)

            # Store after each epoch, if specified
            if self.manager is not None:
                self.manager.save()
        return losses

    def train_offline(self, epochs, batch_size, model_indices, params, sim_data):
        """
        Trains the inference network(s) via offline learning. Assume params and data have already 
        been simulated (i.e., forward inference). 
        ----------
        
        Arguments:
        epochs           : int -- number of epochs (and number of times a checkpoint is stored)
        batch_size       : int -- number of simulations to perform at each backprop step
        model_indices    : np.array of shape (n_sim, ) or (n_sim, n_models) -- the true model indices
        params           : np.array of shape (n_sim, n_params) -- the true data-generating parameters
        sim_data         : np.array of shape (n_sim, n_obs, data_dim) -- the simulated data sets from each model
        ----------

        Returns:
        losses : dict (ep_num : list_of_losses) -- a dictionary storing the losses across epochs and iterations
        """

        # Convert to a data set
        n_sim = int(sim_data.shape[0])

        # Compute summary statistics, if provided
        if self.summary_stats is not None:
            print('Computing hand-crafted summary statistics...')
            sim_data = self.summary_stats(sim_data)

        print('Converting {} simulations to a TensorFlow data set...'.format(n_sim))
        data_set = tf.data.Dataset \
                    .from_tensor_slices((model_indices, params, sim_data)) \
                    .shuffle(n_sim) \
                    .batch(batch_size)

        losses = dict()
        for ep in range(1, epochs+1):
            losses[ep] = []
            with tqdm(total=int(np.ceil(n_sim / batch_size)), desc='Training epoch {}'.format(ep)) as p_bar:
                # Loop through dataset
                for bi, batch in enumerate(data_set):

                    # Extract params from batch
                    model_indices_b, params_b, sim_data_b = batch[0], batch[1], batch[2]

                    # One step backprop
                    loss = self._train_step(model_indices_b, params_b, sim_data_b)

                    # Store loss and update progress bar
                    losses[ep].append(loss)
                    p_bar.set_postfix_str("Epoch {0},Batch {1},Loss: {2:.3f},Running Loss: {3:.3f}"
                    .format(ep, bi+1, loss, np.mean(losses[ep])))
                    p_bar.update(1)
                
            # Store after each epoch, if specified
            if self.manager is not None:
                self.manager.save()
        return losses

    def train_rounds(self, epochs, rounds, sim_per_round, batch_size, n_obs, **kwargs):
        """
        Trains the inference network(s) via round-based learning. Additional arguments are
        passed to the simulator.
        ----------
        
        Arguments:
        epochs         : int -- number of epochs (and number of times a checkpoint is stored)
        rounds         : int -- number of rounds to perform 
        sim_per_round  : int -- number of simulations per round
        batch_size     : int -- number of simulations to perform at each backprop step
        n_obs          : int -- number of observations (fixed) for each data set
        ----------

        Returns:
        losses : nested dict with each (ep_num : list_of_losses) -- a dictionary storing the losses across rounds, 
                 epochs and iterations
        """

        # Make sure n_obs is fixed, otherwise not working 
        assert type(n_obs) is int,\
        'Round-based training currently only works with fixed n_obs. Use online learning for variable n_obs or fix n_obs to an integer value.'

        losses = dict()
        for r in range(1, rounds+1):
            
            # Data generation step
            if r == 1:
                # Simulate initial data
                print('Simulating initial {} data sets...'.format(sim_per_round))
                model_indices, params, sim_data = self._forward_inference(sim_per_round, n_obs, **kwargs)
            else:
                # Simulate further data
                print('Simulating new {} data sets and appending to previous...'.format(sim_per_round))
                print('New total number of simulated data sets: {}'.format(sim_per_round * r))
                model_indices_r, params_r, sim_data_r = self._forward_inference(sim_per_round, n_obs, **kwargs)

                # Add new simulations to previous data
                model_indices = np.concatenate((model_indices, model_indices_r), axis=0)
                params = np.concatenate((params, params_r), axis=0)
                sim_data = np.concatenate((sim_data, sim_data_r), axis=0)

            # Train offline with generated stuff
            losses_r = self.train_offline(epochs, batch_size, model_indices, params, sim_data)
            losses[r] = losses_r

        return losses

    def simulate_and_train_offline(self, n_sim, epochs, batch_size, n_obs, **kwargs):
        """
        Simulates n_sim data sets and then trains the inference network(s) via offline learning. 

        Additional keyword arguments are passed to the simulator.
        ----------
        
        Arguments:
        n_sim          : int -- total number of simulations to perform
        epochs         : int -- number of epochs (and number of times a checkpoint is stored)
        batch_size     : int -- number of simulations to perform at each backprop step
        n_obs          : int -- number of observations for each dataset
        ----------

        Returns:
        losses : dict (ep_num : list_of_losses) -- a dictionary storing the losses across epochs and iterations
        """

        # Make sure n_obs is fixed, otherwise not working, for now
        assert type(n_obs) is int,\
        'Offline training currently only works with fixed n_obs. Use online learning for variable n_obs or fix n_obs to an integer value.'

        # Simulate data
        print('Simulating {} data sets upfront...'.format(n_sim))
        model_indices, params, sim_data = self._forward_inference(n_sim, n_obs, summarize=False, **kwargs)

        # Train offlines
        losses = self.train_offline(epochs, batch_size, model_indices, params, sim_data)
        return losses

    def load_pretrained_network(self):
        """
        Attempts to load a pre-trained network if checkpoint path is provided and a checkpoint manager exists.
        """

        if self.manager is None or self.checkpoint is None:
            return False
        status = self.checkpoint.restore(self.manager.latest_checkpoint)
        return status

    def _forward_inference(self, n_sim, n_obs, summarize=True, **kwargs):
        """
        Performs one step of multi-model forward inference.
        ----------
        
        Arguments:
        n_sim : int -- number of simulation to perform at the given step (i.e., batch size)
        n_obs : int or callable -- if int, then treated as a fixed number of observations, if callable, then
                                   treated as a function for sampling N, i.e., N ~ p(N)
        ----------

        Kyeword arguments:
        summarize : bool -- whether to summarize the data if hand-crafted summaries are given

        Returns:
        params    : np.array (np.float32) of shape (batch_size, param_dim) -- array of sampled parameters
        sim_data  : np.array (np.float32) of shape (batch_size, n_obs, data_dim) -- array of simulated data sets
        """
        
        # Simulate data with n_sims and n_obs
        # Return shape of params is (batch_size, param_dim)
        # Return shape of data is (batch_size, n_obs, data_dim)
        model_indices, params, sim_data = self.generative_model(n_sim, n_obs, **kwargs)

        # Compute hand-crafted summary stats, if given
        if summarize and self.summary_stats is not None:
            # Return shape in this case is (batch_size, n_sum)
            sim_data = self.summary_stats(sim_data)

        return model_indices, params, sim_data

    def _train_step(self, model_indices, params, sim_data):
        """
        Performs one step of backpropagation with the given model indices and data.
        ----------
        
        Arguments:
        model_indices  : np.array (np.float32) of shape (n_sim, n_models) -- the true model indices
        params         : np.array (np.float32) of shape (batch_size, n_params) -- matrix of n_samples x n_params
        sim_data       : np.array (np.float32) of shape (batch_size, n_obs, data_dim) or (batch_size, summary_dim) 
                    -- array of simulated data sets (or summary statistics thereof)      
        ----------

        Returns:
        loss : tf.Tensor of shape (,), i.e., a scalar representing the average loss over the batch of m and x
        """
        
        # Compute loss and store gradients
        with tf.GradientTape() as tape:
            loss = self.loss(self.network, model_indices, params, sim_data)
            
        # One step backprop
        gradients = tape.gradient(loss, self.network.trainable_variables)
        self._apply_gradients(gradients, self.network.trainable_variables)  
        
        return loss.numpy()
        
    def _apply_gradients(self, gradients, tensors):
        """
        Updates each tensor in the 'variables' list via backpropagation. Operation is performed in-place.
        ----------

        Arguments:
        gradients: list of tf.Tensor -- the list of gradients for all neural network parameter
        variables: list of tf.Tensor -- the list of all neural network parameters
        """

        # Optional gradient clipping
        if self.clip_value is not None:
            gradients = clip_gradients(gradients, clip_value=self.clip_value, clip_method=self.clip_method)
        self.optimizer.apply_gradients(zip(gradients, tensors))
Example #11
0
    parse_np_array_image,
)


def predict(model: CAE, image: Image) -> Image:
    image = tf.expand_dims(image, 0)
    pred_image = model(image)
    pred_image = pred_image.numpy()
    pred_image = pred_image[0, :, :, :]

    return parse_np_array_image(pred_image)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--model-path', required=True)
    parser.add_argument('--image-path', required=True)
    parser.add_argument(
        '--image-output-path',
        required=False,
        default='pred_image.jpg',
        type=str,
    )
    args = parser.parse_args()

    model = CAE()
    ckpt = Checkpoint(transformer=model)
    ckpt.restore(latest_checkpoint(args.model_path)).expect_partial()
    pred_image = predict(model, load_image(args.image_path))
    pred_image.save(args.image_output_path)