def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders.""" # Note that there might be RNN state inputs at the end of the list if self._state_inputs: num_state_inputs = len(self._state_inputs) + 1 else: num_state_inputs = 0 if len(self._loss_inputs) + num_state_inputs != len(existing_inputs): raise ValueError("Tensor list mismatch", self._loss_inputs, self._state_inputs, existing_inputs) for i, (k, v) in enumerate(self._loss_inputs): if v.shape.as_list() != existing_inputs[i].shape.as_list(): raise ValueError("Tensor shape mismatch", i, k, v.shape, existing_inputs[i].shape) # By convention, the loss inputs are followed by state inputs and then # the seq len tensor rnn_inputs = [] for i in range(len(self._state_inputs)): rnn_inputs.append(("state_in_{}".format(i), existing_inputs[len(self._loss_inputs) + i])) if rnn_inputs: rnn_inputs.append(("seq_lens", existing_inputs[-1])) input_dict = OrderedDict([(k, existing_inputs[i]) for i, (k, _) in enumerate(self._loss_inputs)] + rnn_inputs) instance = self.__class__(self.observation_space, self.action_space, self.config, existing_inputs=input_dict) loss = instance._loss_fn(instance, input_dict) if instance._stats_fn: instance._stats_fetches.update( instance._stats_fn(instance, input_dict)) TFPolicyGraph._initialize_loss( instance, loss, [(k, existing_inputs[i]) for i, (k, _) in enumerate(self._loss_inputs)]) if instance._grad_stats_fn: instance._stats_fetches.update( instance._grad_stats_fn(instance, instance._grads)) return instance
def _initialize_loss(self): def fake_array(tensor): shape = tensor.shape.as_list() shape[0] = 1 return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), SampleBatch.ACTIONS: fake_array(self._prev_action_input), SampleBatch.REWARDS: np.array([0], dtype=np.float32), SampleBatch.DONES: np.array([False], dtype=np.bool), } state_init = self.get_initial_state() for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) if state_init: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) # postprocessing might depend on variable init, so run it first here self._sess.run(tf.global_variables_initializer()) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) batch_tensors = UsageTrackingDict({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.PREV_ACTIONS, self._prev_action_input), (SampleBatch.PREV_REWARDS, self._prev_reward_input), (SampleBatch.CUR_OBS, self._obs_input), ] for k, v in postprocessed_batch.items(): if k in batch_tensors: continue elif v.dtype == np.object: continue # can't handle arbitrary objects in TF shape = (None, ) + v.shape[1:] dtype = np.float32 if v.dtype == np.float64 else v.dtype placeholder = tf.placeholder(dtype, shape=shape, name=k) batch_tensors[k] = placeholder if log_once("loss_init"): logger.info( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(batch_tensors))) loss = self._loss_fn(self, batch_tensors) if self._stats_fn: self._stats_fetches.update(self._stats_fn(self, batch_tensors)) for k in sorted(batch_tensors.accessed_keys): loss_inputs.append((k, batch_tensors[k])) TFPolicyGraph._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update(self._grad_stats_fn(self, self._grads)) self._sess.run(tf.global_variables_initializer())