def validate_step(self, x, y): """Perform a validation step on an ensemble of models without using bootstrapping weights Args: x: tf.Tensor a batch of validation inputs shaped like [batch_size, channels] y: tf.Tensor a batch of validation labels shaped like [batch_size, 1] Returns: statistics: dict a dictionary that contains logging information """ # corrupt the inputs with noise x = cont_noise(x, self.continuous_noise_std) statistics = dict() # calculate the prediction error and accuracy of the model d = self.forward_model(x, training=False) nll = tf.keras.losses.mean_squared_error(y, d) # evaluate how correct the rank fo the model predictions are rank_correlation = spearman(y[:, 0], d[:, 0]) statistics[f'{self.logger_prefix}/validate/nll'] = nll statistics[f'{self.logger_prefix}/validate/rank_corr'] = rank_correlation return statistics
def validate_step(self, x, y): """Perform a validation step on an ensemble of models without using bootstrapping weights Args: x: tf.Tensor a batch of validation inputs shaped like [batch_size, channels] y: tf.Tensor a batch of validation labels shaped like [batch_size, 1] Returns: statistics: dict a dictionary that contains logging information """ # corrupt the inputs with noise x0 = cont_noise(x, self.noise_std) statistics = dict() # calculate the prediction error and accuracy of the model d = self.fm.get_distribution(x0, training=False) nll = -d.log_prob(y) # evaluate how correct the rank fo the model predictions are rank_correlation = spearman(y[:, 0], d.mean()[:, 0]) statistics[f'validate/nll'] = nll statistics[f'validate/rank_corr'] = rank_correlation return statistics
def train_step(self, x, y): """Perform a training step of gradient descent on an ensemble using bootstrap weights for each model in the ensemble Args: x: tf.Tensor a batch of training inputs shaped like [batch_size, channels] y: tf.Tensor a batch of training labels shaped like [batch_size, 1] b: tf.Tensor bootstrap indicators shaped like [batch_size, num_oracles] Returns: statistics: dict a dictionary that contains logging information """ # corrupt the inputs with noise x = cont_noise(x, self.continuous_noise_std) statistics = dict() with tf.GradientTape() as tape: # calculate the prediction error and accuracy of the model d = self.forward_model(x, training=True) nll = tf.keras.losses.mean_squared_error(y, d) # evaluate how correct the rank fo the model predictions are rank_correlation = spearman(y[:, 0], d[:, 0]) multiplier_loss = 0.0 last_weight = self.forward_model.trainable_variables[-1] if tf.shape(tf.reshape(last_weight, [-1]))[0] == 1: statistics[f'{self.logger_prefix}/train/tanh_multipier'] = \ self.forward_model.trainable_variables[-1] # build the total loss and weight by the bootstrap total_loss = tf.reduce_mean(nll) + multiplier_loss grads = tape.gradient(total_loss, self.forward_model.trainable_variables) self.forward_model_optim.apply_gradients( zip(grads, self.forward_model.trainable_variables)) statistics[f'{self.logger_prefix}/train/nll'] = nll statistics[f'{self.logger_prefix}/train/rank_corr'] = rank_correlation return statistics
def validate_step(self, x, y): """Perform a validation step on the loss function of a conservative objective model Args: x: tf.Tensor a batch of validation inputs shaped like [batch_size, channels] y: tf.Tensor a batch of validation labels shaped like [batch_size, 1] Returns: statistics: dict a dictionary that contains logging information """ statistics = dict() batch_dim = tf.shape(y)[0] # corrupt the inputs with noise x = cont_noise(x, self.continuous_noise_std) # calculate the prediction error and accuracy of the model d = self.forward_model(x, training=False) mse = tf.keras.losses.mean_squared_error(y, d) statistics[f'validate/mse'] = mse # evaluate how correct the rank fo the model predictions are rank_corr = spearman(y[:, 0], d[:, 0]) statistics[f'validate/rank_corr'] = rank_corr # calculate negative samples starting from the dataset x_pos = x x_pos = tf.where(tf.random.uniform([batch_dim] + [1 for _ in x.shape[1:]]) < self.negatives_fraction, x_pos, self.solution[:batch_dim]) x_neg = self.lookahead(x_pos, self.lookahead_steps, training=False) if not self.lookahead_backprop: x_neg = tf.stop_gradient(x_neg) # calculate the prediction error and accuracy of the model d_pos = self.forward_model( {"dataset": x, "mix": x_pos, "solution": self.solution[:batch_dim]} [self.constraint_type], training=False) d_neg = self.forward_model(x_neg, training=False) conservatism = d_neg[:, 0] - d_pos[:, 0] statistics[f'validate/conservatism'] = conservatism return statistics
def train_step(self, x, y, b): """Perform a training step of gradient descent on an ensemble using bootstrap weights for each model in the ensemble Args: x: tf.Tensor a batch of training inputs shaped like [batch_size, channels] y: tf.Tensor a batch of training labels shaped like [batch_size, 1] b: tf.Tensor bootstrap indicators shaped like [batch_size, num_oracles] Returns: statistics: dict a dictionary that contains logging information """ statistics = dict() for i in range(self.bootstraps): model = self.forward_models[i] optim = self.forward_model_optims[i] # corrupt the inputs with noise if self.is_discrete: x0 = disc_noise(x, keep=self.keep, temp=self.temp) else: x0 = cont_noise(x, self.noise_std) with tf.GradientTape(persistent=True) as tape: # calculate the prediction error and accuracy of the model d = model.get_distribution(x0, training=True) nll = -d.log_prob(y) statistics[f'oracle_{i}/train/nll'] = nll # evaluate how correct the rank fo the model predictions are rank_correlation = spearman(y[:, 0], d.mean()[:, 0]) statistics[f'oracle_{i}/train/rank_corr'] = rank_correlation # build the total loss total_loss = tf.math.divide_no_nan( tf.reduce_sum(b[:, i] * nll), tf.reduce_sum(b[:, i])) var_list = model.trainable_variables optim.apply_gradients( zip(tape.gradient(total_loss, var_list), var_list)) return statistics
def train_step(self, x, y, b): """Perform a training step of gradient descent on an ensemble using bootstrap weights for each model in the ensemble Args: x: tf.Tensor a batch of training inputs shaped like [batch_size, channels] y: tf.Tensor a batch of training labels shaped like [batch_size, 1] b: tf.Tensor bootstrap indicators shaped like [batch_size, num_oracles] Returns: statistics: dict a dictionary that contains logging information """ # corrupt the inputs with noise x0 = cont_noise(x, self.noise_std) statistics = dict() with tf.GradientTape(persistent=True) as tape: # calculate the prediction error and accuracy of the model d = self.fm.get_distribution(x0, training=True) nll = -d.log_prob(y) # evaluate how correct the rank fo the model predictions are rank_correlation = spearman(y[:, 0], d.mean()[:, 0]) # model loss that combines maximum likelihood model_loss = nll # build the total and lagrangian losses denom = tf.reduce_sum(b) total_loss = tf.math.divide_no_nan(tf.reduce_sum(b * model_loss), denom) grads = tape.gradient(total_loss, self.fm.trainable_variables) self.optim.apply_gradients(zip(grads, self.fm.trainable_variables)) statistics[f'train/nll'] = nll statistics[f'train/rank_corr'] = rank_correlation return statistics
def validate_step(self, x, y): """Perform a validation step on an ensemble of models without using bootstrapping weights Args: x: tf.Tensor a batch of validation inputs shaped like [batch_size, channels] y: tf.Tensor a batch of validation labels shaped like [batch_size, 1] Returns: statistics: dict a dictionary that contains logging information """ statistics = dict() for i in range(self.bootstraps): model = self.forward_models[i] # corrupt the inputs with noise if self.is_discrete: x0 = disc_noise(x, keep=self.keep, temp=self.temp) else: x0 = cont_noise(x, self.noise_std) # calculate the prediction error and accuracy of the model d = model.get_distribution(x0, training=False) nll = -d.log_prob(y) statistics[f'oracle_{i}/validate/nll'] = nll # evaluate how correct the rank fo the model predictions are rank_correlation = spearman(y[:, 0], d.mean()[:, 0]) statistics[f'oracle_{i}/validate/rank_corr'] = rank_correlation return statistics
def train_step(self, x, y): """Perform a training step of gradient descent on an ensemble using bootstrap weights for each model in the ensemble Args: x: tf.Tensor a batch of training inputs shaped like [batch_size, channels] y: tf.Tensor a batch of training labels shaped like [batch_size, 1] Returns: statistics: dict a dictionary that contains logging information """ # corrupt the inputs with noise x = cont_noise(x, self.continuous_noise_std) statistics = dict() with tf.GradientTape(persistent=True) as tape: # calculate the prediction error and accuracy of the model d_pos = self.forward_model(x, training=True) mse = tf.keras.losses.mean_squared_error(y, d_pos) statistics[f'train/mse'] = mse # evaluate how correct the rank fo the model predictions are rank_corr = spearman(y[:, 0], d_pos[:, 0]) statistics[f'train/rank_corr'] = rank_corr # calculate negative samples starting from the dataset x_neg = self.outer_optimize( x, self.beta, self.outer_gradient_steps, training=False) x_neg = tf.stop_gradient(x_neg) # calculate the prediction error and accuracy of the model d_neg = self.forward_model(x_neg, training=False) conservatism = d_pos[:, 0] - d_neg[:, 0] statistics[f'train/conservatism'] = conservatism # build a lagrangian for dual descent alpha_loss = -(self.alpha * self.target_conservatism - self.alpha * conservatism) statistics[f'train/alpha'] = self.alpha # loss that combines maximum likelihood with a constraint model_loss = mse - self.alpha * conservatism total_loss = tf.reduce_mean(model_loss) alpha_loss = tf.reduce_mean(alpha_loss) # calculate gradients using the model alpha_grads = tape.gradient(alpha_loss, self.log_alpha) model_grads = tape.gradient( total_loss, self.forward_model.trainable_variables) # take gradient steps on the model self.alpha_opt.apply_gradients([[alpha_grads, self.log_alpha]]) self.forward_model_opt.apply_gradients(zip( model_grads, self.forward_model.trainable_variables)) return statistics
def train_step(self, x, y): """Perform a training step of gradient descent on the loss function of a conservative objective model Args: x: tf.Tensor a batch of training inputs shaped like [batch_size, channels] y: tf.Tensor a batch of training labels shaped like [batch_size, 1] Returns: statistics: dict a dictionary that contains logging information """ self.step.assign_add(1) statistics = dict() batch_dim = tf.shape(y)[0] # corrupt the inputs with noise x = cont_noise(x, self.continuous_noise_std) with tf.GradientTape(persistent=True) as tape: # calculate the prediction error and accuracy of the model d = self.forward_model(x, training=True) mse = tf.keras.losses.mean_squared_error(y, d) statistics[f'train/mse'] = mse # evaluate how correct the rank fo the model predictions are rank_corr = spearman(y[:, 0], d[:, 0]) statistics[f'train/rank_corr'] = rank_corr # calculate negative samples starting from the dataset x_pos = x x_pos = tf.where(tf.random.uniform([batch_dim] + [1 for _ in x.shape[1:]]) < self.negatives_fraction, x_pos, self.solution[:batch_dim]) x_neg = self.lookahead(x_pos, self.lookahead_steps, training=False) if not self.lookahead_backprop: x_neg = tf.stop_gradient(x_neg) # calculate the prediction error and accuracy of the model d_pos = self.forward_model( {"dataset": x, "mix": x_pos, "solution": self.solution[:batch_dim]} [self.constraint_type], training=False) d_neg = self.forward_model(x_neg, training=False) conservatism = d_neg[:, 0] - d_pos[:, 0] statistics[f'train/conservatism'] = conservatism # build a lagrangian for dual descent alpha_loss = (self.alpha * self.target_conservatism - self.alpha * conservatism) statistics[f'train/alpha'] = self.alpha multiplier_loss = 0.0 last_weight = self.forward_model.trainable_variables[-1] if tf.shape(tf.reshape(last_weight, [-1]))[0] == 1: statistics[f'train/tanh_multipier'] = \ self.forward_model.trainable_variables[-1] # loss that combines maximum likelihood with a constraint model_loss = mse + self.alpha * conservatism + multiplier_loss total_loss = tf.reduce_mean(model_loss) alpha_loss = tf.reduce_mean(alpha_loss) # initialize stateful variables at the first iteration if self.particle_loss is None: initialization = tf.zeros_like(conservatism) self.particle_loss = tf.Variable(initialization) self.particle_constraint = tf.Variable(initialization) # calculate gradients using the model alpha_grads = tape.gradient(alpha_loss, self.log_alpha) model_grads = tape.gradient( total_loss, self.forward_model.trainable_variables) # occasionally take gradient ascent steps on the solution if tf.logical_and( tf.equal(tf.math.mod(self.step, self.solver_interval), 0), tf.math.greater_equal(self.step, self.solver_warmup)): with tf.GradientTape() as tape: # take gradient steps on the model self.alpha_opt.apply_gradients([[alpha_grads, self.log_alpha]]) self.forward_model_opt.apply_gradients( zip(model_grads, self.forward_model.trainable_variables)) # calculate the predicted score of the current solution current_score_new_model = self.forward_model( self.solution, training=False)[:, 0] # look into the future and evaluate future solutions future_new_model = self.lookahead( self.solution, self.solver_steps, training=False) future_score_new_model = self.forward_model( future_new_model, training=False)[:, 0] # evaluate the conservatism of the current solution particle_loss = (self.solver_beta * future_score_new_model - current_score_new_model) update = (self.solution - self.solver_lr * tape.gradient(particle_loss, self.solution)) # if optimizer conservatism passes threshold stop optimizing self.solution.assign(tf.where(self.done, self.solution, update)) self.particle_loss.assign(particle_loss) self.particle_constraint.assign( future_score_new_model - current_score_new_model) else: # take gradient steps on the model self.alpha_opt.apply_gradients([[alpha_grads, self.log_alpha]]) self.forward_model_opt.apply_gradients( zip(model_grads, self.forward_model.trainable_variables)) statistics[f'train/done'] = tf.cast(self.done, tf.float32) statistics[f'train/particle_loss'] = self.particle_loss statistics[f'train/particle_constraint'] = self.particle_constraint return statistics
def validate_step(self, x_real, y_real): """Perform a validation step for a generator and a discriminator using a least squares objective function Args: x_real: tf.Tensor a batch of validation inputs shaped like [batch_size, channels] y_real: tf.Tensor a batch of validation labels shaped like [batch_size, 1] Returns: statistics: dict a dictionary that contains logging information """ statistics = dict() batch_dim = tf.shape(y_real)[0] # corrupt the inputs with noise if self.is_discrete: x_real = disc_noise(x_real, keep=self.keep, temp=self.temp) else: x_real = cont_noise(x_real, self.noise_std) # evaluate the discriminator on generated samples x_fake = self.generator.sample(y_real, temp=self.temp, training=False) p_fake, d_fake, acc_fake = self.discriminator.loss(x_fake, y_real, tf.zeros( [batch_dim, 1]), training=False) statistics[f'generator/validate/x_fake'] = x_fake statistics[f'generator/validate/y_real'] = y_real statistics[f'discriminator/validate/p_fake'] = p_fake statistics[f'discriminator/validate/d_fake'] = d_fake statistics[f'discriminator/validate/acc_fake'] = acc_fake x_pair = tf.zeros_like(x_fake) p_pair = tf.zeros_like(p_fake) d_pair = tf.zeros_like(d_fake) acc_pair = tf.zeros_like(acc_fake) if self.fake_pair_frac > 0: # evaluate the discriminator on fake pairs of real inputs x_pair = tf.random.shuffle(x_real) p_pair, d_pair, acc_pair = self.discriminator.loss( x_pair, y_real, tf.zeros([batch_dim, 1]), training=False) statistics[f'generator/validate/x_pair'] = x_pair statistics[f'discriminator/validate/p_pair'] = p_pair statistics[f'discriminator/validate/d_pair'] = d_pair statistics[f'discriminator/validate/acc_pair'] = acc_pair x_pool = tf.zeros_like(x_fake) p_pool = tf.zeros_like(p_fake) d_pool = tf.zeros_like(d_fake) acc_pool = tf.zeros_like(acc_fake) if self.pool.size > batch_dim and self.pool_frac > 0: # evaluate discriminator on samples from a replay buffer x_pool, y_pool = self.pool.sample(batch_dim) p_pool, d_pool, acc_pool = self.discriminator.loss( x_pool, y_pool, tf.zeros([batch_dim, 1]), training=False) statistics[f'generator/validate/x_pool'] = x_pool statistics[f'discriminator/validate/p_pool'] = p_pool statistics[f'discriminator/validate/d_pool'] = d_pool statistics[f'discriminator/validate/acc_pool'] = acc_pool # evaluate the discriminator on real inputs p_real, d_real, acc_real = self.discriminator.loss(x_real, y_real, tf.ones( [batch_dim, 1]), training=False) statistics[f'generator/validate/x_real'] = x_real statistics[f'discriminator/validate/p_real'] = p_real statistics[f'discriminator/validate/d_real'] = d_real statistics[f'discriminator/validate/acc_real'] = acc_real # evaluate a gradient penalty on interpolations e = tf.random.uniform([batch_dim] + [1] * (len(x_fake.shape) - 1)) x_interp = x_real * e + x_fake * (1 - e) penalty = self.discriminator.penalty(x_interp, y_real, training=False) statistics[f'discriminator/validate/neg_critic_loss'] = -(d_real + d_fake) statistics[f'discriminator/validate/penalty'] = penalty return statistics
def train_step(self, i, x_real, y_real, w): """Perform a training step for a generator and a discriminator using a least squares objective function Args: x_real: tf.Tensor a batch of training inputs shaped like [batch_size, channels] y_real: tf.Tensor a batch of training labels shaped like [batch_size, 1] w: tf.Tensor importance sampling weights shaped like [batch_size, 1] Returns: statistics: dict a dictionary that contains logging information """ statistics = dict() batch_dim = tf.shape(y_real)[0] # corrupt the inputs with noise if self.is_discrete: x_real = disc_noise(x_real, keep=self.keep, temp=self.temp) else: x_real = cont_noise(x_real, self.noise_std) with tf.GradientTape() as tape: # evaluate the discriminator on generated samples x_fake = self.generator.sample(y_real, temp=self.temp, training=False) p_fake, d_fake, acc_fake = self.discriminator.loss( x_fake, y_real, tf.zeros([batch_dim, 1]), training=False) statistics[f'generator/train/x_fake'] = x_fake statistics[f'generator/train/y_real'] = y_real statistics[f'discriminator/train/p_fake'] = p_fake statistics[f'discriminator/train/d_fake'] = d_fake statistics[f'discriminator/train/acc_fake'] = acc_fake # normalize the fake evaluation metrics d_fake = d_fake * (1.0 - self.fake_pair_frac - self.pool_frac) x_pair = tf.zeros_like(x_fake) p_pair = tf.zeros_like(p_fake) d_pair = tf.zeros_like(d_fake) acc_pair = tf.zeros_like(acc_fake) if self.fake_pair_frac > 0: # evaluate the discriminator on fake pairs of real inputs x_pair = tf.random.shuffle(x_real) p_pair, d_pair, acc_pair = self.discriminator.loss( x_pair, y_real, tf.zeros([batch_dim, 1]), training=False) # average the metrics between fake samples d_fake = d_pair * self.fake_pair_frac + d_fake statistics[f'generator/train/x_pair'] = x_pair statistics[f'discriminator/train/p_pair'] = p_pair statistics[f'discriminator/train/d_pair'] = d_pair statistics[f'discriminator/train/acc_pair'] = acc_pair x_pool = tf.zeros_like(x_fake) p_pool = tf.zeros_like(p_fake) d_pool = tf.zeros_like(d_fake) acc_pool = tf.zeros_like(acc_fake) if self.pool.size > batch_dim and self.pool_frac > 0: # evaluate discriminator on samples from a replay buffer x_pool, y_pool = self.pool.sample(batch_dim) p_pool, d_pool, acc_pool = self.discriminator.loss( x_pool, y_pool, tf.zeros([batch_dim, 1]), training=False) # average the metrics between fake samples d_fake = d_pool * self.pool_frac + d_fake statistics[f'generator/train/x_pool'] = x_pool statistics[f'discriminator/train/p_pool'] = p_pool statistics[f'discriminator/train/d_pool'] = d_pool statistics[f'discriminator/train/acc_pool'] = acc_pool if self.pool_save > 0: # possibly add more generated samples to the replay pool self.pool.insert_many(x_fake[:self.pool_save], y_real[:self.pool_save]) # evaluate the discriminator on real inputs labels = tf.cast( self.flip_frac <= tf.random.uniform([batch_dim, 1]), tf.float32) p_real, d_real, acc_real = self.discriminator.loss(x_real, y_real, labels, training=True) statistics[f'generator/train/x_real'] = x_real statistics[f'discriminator/train/p_real'] = p_real statistics[f'discriminator/train/d_real'] = d_real statistics[f'discriminator/train/acc_real'] = acc_real # evaluate a gradient penalty on interpolations e = tf.random.uniform([batch_dim] + [1] * (len(x_fake.shape) - 1)) x_interp = x_real * e + x_fake * (1 - e) penalty = self.discriminator.penalty(x_interp, y_real, training=False) statistics[f'discriminator/train/neg_critic_loss'] = -(d_real + d_fake) statistics[f'discriminator/train/penalty'] = penalty # build the total loss total_loss = tf.reduce_mean( w * (d_real + d_fake + self.penalty_weight * penalty)) var_list = self.discriminator.trainable_variables self.discriminator_optim.apply_gradients( zip(tape.gradient(total_loss, var_list), var_list)) if tf.equal(tf.math.floormod(i, self.critic_frequency), 0): with tf.GradientTape() as tape: # evaluate the discriminator on generated samples x_fake = self.generator.sample(y_real, temp=self.temp, training=True) p_fake, d_fake, acc_fake = self.discriminator.loss( x_fake, y_real, tf.ones([batch_dim, 1]), training=False) # build the total loss total_loss = tf.reduce_mean(w * d_fake) var_list = self.generator.trainable_variables self.generator_optim.apply_gradients( zip(tape.gradient(total_loss, var_list), var_list)) return statistics