Exemple #1
0
    def build_graph(self):
        """Builds the neural network graph."""

        # define graph
        self.g = tf.Graph()
        with self.g.as_default():

            # create and store a new session for the graph
            self.sess = tf.Session()

            # define placeholders
            self.x = tf.placeholder(shape=[None, self.dim_input],
                                    dtype=tf.float32)
            self.y = tf.placeholder(shape=[None, self.num_classes],
                                    dtype=tf.float32)

            # define simple model
            with tf.variable_scope('last_layer'):
                self.z = tf.layers.dense(inputs=self.x, units=self.num_classes)

            self.loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.y,
                                                           logits=self.z))

            self.output_probs = tf.nn.softmax(self.z)

            # Variables of the last layer
            self.ll_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            self.ll_vars_concat = tf.concat(
                [self.ll_vars[0],
                 tf.expand_dims(self.ll_vars[1], axis=0)], 0)

            # Summary
            _variable_summaries(self.ll_vars_concat)

            # saving the weights of last layer when running bootstrap algorithm
            self.saver = tf.train.Saver(var_list=self.ll_vars)

            self.gd_opt = tf.train.GradientDescentOptimizer(self.step_size)

            # SGD optimizer for the last layer
            grads_vars_sgd = self.gd_opt.compute_gradients(self.loss)
            self.train_op = self.gd_opt.apply_gradients(grads_vars_sgd)

            for g, v in grads_vars_sgd:
                if g is not None:
                    s = list(v.name)
                    s[v.name.rindex(':')] = '_'
                    tf.summary.histogram(''.join(s) + '/grad_hist_boot_sgd', g)

            # Merge all the summaries and write them out
            self.all_summaries = tf.summary.merge_all()
            location = os.path.join(self.working_dir, 'logs')
            self.writer = tf.summary.FileWriter(location, graph=self.g)

            saver_network = tf.train.Saver(var_list=self.ll_vars)
            print('Loading the network...')
            # Restores from checkpoint
            saver_network.restore(self.sess, self.model_dir)
            print('Graph successfully loaded.')
Exemple #2
0
 def _set_up(self, eval_mode):
     """Sets up the runner by creating and initializing the agent."""
     # Reset the tf default graph to avoid name collisions from previous runs
     # before doing anything else.
     tf.reset_default_graph()
     self._summary_writer = tf.summary.FileWriter(self._output_dir)
     if self._episode_log_file:
         self._episode_writer = tf.io.TFRecordWriter(
             os.path.join(self._output_dir, self._episode_log_file))
     # Set up a session and initialize variables.
     self._sess = tf.Session(config=tf.ConfigProto(
         allow_soft_placement=True))
     self._agent = self._create_agent_fn(
         self._sess,
         self._env,
         summary_writer=self._summary_writer,
         eval_mode=eval_mode)
     # type check: env/agent must both be multi- or single-user
     if self._agent.multi_user and not isinstance(
             self._env.environment, environment.MultiUserEnvironment):
         raise ValueError(
             'Multi-user agent requires multi-user environment.')
     if not self._agent.multi_user and isinstance(
             self._env.environment, environment.MultiUserEnvironment):
         raise ValueError(
             'Single-user agent requires single-user environment.')
     self._summary_writer.add_graph(graph=tf.get_default_graph())
     self._sess.run(tf.global_variables_initializer())
     self._sess.run(tf.local_variables_initializer())
Exemple #3
0
def init_sess(var_list=None, path=None):
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    saver = None
    writer = None
    if var_list is not None:
        saver = tf.train.Saver(var_list=var_list, filename='model.ckpt')
    if path is not None:
        writer = tf.summary.FileWriter(logdir=path)
    return sess, saver, writer
Exemple #4
0
 def _set_up(self, eval_mode):
     """Sets up the runner by creating and initializing the agent."""
     # Reset the tf default graph to avoid name collisions from previous runs
     # before doing anything else.
     tf.reset_default_graph()
     self._summary_writer = tf.summary.FileWriter(self._output_dir)
     if self._episode_log_file:
         self._episode_writer = tf.python_io.TFRecordWriter(
             os.path.join(self._output_dir, self._episode_log_file))
     # Set up a session and initialize variables.
     self._sess = tf.Session(config=tf.ConfigProto(
         allow_soft_placement=True))
     self._agent = self._create_agent_fn(
         self._sess,
         self._env,
         summary_writer=self._summary_writer,
         eval_mode=eval_mode)
     self._summary_writer.add_graph(graph=tf.get_default_graph())
     self._sess.run(tf.global_variables_initializer())
     self._sess.run(tf.local_variables_initializer())
    def generate_episodes(self,
                          sampler,
                          num_episodes,
                          shuffle=True,
                          shuffle_seed=None):
        dataset_spec = sampler.dataset_spec
        split = sampler.split
        if shuffle:
            shuffle_buffer_size = self.shuffle_buffer_size
        else:
            shuffle_buffer_size = 0

        episode_reader = DummyEpisodeReader(dataset_spec, split,
                                            shuffle_buffer_size,
                                            self.read_buffer_size_bytes)
        input_pipeline = episode_reader.create_dataset_input_pipeline(
            sampler, shuffle_seed=shuffle_seed)
        iterator = input_pipeline.make_one_shot_iterator()
        next_element = iterator.get_next()
        with tf.Session() as sess:
            episodes = [sess.run(next_element) for _ in range(num_episodes)]
        return episodes
