Esempio n. 1
0
 def get_samples(self, env):
     samples = [
         agent_factory.Sample(traj=collect_traj_for_program(env, program),
                              prob=prob)
         for program, prob in self.sketch_programs
     ]
     return samples
Esempio n. 2
0
 def weight_samples(samples):
     if FLAGS.use_replay_prob_as_weight:
         new_samples = []
         for sample in samples:
             name = sample.traj.env_name
             if name in replay_buffer.prob_sum_dict:
                 replay_prob = max(
                     replay_buffer.prob_sum_dict[name],
                     FLAGS.min_replay_weight)
             else:
                 replay_prob = 0.0
             scale = 1.0 - replay_prob
             new_samples.append(
                 agent_factory.Sample(traj=sample.traj,
                                      prob=sample.prob * scale))
     else:
         new_samples = agent_factory.scale_probs(
             samples, 1 - FLAGS.fixed_replay_weight)
     return new_samples
Esempio n. 3
0
    def annotate_example_exploration(self, envs):
        # first create envs from jsons
        envs = json_to_envs(envs)

        if FLAGS.executor == 'wtq':
            oracle_envs, oracle_trajs = get_wtq_annotations(envs)
        else:
            oracle_envs, oracle_trajs = get_env_trajs(envs)

        oracle_env_programs = [(env.name,
                                traj_to_program(traj, envs[0].de_vocab))
                               for env, traj in zip(oracle_envs, oracle_trajs)]
        env_name_program_dict = dict(oracle_env_programs)
        env_name_annotation_dict = dict()

        for env in envs:
            program = env_name_program_dict.get(env.name, None)

            if program is None:
                env_name_annotation_dict[env.name] = None
            else:
                sketch = get_sketch(program)
                explored_programs = SketchAnnotator.explore_sketch_programs(
                    env, sketch)
                oracle_trajs = [
                    collect_traj_for_program(env, explored_program)
                    for explored_program in explored_programs
                ]
                samples = [
                    agent_factory.Sample(oracle_traj, prob=1.0)
                    for oracle_traj in oracle_trajs
                ]

                annotation = SketchAnnotation(env, sketch, samples)
                env_name_annotation_dict[env.name] = annotation

        return env_name_annotation_dict
  def run(self):
    agent, envs = init_experiment(
      [get_train_shard_path(i) for i in self.shard_ids],
      use_gpu=FLAGS.actor_use_gpu,
      gpu_id=str(self.actor_id + FLAGS.actor_gpu_start_id))
    graph = agent.model.graph
    current_ckpt = get_init_model_path()

    env_dict = dict([(env.name, env) for env in envs])
    replay_buffer = agent_factory.AllGoodReplayBuffer(agent, envs[0].de_vocab)

    # Load saved programs to warm start the replay buffer. 
    if FLAGS.load_saved_programs:
      load_programs(
        envs, replay_buffer, FLAGS.saved_program_file)

    i = 0
    while True:
      # Create the logging files. 
      if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
        f_replay = codecs.open(os.path.join(
          get_experiment_dir(), 'replay_samples_{}_{}.txt'.format(self.name, i)),
                               'w', encoding='utf-8')
        f_policy = codecs.open(os.path.join(
          get_experiment_dir(), 'policy_samples_{}_{}.txt'.format(self.name, i)),
                               'w', encoding='utf-8')
        f_train = codecs.open(os.path.join(
          get_experiment_dir(), 'train_samples_{}_{}.txt'.format(self.name, i)),
                              'w', encoding='utf-8')

      n_train_samples = 0
      if FLAGS.use_replay_samples_in_train:
        n_train_samples += FLAGS.n_replay_samples

      if FLAGS.use_policy_samples_in_train and FLAGS.use_nonreplay_samples_in_train:
        raise ValueError(
          'Cannot use both on-policy samples and nonreplay samples for training!')
        
      if FLAGS.use_policy_samples_in_train or FLAGS.use_nonreplay_samples_in_train:
        # Note that nonreplay samples are drawn by rejection
        # sampling from on-policy samples.
        n_train_samples += FLAGS.n_policy_samples

      # Make sure that all the samples from the env batch
      # fits into one batch for training.
      if FLAGS.batch_size < n_train_samples:
        raise ValueError(
            'One batch have to at least contain samples from one environment.')

      env_batch_size = FLAGS.batch_size / n_train_samples
      
      env_iterator = data_utils.BatchIterator(
        dict(envs=envs), shuffle=True,
        batch_size=env_batch_size)

      for j, batch_dict in enumerate(env_iterator):
        batch_envs = batch_dict['envs']
        tf.logging.info('=' * 50)
        tf.logging.info('{} iteration {}, batch {}: {} envs'.format(
            self.name, i, j, len(batch_envs)))
  
        t1 = time.time()
        # Generate samples with cache and save to replay buffer.
        t3 = time.time()
        n_explore = 0
        for _ in xrange(FLAGS.n_explore_samples):
          explore_samples = agent.generate_samples(
            batch_envs, n_samples=1, use_cache=FLAGS.use_cache,
            greedy=FLAGS.greedy_exploration)
          replay_buffer.save(explore_samples)
          n_explore += len(explore_samples)

        if FLAGS.n_extra_explore_for_hard > 0:
          hard_envs = [env for env in batch_envs
                       if not replay_buffer.has_found_solution(env.name)]
          if hard_envs:
            for _ in xrange(FLAGS.n_extra_explore_for_hard):
              explore_samples = agent.generate_samples(
                hard_envs, n_samples=1, use_cache=FLAGS.use_cache,
                greedy=FLAGS.greedy_exploration)
              replay_buffer.save(explore_samples)
              n_explore += len(explore_samples)

        t4 = time.time()
        tf.logging.info('{} sec used generating {} exploration samples.'.format(
          t4 - t3, n_explore))

        tf.logging.info('{} samples saved in the replay buffer.'.format(
          replay_buffer.size))
        
        t3 = time.time()
        replay_samples = replay_buffer.replay(
          batch_envs, FLAGS.n_replay_samples,
          use_top_k=FLAGS.use_top_k_replay_samples,
          agent=None if FLAGS.random_replay_samples else agent,
          truncate_at_n=FLAGS.truncate_replay_buffer_at_n)
        t4 = time.time()
        tf.logging.info('{} sec used selecting {} replay samples.'.format(
          t4 - t3, len(replay_samples)))
          
        t3 = time.time()
        if FLAGS.use_top_k_policy_samples:
          if FLAGS.n_policy_samples == 1:
            policy_samples = agent.generate_samples(
              batch_envs, n_samples=FLAGS.n_policy_samples,
              greedy=True)
          else:
            policy_samples = agent.beam_search(
              batch_envs, beam_size=FLAGS.n_policy_samples)
        else:
          policy_samples = agent.generate_samples(
            batch_envs, n_samples=FLAGS.n_policy_samples,
            greedy=False)
        t4 = time.time()
        tf.logging.info('{} sec used generating {} on-policy samples'.format(
          t4-t3, len(policy_samples)))

        t2 = time.time()
        tf.logging.info(
          ('{} sec used generating replay and on-policy samples,'
           ' {} iteration {}, batch {}: {} envs').format(
            t2-t1, self.name, i, j, len(batch_envs)))

        t1 = time.time()
        self.eval_queue.put((policy_samples, len(batch_envs)))
        self.replay_queue.put((replay_samples, len(batch_envs)))

        assert (FLAGS.fixed_replay_weight >= 0.0 and FLAGS.fixed_replay_weight <= 1.0)

        if FLAGS.use_replay_prob_as_weight:
          new_samples = []
          for sample in replay_samples:
            name = sample.traj.env_name
            if name in replay_buffer.prob_sum_dict:
              replay_prob = max(
                replay_buffer.prob_sum_dict[name], FLAGS.min_replay_weight)
            else:
              replay_prob = 0.0
            scale = replay_prob
            new_samples.append(
              agent_factory.Sample(
                traj=sample.traj,
                prob=sample.prob * scale))
          replay_samples = new_samples
        else:
          replay_samples = agent_factory.scale_probs(
            replay_samples, FLAGS.fixed_replay_weight)

        replay_samples = sorted(
          replay_samples, key=lambda x: x.traj.env_name)

        policy_samples = sorted(
          policy_samples, key=lambda x: x.traj.env_name)

        if FLAGS.use_nonreplay_samples_in_train:
          nonreplay_samples = []
          for sample in policy_samples:
            if not replay_buffer.contain(sample.traj):
              nonreplay_samples.append(sample)

        replay_buffer.save(policy_samples)

        def weight_samples(samples):
          if FLAGS.use_replay_prob_as_weight:
            new_samples = []
            for sample in samples:
              name = sample.traj.env_name
              if name in replay_buffer.prob_sum_dict:
                replay_prob = max(
                  replay_buffer.prob_sum_dict[name],
                  FLAGS.min_replay_weight)
              else:
                replay_prob = 0.0
              scale = 1.0 - replay_prob
              new_samples.append(
                agent_factory.Sample(
                  traj=sample.traj,
                  prob=sample.prob * scale))
          else:
            new_samples = agent_factory.scale_probs(
              samples, 1 - FLAGS.fixed_replay_weight)
          return new_samples

        train_samples = []
        if FLAGS.use_replay_samples_in_train:
          if FLAGS.use_trainer_prob:
            replay_samples = [
              sample._replace(prob=None) for sample in replay_samples]
          train_samples += replay_samples

        if FLAGS.use_policy_samples_in_train:
          train_samples += weight_samples(policy_samples)

        if FLAGS.use_nonreplay_samples_in_train:
          train_samples += weight_samples(nonreplay_samples)
        
        train_samples = sorted(train_samples, key=lambda x: x.traj.env_name)
        tf.logging.info('{} train samples'.format(len(train_samples)))

        if FLAGS.use_importance_sampling:
          step_logprobs = agent.compute_step_logprobs(
            [s.traj for s in train_samples])
        else:
          step_logprobs = None

        if FLAGS.use_replay_prob_as_weight:
          n_clip = 0
          for env in batch_envs:
            name = env.name
            if (name in replay_buffer.prob_sum_dict and
                replay_buffer.prob_sum_dict[name] < FLAGS.min_replay_weight):
              n_clip += 1
          clip_frac = float(n_clip) / len(batch_envs)
        else:
          clip_frac = 0.0
  
        self.train_queue.put((train_samples, step_logprobs, clip_frac))
        t2 = time.time()
        tf.logging.info(
          ('{} sec used preparing and enqueuing samples, {}'
           ' iteration {}, batch {}: {} envs').format(
             t2-t1, self.name, i, j, len(batch_envs)))

        t1 = time.time()
        # Wait for a ckpt that still exist or it is the same
        # ckpt (no need to load anything).
        while True:
          new_ckpt = self.ckpt_queue.get()
          new_ckpt_file = new_ckpt + '.meta'
          if new_ckpt == current_ckpt or tf.gfile.Exists(new_ckpt_file):
            break
        t2 = time.time()
        tf.logging.info('{} sec waiting {} iteration {}, batch {}'.format(
          t2-t1, self.name, i, j))

        if new_ckpt != current_ckpt:
          # If the ckpt is not the same, then restore the new
          # ckpt.
          tf.logging.info('{} loading ckpt {}'.format(self.name, new_ckpt))
          t1 = time.time()          
          graph.restore(new_ckpt)
          t2 = time.time()
          tf.logging.info('{} sec used {} restoring ckpt {}'.format(
            t2-t1, self.name, new_ckpt))
          current_ckpt = new_ckpt

        if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
          f_replay.write(show_samples(replay_samples, envs[0].de_vocab, env_dict))
          f_policy.write(show_samples(policy_samples, envs[0].de_vocab, env_dict))
          f_train.write(show_samples(train_samples, envs[0].de_vocab, env_dict))

      if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
        f_replay.close()
        f_policy.close()
        f_train.close()

      if agent.model.get_global_step() >= FLAGS.n_steps:
        tf.logging.info('{} finished'.format(self.name))
        return
      i += 1
