def _encoder_preprocessor(self, velocity_sequence, n_nodes, n_conn, node_locations, node_connections): (senders, receivers, n_edge) = connectivity_utils.get_connectivity_for_batch_pyfunc( node_locations, node_connections, n_nodes, n_conn) node_features = [] flat_velocity_sequence = snt.MergeDims(start=1, size=2)(velocity_sequence) node_features.append(flat_velocity_sequence) edge_features = [] send = tf.gather(node_locations, senders) rec = tf.gather(node_locations, receivers) relative_displacements = send - rec edge_features.append(relative_displacements) relative_distances = tf.norm(relative_displacements, axis=-1, keepdims=True) edge_features.append(relative_distances) return gn.graphs.GraphsTuple( nodes=tf.concat(node_features, axis=-1), edges=tf.concat(edge_features, axis=-1), globals=None, # self._graph_net will appending this to nodes. n_node=n_nodes, n_edge=n_edge, senders=senders, receivers=receivers, )
def sample(self, sample_shape=(), y=None, mean=False): """Draws a sample from the learnt distribution p(x). Args: sample_shape: `int` or 0D `Tensor` giving the number of samples to return. If empty tuple (default value), 1 sample will be returned. y: Optional, the one hot label on which to condition the sample. mean: Boolean, if True the expected value of the output distribution is returned, otherwise samples from the output distribution. Returns: Sample tensor of shape `[B * N, ...]` where `B` is the batch size of the prior, `N` is the number of samples requested, and `...` represents the shape of the observations. Raises: ValueError: If both `sample_shape` and `n` are provided. ValueError: If `sample_shape` has rank > 0 or if `sample_shape` is an int that is < 1. """ with tf.name_scope('{}_sample'.format(self.scope_name)): if y is None: y = tf.to_float(self.compute_prior().sample(sample_shape)) if y.shape.ndims > 2: y = snt.MergeDims(start=0, size=y.shape.ndims - 1, name='merge_y')(y) z = self._latent_decoder(y, is_training=self._is_training) if mean: samples = self.predict(z.sample(), y).mean() else: samples = self.predict(z.sample(), y).sample() return samples
def _build(self, h, x, presence=None): """Builds the module. Args: h: Tensor of encodings of shape [B, n_enc_dims]. x: Tensor of inputs of shape [B, n_points, n_input_dims] presence: Tensor of shape [B, n_points, 1] or None; if it exists, it indicates which input points exist. Returns: A bunch of stuff. """ batch_size, n_input_points, _ = x.shape.as_list() capsule = _capsule.CapsuleLayer(self._n_caps, self._n_caps_dims, self._n_votes, **self._capsule_kwargs) res = capsule(h) res.transform = res.vote res.vote = math_ops.apply_transform(transform=res.vote) for k, v in res.items(): if v.shape.ndims > 0: res[k] = snt.MergeDims(1, 2)(v) likelihood = _capsule.OrderInvariantCapsuleLikelihood( self._n_votes, res.vote, res.scale, res.vote_presence) ll_res = likelihood(x, presence) res.update(ll_res._asdict()) # post processing mixing_probs = tf.nn.softmax(ll_res.mixing_logits, 1) prior_mixing_log_prob = tf.log(1. / n_input_points) mixing_kl = mixing_probs * (ll_res.mixing_log_prob - prior_mixing_log_prob) mixing_kl = tf.reduce_mean(tf.reduce_sum(mixing_kl, -1)) wins_per_caps = tf.one_hot(ll_res.is_from_capsule, depth=self._n_caps) if presence is not None: wins_per_caps *= tf.expand_dims(presence, -1) wins_per_caps = tf.reduce_sum(wins_per_caps, 1) has_any_wins = tf.to_float(tf.greater(wins_per_caps, 0)) should_be_active = tf.to_float(tf.greater(wins_per_caps, 1)) sparsity_loss = tf.nn.sigmoid_cross_entropy_with_logits( labels=should_be_active, logits=res.pres_logit_per_caps) sparsity_loss = tf.reduce_sum(sparsity_loss * has_any_wins, -1) sparsity_loss = tf.reduce_mean(sparsity_loss) caps_presence_prob = tf.reduce_max( tf.reshape(res.vote_presence, [batch_size, self._n_caps, self._n_votes]), 2) res.update( dict(mixing_kl=mixing_kl, sparsity_loss=sparsity_loss, caps_presence_prob=caps_presence_prob, mean_scale=tf.reduce_mean(res.scale))) return res
def _build_layers_v2(self, input_dict, num_outputs, options): delay = options["custom_options"]["delay"] assert(delay > 0) self.state_init = np.zeros([delay-1], np.int64) if not self.state_in: self.state_in = tf.placeholder(tf.int64, [None, delay-1], name="delayed_actions") delayed_actions = tf.concat([ self.state_in, tf.expand_dims(input_dict["prev_actions"], 1) ], axis=1) self.state_out = delayed_actions[:, 1:] embedded_delayed_actions = tf.one_hot(delayed_actions, num_outputs) embedded_delayed_actions = snt.MergeDims(1, 2)(embedded_delayed_actions) trunk = snt.nets.MLP( output_sizes=options["fcnet_hiddens"], activation=getattr(tf.nn, options["fcnet_activation"]), activate_final=True) inputs = tf.concat([input_dict["obs"], embedded_delayed_actions], 1) trunk_outputs = trunk(input_dict["obs"]) logits = snt.Linear(num_outputs)(trunk_outputs) return logits, trunk_outputs
def _build(self, x): batch_size = x.shape[0] img_embedding = self._encoder(x) splits = [self._n_caps_dims, self._n_features, 1] # 1 for presence n_dims = sum(splits) if self._encoder_type == 'linear': n_outputs = self._n_caps * n_dims h = snt.BatchFlatten()(img_embedding) h = snt.Linear(n_outputs)(h) else: h = snt.AddBias(bias_dims=[1, 2, 3])(img_embedding) if self._encoder_type == 'conv': h = snt.Conv2D(n_dims * self._n_caps, 1, 1)(h) h = tf.reduce_mean(h, (1, 2)) h = tf.reshape(h, [batch_size, self._n_caps, n_dims]) elif self._encoder_type == 'conv_att': h = snt.Conv2D(n_dims * self._n_caps + self._n_caps, 1, 1)(h) h = snt.MergeDims(1, 2)(h) h, a = tf.split(h, [n_dims * self._n_caps, self._n_caps], -1) h = tf.reshape(h, [batch_size, -1, n_dims, self._n_caps]) a = tf.nn.softmax(a, 1) a = tf.reshape(a, [batch_size, -1, 1, self._n_caps]) h = tf.reduce_sum(h * a, 1) else: raise ValueError('Invalid encoder type="{}".'.format( self._encoder_type)) h = tf.reshape(h, [batch_size, self._n_caps, n_dims]) pose, feature, pres_logit = tf.split(h, splits, -1) if self._n_features == 0: feature = None pres_logit = tf.squeeze(pres_logit, -1) if self._noise_scale > 0.: pres_logit += ((tf.random.uniform(pres_logit.shape) - .5) * self._noise_scale) pres = tf.nn.sigmoid(pres_logit) pose = math_ops.geometric_transform(pose, self._similarity_transform) return self.OutputTuple(pose, feature, pres, pres_logit, img_embedding)
def _build(self, data): input_x = self._img(data, False) target_x = self._img(data, prep=self._prep) batch_size = int(input_x.shape[0]) primary_caps = self._primary_encoder(input_x) pres = primary_caps.presence expanded_pres = tf.expand_dims(pres, -1) pose = primary_caps.pose input_pose = tf.concat([pose, 1. - expanded_pres], -1) input_pres = pres if self._stop_grad_caps_inpt: input_pose = tf.stop_gradient(input_pose) input_pres = tf.stop_gradient(pres) target_pose, target_pres = pose, pres if self._stop_grad_caps_target: target_pose = tf.stop_gradient(target_pose) target_pres = tf.stop_gradient(target_pres) # skip connection from the img to the higher level capsule if primary_caps.feature is not None: input_pose = tf.concat([input_pose, primary_caps.feature], -1) # try to feed presence as a separate input # and if that works, concatenate templates to poses # this is necessary for set transformer n_templates = int(primary_caps.pose.shape[1]) templates = self._primary_decoder.make_templates( n_templates, primary_caps.feature) try: if self._feed_templates: inpt_templates = templates if self._stop_grad_caps_inpt: inpt_templates = tf.stop_gradient(inpt_templates) if inpt_templates.shape[0] == 1: inpt_templates = snt.TileByDim( [0], [batch_size])(inpt_templates) inpt_templates = snt.BatchFlatten(2)(inpt_templates) pose_with_templates = tf.concat([input_pose, inpt_templates], -1) else: pose_with_templates = input_pose h = self._encoder(pose_with_templates, input_pres) except TypeError: h = self._encoder(input_pose) res = self._decoder(h, target_pose, target_pres) res.primary_presence = primary_caps.presence if self._vote_type == 'enc': primary_dec_vote = primary_caps.pose elif self._vote_type == 'soft': primary_dec_vote = res.soft_winner elif self._vote_type == 'hard': primary_dec_vote = res.winner else: raise ValueError('Invalid vote_type="{}"".'.format( self._vote_type)) if self._pres_type == 'enc': primary_dec_pres = pres elif self._pres_type == 'soft': primary_dec_pres = res.soft_winner_pres elif self._pres_type == 'hard': primary_dec_pres = res.winner_pres else: raise ValueError('Invalid pres_type="{}"".'.format( self._pres_type)) res.bottom_up_rec = self._primary_decoder( primary_caps.pose, primary_caps.presence, template_feature=primary_caps.feature, img_embedding=primary_caps.img_embedding) res.top_down_rec = self._primary_decoder( res.winner, primary_caps.presence, template_feature=primary_caps.feature, img_embedding=primary_caps.img_embedding) rec = self._primary_decoder(primary_dec_vote, primary_dec_pres, template_feature=primary_caps.feature, img_embedding=primary_caps.img_embedding) tile = snt.TileByDim([0], [res.vote.shape[1]]) tiled_presence = tile(primary_caps.presence) tiled_feature = primary_caps.feature if tiled_feature is not None: tiled_feature = tile(tiled_feature) tiled_img_embedding = tile(primary_caps.img_embedding) res.top_down_per_caps_rec = self._primary_decoder( snt.MergeDims(0, 2)(res.vote), snt.MergeDims(0, 2)(res.vote_presence) * tiled_presence, template_feature=tiled_feature, img_embedding=tiled_img_embedding) res.templates = templates res.template_pres = pres res.used_templates = rec.transformed_templates res.rec_mode = rec.pdf.mode() res.rec_mean = rec.pdf.mean() res.mse_per_pixel = tf.square(target_x - res.rec_mode) res.mse = math_ops.flat_reduce(res.mse_per_pixel) res.rec_ll_per_pixel = rec.pdf.log_prob(target_x) res.rec_ll = math_ops.flat_reduce(res.rec_ll_per_pixel) n_points = int(res.posterior_mixing_probs.shape[1]) mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs, 1) (res.posterior_within_sparsity_loss, res.posterior_between_sparsity_loss) = _capsule.sparsity_loss( self._posterior_sparsity_loss_type, mass_explained_by_capsule / n_points, num_classes=self._n_classes) (res.prior_within_sparsity_loss, res.prior_between_sparsity_loss) = _capsule.sparsity_loss( self._prior_sparsity_loss_type, res.caps_presence_prob, num_classes=self._n_classes, within_example_constant=self._prior_within_example_constant) label = self._label(data) if label is not None: res.posterior_cls_xe, res.posterior_cls_acc = probe.classification_probe( mass_explained_by_capsule, label, self._n_classes, labeled=data.get('labeled', None)) res.prior_cls_xe, res.prior_cls_acc = probe.classification_probe( res.caps_presence_prob, label, self._n_classes, labeled=data.get('labeled', None)) res.best_cls_acc = tf.maximum(res.prior_cls_acc, res.posterior_cls_acc) res.primary_caps_l1 = math_ops.flat_reduce(res.primary_presence) if self._weight_decay > 0.0: decay_losses_list = [] for var in tf.trainable_variables(): if 'w:' in var.name or 'weights:' in var.name: decay_losses_list.append(tf.nn.l2_loss(var)) res.weight_decay_loss = tf.reduce_sum(decay_losses_list) else: res.weight_decay_loss = 0.0 return res
def render_constellations(pred_points, capsule_num, canvas_size, gt_points=None, n_caps=2, gt_presence=None, pred_presence=None, caps_presence_prob=None): """Renderes predicted and ground-truth points as gaussian blobs. Args: pred_points: [B, m, 2]. capsule_num: [B, m] tensor indicating which capsule the corresponding point comes from. Plots from different capsules are plotted with different colors. Currently supported values: {0, 1, ..., 11}. canvas_size: tuple of ints gt_points: [B, k, 2]; plots ground-truth points if present. n_caps: integer, number of capsules. gt_presence: [B, k] binary tensor. pred_presence: [B, m] binary tensor. caps_presence_prob: [B, m], a tensor of presence probabilities for caps. Returns: [B, *canvas_size] tensor with plotted points """ # convert coords to be in [0, side_length] pred_points = denormalize_coords(pred_points, canvas_size, rounded=True) # render predicted points batch_size, n_points = pred_points.shape[:2].as_list() capsule_num = tf.to_float(tf.one_hot(capsule_num, depth=n_caps)) capsule_num = tf.reshape(capsule_num, [batch_size, n_points, 1, 1, n_caps, 1]) color = tf.convert_to_tensor(_COLORS[:n_caps]) color = tf.reshape(color, [1, 1, 1, 1, n_caps, 3]) * capsule_num color = tf.reduce_sum(color, -2) color = tf.squeeze(tf.squeeze(color, 3), 2) colored = render_by_scatter(canvas_size, pred_points, color, pred_presence) # Prepare a vertical separator between predicted and gt points. # Separator is composed of all supported colors and also serves as # a legend. # [b, h, w, 3] n_colors = _COLORS.shape[0] sep = tf.reshape(tf.convert_to_tensor(_COLORS), [1, 1, n_colors, 3]) n_tiles = int(colored.shape[2]) // n_colors sep = snt.TileByDim([0, 1, 3], [batch_size, 3, n_tiles])(sep) sep = tf.reshape(sep, [batch_size, 3, n_tiles * n_colors, 3]) pad = int(colored.shape[2]) - n_colors * n_tiles pad, r = pad // 2, pad % 2 if caps_presence_prob is not None: n_caps = int(caps_presence_prob.shape[1]) prob_pads = ([0, 0], [0, n_colors - n_caps]) caps_presence_prob = tf.pad(caps_presence_prob, prob_pads) zeros = tf.zeros([batch_size, 3, n_colors, n_tiles, 3], dtype=tf.float32) shape = [batch_size, 1, n_colors, 1, 1] caps_presence_prob = tf.reshape(caps_presence_prob, shape) prob_vals = snt.MergeDims(2, 2)(caps_presence_prob + zeros) sep = tf.concat([sep, tf.ones_like(sep[:, :1]), prob_vals], 1) sep = tf.pad(sep, [(0, 0), (1, 1), (pad, pad + r), (0, 0)], constant_values=1.) # render gt points if gt_points is not None: gt_points = denormalize_coords(gt_points, canvas_size, rounded=True) gt_rendered = render_by_scatter(canvas_size, gt_points, colors=None, gt_presence=gt_presence) colored = tf.where(tf.cast(colored, bool), colored, gt_rendered) colored = tf.concat([gt_rendered, sep, colored], 1) res = tf.clip_by_value(colored, 0., 1.) return res
def _init_helper(self, observation_space, action_space, config, existing_inputs=None): print(get_available_gpus()) config = dict(impala.impala.DEFAULT_CONFIG, **config) assert config["batch_mode"] == "truncate_episodes", \ "Must use `truncate_episodes` batch mode with V-trace." self.config = config self.sess = tf.get_default_session() self.grads = None imitation = config["imitation"] if imitation: T = config["sample_batch_size"] B = config["train_batch_size"] // T batch_shape = (T, B) else: batch_shape = (None, ) if isinstance(action_space, gym.spaces.Discrete): is_multidiscrete = False actions_shape = batch_shape output_hidden_shape = [action_space.n] elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True actions_shape = batch_shape + (len(action_space.nvec), ) output_hidden_shape = action_space.nvec.astype(np.int32) else: raise UnsupportedSpaceException( "Action space {} is not supported for IMPALA.".format( action_space)) if imitation: make_action_ph = lambda: ssbm_actions.make_ph( ssbm_actions.flat_repeated_config, batch_shape) actions = make_action_ph() prev_actions = make_action_ph() else: actions = tf.placeholder(tf.int64, actions_shape, name="actions") prev_actions = tf.placeholder(tf.int64, actions_shape, name="prev_actions") # Create input placeholders dones = tf.placeholder(tf.bool, batch_shape, name="dones") rewards = tf.placeholder(tf.float32, batch_shape, name="rewards") if imitation: observations = ssbm_spaces.slippi_conv_list[0].make_ph(batch_shape) else: observations = tf.placeholder(tf.float32, [None] + list(observation_space.shape)) existing_state_in = None existing_seq_lens = None # Setup the policy autoregressive = config.get("autoregressive") if autoregressive: logit_dim = 128 # not really logits else: dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) prev_rewards = tf.placeholder(tf.float32, batch_shape, name="prev_reward") self.model = HumanActionModel( { "obs": observations, "prev_actions": prev_actions, "prev_rewards": prev_rewards, "is_training": self._get_is_training_placeholder(), }, observation_space, action_space, logit_dim, self.config["model"], imitation=imitation, state_in=existing_state_in, seq_lens=existing_seq_lens) if autoregressive: action_dist = ssbm_actions.AutoRegressive( nest.map_structure(lambda conv: conv.build_dist(), ssbm_actions.flat_repeated_config), residual=config.get("residual")) actions_logp = snt.BatchApply(action_dist.logp)(self.model.outputs, actions) action_sampler, sampled_logp = snt.BatchApply(action_dist.sample)( self.model.outputs) sampled_prob = tf.exp(sampled_logp) else: dist_inputs = tf.split(self.model.outputs, output_hidden_shape, axis=-1) action_dist = dist_class(snt.MergeDims(0, 2)(dist_inputs)) int64_actions = [tf.cast(x, tf.int64) for x in actions] actions_logp = action_dist.logp(snt.MergeDims(0, 2)(int64_actions)) action_sampler = action_dist.sample() sampled_prob = action_dist.sampled_action_prob(), self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) # actual loss computation imitation_loss = -tf.reduce_mean(actions_logp) tm_values = self.model.values baseline_values = tm_values[:-1] if config.get("soft_horizon"): discounts = config["gamma"] else: discounts = tf.to_float(~dones[:-1]) * config["gamma"] td_lambda = trfl.td_lambda(state_values=baseline_values, rewards=rewards[:-1], pcontinues=discounts, bootstrap_value=tm_values[-1], lambda_=config.get("lambda", 1.)) # td_lambda.loss has shape [B] after a reduce_sum vf_loss = tf.reduce_mean(td_lambda.loss) / T self.total_loss = imitation_loss + self.config[ "vf_loss_coeff"] * vf_loss # Initialize TFPolicyGraph loss_in = [ (SampleBatch.ACTIONS, actions), (SampleBatch.DONES, dones), # (BEHAVIOUR_LOGITS, behaviour_logits), (SampleBatch.REWARDS, rewards), (SampleBatch.CUR_OBS, observations), (SampleBatch.PREV_ACTIONS, prev_actions), (SampleBatch.PREV_REWARDS, prev_rewards), ] LearningRateSchedule.__init__(self, self.config["lr"], self.config["lr_schedule"]) TFPolicyGraph.__init__( self, observation_space, action_space, self.sess, obs_input=observations, action_sampler=action_sampler, action_prob=sampled_prob, loss=self.total_loss, model=self.model, loss_inputs=loss_in, state_inputs=self.model.state_in, state_outputs=self.model.state_out, prev_action_input=prev_actions, prev_reward_input=prev_rewards, seq_lens=self.model.seq_lens, max_seq_len=self.config["model"]["max_seq_len"], batch_divisibility_req=self.config["sample_batch_size"]) self._loss_input_dict = dict(self._loss_inputs, state_in=self._state_inputs) self.sess.run(tf.global_variables_initializer()) self.stats_fetches = { LEARNER_STATS_KEY: { "cur_lr": tf.cast(self.cur_lr, tf.float64), "imitation_loss": imitation_loss, #"entropy": self.loss.entropy, "grad_gnorm": tf.global_norm(self._grads), "var_gnorm": tf.global_norm(self.var_list), "vf_loss": vf_loss, "vf_explained_var": explained_variance( tf.reshape(td_lambda.extra.discounted_returns, [-1]), tf.reshape(baseline_values, [-1])), }, "state_out": self.model.state_out, }
def _encoder_preprocessor(self, position_sequence, n_node, global_context, particle_types): # Extract important features from the position_sequence. most_recent_position = position_sequence[:, -1] velocity_sequence = time_diff(position_sequence) # Finite-difference. # Get connectivity of the graph. (senders, receivers, n_edge) = connectivity_utils.compute_connectivity_for_batch_pyfunc( most_recent_position, n_node, self._connectivity_radius) # Collect node features. node_features = [] # Normalized velocity sequence, merging spatial an time axis. velocity_stats = self._normalization_stats["velocity"] normalized_velocity_sequence = ( velocity_sequence - velocity_stats.mean) / velocity_stats.std flat_velocity_sequence = snt.MergeDims( start=1, size=2)(normalized_velocity_sequence) node_features.append(flat_velocity_sequence) # Normalized clipped distances to lower and upper boundaries. # boundaries are an array of shape [num_dimensions, 2], where the second # axis, provides the lower/upper boundaries. boundaries = tf.constant(self._boundaries, dtype=tf.float32) distance_to_lower_boundary = (most_recent_position - tf.expand_dims(boundaries[:, 0], 0)) distance_to_upper_boundary = (tf.expand_dims(boundaries[:, 1], 0) - most_recent_position) distance_to_boundaries = tf.concat( [distance_to_lower_boundary, distance_to_upper_boundary], axis=1) normalized_clipped_distance_to_boundaries = tf.clip_by_value( distance_to_boundaries / self._connectivity_radius, -1., 1.) node_features.append(normalized_clipped_distance_to_boundaries) # Particle type. if self._num_particle_types > 1: particle_type_embeddings = tf.nn.embedding_lookup( self._particle_type_embedding, particle_types) node_features.append(particle_type_embeddings) # Collect edge features. edge_features = [] # Relative displacement and distances normalized to radius normalized_relative_displacements = ( tf.gather(most_recent_position, senders) - tf.gather( most_recent_position, receivers)) / self._connectivity_radius edge_features.append(normalized_relative_displacements) normalized_relative_distances = tf.norm( normalized_relative_displacements, axis=-1, keepdims=True) edge_features.append(normalized_relative_distances) # Normalize the global context. if global_context is not None: context_stats = self._normalization_stats["context"] # Context in some datasets are all zero, so add an epsilon for numerical # stability. global_context = (global_context - context_stats.mean) / tf.math.maximum( context_stats.std, STD_EPSILON) return gn.graphs.GraphsTuple( nodes=tf.concat(node_features, axis=-1), edges=tf.concat(edge_features, axis=-1), globals= global_context, # self._graph_net will appending this to nodes. n_node=n_node, n_edge=n_edge, senders=senders, receivers=receivers, )
def _init_helper(self, observation_space, action_space, config, existing_inputs=None): config = dict(DEFAULT_CONFIG, **config) assert config["batch_mode"] == "truncate_episodes", \ "Must use `truncate_episodes` batch mode with V-trace." self.config = config self.sess = tf.get_default_session() self.grads = None imitation = config["imitation"] assert not imitation if imitation: T = config["sample_batch_size"] B = config["train_batch_size"] // T batch_shape = (T, B) else: batch_shape = (None, ) if isinstance(action_space, gym.spaces.Discrete): is_multidiscrete = False actions_shape = batch_shape output_hidden_shape = [action_space.n] elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True actions_shape = batch_shape + (len(action_space.nvec), ) output_hidden_shape = action_space.nvec.astype(np.int32) else: raise UnsupportedSpaceException( "Action space {} is not supported for IMPALA.".format( action_space)) assert is_multidiscrete if imitation: make_action_ph = lambda: ssbm_actions.make_ph( ssbm_actions.flat_repeated_config, batch_shape) actions = make_action_ph() prev_actions = make_action_ph() else: # actions are stacked "multidiscrete" actions = tf.placeholder(tf.int64, actions_shape, name="actions") prev_actions = tf.placeholder(tf.int64, actions_shape, name="prev_actions") # Create input placeholders dones = tf.placeholder(tf.bool, batch_shape, name="dones") rewards = tf.placeholder(tf.float32, batch_shape, name="rewards") if imitation: observations = ssbm_spaces.slippi_conv_list[0].make_ph(batch_shape) else: observations = tf.placeholder(tf.float32, [None] + list(observation_space.shape)) behavior_logp = tf.placeholder(tf.float32, batch_shape) existing_state_in = None existing_seq_lens = None # Setup the policy autoregressive = config.get("autoregressive") if autoregressive: logit_dim = 128 # not really logits else: dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) prev_rewards = tf.placeholder(tf.float32, batch_shape, name="prev_reward") self.model = HumanActionModel( { "obs": observations, "prev_actions": prev_actions, "prev_rewards": prev_rewards, "is_training": self._get_is_training_placeholder(), }, observation_space, action_space, logit_dim, self.config["model"], imitation=imitation, state_in=existing_state_in, seq_lens=existing_seq_lens) # HumanActionModel doesn't flatten outputs flat_outputs = snt.MergeDims(0, 2)(self.model.outputs) if autoregressive: action_dist = ssbm_actions.AutoRegressive( nest.map_structure(lambda conv: conv.build_dist(), ssbm_actions.flat_repeated_config), residual=config.get("residual")) actions_logp, actions_entropy = action_dist.logp( flat_outputs, tf.unstack(actions, axis=-1)) action_sampler, self.sampled_logp = action_dist.sample( flat_outputs) action_sampler = tf.stack( [tf.cast(t, tf.int64) for t in nest.flatten(action_sampler)], axis=-1) sampled_prob = tf.exp(self.sampled_logp) else: dist_inputs = tf.split(flat_outputs, output_hidden_shape, axis=-1) action_dist = dist_class(dist_inputs) int64_actions = [tf.cast(x, tf.int64) for x in actions] actions_logp = action_dist.logp(int64_actions) actions_entropy = action_dist.entropy() action_sampler = action_dist.sample() sampled_prob = action_dist.sampled_action_prob() self.sampled_logp = tf.log(sampled_prob) self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) def make_time_major(tensor, drop_last=False): """Swaps batch and trajectory axis. Args: tensor: A tensor or list of tensors to reshape. drop_last: A bool indicating whether to drop the last trajectory item. Returns: res: A tensor with swapped axes or a list of tensors with swapped axes. """ if isinstance(tensor, list): return [make_time_major(t, drop_last) for t in tensor] if self.model.state_init: B = tf.shape(self.model.seq_lens)[0] T = tf.shape(tensor)[0] // B else: # Important: chop the tensor into batches at known episode cut # boundaries. TODO(ekl) this is kind of a hack T = self.config["sample_batch_size"] B = tf.shape(tensor)[0] // T rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0)) # swap B and T axes res = tf.transpose( rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0])))) if drop_last: return res[:-1] return res # actual loss computation values_tm = make_time_major(self.model.value_function()) baseline_values = values_tm[:-1] actions_logp_tm = make_time_major(actions_logp, True) behavior_logp_tm = make_time_major(behavior_logp, True) log_rhos_tm = actions_logp_tm - behavior_logp_tm discounts = tf.fill(tf.shape(baseline_values), config["gamma"]) if not config.get("soft_horizon"): discounts *= tf.to_float(~make_time_major(dones, True)) vtrace_returns = vtrace.from_importance_weights( log_rhos=log_rhos_tm, discounts=discounts, rewards=make_time_major(rewards, True), values=baseline_values, bootstrap_value=values_tm[-1]) vf_loss = tf.reduce_mean( tf.squared_difference(vtrace_returns.vs, baseline_values)) pi_loss = -tf.reduce_mean( actions_logp_tm * vtrace_returns.pg_advantages) entropy_mean = tf.reduce_mean(actions_entropy) total_loss = pi_loss total_loss += self.config["vf_loss_coeff"] * vf_loss total_loss -= self.config["entropy_coeff"] * entropy_mean self.total_loss = total_loss kl_mean = -tf.reduce_mean(log_rhos_tm) # Initialize TFPolicyGraph loss_in = [ (SampleBatch.ACTIONS, actions), (SampleBatch.DONES, dones), ("behavior_logp", behavior_logp), (SampleBatch.REWARDS, rewards), (SampleBatch.CUR_OBS, observations), (SampleBatch.PREV_ACTIONS, prev_actions), (SampleBatch.PREV_REWARDS, prev_rewards), ] LearningRateSchedule.__init__(self, self.config["lr"], self.config["lr_schedule"]) TFPolicyGraph.__init__( self, observation_space, action_space, self.sess, obs_input=observations, action_sampler=action_sampler, action_prob=sampled_prob, loss=self.total_loss, model=self.model, loss_inputs=loss_in, state_inputs=self.model.state_in, state_outputs=self.model.state_out, prev_action_input=prev_actions, prev_reward_input=prev_rewards, seq_lens=self.model.seq_lens, max_seq_len=self.config["model"]["max_seq_len"], batch_divisibility_req=self.config["sample_batch_size"]) self.sess.run(tf.global_variables_initializer()) self.stats_fetches = { LEARNER_STATS_KEY: { "cur_lr": tf.cast(self.cur_lr, tf.float64), "pi_loss": pi_loss, "entropy": entropy_mean, "grad_gnorm": tf.global_norm(self._grads), "var_gnorm": tf.global_norm(self.var_list), "vf_loss": vf_loss, "vf_explained_var": explained_variance(tf.reshape(vtrace_returns.vs, [-1]), tf.reshape(baseline_values, [-1])), "kl_mean": kl_mean, }, }
def _build(self, h, x, presence=None): """Builds the module. Args: h: Tensor of encodings of shape [B, n_enc_dims]. x: Tensor of inputs of shape [B, n_points, n_input_dims] presence: Tensor of shape [B, n_points, 1] or None; if it exists, it indicates which input points exist. Returns: A bunch of stuff. """ batch_size, n_input_points, _ = x.shape.as_list() res = AttrDict( dynamic_weights_l2=tf.constant(0.) ) output_shapes = ( [1], # per-capsule presence [self._n_votes], # per-vote-presence [self._n_votes], # per-vote scale [self._n_votes, self._n_caps_dims] ) splits = [np.prod(i).astype(np.int32) for i in output_shapes] n_outputs = sum(splits) batch_mlp = neural.BatchMLP([self._n_hiddens, self._n_hiddens, n_outputs], use_bias=True) all_params = batch_mlp(h) all_params = tf.split(all_params, splits, -1) batch_shape = [batch_size, self._n_caps] all_params = [tf.reshape(i, batch_shape + s) for (i, s) in zip(all_params, output_shapes)] def add_noise(tensor): return tf.random.uniform(tensor.shape, minval=-.5, maxval=.5) * 4. res.pres_logit_per_caps = add_noise(all_params[0]) res.pres_logit_per_vote = add_noise(all_params[1]) res.scale = tf.nn.softplus(all_params[2] + .5) + 1e-6 res.vote_presence = (tf.nn.sigmoid(res.pres_logit_per_caps) * tf.nn.sigmoid(res.pres_logit_per_vote)) res.vote = all_params[3] for k, v in res.items(): if v.shape.ndims > 0: res[k] = snt.MergeDims(1, 2)(v) likelihood = _capsule.OrderInvariantCapsuleLikelihood(self._n_votes, res.vote, res.scale, res.vote_presence) ll_res = likelihood(x, presence) res.update(ll_res._asdict()) # post processing mixing_probs = tf.nn.softmax(ll_res.mixing_logits, 1) prior_mixing_log_prob = tf.log(1. / n_input_points) mixing_kl = mixing_probs * (ll_res.mixing_log_prob - prior_mixing_log_prob) mixing_kl = tf.reduce_mean(tf.reduce_sum(mixing_kl, -1)) wins_per_caps = tf.one_hot(ll_res.is_from_capsule, depth=self._n_caps) if presence is not None: wins_per_caps *= tf.expand_dims(presence, -1) wins_per_caps = tf.reduce_sum(wins_per_caps, 1) has_any_wins = tf.to_float(tf.greater(wins_per_caps, 0)) should_be_active = tf.to_float(tf.greater(wins_per_caps, 1)) sparsity_loss = tf.nn.sigmoid_cross_entropy_with_logits( labels=should_be_active, logits=res.pres_logit_per_caps) sparsity_loss = tf.reduce_sum(sparsity_loss * has_any_wins, -1) sparsity_loss = tf.reduce_mean(sparsity_loss) caps_presence_prob = tf.reduce_max( tf.reshape(res.vote_presence, [batch_size, self._n_caps, self._n_votes]), 2) res.update(dict( mixing_kl=mixing_kl, sparsity_loss=sparsity_loss, caps_presence_prob=caps_presence_prob, mean_scale=tf.reduce_mean(res.scale) )) return res
def value_function(self): return snt.MergeDims(0, 2)(self.values)