Exemple #6
0
    def build_graph(self):
        """Builds the neural network graph."""

        # define graph
        self.g = tf.Graph()
        with self.g.as_default():

            # create and store a new session for the graph
            self.sess = tf.Session()

            # define placeholders
            self.x = tf.placeholder(shape=[None, self.dim_input],
                                    dtype=tf.float32)
            self.y = tf.placeholder(shape=[None, self.num_classes],
                                    dtype=tf.float32)

            # linear layer(WX + b)
            with tf.variable_scope('last_layer/dense') as scope:
                weights = tf.get_variable('kernel',
                                          [self.dim_input, self.num_classes],
                                          dtype=tf.float32)
                biases = tf.get_variable('bias', [self.num_classes],
                                         dtype=tf.float32)
                wb = tf.concat([weights, tf.expand_dims(biases, axis=0)], 0)
                wb_renorm = tf.matmul(self.sigma_half_inv, wb)
                weights_renorm = wb_renorm[:self.dim_input, :]
                biases_renorm = wb_renorm[-1, :]
                self.z = tf.add(tf.matmul(self.x, weights_renorm),
                                biases_renorm,
                                name=scope.name)

            # Gaussian prior
            # prior = tf.nn.l2_loss(weights) + tf.nn.l2_loss(biases)

            # Non normalized loss, because of the preconditioning
            self.loss = self.n * tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.y,
                                                           logits=self.z))

            # Bayesian loss
            self.bayesian_loss = self.loss  # + prior

            self.output_probs = tf.nn.softmax(self.z)

            # Variables of the last layer
            self.ll_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            self.ll_vars_concat = tf.concat(
                [self.ll_vars[0],
                 tf.expand_dims(self.ll_vars[1], axis=0)], 0)

            # Summary
            _variable_summaries(self.ll_vars_concat)

            # saving the weights of last layer when running SGLD/SGD/MCMC algorithm
            self.saver = tf.train.Saver(var_list=self.ll_vars,
                                        max_to_keep=self.num_samples)

            self.gd_opt = tf.train.GradientDescentOptimizer(self.step_size)
            # SGLD optimizer for the last layer
            if self.sampler in ['sgld', 'lmc']:
                grads_vars = self.gd_opt.compute_gradients(self.bayesian_loss)
                grads_vars_sgld = []

                for g, v in grads_vars:
                    if g is not None:
                        s = list(v.name)
                        s[v.name.rindex(':')] = '_'
                        # Adding Gaussian noise to the gradient
                        gaussian_noise = (np.sqrt(2. / self.step_size) *
                                          tf.random_normal(tf.shape(g)))
                        g_sgld = g + gaussian_noise
                        tf.summary.histogram(''.join(s) + '/grad_hist_mcmc', g)
                        tf.summary.histogram(
                            ''.join(s) + '/gaussian_noise_hist_mcmc',
                            gaussian_noise)
                        tf.summary.histogram(
                            ''.join(s) + '/grad_total_hist_mcmc', g_sgld)
                        grads_vars_sgld.append((g_sgld, v))

                self.train_op = self.gd_opt.apply_gradients(grads_vars_sgld)

            # SGD optimizer for the last layer
            if self.sampler == 'sgd':
                grads_vars_sgd = self.gd_opt.compute_gradients(self.loss)
                self.train_op = self.gd_opt.apply_gradients(grads_vars_sgd)

                for g, v in grads_vars_sgd:
                    if g is not None:
                        s = list(v.name)
                        s[v.name.rindex(':')] = '_'
                        tf.summary.histogram(''.join(s) + '/grad_hist_sgd', g)

            # Merge all the summaries and write them out
            self.all_summaries = tf.summary.merge_all()
            location = os.path.join(self.working_dir, 'logs')
            self.writer = tf.summary.FileWriter(location, graph=self.g)

            saver_network = tf.train.Saver(var_list=self.ll_vars)
            print('loading the network ...')
            # Restores from checkpoint
            saver_network.restore(self.sess, self.model_dir)
            print('Graph successfully loaded.')
Exemple #7
0
    def __init__(self,
                 base_dir,
                 data_load_fn=load_data,
                 checkpoint_file_prefix='ckpt',
                 logging_file_prefix='log',
                 log_every_n=1,
                 num_iterations=200,
                 training_steps=250,
                 batch_size=100,
                 evaluation_inputs=None,
                 evaluation_size=None):
        """Initialize the Runner object in charge of running a full experiment.

    Args:
      base_dir: str, the base directory to host all required sub-directories.
      data_load_fn: function that returns data as a tuple (inputs, outputs).
      checkpoint_file_prefix: str, the prefix to use for checkpoint files.
      logging_file_prefix: str, prefix to use for the log files.
      log_every_n: int, the frequency for writing logs.
      num_iterations: int, the iteration number threshold (must be greater than
        start_iteration).
      training_steps: int, the number of training steps to perform.
      batch_size: int, batch size used for the training.
      evaluation_inputs: tuple of inputs to the generator that can be used
        during qualitative evaluation. If None, inputs set passed above will
        be used.
      evaluation_size: int, the number of images that should be generated
        randomly sampling from the data specified in evaluation_inputs. If
        None, all evaluation_inputs are generated.

    This constructor will take the following actions:
    - Initialize a `tf.Session`.
    - Initialize a logger.
    - Initialize a generator.
    - Reload from the latest checkpoint, if available, and initialize the
      Checkpointer object.
    """
        assert base_dir is not None
        inputs, data_to_generate = data_load_fn()
        assert inputs is None or inputs.shape[0] == data_to_generate.shape[0]
        assert evaluation_inputs is None or \
               evaluation_inputs.shape[1:] == inputs.shape[1:]
        assert evaluation_inputs is not None or evaluation_size is not None, \
               'Either evaluation_inputs or evaluation_size has to be initialised.'

        self._logging_file_prefix = logging_file_prefix
        self._log_every_n = log_every_n
        self._data_to_generate = data_to_generate
        self._inputs = inputs
        self._num_iterations = num_iterations
        self._training_steps = training_steps
        self._batch_size = batch_size
        self._evaluation_inputs = evaluation_inputs
        if self._evaluation_inputs is None:
            self._evaluation_inputs = inputs
        self._evaluation_size = evaluation_size
        self._base_dir = base_dir
        self._create_directories()
        self._summary_writer = tf.summary.FileWriter(self._base_dir)

        config = tf.ConfigProto(allow_soft_placement=True)
        # Allocate only subset of the GPU memory as needed which allows for running
        # multiple workers on the same GPU.
        config.gpu_options.allow_growth = True
        # Set up a session and initialize variables.
        self._sess = tf.Session('', config=config)
        self._generator = create_generator(self._sess,
                                           data_to_generate,
                                           inputs,
                                           summary_writer=self._summary_writer)
        self._summary_writer.add_graph(graph=tf.get_default_graph())
        self._sess.run(tf.global_variables_initializer())

        self._initialize_checkpointer_and_maybe_resume(checkpoint_file_prefix)
    def __init__(self,
                 num_actions=None,
                 observation_size=None,
                 num_players=None,
                 gamma=0.99,
                 update_horizon=1,
                 min_replay_history=500,
                 update_period=4,
                 stack_size=1,
                 target_update_period=500,
                 epsilon_fn=linearly_decaying_epsilon,
                 epsilon_train=0.02,
                 epsilon_eval=0.001,
                 epsilon_decay_period=1000,
                 graph_template=dqn_template,
                 tf_device='/cpu:*',
                 use_staging=True,
                 optimizer=tf.train.RMSPropOptimizer(learning_rate=.0025,
                                                     decay=0.95,
                                                     momentum=0.0,
                                                     epsilon=1e-6,
                                                     centered=True)):
        """Initializes the agent and constructs its graph.

    Args:
      num_actions: int, number of actions the agent can take at any state.
      observation_size: int, size of observation vector.
      num_players: int, number of players playing this game.
      gamma: float, discount factor as commonly used in the RL literature.
      update_horizon: int, horizon at which updates are performed, the 'n' in
        n-step update.
      min_replay_history: int, number of stored transitions before training.
      update_period: int, period between DQN updates.
      stack_size: int, number of observations to use as state.
      target_update_period: Update period for the target network.
      epsilon_fn: Function expecting 4 parameters: (decay_period, step,
        warmup_steps, epsilon), and which returns the epsilon value used for
        exploration during training.
      epsilon_train: float, final epsilon for training.
      epsilon_eval: float, epsilon during evaluation.
      epsilon_decay_period: int, number of steps for epsilon to decay.
      graph_template: function for building the neural network graph.
      tf_device: str, Tensorflow device on which to run computations.
      use_staging: bool, when True use a staging area to prefetch the next
        sampling batch.
      optimizer: Optimizer instance used for learning.
    """

        self.partial_reload = False

        tf.logging.info('Creating %s agent with the following parameters:',
                        self.__class__.__name__)
        tf.logging.info('\t gamma: %f', gamma)
        tf.logging.info('\t update_horizon: %f', update_horizon)
        tf.logging.info('\t min_replay_history: %d', min_replay_history)
        tf.logging.info('\t update_period: %d', update_period)
        tf.logging.info('\t target_update_period: %d', target_update_period)
        tf.logging.info('\t epsilon_train: %f', epsilon_train)
        tf.logging.info('\t epsilon_eval: %f', epsilon_eval)
        tf.logging.info('\t epsilon_decay_period: %d', epsilon_decay_period)
        tf.logging.info('\t tf_device: %s', tf_device)
        tf.logging.info('\t use_staging: %s', use_staging)
        tf.logging.info('\t optimizer: %s', optimizer)

        # Global variables.
        self.num_actions = num_actions
        self.observation_size = observation_size
        self.num_players = num_players
        self.gamma = gamma
        self.update_horizon = update_horizon
        self.cumulative_gamma = math.pow(gamma, update_horizon)
        self.min_replay_history = min_replay_history
        self.target_update_period = target_update_period
        self.epsilon_fn = epsilon_fn
        self.epsilon_train = epsilon_train
        self.epsilon_eval = epsilon_eval
        self.epsilon_decay_period = epsilon_decay_period
        self.update_period = update_period
        self.eval_mode = False
        self.training_steps = 0
        self.batch_staged = False
        self.optimizer = optimizer

        with tf.device(tf_device):
            # Calling online_convnet will generate a new graph as defined in
            # graph_template using whatever input is passed, but will always share
            # the same weights.
            online_convnet = tf.make_template('Online', graph_template)
            target_convnet = tf.make_template('Target', graph_template)
            # The state of the agent. The last axis is the number of past observations
            # that make up the state.
            states_shape = (1, observation_size, stack_size)
            self.state = np.zeros(states_shape)
            self.state_ph = tf.placeholder(tf.uint8,
                                           states_shape,
                                           name='state_ph')
            self.legal_actions_ph = tf.placeholder(tf.float32,
                                                   [self.num_actions],
                                                   name='legal_actions_ph')
            self._q = online_convnet(state=self.state_ph,
                                     num_actions=self.num_actions)
            self._replay = self._build_replay_memory(use_staging)
            self._replay_qs = online_convnet(self._replay.states,
                                             self.num_actions)
            self._replay_next_qt = target_convnet(self._replay.next_states,
                                                  self.num_actions)
            self._train_op = self._build_train_op()
            self._sync_qt_ops = self._build_sync_op()

            self._q_argmax = tf.argmax(self._q + self.legal_actions_ph,
                                       axis=1)[0]

        # Set up a session and initialize variables.
        self._sess = tf.Session(
            '', config=tf.ConfigProto(allow_soft_placement=True))
        self._init_op = tf.global_variables_initializer()
        self._sess.run(self._init_op)

        self._saver = tf.train.Saver(max_to_keep=3)

        # This keeps tracks of the observed transitions during play, for each
        # player.
        self.transitions = [[] for _ in range(num_players)]
