예제 #1
0
    def fake_train(self):
        experience = (ssbm.SimpleStateAction *
                      self.learner.config.experience_length)()
        experience = ssbm.prepareStateActions(experience)
        experience['initial'] = util.deepMap(np.zeros,
                                             self.learner.core.hidden_size)

        experiences = [experience] * self.batch_size

        # For more advanced usage, user can control the tracing steps and
        # dumping steps. User can also run online profiling during training.
        #
        # Create options to profile time/memory as well as parameters.
        builder = tf.profiler.ProfileOptionBuilder
        opts = builder(builder.time_and_memory()).order_by('micros').build()
        opts2 = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter(
        )

        # Collect traces of steps 10~20, dump the whole profile (with traces of
        # step 10~20) at step 20. The dumped profile can be used for further profiling
        # with command line interface or Web UI.
        with tf.contrib.tfprof.ProfileContext('/tmp/train_dir',
                                              trace_steps=range(10, 20),
                                              dump_steps=[20]) as pctx:
            # Run online profiling with 'op' view and 'opts' options at step 15, 18, 20.
            pctx.add_auto_profiling('op', opts, [15, 18, 20])
            # Run online profiling with 'scope' view and 'opts2' options at step 20.
            pctx.add_auto_profiling('scope', opts2, [20])
            # High level API, such as slim, Estimator, etc.
            count = 0
            while count != self.sweep_limit:
                self.learner.train(experiences, self.batch_steps)
                count += 1
예제 #2
0
  def train(self):
    shape = self.experiences['action'].shape
    data_size = shape[0]
    
    batches = []
    for i in range(0, data_size, self.batch_size):
      batches.append(util.deepMap(lambda t: t[i:i+self.batch_size], self.experiences))
  
    valid_batches = batches[:self.valid_batches]
    train_batches = batches[self.valid_batches:]
    
    for epoch in range(self.epochs):
      print("Epoch", epoch)
      start_time = time.time()
      
      for batch in train_batches:
        self.rl.train(batch, log=False, zipped=True)
      
      print(time.time() - start_time) 
      
      for batch in valid_batches:
        self.rl.train(batch, train=False, zipped=True)

      self.rl.save()

      import resource
      print('Memory usage: %s (kb)' % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
예제 #3
0
파일: tf_lib.py 프로젝트: jCrompton/phillip
def scan(f, inputs, initial_state, axis=0):
  inputs = util.deepIter(util.deepMap(lambda t: iter(tf.unstack(t, axis=axis)), inputs))
  outputs = []
  output = initial_state
  for input_ in inputs:
    output = f(output, input_)
    outputs.append(output)
  return util.deepZipWith(lambda *ts: tf.stack(ts, axis=axis), *outputs)