Esempio n. 5
0
    def decode_sketch_program(envs):
        # first create the real envs from the jsons and add constraints to them
        envs = json_to_envs(envs)
        env_name_dict = dict(map(lambda env: (env.name, env), envs))

        if FLAGS.executor == 'wtq':
            oracle_envs, oracle_trajs = get_wtq_annotations(envs)
        else:
            oracle_envs, oracle_trajs = get_env_trajs(envs)

        env_sketch_dict = dict([
            (env.name, get_sketch(traj_to_program(traj, envs[0].de_vocab)))
            for env, traj in zip(oracle_envs, oracle_trajs)
        ])
        for env in envs:
            sketch = env_sketch_dict.get(env.name, None)
            if sketch is not None:
                env.set_sketch_constraint(sketch[:])

        # create the agent
        graph_config = get_saved_graph_config()
        graph_config['use_gpu'] = False
        graph_config['gpu_id'] = '0'
        init_model_path = get_init_model_path()
        agent = create_agent(graph_config, init_model_path)

        # beam search
        beam_samples = []
        env_iterator = data_utils.BatchIterator(
            dict(envs=envs), shuffle=False, batch_size=FLAGS.eval_batch_size)
        for j, batch_dict in tqdm(enumerate(env_iterator)):
            batch_envs = batch_dict['envs']
            beam_samples += agent.beam_search(batch_envs, beam_size=50)

        # group the samples into beams (because the impl is so bad)
        env_beam_dict = dict()
        for sample in beam_samples:
            env_beam_dict[sample.traj.env_name] = env_beam_dict.get(
                sample.traj.env_name, []) + [sample]

        # get the trajs with 1.0 reward for each example and re-weight the prob
        env_name_annotation_dict = dict()
        for env_name, env in env_name_dict.iteritems():
            beam = env_beam_dict.get(env_name, [])
            success_beam = filter(lambda x: x.traj.rewards[-1] == 1.0, beam)
            if len(success_beam) > 0:
                # retrieve the sketch result from previous steps
                sketch = env_sketch_dict.get(env_name, None)

                if sketch is None:
                    env_name_annotation_dict[env_name] = None
                else:
                    # re-weight the examples in the beam
                    prob_sum = sum(
                        map(lambda sample: sample.prob, success_beam))
                    success_beam = map(
                        lambda sample: agent_factory.Sample(
                            traj=sample.traj, prob=sample.prob / prob_sum),
                        success_beam)
                    if len(success_beam) > 10:
                        success_beam = sorted(success_beam,
                                              key=lambda sample: sample.prob,
                                              reverse=True)
                        success_beam = success_beam[:10]

                    annotation = SketchAnnotation(env, sketch, success_beam)
                    env_name_annotation_dict[env_name] = annotation
            else:
                env_name_annotation_dict[env_name] = None

        return env_name_annotation_dict