Exemple #9
0
    def __init__(self,
                 base_dir,
                 create_agent_fn,
                 create_environment_fn=atari_lib.create_atari_environment,
                 checkpoint_file_prefix='ckpt',
                 logging_file_prefix='log',
                 log_every_n=1,
                 num_iterations=200,
                 training_steps=250000,
                 evaluation_steps=125000,
                 max_steps_per_episode=27000,
                 reward_clipping=(-1, 1)):
        """Initialize the Runner object in charge of running a full experiment.

    Args:
      base_dir: str, the base directory to host all required sub-directories.
      create_agent_fn: A function that takes as args a Tensorflow session and an
        environment, and returns an agent.
      create_environment_fn: A function which receives a problem name and
        creates a Gym environment for that problem (e.g. an Atari 2600 game).
      checkpoint_file_prefix: str, the prefix to use for checkpoint files.
      logging_file_prefix: str, prefix to use for the log files.
      log_every_n: int, the frequency for writing logs.
      num_iterations: int, the iteration number threshold (must be greater than
        start_iteration).
      training_steps: int, the number of training steps to perform.
      evaluation_steps: int, the number of evaluation steps to perform.
      max_steps_per_episode: int, maximum number of steps after which an episode
        terminates.
      reward_clipping: Tuple(int, int), with the minimum and maximum bounds for
        reward at each step. If `None` no clipping is applied.

    This constructor will take the following actions:
    - Initialize an environment.
    - Initialize a `tf.Session`.
    - Initialize a logger.
    - Initialize an agent.
    - Reload from the latest checkpoint, if available, and initialize the
      Checkpointer object.
    """
        assert base_dir is not None
        self._logging_file_prefix = logging_file_prefix
        self._log_every_n = log_every_n
        self._num_iterations = num_iterations
        self._training_steps = training_steps
        self._evaluation_steps = evaluation_steps
        self._max_steps_per_episode = max_steps_per_episode
        self._base_dir = base_dir
        self._create_directories()
        self._summary_writer = tf.summary.FileWriter(self._base_dir)

        self._environment = create_environment_fn()
        # Set up a session and initialize variables.
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        self._sess = tf.Session('', config=config)

        self._agent = create_agent_fn(self._sess,
                                      self._environment,
                                      summary_writer=self._summary_writer)
        self._summary_writer.add_graph(graph=tf.get_default_graph())
        self._sess.run(tf.global_variables_initializer())

        self._initialize_checkpointer_and_maybe_resume(checkpoint_file_prefix)
        self._reward_clipping = reward_clipping
