def single_policy_prediction(self, state, next_action): hyp = self._hyp_tiling(n_classes=self.num_classes_kn, n_tiles=self.B * 1) # [B * 1 * n_classes_kn, n_classes_kn] s_tiled = repeat_axis(state['s'], axis=0, repeats=self.num_classes_kn) # [B, rnn] -> [B * hyp, rnn] next_action_tiled = repeat_axis(next_action, axis=0, repeats=self.num_classes_kn) # [B, loc] -> [B * hyp, loc] exp_obs_enc = self.m['VAEEncoder'].calc_prior([hyp, s_tiled, next_action_tiled], out_shp=[self.B, self.num_classes_kn]) return exp_obs_enc
def current_belief_update(current_state, new_observation, exp_obs_prior, time): """Given a new observation, and the last believes over the state, update the believes over the states. The sufficient statistic of the old state in this case is z, as the VAEencoder is class-specific. Returns: c: [B, num_classes} believe over classes based on past observations zs_post: [B, num_classes, size_z] inferred zs conditional on each class glimpse_nll_stacked: [B, num_classes] likelihood of each past observation conditional on each class """ with tf.name_scope('Belief_update'): # Infer posterior z for all hypotheses with tf.name_scope('poterior_inference_per_hyp'): class_conditional_s = tf.reshape(current_state['s'], [self.B * FLAGS.num_classes, FLAGS.size_rnn]) new_action_repeated = repeat_axis(current_state['l'], 0, FLAGS.num_classes) new_observation_repeated = repeat_axis(new_observation, 0, FLAGS.num_classes) z_post = VAEencoder.posterior_inference(one_hot_label_repeated, class_conditional_s, tf.stop_gradient(new_action_repeated), new_observation_repeated) # 2 possibilties to infer state from received observations: # i) judge by likelihood of the observations under each hypothesis # ii) train a separate model (e.g. LSTM) for infering states # TODO: CAN WE DO THIS IN AN ENCODED SPACE? posterior = VAEdecoder.decode(one_hot_label_repeated, class_conditional_s, z_post['sample'], tf.stop_gradient(new_action_repeated), new_observation_repeated) # ^= filtering, given that transitions are deterministic zs_post = tf.reshape(tf.concat([z_post['mu'], z_post['sigma']], axis=1), [self.B, FLAGS.num_classes, 2*FLAGS.size_z]) zs_post_samples = tf.reshape(z_post['sample'], [self.B, FLAGS.num_classes, FLAGS.size_z]) reconstr_post = tf.reshape(posterior['sample'], [self.B, FLAGS.num_classes, env.patch_shape_flat]) nll_post = tf.reshape(posterior['loss'], [self.B, FLAGS.num_classes]) # believes over the classes based on all past observations (uniformly weighted) with tf.name_scope('belief_update'): # TODO: THINK ABOUT THE SHAPE. PRIOR SHOULD BE FOR EACH HYP. USE new_observation_repeated? prior_nll = calculate_gaussian_nll(exp_obs_prior, new_observation) if time == 0: c = tf.nn.softmax(-prior_nll, axis=1) else: c = (1. / time) * tf.nn.softmax(-prior_nll, axis=1) + (time - 1.) / time * current_state['c'] return (c, # [B, num_classes] zs_post, # [B, num_classes, 2*z] zs_post_samples, # [B, num_classes, z] nll_post, # [B, num_classes] reconstr_post) # [B, num_classes, glimpse]
def z_tiling(): batch_sz = 3 size_z = 3 n_policies = 2 hyp = 4 dummy = np.tile( np.arange(batch_sz)[:, np.newaxis, np.newaxis], [1, hyp, size_z]) assert (dummy == [[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]], [[2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]]).all() assert dummy.shape == (batch_sz, hyp, size_z) repeated = repeat_axis( tf.constant(dummy), axis=0, repeats=n_policies) # [B, hyp, z] -> [[B] * n_policies, hyp, z] tiled = tf.reshape( repeated, [batch_sz * hyp * n_policies, size_z ]) # [[B] * n_policies, hyp, z] -> [[B * hyp] * n_policies, z] un_tiled = tf.reshape(tiled, [batch_sz, n_policies, hyp, size_z]) with tf.Session() as sess: rep, til, un_til = sess.run([repeated, tiled, un_tiled]) assert rep.shape == (n_policies * batch_sz, hyp, size_z) assert (rep == [[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]], [[2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], [[2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]]).all() # per batch-obs: for each of hyp=4, n_poicies=2 times tiled below each other assert til.shape == (n_policies * batch_sz * hyp, size_z) assert (til == [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]).all() assert un_til.shape == (batch_sz, n_policies, hyp, size_z) assert (un_til == [ [ [ [0, 0, 0], # [B, n_policies, hyp, size_z] [0, 0, 0], [0, 0, 0], [0, 0, 0] ], [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]] ], [[[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]], [[[2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], [[2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]] ]).all()
def hyp_tiling2(): num_classes = 5 hyp = repeat_axis(tf.one_hot(tf.range(num_classes), depth=num_classes), axis=0, repeats=2) with tf.Session() as sess: out = sess.run(hyp) assert (out == [[1., 0., 0., 0., 0.], [1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.], [0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.], [0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.], [0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.], [0., 0., 0., 0., 1.]]).all()
def class_cond_predictions(self, r, prior_h, prior_z, img, seen, class_believes): r_repeated = repeat_axis(r, axis=0, repeats=self.num_classes_kn) out_shp = [self.B, self.num_classes_kn] VAE_results, obs_prior = self.predict( r=r_repeated, hyp=self.hyp, prior_h=prior_h, prior_z=prior_z, out_shp=out_shp) # [B, hyp, height, width, C] VAE_results['nll'] = self.nll_loss(xhat_probs=obs_prior['mu_probs'], true_img=img, out_shp=out_shp, seen=seen, class_believes=class_believes) return VAE_results, obs_prior
def __init__(self, FLAGS, env, phase): super().__init__(FLAGS, env, phase) min_glimpses = 3 random_locations = phase['random_locations'] # tf.logical_and(self.epoch_num < FLAGS.pre_train_epochs, self.is_training) # Initialise modules n_policies = FLAGS.num_classes if FLAGS.planner == 'ActInf' else 1 policyNet = PolicyNetwork(FLAGS, self.B, n_policies) glimpseEncoder = GlimpseEncoder(FLAGS) VAEencoder = Encoder(FLAGS, env.patch_shape_flat) VAEdecoder = Decoder(FLAGS, env.patch_shape_flat) stateTransition_AC = StateTransition_AC(FLAGS.size_rnn, 2*FLAGS.size_z) fc_baseline = tf.layers.Dense(1, name='fc_baseline') submodules = {'policyNet': policyNet, 'VAEencoder': VAEencoder, 'VAEdecoder': VAEdecoder} if FLAGS.planner == 'ActInf': planner = ActInfPlanner(FLAGS, submodules, self.B, env.patch_shape_flat, self.C, stateTransition_AC) elif FLAGS.planner == 'RL': planner = REINFORCEPlanner(FLAGS, submodules, self.B, env.patch_shape_flat) else: raise ValueError('Undefined planner.') self.n_policies = planner.n_policies # variables to remember. Probably to be implemented via TensorArray out_ta = [] out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses, dynamic_size=True, name='obs')) out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses, dynamic_size=True, name='glimpse_nlls_posterior')) out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses, dynamic_size=True, name='glimpse_reconstr')) out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses, dynamic_size=True, name='zs_post')) out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses, dynamic_size=True, name='G')) out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses, dynamic_size=True, name='actions')) out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses, dynamic_size=True, name='actions_mean')) out_ta.append(tf.TensorArray(tf.int32, size=min_glimpses, dynamic_size=True, name='decisions')) out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses, dynamic_size=True, name='rewards')) out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses, dynamic_size=True, name='baselines')) out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses+1, dynamic_size=True, name='current_cs')) out_ta.append(tf.TensorArray(tf.bool, size=min_glimpses, dynamic_size=True, name='done')) out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses, dynamic_size=True, name='exp_exp_obs')) out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses, dynamic_size=True, name='exp_obs')) out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses, dynamic_size=True, name='H_exp_exp_obs')) out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses, dynamic_size=True, name='exp_H')) out_ta.append(tf.TensorArray(tf.float32, size=min_glimpses, dynamic_size=True, name='potential_actions')) ta_d = {} for i, ta in enumerate(out_ta): ta_d[ta.handle.name.split('/')[-1].replace(':0', '')] = ta # Initial values last_done = tf.zeros([self.B], dtype=tf.bool) last_decision = tf.fill([self.B], -1) # in case starting calculation after initial observation (as first location should be identical for all images) next_action, next_action_mean = policyNet.inital_loc() next_decision = tf.fill([self.B], -1) current_state = stateTransition_AC.initial_state(self.B, next_action) ta_d['current_cs'] = write_zero_out(0, ta_d['current_cs'], current_state['c'], last_done) # out of loop to not create new tensors every step one_hot_label = tf.one_hot(tf.range(FLAGS.num_classes), depth=FLAGS.num_classes) one_hot_label_repeated = repeat_axis(one_hot_label, 0, self.B) # [B * hyp, hyp] def current_belief_update(current_state, new_observation, exp_obs_prior, time): """Given a new observation, and the last believes over the state, update the believes over the states. The sufficient statistic of the old state in this case is z, as the VAEencoder is class-specific. Returns: c: [B, num_classes} believe over classes based on past observations zs_post: [B, num_classes, size_z] inferred zs conditional on each class glimpse_nll_stacked: [B, num_classes] likelihood of each past observation conditional on each class """ with tf.name_scope('Belief_update'): # Infer posterior z for all hypotheses with tf.name_scope('poterior_inference_per_hyp'): class_conditional_s = tf.reshape(current_state['s'], [self.B * FLAGS.num_classes, FLAGS.size_rnn]) new_action_repeated = repeat_axis(current_state['l'], 0, FLAGS.num_classes) new_observation_repeated = repeat_axis(new_observation, 0, FLAGS.num_classes) z_post = VAEencoder.posterior_inference(one_hot_label_repeated, class_conditional_s, tf.stop_gradient(new_action_repeated), new_observation_repeated) # 2 possibilties to infer state from received observations: # i) judge by likelihood of the observations under each hypothesis # ii) train a separate model (e.g. LSTM) for infering states # TODO: CAN WE DO THIS IN AN ENCODED SPACE? posterior = VAEdecoder.decode(one_hot_label_repeated, class_conditional_s, z_post['sample'], tf.stop_gradient(new_action_repeated), new_observation_repeated) # ^= filtering, given that transitions are deterministic zs_post = tf.reshape(tf.concat([z_post['mu'], z_post['sigma']], axis=1), [self.B, FLAGS.num_classes, 2*FLAGS.size_z]) zs_post_samples = tf.reshape(z_post['sample'], [self.B, FLAGS.num_classes, FLAGS.size_z]) reconstr_post = tf.reshape(posterior['sample'], [self.B, FLAGS.num_classes, env.patch_shape_flat]) nll_post = tf.reshape(posterior['loss'], [self.B, FLAGS.num_classes]) # believes over the classes based on all past observations (uniformly weighted) with tf.name_scope('belief_update'): # TODO: THINK ABOUT THE SHAPE. PRIOR SHOULD BE FOR EACH HYP. USE new_observation_repeated? prior_nll = calculate_gaussian_nll(exp_obs_prior, new_observation) if time == 0: c = tf.nn.softmax(-prior_nll, axis=1) else: c = (1. / time) * tf.nn.softmax(-prior_nll, axis=1) + (time - 1.) / time * current_state['c'] return (c, # [B, num_classes] zs_post, # [B, num_classes, 2*z] zs_post_samples, # [B, num_classes, z] nll_post, # [B, num_classes] reconstr_post) # [B, num_classes, glimpse] with tf.name_scope('Main_loop'): for time in range(FLAGS.num_glimpses): if time == 0: if time > 1: if random_locations: next_decision, next_action, next_action_mean, pl_records = planner.random_policy() else: next_decision, next_action, next_action_mean, next_exp_obs, pl_records = planner.planning_step(current_state, z_samples, time, self.is_training) # TODO : Could REUSE FROM PLANNING STEP current_state = stateTransition_AC([last_z, labels, next_action], last_state) observation, corr_classification_fb, done = env.step(next_action, next_decision) done = tf.logical_or(last_done, done) obs_enc = glimpseEncoder.encode(observation) current_state['c'], zs_post, z_samples, nll_posterior, reconstr_posterior = current_belief_update(current_state, obs_enc, next_exp_obs, time) # baseline = fc_baseline(tf.stop_gradient(tf.concat([current_c, tf.fill([self.B, 1], tf.cast(time, tf.float32))], axis=1))) baseline = tf.squeeze(fc_baseline(tf.stop_gradient(current_state['c'])), 1) # t=0 to T-1. ACTION RECORDING HAS TO STAY BEFORE PLANNING OR WILL BE OVERWRITTEN ta_d['obs'] = write_zero_out(time, ta_d['obs'], observation, done) ta_d['zs_post'] = write_zero_out(time, ta_d['zs_post'], zs_post, done) # [B, n_policies, size_z] ta_d['glimpse_nlls_posterior'] = write_zero_out(time, ta_d['glimpse_nlls_posterior'], nll_posterior, done) # [B, n_policies] ta_d['glimpse_reconstr'] = write_zero_out(time, ta_d['glimpse_reconstr'], reconstr_posterior, done) # for visualisation only ta_d['actions'] = write_zero_out(time, ta_d['actions'], next_action, done) # location actions, not including the decision acions ta_d['actions_mean'] = write_zero_out(time, ta_d['actions_mean'], next_action_mean, done) # location actions, not including the decision acions ta_d['baselines'] = write_zero_out(time, ta_d['baselines'], baseline, done) ta_d['done'] = ta_d['done'].write(time, done) # t=0 to T ta_d['rewards'] = write_zero_out(time, ta_d['rewards'] , corr_classification_fb, last_done) if random_locations: next_decision, next_action, next_action_mean, pl_records = planner.random_policy() else: next_decision, next_action, next_action_mean, pl_records = planner.planning_step(current_state, zs_post, z_samples, time, self.is_training) # t=1 to T for k, v in pl_records.items(): ta_d[k] = write_zero_out(time, ta_d[k], v, last_done) ta_d['current_cs'] = write_zero_out(time+1, ta_d['current_cs'], current_state['c'], last_done) # ONLY ONE t=0 TO T ta_d['decisions'] = write_zero_out(time, ta_d['decisions'], next_decision, last_done) # copy forward classification_decision = tf.where(last_done, last_decision, next_decision) # pass on to next time step last_done = done last_decision = next_decision last_z = zs_post # TODO: or should this be the sampled ones? # last_c = current_c # TODO: could also use the one from planning (new_c) or pi # last_s = current_s last_state = current_state # TODO: break loop if tf.reduce_all(last_done) (requires tf.while loop) time += 1 with tf.name_scope('Stacking'): self.obs = ta_d['obs'].stack() # [T,B,glimpse] self.actions = ta_d['actions'].stack() # [T,B,2] actions_mean = ta_d['actions_mean'].stack() # [T,B,2] self.decisions = ta_d['decisions'].stack() rewards = ta_d['rewards'].stack() done = ta_d['done'].stack() self.glimpse_nlls_posterior = ta_d['glimpse_nlls_posterior'].stack() # [T,B,hyp] zs_post = ta_d['zs_post'].stack() # [T,B,hyp,2*z] self.state_believes = ta_d['current_cs'].stack() # [T+1,B,hyp] self.G = ta_d['G'].stack() # not zero'd-out so far! bl_loc = ta_d['baselines'].stack() self.glimpse_reconstr = ta_d['glimpse_reconstr'].stack() # [T,B,hyp,glimpse] # further records for debugging self.exp_exp_obs = ta_d['exp_exp_obs'].stack() self.exp_obs = ta_d['exp_obs'].stack() self.H_exp_exp_obs = ta_d['H_exp_exp_obs'].stack() self.exp_H = ta_d['exp_H'].stack() self.potential_actions = ta_d['potential_actions'].stack() # [T,B,n_policies,loc] self.num_glimpses_dyn = tf.shape(self.obs)[0] T = FLAGS.num_glimpses - tf.count_nonzero(done, 0, dtype=tf.float32) self.avg_T = tf.reduce_mean(T) with tf.name_scope('Losses'): with tf.name_scope('RL'): returns = tf.cumsum(rewards, reverse=True, axis=0) policy_losses = policyNet.REINFORCE_losses(returns, bl_loc, self.actions, actions_mean) # [T,B] policy_loss = tf.reduce_sum(tf.reduce_mean(policy_losses, 1)) baseline_mse = tf.reduce_mean(tf.square(tf.stop_gradient(returns[1:]) - bl_loc[:-1])) with tf.name_scope('Classification'): # might never make a classification decision # TODO: SHOULD I FORCE THE ACTION AT t=t TO BE A CLASSIFICATION? self.classification = classification_decision with tf.name_scope('VAE'): # mask losses of wrong hyptheses nll_posterior = tf.reduce_sum(self.glimpse_nlls_posterior, 0) # sum over time correct_hypoths = tf.cast(tf.one_hot(env.y_MC, depth=FLAGS.num_classes), tf.bool) nll_posterior = tf.where(correct_hypoths, nll_posterior, tf.zeros_like(nll_posterior)) # zero-out all but true hypothesis nll_posterior = tf.reduce_mean(nll_posterior) # mean over batch # assume N(0,1) prior model (event though atm prior never used) prior_mu = tf.fill([self.B, FLAGS.size_z], 0.) prior_sigma = tf.fill([self.B, FLAGS.size_z], 1.) zs_post_correct = tf.boolean_mask(zs_post, correct_hypoths, axis=1) post_mu, post_sigma = tf.split(zs_post_correct, 2, axis=2) # KL_div = T * VAEencoder.kl_div_normal(post_mu, post_sigma, prior_mu, prior_sigma) # NOTE: "T *" is wrong as T is [self.B]. Incorporat before reducing to a scalar N_post = tfd.Normal(loc=post_mu, scale=post_sigma) N_prior = tfd.Normal(loc=prior_mu, scale=prior_sigma) KL_div = N_post.kl_divergence(N_prior) KL_div = tf.where(tf.tile(done[:, :, tf.newaxis], [1, 1, FLAGS.size_z]), tf.zeros_like(KL_div), KL_div) # replace those that are done KL_div = tf.reduce_mean(tf.reduce_sum(KL_div, 0)) # TODO: SCALE LOSSES DIFFERENTLY? (only necessary if they flow into the same weights, might not be the case so far) self.loss = policy_loss + baseline_mse + nll_posterior + KL_div with tf.variable_scope('Optimizer'): if random_locations: pretrain_vars = VAEencoder.trainable + VAEdecoder.trainable self.train_op, gradient_check_Pre, _ = self._create_train_op(FLAGS, nll_posterior + KL_div, self.global_step, varlist=pretrain_vars) else: self.train_op, gradient_check_F, _ = self._create_train_op(FLAGS, self.loss, self.global_step) with tf.name_scope('Summaries'): metrics_upd_coll = "streaming_updates" scalars = {'loss/loss': self.loss, 'loss/accuracy': tf.reduce_mean(tf.cast(tf.equal(classification_decision, self.y_MC), tf.float32)), 'loss/VAE_nll_posterior': nll_posterior, 'loss/VAE_KL_div': KL_div, 'loss/RL_loc_baseline_mse': tf.reduce_mean(baseline_mse), 'loss/RL_policy_loss': policy_loss, 'loss/RL_returns': tf.reduce_mean(returns), 'misc/T': self.avg_T, 'misc/share_no_decision': tf.count_nonzero(tf.equal(classification_decision, -1), dtype=tf.float32) / tf.cast(self.B, tf.float32)} for name, scalar in scalars.items(): tf.summary.scalar(name, scalar) tf.metrics.mean(scalar, name=name, updates_collections=metrics_upd_coll) self.metrics_update = tf.get_collection(metrics_upd_coll) self.metrics_names = [v.name.replace('_1/update_op:0', '').replace('Summaries/', '') for v in self.metrics_update] self.summary = tf.summary.merge_all() self.glimpses_composed = env.composed_glimpse(FLAGS, self.obs, self.num_glimpses_dyn) self.acc = tf.reduce_mean(tf.cast(tf.equal(classification_decision, self.y_MC), tf.float32)) # only to get easy direct intermendiate outputs self.saver = self._create_saver(phase)
def planning_step(self, current_state, time, is_training, rnd_loc_eval): """Perform one planning step. Args: current state Returns: Next state """ with tf.name_scope( 'Planning_loop/' ): # loop over policies, parallised into [B * self.n_policies, ...] # TODO: define inputs for policyNet (and use the same in reinforce-planner if using it for pre-training)` # inputs = [current_state['s'], tf.fill([self.B, 1], tf.cast(time, tf.float32))] if self.n_policies == 1: # 'G1' inputs = [current_state['s'], current_state['c']] elif self.rl_reward == 'clf': assert self.n_policies == self.num_classes_kn inputs = [ repeat_axis(current_state['s'], axis=0, repeats=self.num_classes_kn), self.policy_dep_input ] elif self.rl_reward == 'G': assert self.n_policies == self.num_classes_kn boosted_hyp = repeat_axis( current_state['c'], axis=0, repeats=self.n_policies) # [B * n_policies, num_classes] # increase each 'hypothesis'-policy by 50% and renormalise boosted_hyp += self.policy_dep_input * boosted_hyp * 0.5 def re_normalise(x, axis=-1): return x / tf.reduce_sum(x, axis=axis, keep_dims=True) boosted_hyp = re_normalise(boosted_hyp, axis=-1) inputs = [ repeat_axis(current_state['s'], axis=0, repeats=self.n_policies), boosted_hyp ] else: raise ValueError( 'Unknown policy strategies', 'n_policies: {}, rl_reward: {}'.format( self.n_policies, self.rl_reward)) # select locations to evaluate next_actions, next_actions_mean = self._location_planning( inputs, is_training, rnd_loc_eval) with tf.name_scope('Hypotheses_loop/' ): # for every action: loop over hypotheses s_tiled = repeat_axis( current_state['s'], axis=0, repeats=self.n_policies * self.num_classes_kn ) # [B, rnn] -> [B * n_policies * hyp, rnn] # TODO: THIS MIGHT HAVE TO CHANGE IF POLICIES DEPEND ON CLASS-CONDITIONAL Z next_actions_tiled = repeat_axis( tf.reshape(next_actions, [self.B * self.n_policies, self.loc_dim]), axis=0, repeats=self.num_classes_kn ) # [B, n_policies, loc] -> [B * n_policies, loc] -> [B * n_policies * hyp, loc] exp_obs_prior_enc = self.m['VAEEncoder'].calc_prior( [self.hyp, s_tiled, next_actions_tiled], out_shp=[self.B, self.n_policies, self.num_classes_kn]) if not self.use_pixel_obs_FE: exp_obs_prior_logits = exp_obs_prior_enc['mu'] exp_obs_prior_sigma = exp_obs_prior_enc['sigma'] sample = exp_obs_prior_enc['sample'] else: exp_obs_prior = self.m['VAEDecoder'].decode( [ tf.reshape(exp_obs_prior_enc['sample'], [-1] + self.m['VAEEncoder'].output_shape_flat), next_actions_tiled ], out_shp=[self.B, self.n_policies, self.num_classes_kn]) exp_obs_prior_logits = exp_obs_prior['mu_logits'] exp_obs_prior_sigma = exp_obs_prior['sigma'] sample = exp_obs_prior['sample'] G_obs, exp_exp_obs, exp_H, H_exp_exp_obs = self.calc_G_obs_prePreferences( exp_obs_prior_logits, exp_obs_prior_sigma, c_believes=current_state['c']) # For all non-decision actions the probability of classifying is 0, hence the probability of an observation is 1 preference_error_obs = 1. * self.C[time, 0] G_obs += preference_error_obs # decision actions G_dec = self._G_decision(time, current_state['c']) G = tf.concat([G_obs, G_dec[:, tf.newaxis]], axis=1) # action selection decision, selected_action, selected_action_mean, selected_exp_obs_enc, selected_action_idx = self._action_selection( next_actions, next_actions_mean, current_state, G, exp_obs_prior_enc, time, is_training) if self.rl_reward == 'G': boosted_hyp = tf.reshape( boosted_hyp, [self.B, self.n_policies, self.num_classes_kn]) rewards_Gobs, _, rewards_exp_H, rewards_H_exp_exp_obs = self.calc_G_obs_prePreferences( exp_obs_prior_logits, exp_obs_prior_sigma, c_believes_tiled=boosted_hyp) r = rewards_Gobs else: r = tf.zeros([self.B, self.n_policies]) records = { 'G': G, 'exp_obs': sample, # [B, n_policies, num_classes, z] 'exp_exp_obs': exp_exp_obs, # [B, n_policies, z] 'H_exp_exp_obs': H_exp_exp_obs, 'exp_H': exp_H, 'potential_actions': next_actions[:, tf.newaxis, :] if (self.n_policies == 1) else next_actions, 'selected_action_idx': selected_action_idx, 'rewards_Gobs': r } return decision, selected_action, selected_action_mean, selected_exp_obs_enc, records