Esempio n. 6
0
 def get_samples(self, env):
     sample_traj = collect_traj_for_program(env, self.oracle_program)
     return [agent_factory.Sample(sample_traj, prob=1.0)]
Esempio n. 7
0
    def run(self):
        agent, all_envs = init_experiment(
            [get_train_shard_path(i) for i in self.shard_ids],
            use_gpu=FLAGS.actor_use_gpu,
            gpu_id=str(self.actor_id + FLAGS.actor_gpu_start_id))
        graph = agent.model.graph
        current_ckpt = get_init_model_path()

        # obtain the oracle of the examples and delete the examples that can not obtain oracle
        envs, env_trajs = get_env_trajs(all_envs)

        # build a dict to store the oracle trajs
        env_oracle_trajs_dict = dict()
        for env, env_traj in zip(envs, env_trajs):
            env_oracle_trajs_dict[env.name] = env_traj
        tf.logging.info(
            'Found oracle for {} envs out of total of {} for actor_{}'.format(
                len(all_envs), len(envs), self.actor_id))

        i = 0
        while True:
            n_train_samples = 0

            n_train_samples += 1

            # Make sure that all the samples from the env batch
            # fits into one batch for training.
            if FLAGS.batch_size < n_train_samples:
                raise ValueError(
                    'One batch have to at least contain samples from one environment.'
                )

            env_batch_size = FLAGS.batch_size / n_train_samples

            env_iterator = data_utils.BatchIterator(dict(envs=envs),
                                                    shuffle=True,
                                                    batch_size=env_batch_size)

            for j, batch_dict in enumerate(env_iterator):
                batch_envs = batch_dict['envs']
                tf.logging.info('=' * 50)
                tf.logging.info('{} iteration {}, batch {}: {} envs'.format(
                    self.name, i, j, len(batch_envs)))

                t1 = time.time()

                # get the oracle samples
                oracle_samples = []
                for batch_env in batch_envs:
                    oracle_samples.append(
                        agent_factory.Sample(
                            traj=env_oracle_trajs_dict[batch_env.name],
                            prob=1.0))

                self.eval_queue.put((oracle_samples, len(batch_envs)))
                self.replay_queue.put((oracle_samples, len(batch_envs)))

                assert (FLAGS.fixed_replay_weight >= 0.0
                        and FLAGS.fixed_replay_weight <= 1.0)

                train_samples = []

                train_samples += oracle_samples

                train_samples = sorted(train_samples,
                                       key=lambda x: x.traj.env_name)
                tf.logging.info('{} train samples'.format(len(train_samples)))

                if FLAGS.use_importance_sampling:
                    step_logprobs = agent.compute_step_logprobs(
                        [s.traj for s in train_samples])
                else:
                    step_logprobs = None

                # TODO: the clip_factor may be wrong
                self.train_queue.put((train_samples, step_logprobs, 0.0))
                t2 = time.time()
                tf.logging.info(
                    ('{} sec used preparing and enqueuing samples, {}'
                     ' iteration {}, batch {}: {} envs').format(
                         t2 - t1, self.name, i, j, len(batch_envs)))

                t1 = time.time()
                # Wait for a ckpt that still exist or it is the same
                # ckpt (no need to load anything).
                while True:
                    new_ckpt = self.ckpt_queue.get()
                    new_ckpt_file = new_ckpt + '.meta'
                    if new_ckpt == current_ckpt or tf.gfile.Exists(
                            new_ckpt_file):
                        break
                t2 = time.time()
                tf.logging.info(
                    '{} sec waiting {} iteration {}, batch {}'.format(
                        t2 - t1, self.name, i, j))

                if new_ckpt != current_ckpt:
                    # If the ckpt is not the same, then restore the new
                    # ckpt.
                    tf.logging.info('{} loading ckpt {}'.format(
                        self.name, new_ckpt))
                    t1 = time.time()
                    graph.restore(new_ckpt)
                    t2 = time.time()
                    tf.logging.info('{} sec used {} restoring ckpt {}'.format(
                        t2 - t1, self.name, new_ckpt))
                    current_ckpt = new_ckpt

            if agent.model.get_global_step() >= FLAGS.n_steps:
                tf.logging.info('{} finished'.format(self.name))
                return
            i += 1