Exemple #10
0
  def sample_distance_pairs(self, num_samples_per_cell=2, verbose=False):
    """Sample a set of points from each cell and compute all pairwise distances.

    This method also writes the resulting distances to disk.

    Args:
      num_samples_per_cell: int, number of samples to draw per cell.
      verbose: bool, whether to print verbose messages.
    """
    paired_states_ph = tf.placeholder(tf.float64, (1, 4),
                                      name='paired_states_ph')
    online_network = tf.make_template('Online', self._network_template)
    distance = online_network(paired_states_ph)
    saver = tf.train.Saver()
    if not self.add_noise:
      num_samples_per_cell = 1
    with tf.Session() as sess:
      saver.restore(sess, os.path.join(self.base_dir, 'tf_ckpt-239900'))
      total_samples = None
      for s_idx in range(self.num_states):
        s = self.inverse_index_states[s_idx]
        s = s.astype(np.float32)
        s += 0.5  # Place in center of cell.
        s = np.tile([s], (num_samples_per_cell, 1))
        if self.add_noise:
          sampled_noise = np.clip(
              np.random.normal(0, 0.1, size=(num_samples_per_cell, 2)),
              -0.3, 0.3)
          s += sampled_noise
        if total_samples is None:
          total_samples = s
        else:
          total_samples = np.concatenate([total_samples, s])
      num_total_samples = len(total_samples)
      distances = np.zeros((num_total_samples, num_total_samples))
      if verbose:
        tf.logging.info('Will compute distances for %d samples',
                        num_total_samples)
      for i in range(num_total_samples):
        s1 = total_samples[i]
        if verbose:
          tf.logging.info('Will compute distances from sample %d', i)
        for j in range(num_total_samples):
          s2 = total_samples[j]
          paired_states_1 = np.reshape(np.append(s1, s2), (1, 4))
          paired_states_2 = np.reshape(np.append(s2, s1), (1, 4))
          distance_np_1 = sess.run(
              distance, feed_dict={paired_states_ph: paired_states_1})
          distance_np_2 = sess.run(
              distance, feed_dict={paired_states_ph: paired_states_2})
          max_dist = max(distance_np_1, distance_np_2)
          distances[i, j] = max_dist
          distances[j, i] = max_dist
    sampled_distances = {
        'samples_per_cell': num_samples_per_cell,
        'samples': total_samples,
        'distances': distances,
    }
    file_path = os.path.join(self.base_dir, 'sampled_distances.pkl')
    with tf.gfile.GFile(file_path, 'w') as f:
      pickle.dump(sampled_distances, f)
Exemple #11
0
  def learn_metric(self, verbose=False):
    """Approximate the bisimulation metric by learning.

    Args:
      verbose: bool, whether to print verbose messages.
    """
    summary_writer = tf.summary.FileWriter(self.base_dir)
    global_step = tf.Variable(0, trainable=False)
    inc_global_step_op = tf.assign_add(global_step, 1)
    bisim_horizon = 0.0
    bisim_horizon_discount_value = 1.0
    if self.use_decayed_learning_rate:
      learning_rate = tf.train.exponential_decay(self.starting_learning_rate,
                                                 global_step,
                                                 self.num_iterations,
                                                 self.learning_rate_decay,
                                                 staircase=self.staircase)
    else:
      learning_rate = self.starting_learning_rate
    tf.summary.scalar('Learning/LearningRate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                       epsilon=self.epsilon)
    train_op = self._build_train_op(optimizer)
    sync_op = self._build_sync_op()
    eval_op = tf.stop_gradient(self._build_eval_metric())
    eval_states = []
    # Build the evaluation tensor.
    for state in range(self.num_states):
      row, col = self.inverse_index_states[state]
      # We make the evaluation states at the center of each grid cell.
      eval_states.append([row + 0.5, col + 0.5])
    eval_states = np.array(eval_states, dtype=np.float64)
    normalized_bisim_metric = (
        self.bisim_metric / np.linalg.norm(self.bisim_metric))
    metric_errors = []
    average_metric_errors = []
    normalized_metric_errors = []
    average_normalized_metric_errors = []
    saver = tf.train.Saver(max_to_keep=3)
    with tf.Session() as sess:
      summary_writer.add_graph(graph=tf.get_default_graph())
      sess.run(tf.global_variables_initializer())
      merged_summaries = tf.summary.merge_all()
      for i in range(self.num_iterations):
        sampled_states = np.random.randint(self.num_states,
                                           size=(self.batch_size,))
        sampled_actions = np.random.randint(4,
                                            size=(self.batch_size,))
        if self.add_noise:
          sampled_noise = np.clip(
              np.random.normal(0, 0.1, size=(self.batch_size, 2)),
              -0.3, 0.3)
        sampled_action_names = [self.actions[x] for x in sampled_actions]
        next_states = [self.next_states[a][s]
                       for s, a in zip(sampled_states, sampled_action_names)]
        rewards = np.array([self.rewards[a][s]
                            for s, a in zip(sampled_states,
                                            sampled_action_names)])
        states = np.array(
            [self.inverse_index_states[x] for x in sampled_states])
        next_states = np.array([self.inverse_index_states[x]
                                for x in next_states])
        states = states.astype(np.float64)
        states += 0.5  # Place points in center of grid.
        next_states = next_states.astype(np.float64)
        next_states += 0.5
        if self.add_noise:
          states += sampled_noise
          next_states += sampled_noise

        _, summary = sess.run(
            [train_op, merged_summaries],
            feed_dict={self.s1_ph: states,
                       self.s2_ph: next_states,
                       self.action_ph: sampled_actions,
                       self.rewards_ph: rewards,
                       self.bisim_horizon_ph: bisim_horizon,
                       self.eval_states_ph: eval_states})
        summary_writer.add_summary(summary, i)
        if self.double_period_halfway and i > self.num_iterations / 2.:
          self.target_update_period *= 2
          self.double_period_halfway = False
        if i % self.target_update_period == 0:
          bisim_horizon = 1.0 - bisim_horizon_discount_value
          bisim_horizon_discount_value *= self.bisim_horizon_discount
          sess.run(sync_op)
        # Now compute difference with exact metric.
        self.learned_distance = sess.run(
            eval_op, feed_dict={self.eval_states_ph: eval_states})
        self.learned_distance = np.reshape(self.learned_distance,
                                           (self.num_states, self.num_states))
        metric_difference = np.max(
            abs(self.learned_distance - self.bisim_metric))
        average_metric_difference = np.mean(
            abs(self.learned_distance - self.bisim_metric))
        normalized_learned_distance = (
            self.learned_distance / np.linalg.norm(self.learned_distance))
        normalized_metric_difference = np.max(
            abs(normalized_learned_distance - normalized_bisim_metric))
        average_normalized_metric_difference = np.mean(
            abs(normalized_learned_distance - normalized_bisim_metric))
        error_summary = tf.Summary(value=[
            tf.Summary.Value(tag='Approx/Error',
                             simple_value=metric_difference),
            tf.Summary.Value(tag='Approx/AvgError',
                             simple_value=average_metric_difference),
            tf.Summary.Value(tag='Approx/NormalizedError',
                             simple_value=normalized_metric_difference),
            tf.Summary.Value(tag='Approx/AvgNormalizedError',
                             simple_value=average_normalized_metric_difference),
        ])
        summary_writer.add_summary(error_summary, i)
        sess.run(inc_global_step_op)
        if i % 100 == 0:
          # Collect statistics every 100 steps.
          metric_errors.append(metric_difference)
          average_metric_errors.append(average_metric_difference)
          normalized_metric_errors.append(normalized_metric_difference)
          average_normalized_metric_errors.append(
              average_normalized_metric_difference)
          saver.save(sess, os.path.join(self.base_dir, 'tf_ckpt'),
                     global_step=i)
        if self.debug and i % 100 == 0:
          self.pretty_print_metric(metric_type='learned')
          print('Iteration: {}'.format(i))
          print('Metric difference: {}'.format(metric_difference))
          print('Normalized metric difference: {}'.format(
              normalized_metric_difference))
      if self.add_noise:
        # Finally, if we have noise, we draw a bunch of samples to get estimates
        # of the distances between states.
        sampled_distances = {}
        for _ in range(self.total_final_samples):
          eval_states = []
          for state in range(self.num_states):
            row, col = self.inverse_index_states[state]
            # We make the evaluation states at the center of each grid cell.
            eval_states.append([row + 0.5, col + 0.5])
          eval_states = np.array(eval_states, dtype=np.float64)
          eval_noise = np.clip(
              np.random.normal(0, 0.1, size=(self.num_states, 2)),
              -0.3, 0.3)
          eval_states += eval_noise
          distance_samples = sess.run(
              eval_op, feed_dict={self.eval_states_ph: eval_states})
          distance_samples = np.reshape(distance_samples,
                                        (self.num_states, self.num_states))
          for s1 in range(self.num_states):
            for s2 in range(self.num_states):
              sampled_distances[(tuple(eval_states[s1]),
                                 tuple(eval_states[s2]))] = (
                                     distance_samples[s1, s2])
      else:
        # Otherwise we just use the last evaluation metric.
        sampled_distances = self.learned_distance
    learned_statistics = {
        'num_iterations': self.num_iterations,
        'metric_errors': metric_errors,
        'average_metric_errors': average_metric_errors,
        'normalized_metric_errors': normalized_metric_errors,
        'average_normalized_metric_errors': average_normalized_metric_errors,
        'learned_distances': sampled_distances,
    }
    self.statistics['learned'] = learned_statistics
    if verbose:
      self.pretty_print_metric(metric_type='learned')
Exemple #12
0
    def build_graph(self):
        """Builds the neural network graph."""

        # define graph
        self.g = tf.Graph()
        with self.g.as_default():

            # create and store a new session for the graph
            self.sess = tf.Session()

            # define placeholders
            self.x = tf.placeholder(shape=[None, self.dim_input],
                                    dtype=tf.float32)
            self.y = tf.placeholder(shape=[None, self.num_classes],
                                    dtype=tf.float32)

            # define simple model
            with tf.variable_scope('last_layer'):
                self.z = tf.layers.dense(inputs=self.x, units=self.num_classes)

            self.loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.y,
                                                           logits=self.z))

            self.output_probs = tf.nn.softmax(self.z)

            # Variables of the last layer
            self.ll_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            self.ll_vars_concat = tf.concat(
                [self.ll_vars[0],
                 tf.expand_dims(self.ll_vars[1], axis=0)], 0)

            # Summary
            _variable_summaries(self.ll_vars_concat)

            # add regularization that acts as a unit Gaussian prior on the last layer
            regularizer = tf.contrib.layers.l2_regularizer(1.0)

            # regularization
            prior = tf.contrib.layers.apply_regularization(
                regularizer, self.ll_vars)
            self.bayesian_loss = self.n * self.loss + prior

            # saving the weights of last layer when running SGLD/SGD/MCMC algorithm
            self.saver = tf.train.Saver(var_list=self.ll_vars,
                                        max_to_keep=self.num_samples)

            # SGLD optimizer for the last layer
            if self.sampler in ['sgld', 'lmc']:
                step = self.step_size / self.n
                gd_opt = tf.train.GradientDescentOptimizer(step)
                grads_vars = gd_opt.compute_gradients(self.bayesian_loss)
                grads_vars_sgld = []

                for g, v in grads_vars:
                    if g is not None:
                        s = list(v.name)
                        s[v.name.rindex(':')] = '_'
                        # Adding Gaussian noise to the gradient
                        gaussian_noise = (np.sqrt(2. / step) *
                                          tf.random_normal(tf.shape(g)))
                        g_sgld = g + gaussian_noise
                        tf.summary.histogram(''.join(s) + '/grad_hist_mcmc',
                                             g / self.n)
                        tf.summary.histogram(
                            ''.join(s) + '/gaussian_noise_hist_mcmc',
                            gaussian_noise / self.n)
                        tf.summary.histogram(
                            ''.join(s) + '/grad_total_hist_mcmc',
                            g_sgld / self.n)
                        grads_vars_sgld.append((g_sgld, v))

                self.train_op = gd_opt.apply_gradients(grads_vars_sgld)

            # SGD optimizer for the last layer
            if self.sampler == 'sgd':
                gd_opt = tf.train.GradientDescentOptimizer(self.step_size)
                grads_vars_sgd = gd_opt.compute_gradients(self.loss)
                self.train_op = gd_opt.apply_gradients(grads_vars_sgd)

                for g, v in grads_vars_sgd:
                    if g is not None:
                        s = list(v.name)
                        s[v.name.rindex(':')] = '_'
                        tf.summary.histogram(''.join(s) + '/grad_hist_sgd', g)

            # Merge all the summaries and write them out
            self.all_summaries = tf.summary.merge_all()
            location = os.path.join(self.working_dir, 'logs')
            self.writer = tf.summary.FileWriter(location, graph=self.g)

            saver_network = tf.train.Saver(var_list=self.ll_vars)
            print('loading the network ...')
            # Restores from checkpoint
            # self.sess.run(tf.global_variables_initializer())
            saver_network.restore(self.sess, self.model_dir)
            print('Graph successfully loaded.')
Exemple #13
0
def compute_recall(
        ground_truth_data,
        encoder_fn,
        repr_transform_fn,
        decoder_fn,
        random_state,
        artifact_dir=None,
        #                   num_recall_samples=gin.REQUIRED,
        nhood_sizes=gin.REQUIRED,
        num_interventions_per_latent_dim=gin.REQUIRED,
        num_pca_components=gin.REQUIRED):
    """TBA

  Args:



    random_state: Numpy random state used for randomness.
    artifact_dir: Optional path to directory where artifacts can be saved.
  """
    del artifact_dir
    train_ground_truth_data, test_ground_truth_data = ground_truth_data
    ground_truth_data = train_ground_truth_data
    num_recall_samples = train_ground_truth_data.data_size
    dummy_input = ground_truth_data.sample_observations(1, random_state)
    dummy_mean, dummy_var = encoder_fn(dummy_input)

    # Samples from the normal prior
    latent_dim = repr_transform_fn(*encoder_fn(dummy_input)).shape[-1]
    latent_shape = [num_recall_samples, latent_dim]
    latent_prior_samples_np = np.random.normal(size=latent_shape)

    # Ground truth samples
    gt_train_samples = ground_truth_data.sample_observations(
        num_recall_samples, random_state)
    gt_train_repr = repr_transform_fn(*encoder_fn(gt_train_samples))
    gt_train_repr_mean = np.mean(gt_train_repr, axis=0)  # latent_shape
    gt_train_repr_std = np.std(gt_train_repr, axis=0)
    gt_train_repr_min = np.min(gt_train_repr, axis=0)
    gt_train_repr_max = np.max(gt_train_repr, axis=0)

    # The predetermined set of interventions from the estimated training prior
    fixed_trained_prior_samples_np = np.random.normal(loc=gt_train_repr_mean,
                                                      scale=gt_train_repr_std,
                                                      size=latent_shape)
    print(fixed_trained_prior_samples_np.shape, fixed_trained_prior_samples_np)

    result_d = {
        'nhoods': nhood_sizes,
        'gt_train_repr_mean': list(gt_train_repr_mean),
        'gt_train_repr_std': list(gt_train_repr_std),
        'gt_train_repr_min': list(gt_train_repr_min),
        'gt_train_repr_max': list(gt_train_repr_max)
    }
    sess = tf.Session()
    with sess.as_default():
        n_comp = min(num_recall_samples, num_pca_components)

        print('\n\n\n Computing the total recall...')
        # Sample ground truth data and vae process it
        decoded_gt_samples = decoder_fn(
            repr_transform_fn(*encoder_fn(gt_train_samples)))
        decoded_gt_samples = decoded_gt_samples.reshape(num_recall_samples, -1)
        decoded_gt_pca = PCA(n_components=n_comp)
        reduced_decoded_gt_samples = decoded_gt_pca.fit_transform(
            decoded_gt_samples)

        # Generated samples from the normal prior, processed with gt PCA
        generated_prior_samples = decoder_fn(latent_prior_samples_np)
        generated_prior_samples = generated_prior_samples.reshape(
            num_recall_samples, -1)
        reduced_generated_prior_samples = decoded_gt_pca.transform(
            generated_prior_samples)
        assert (reduced_generated_prior_samples.shape ==
                reduced_decoded_gt_samples.shape)

        # Generated samples from the estimated training prior, processed with gt PCA
        generated_trained_prior_samples = decoder_fn(
            fixed_trained_prior_samples_np)
        generated_trained_prior_samples = generated_trained_prior_samples.reshape(
            num_recall_samples, -1)
        reduced_generated_trained_prior_samples = decoded_gt_pca.transform(
            generated_trained_prior_samples)
        assert (reduced_generated_prior_samples.shape ==
                reduced_generated_trained_prior_samples.shape)

        # --- Original sharp images - discarded to now
        #    gt_train_samples = gt_samples.reshape(num_recall_samples, -1)
        #    gt_pca = PCA(n_components=n_comp)
        #    reduced_gt_samples = gt_pca.fit_transform(gt_samples)
        #    assert(reduced_gt_samples.shape == reduced_decoded_gt_samples.shape)

        #    # compute model recall: gt vs generated
        #    gt_generated_result = iprd.knn_precision_recall_features(
        #        reduced_gt_samples,
        #        reduced_generated_prior_samples,
        #        nhood_sizes=nhood_sizes,
        #        row_batch_size=500, col_batch_size=100, num_gpus=1)
        #    update_result_dict(result_d, ['gt_generated_', gt_generated_result])
        # ----

        # compute model recall: model(gt) vs normal prior generated
        decoded_gt_prior_generated_result = iprd.knn_precision_recall_features(
            reduced_decoded_gt_samples,
            reduced_generated_prior_samples,
            nhood_sizes=nhood_sizes,
            row_batch_size=500,
            col_batch_size=100,
            num_gpus=1)
        update_result_dict(
            result_d,
            ['decoded_gt_prior_generated_', decoded_gt_prior_generated_result])

        # compute model recall: model(gt) vs estimated training prior generated
        decoded_gt_trained_prior_generated_result = iprd.knn_precision_recall_features(
            reduced_decoded_gt_samples,
            reduced_generated_trained_prior_samples,
            nhood_sizes=nhood_sizes,
            row_batch_size=500,
            col_batch_size=100,
            num_gpus=1)
        update_result_dict(result_d, [
            'decoded_gt_trained_prior_generated_',
            decoded_gt_trained_prior_generated_result
        ])

        # compute model recall:normal prior generated vs estimated training prior generated
        prior_generated_trained_prior_generated_result = iprd.knn_precision_recall_features(
            reduced_generated_prior_samples,
            reduced_generated_trained_prior_samples,
            nhood_sizes=nhood_sizes,
            row_batch_size=500,
            col_batch_size=100,
            num_gpus=1)
        update_result_dict(result_d, [
            'prior_generated_trained_prior_generated_',
            prior_generated_trained_prior_generated_result
        ])

        # Choose a subset of interventions
        subset_interventions = np.random.choice(
            np.arange(num_recall_samples),
            size=num_interventions_per_latent_dim,
            replace=False)
        result_d['subset_interventions'] = list(subset_interventions)
        result_d['num_pca_comp'] = n_comp

        # Pick a latent dimension
        for dim in range(latent_dim):
            print('\n\n\n Computing the recall for latent dim ', dim)

            agg_fix_one_vs_prior_generated_result = {
                'precision': [],
                'recall': []
            }
            agg_fix_one_vs_trained_prior_generated_result = {
                'precision': [],
                'recall': []
            }
            agg_fix_one_vs_decoded_gt_result = {'precision': [], 'recall': []}

            agg_vary_one_vs_prior_generated_result = {
                'precision': [],
                'recall': []
            }
            agg_vary_one_vs_trained_prior_generated_result = {
                'precision': [],
                'recall': []
            }
            agg_vary_one_vs_decoded_gt_result = {'precision': [], 'recall': []}

            # intervene several times
            for intervention in range(num_interventions_per_latent_dim):
                inter_id = subset_interventions[intervention]
                print('n\n\n Intervention num', intervention)
                print(' Intervention row', inter_id)

                # --- fix one, vary the rest
                latent_intervention = float(
                    fixed_trained_prior_samples_np[inter_id, dim])
                fix_one_latent_from_trained_prior_samples = np.copy(
                    fixed_trained_prior_samples_np)
                fix_one_latent_from_trained_prior_samples[:,
                                                          dim] = latent_intervention
                # decode the samples and trasform them with the PCA
                gen_fix_one_latent_from_trained_prior_samples = decoder_fn(
                    fix_one_latent_from_trained_prior_samples).reshape(
                        num_recall_samples, -1)
                reduced_gen_fix_one_latent_from_trained_prior_samples = decoded_gt_pca.transform(
                    gen_fix_one_latent_from_trained_prior_samples)
                assert(reduced_gen_fix_one_latent_from_trained_prior_samples.shape == \
                       reduced_generated_prior_samples.shape)

                # - Calculate the relative recall
                # compare to the generated normal prior samples
                fix_one_vs_prior_generated_result = iprd.knn_precision_recall_features(
                    reduced_generated_prior_samples,
                    reduced_gen_fix_one_latent_from_trained_prior_samples,
                    nhood_sizes=nhood_sizes,
                    row_batch_size=500,
                    col_batch_size=100,
                    num_gpus=1)
                print(agg_fix_one_vs_prior_generated_result)
                print(fix_one_vs_prior_generated_result)
                agg_fix_one_vs_prior_generated_result = agg_recall_dict(
                    agg_fix_one_vs_prior_generated_result,
                    fix_one_vs_prior_generated_result, intervention)
                print(agg_fix_one_vs_prior_generated_result)

                # compare to the generated trained prior samples
                fix_one_vs_trained_prior_generated_result = iprd.knn_precision_recall_features(
                    reduced_generated_trained_prior_samples,
                    reduced_gen_fix_one_latent_from_trained_prior_samples,
                    nhood_sizes=nhood_sizes,
                    row_batch_size=500,
                    col_batch_size=100,
                    num_gpus=1)
                agg_fix_one_vs_trained_prior_generated_result = agg_recall_dict(
                    agg_fix_one_vs_trained_prior_generated_result,
                    fix_one_vs_trained_prior_generated_result, intervention)

                # compare to the decoded gt
                fix_one_vs_decoded_gt_result = iprd.knn_precision_recall_features(
                    reduced_decoded_gt_samples,
                    reduced_gen_fix_one_latent_from_trained_prior_samples,
                    nhood_sizes=nhood_sizes,
                    row_batch_size=500,
                    col_batch_size=100,
                    num_gpus=1)
                agg_fix_one_vs_decoded_gt_result = agg_recall_dict(
                    agg_fix_one_vs_decoded_gt_result,
                    fix_one_vs_decoded_gt_result, intervention)

                # --- vary one, fix the rest
                latent_variation = np.copy(fixed_trained_prior_samples_np[:,
                                                                          dim])
                print(latent_variation.shape, latent_variation)
                vary_one_latent_from_trained_prior_samples = np.copy(
                    fixed_trained_prior_samples_np[inter_id]).reshape(
                        1, latent_dim)
                vary_one_latent_from_trained_prior_samples = np.full(
                    latent_shape, vary_one_latent_from_trained_prior_samples)
                print(vary_one_latent_from_trained_prior_samples.shape)
                print(vary_one_latent_from_trained_prior_samples)

                vary_one_latent_from_trained_prior_samples[:,
                                                           dim] = latent_variation
                print(vary_one_latent_from_trained_prior_samples.shape)
                print(vary_one_latent_from_trained_prior_samples)

                # decode the samples and trasform them with the PCA
                gen_vary_one_latent_from_trained_prior_samples = decoder_fn(
                    vary_one_latent_from_trained_prior_samples).reshape(
                        num_recall_samples, -1)
                reduced_gen_vary_one_latent_from_trained_prior_samples = decoded_gt_pca.transform(
                    gen_vary_one_latent_from_trained_prior_samples)
                assert(reduced_gen_vary_one_latent_from_trained_prior_samples.shape == \
                       reduced_generated_prior_samples.shape)

                # - Calculate the recall
                # compare to the generated normal prior samples
                vary_one_vs_prior_generated_result = iprd.knn_precision_recall_features(
                    reduced_generated_prior_samples,
                    reduced_gen_vary_one_latent_from_trained_prior_samples,
                    nhood_sizes=nhood_sizes,
                    row_batch_size=500,
                    col_batch_size=100,
                    num_gpus=1)
                print(agg_vary_one_vs_prior_generated_result)
                print(vary_one_vs_prior_generated_result)
                agg_vary_one_vs_prior_generated_result = agg_recall_dict(
                    agg_vary_one_vs_prior_generated_result,
                    vary_one_vs_prior_generated_result, intervention)
                print(agg_vary_one_vs_prior_generated_result)

                # compare to the generated trained prior samples
                vary_one_vs_trained_prior_generated_result = iprd.knn_precision_recall_features(
                    reduced_generated_trained_prior_samples,
                    reduced_gen_vary_one_latent_from_trained_prior_samples,
                    nhood_sizes=nhood_sizes,
                    row_batch_size=500,
                    col_batch_size=100,
                    num_gpus=1)
                agg_vary_one_vs_trained_prior_generated_result = agg_recall_dict(
                    agg_vary_one_vs_trained_prior_generated_result,
                    vary_one_vs_trained_prior_generated_result, intervention)

                # compare to the decoded gt
                vary_one_vs_decoded_gt_result = iprd.knn_precision_recall_features(
                    reduced_decoded_gt_samples,
                    reduced_gen_vary_one_latent_from_trained_prior_samples,
                    nhood_sizes=nhood_sizes,
                    row_batch_size=500,
                    col_batch_size=100,
                    num_gpus=1)
                agg_vary_one_vs_decoded_gt_result = agg_recall_dict(
                    agg_vary_one_vs_decoded_gt_result,
                    vary_one_vs_decoded_gt_result, intervention)

            update_result_dict_with_agg(result_d, [
                str(dim) + '_fix_one_vs_prior_generated_',
                agg_fix_one_vs_prior_generated_result
            ], [
                str(dim) + '_fix_one_vs_trained_prior_generated_',
                agg_fix_one_vs_trained_prior_generated_result
            ], [
                str(dim) + '_fix_one_vs_decoded_gt_',
                agg_fix_one_vs_decoded_gt_result
            ], [
                str(dim) + '_vary_one_vs_prior_generated_',
                agg_vary_one_vs_prior_generated_result
            ], [
                str(dim) + '_vary_one_vs_trained_prior_generated_',
                agg_vary_one_vs_trained_prior_generated_result
            ], [
                str(dim) + '_vary_one_vs_decoded_gt_',
                agg_vary_one_vs_decoded_gt_result
            ])
    print(result_d)
    return [result_d]
Exemple #14
0
def run_experiment(agent,
                   environment,
                   start_iteration,
                   obs_stacker,
                   experiment_logger,
                   experiment_checkpointer,
                   checkpoint_dir,
                   num_iterations=200,
                   training_steps=5000,
                   logging_file_prefix='log',
                   log_every_n=100,
                   checkpoint_every_n=1):
    """Runs a full experiment, spread over multiple iterations."""
    tf.logging.info('Beginning training...')
    if num_iterations <= start_iteration:
        tf.logging.warning('num_iterations (%d) < start_iteration(%d)',
                           num_iterations, start_iteration)
        return
    """ 
  run_one_episode() updates the metrics
  metrics compute tf.summaries
  """

    # -----------

    # train_summary_writer = tf.compat.v2.summary.create_file_writer(checkpoint_dir+'_tensorboard/', flush_millis=1000)
    # train_summary_writer.set_as_default()

    # metric_avg_return = AverageReturnMetric()
    # env_steps = EnvironmentSteps()

    # observers = [metric_avg_return]
    # global_step = tf.compat.v1.train.get_or_create_global_step()
    # write graph to disk
    # with tf.compat.v2.summary.record_if(lambda: tf.math.equal(global_step % 5, 0)):
    #   summary_avg_return = tf.identity(metric_avg_return.tf_summaries(train_step=global_step))
    #   with tf.Session() as sess:
    #     initialize_uninitialized_variables(sess)
    #     sess.run(train_summary_writer.init())
    # -----------
    tf.reset_default_graph()
    sess = tf.Session()
    for iteration in range(start_iteration, num_iterations):
        # # -----------
        # global_step_val = sess.run(global_step)
        # # -----------
        # start_time = time.time()
        statistics = run_one_iteration(agent,
                                       environment,
                                       obs_stacker,
                                       iteration,
                                       training_steps,
                                       observers=None)
        # tf.logging.info('Iteration %d took %d seconds', iteration, time.time() - start_time)
        # start_time = time.time()
        log_experiment(experiment_logger, iteration, statistics,
                       logging_file_prefix, log_every_n)
        # tf.logging.info('Logging iteration %d took %d seconds', iteration, time.time() - start_time)

        # start_time = time.time()
        checkpoint_experiment(experiment_checkpointer, agent,
                              experiment_logger, iteration, checkpoint_dir,
                              checkpoint_every_n)
        summary_writer = tf.summary.FileWriter(checkpoint_dir + '/summary/')
        summary = tf.Summary()
        summary.value.add(tag='AverageReturn/EnvironmentSteps',
                          simple_value=statistics['average_return'][0])
        summary_writer.add_summary(summary, statistics['env_steps'][0])
        summary_writer.flush()
Exemple #15
0
    def initialize_session(self):
        """Initializes a tf Session."""
        if ENABLE_TF_OPTIMIZATIONS:
            self.sess = tf.Session()
        else:
            rewriter_config = rewriter_config_pb2.RewriterConfig(
                disable_model_pruning=True,
                constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
                arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                remapping=rewriter_config_pb2.RewriterConfig.OFF,
                shape_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                function_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
                loop_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                memory_optimization=rewriter_config_pb2.RewriterConfig.
                NO_MEM_OPT)
            graph_options = tf.GraphOptions(rewrite_options=rewriter_config)
            session_config = tf.ConfigProto(graph_options=graph_options)
            self.sess = tf.Session(config=session_config)

        # Restore or initialize the variables.
        self.sess.run(tf.global_variables_initializer())
        self.sess.run(tf.local_variables_initializer())
        if self.learner_config.checkpoint_for_eval:
            # Requested a specific checkpoint.
            self.saver.restore(self.sess,
                               self.learner_config.checkpoint_for_eval)
            tf.logging.info('Restored checkpoint: %s' %
                            self.learner_config.checkpoint_for_eval)
        else:
            # Continue from the latest checkpoint if one exists.
            # This handles fault-tolerance.
            latest_checkpoint = None
            if self.checkpoint_dir is not None:
                latest_checkpoint = tf.train.latest_checkpoint(
                    self.checkpoint_dir)
            if latest_checkpoint:
                self.saver.restore(self.sess, latest_checkpoint)
                tf.logging.info('Restored checkpoint: %s' % latest_checkpoint)
            else:
                tf.logging.info('No previous checkpoint.')
                self.sess.run(tf.global_variables_initializer())
                self.sess.run(tf.local_variables_initializer())

        # For episodic models, potentially use pretrained weights at the start of
        # training. If this happens it will overwrite the embedding weights, but
        # taking care to not restore the Adam parameters.
        if self.learner_config.pretrained_checkpoint and not self.sess.run(
                tf.train.get_global_step()):
            self.saver.restore(self.sess,
                               self.learner_config.pretrained_checkpoint)
            tf.logging.info('Restored checkpoint: %s' %
                            self.learner_config.pretrained_checkpoint)
            # We only want the embedding weights of the checkpoint we just restored.
            # So we re-initialize everything that's not an embedding weight. Also,
            # since this episodic finetuning procedure is a different optimization
            # problem than the original training of the baseline whose embedding
            # weights are re-used, we do not reload ADAM's variables and instead learn
            # them from scratch.
            vars_to_reinit, embedding_var_names, vars_to_reinit_names = [], [], []
            for var in tf.global_variables():
                if (any(keyword in var.name for keyword in EMBEDDING_KEYWORDS)
                        and 'adam' not in var.name.lower()):
                    embedding_var_names.append(var.name)
                    continue
                vars_to_reinit.append(var)
                vars_to_reinit_names.append(var.name)
            tf.logging.info('Initializing all variables except for %s.' %
                            embedding_var_names)
            self.sess.run(tf.variables_initializer(vars_to_reinit))
            tf.logging.info('Re-initialized vars %s.' % vars_to_reinit_names)
    def __init__(self,
                 base_dir,
                 agent_creator,
                 create_environment_fn=create_atari_environment,
                 game_name=None,
                 checkpoint_file_prefix='ckpt',
                 logging_file_prefix='log',
                 log_every_n=1,
                 num_iterations=200,
                 training_steps=250000,
                 evaluation_steps=125000,
                 max_steps_per_episode=27000):
        """Initialize the Runner object in charge of running a full experiment.

    Args:
      base_dir: str, the base directory to host all required sub-directories.
      agent_creator: A function that takes as args a Tensorflow session and an
        Atari 2600 Gym environment, and returns an agent.
      create_environment_fn: A function which receives a game name and creates
        an Atari 2600 Gym environment.
      game_name: str, name of the Atari 2600 domain to run.
      sticky_actions: bool, whether to enable sticky actions in the environment.
      checkpoint_file_prefix: str, the prefix to use for checkpoint files.
      logging_file_prefix: str, prefix to use for the log files.
      log_every_n: int, the frequency for writing logs.
      num_iterations: int, the iteration number threshold (must be greater than
        start_iteration).
      training_steps: int, the number of training steps to perform.
      evaluation_steps: int, the number of evaluation steps to perform.
      max_steps_per_episode: int, maximum number of steps after which an episode
        terminates.

    This constructor will take the following actions:
    - Initialize an environment.
    - Initialize a `tf.Session`.
    - Initialize a logger.
    - Initialize an agent.
    - Reload from the latest checkpoint, if available, and initialize the
      Checkpointer object.
    """
        assert base_dir is not None
        self._logging_file_prefix = logging_file_prefix
        self._log_every_n = log_every_n
        self._num_iterations = num_iterations
        self._training_steps = training_steps
        self._evaluation_steps = evaluation_steps
        self._max_steps_per_episode = max_steps_per_episode
        self._base_dir = base_dir
        self._create_directories()
        self._summary_writer = tf.summary.FileWriter(self._base_dir)

        self._environment = create_environment_fn()
        # Set up a session and initialize variables.
        self._sess = tf.Session(
            '', config=tf.ConfigProto(allow_soft_placement=True))
        self._agent = agent_creator(self._sess,
                                    self._environment,
                                    summary_writer=self._summary_writer)
        self._summary_writer.add_graph(graph=tf.get_default_graph())
        self._sess.run(tf.global_variables_initializer())

        self._summary_helper = SummaryHelper(self._summary_writer)

        self._initialize_checkpointer_and_maybe_resume(checkpoint_file_prefix)
        self._steps_done = 0

        self._total_timer = None