Ejemplo n.º 1
0
    def eval(envs):
        # first create the real envs from the jsons
        envs = json_to_envs(envs)

        # create the agent
        graph_config = get_saved_graph_config()
        graph_config['use_gpu'] = False
        graph_config['gpu_id'] = '0'
        init_model_path = get_init_model_path()
        agent = create_agent(graph_config, init_model_path)

        # greedy decode
        greedy_samples = []
        env_iterator = data_utils.BatchIterator(
            dict(envs=envs), shuffle=False, batch_size=FLAGS.eval_batch_size)
        for j, batch_dict in tqdm(enumerate(env_iterator)):
            batch_envs = batch_dict['envs']
            greedy_samples += agent.generate_samples(batch_envs,
                                                     n_samples=1,
                                                     greedy=True,
                                                     use_cache=False,
                                                     filter_error=False)

        env_sample_list = zip(envs, greedy_samples)
        failed_env_sample_list = filter(lambda x: x[1].traj.rewards[-1] < 1.0,
                                        env_sample_list)

        failed_envs = [env_sample[0] for env_sample in failed_env_sample_list]
        failed_envs = list(np.random.permutation(failed_envs))
        failed_envs = [(1.0, env.name) for env in failed_envs]

        return failed_envs
Ejemplo n.º 2
0
    def eval(envs):
        # first create the real envs from the jsons
        envs = json_to_envs(envs)

        # create the agent
        graph_config = get_saved_graph_config()
        graph_config['use_gpu'] = False
        graph_config['gpu_id'] = '0'
        init_model_path = get_init_model_path()
        agent = create_agent(graph_config, init_model_path)

        # greedy decode
        beam_samples = []
        env_iterator = data_utils.BatchIterator(
            dict(envs=envs), shuffle=False, batch_size=FLAGS.eval_batch_size)
        for j, batch_dict in tqdm(enumerate(env_iterator)):
            batch_envs = batch_dict['envs']
            beam_samples += agent.beam_search(batch_envs, beam_size=5)

        # group the samples into beams (because the impl is so bad)
        env_beam_dict = dict()
        for sample in beam_samples:
            env_beam_dict[sample.traj.env_name] = env_beam_dict.get(
                sample.traj.env_name, []) + [sample]

        # get the highest confidence from the beam of each example
        conf_envs = [(1.0 - max(map(lambda x: x.prob, beam)), env_name)
                     for env_name, beam in env_beam_dict.items()]

        return conf_envs
Ejemplo n.º 3
0
    def eval(envs):
        # first create the real envs from the jsons
        envs = json_to_envs(envs)

        # create the agent
        graph_config = get_saved_graph_config()
        graph_config['use_gpu'] = False
        graph_config['gpu_id'] = '0'
        init_model_path = get_init_model_path()
        agent = create_agent(graph_config, init_model_path)

        # greedy decode
        beam_samples = []
        env_iterator = data_utils.BatchIterator(
            dict(envs=envs), shuffle=False, batch_size=FLAGS.eval_batch_size)
        for j, batch_dict in tqdm(enumerate(env_iterator)):
            batch_envs = batch_dict['envs']
            beam_samples += agent.beam_search(batch_envs, beam_size=5)

        # group the samples into beams (because the impl is so bad)
        env_beam_dict = dict()
        for sample in beam_samples:
            env_beam_dict[sample.traj.env_name] = env_beam_dict.get(
                sample.traj.env_name, []) + [sample]

        # get the top hyps and find those failed ones
        top_hyps = map(
            lambda (env_name, beam):
            (env_name,
             reduce(lambda s1, s2: s1 if s1.prob > s2.prob else s2, beam)),
            env_beam_dict.items())
        failed_top_hyps = filter(
            lambda (env_name, sample): sample.traj.rewards[-1] == 0.0,
            top_hyps)
        conf_envs = map(lambda (env_name, sample): (sample.prob, env_name),
                        failed_top_hyps)

        return conf_envs
Ejemplo n.º 4
0
def run_experiment():
  print('=' * 100)
  if FLAGS.show_log:
    tf.logging.set_verbosity(tf.logging.INFO)

  experiment_dir = get_experiment_dir()
  if tf.gfile.Exists(experiment_dir):
    tf.gfile.DeleteRecursively(experiment_dir)
  tf.gfile.MkDir(experiment_dir)

  experiment_config = create_experiment_config()
  
  with open(os.path.join(
      get_experiment_dir(), 'experiment_config.json'), 'w') as f:
    json.dump(experiment_config, f)

  ckpt_queue = multiprocessing.Queue()
  train_queue = multiprocessing.Queue()
  eval_queue = multiprocessing.Queue()
  replay_queue = multiprocessing.Queue()

  run_type = 'evaluation' if FLAGS.eval_only else 'experiment'
  print('Start {} {}.'.format(run_type, FLAGS.experiment_name))
  print('The data of this {} is saved in {}.'.format(run_type, experiment_dir))

  if FLAGS.eval_only:
    print('Start evaluating the best model {}.'.format(get_init_model_path()))
  else:
    print('Start distributed training.')

  print('Start evaluator.')
  if FLAGS.eval_on_train:
    print('Evaluating on the training set...')
    evaluator = Evaluator(
      'Evaluator',
      [get_train_shard_path(i) for i in range(FLAGS.shard_start, FLAGS.shard_end)])
  else:
    evaluator = Evaluator(
      'Evaluator',
      [FLAGS.eval_file if FLAGS.eval_only else FLAGS.dev_file])
  evaluator.start()

  if not FLAGS.eval_only:
    actors = []
    actor_shard_dict = dict([(i, []) for i in range(FLAGS.n_actors)])
    for i in xrange(FLAGS.shard_start, FLAGS.shard_end):
      actor_num = i % FLAGS.n_actors
      actor_shard_dict[actor_num].append(i)

    if FLAGS.use_active_learning:
        print('########## use active actor ##########')
        envs = load_envs_as_json([get_train_shard_path(i) for i in range(FLAGS.shard_start, FLAGS.shard_end)])
        al_dict = active_learning(envs, FLAGS.active_picker_class, FLAGS.active_annotator_class, FLAGS.al_budget_n)

    for k in xrange(FLAGS.n_actors):
      name = 'actor_{}'.format(k)

      if FLAGS.use_oracle_examples_in_train:
        actor = OracleActor(name, k, actor_shard_dict[k], ckpt_queue, train_queue, eval_queue, replay_queue)
      elif FLAGS.use_active_learning:
        actor = ActiveActor(name, k, actor_shard_dict[k], ckpt_queue, train_queue, eval_queue, replay_queue, al_dict)
      else:
        actor = Actor(name, k, actor_shard_dict[k], ckpt_queue, train_queue, eval_queue, replay_queue)
      actors.append(actor)
      actor.start()
    print('Start {} actors.'.format(len(actors)))

    print('Start learner.')
    learner = Learner(
      'Learner', [FLAGS.dev_file], ckpt_queue,
      train_queue, eval_queue, replay_queue)
    learner.start()
    print('Use tensorboard to monitor the training progress (see README).')
    for actor in actors:
      actor.join()
    print('All actors finished')
    # Send learner the signal that all the actors have finished.
    train_queue.put(None)
    eval_queue.put(None)
    replay_queue.put(None)
    learner.join()
    print('Learner finished')

  evaluator.join()
  print('Evaluator finished')
  print('=' * 100)
Ejemplo n.º 5
0
  def run(self):
    # Writers to record training and replay information.
    train_writer = tf.summary.FileWriter(os.path.join(
      get_experiment_dir(), FLAGS.tb_log_dir, 'train'))
    replay_writer = tf.summary.FileWriter(os.path.join(
      get_experiment_dir(), FLAGS.tb_log_dir, 'replay'))
    saved_model_dir = os.path.join(get_experiment_dir(), FLAGS.saved_model_dir)
    if not tf.gfile.Exists(saved_model_dir):
      tf.gfile.MkDir(saved_model_dir)
    agent, envs = init_experiment(self.fns, FLAGS.train_use_gpu, gpu_id=str(FLAGS.train_gpu_id))
    agent.train_writer = train_writer
    graph = agent.model.graph
    current_ckpt = get_init_model_path()

    i = 0
    n_save = 0
    while True:
      tf.logging.info('Start train step {}'.format(i))
      t1 = time.time()
      train_samples, behaviour_logprobs, clip_frac  = self.train_queue.get()
      eval_samples, eval_true_n = self.eval_queue.get()
      replay_samples, replay_true_n = self.replay_queue.get()
      t2 = time.time()
      tf.logging.info('{} secs used waiting in train step {}.'.format(
        t2-t1, i))
      t1 = time.time()
      n_train_samples = 0
      if FLAGS.use_replay_samples_in_train:
        n_train_samples += FLAGS.n_replay_samples
      if FLAGS.use_policy_samples_in_train and FLAGS.use_nonreplay_samples_in_train:
        raise ValueError(
          'Cannot use both on-policy samples and nonreplay samples for training!')
      if FLAGS.use_policy_samples_in_train:
        n_train_samples += FLAGS.n_policy_samples

      if train_samples:
        if FLAGS.use_trainer_prob:
          train_samples = agent.update_replay_prob(
            train_samples, min_replay_weight=FLAGS.min_replay_weight)
        for _ in xrange(FLAGS.n_opt_step):
          agent.train(
            train_samples,
            parameters=dict(en_rnn_dropout=FLAGS.dropout,rnn_dropout=FLAGS.dropout),
            use_baseline=FLAGS.use_baseline,
            min_prob=FLAGS.min_prob,
            scale=n_train_samples,
            behaviour_logprobs=behaviour_logprobs,
            use_importance_sampling=FLAGS.use_importance_sampling,
            ppo_epsilon=FLAGS.ppo_epsilon,
            de_vocab=envs[0].de_vocab,
            debug=FLAGS.debug)

      avg_return, avg_len = agent.evaluate(
        eval_samples, writer=train_writer, true_n=eval_true_n,
        clip_frac=clip_frac)
      tf.logging.info('train: avg return: {}, avg length: {}.'.format(
        avg_return, avg_len))
      avg_return, avg_len = agent.evaluate(
        replay_samples, writer=replay_writer, true_n=replay_true_n)
      tf.logging.info('replay: avg return: {}, avg length: {}.'.format(avg_return, avg_len))
      t2 = time.time()
      tf.logging.info('{} sec used in training train iteration {}, {} samples.'.format(
        t2-t1, i, len(train_samples)))
      i += 1
      if i % self.save_every_n == 0:
        t1 = time.time()
        current_ckpt = graph.save(
          os.path.join(saved_model_dir, 'model'),
          agent.model.get_global_step())
        t2 = time.time()
        tf.logging.info('{} sec used saving model to {}, train iteration {}.'.format(
          t2-t1, current_ckpt, i))
        self.ckpt_queue.put(current_ckpt)
        if agent.model.get_global_step() >= FLAGS.n_steps:
          t1 = time.time()
          while True:
            train_data = self.train_queue.get()
            _ = self.eval_queue.get()
            _ = self.replay_queue.get()
            self.ckpt_queue.put(current_ckpt)
            # Get the signal that all the actors have
            # finished.
            if train_data is None:
              t2 = time.time()
              tf.logging.info('{} finished, {} sec used waiting for actors'.format(
                self.name, t2-t1))
              return
      else:
        # After training on one set of samples, put one ckpt
        # back so that the ckpt queue is always full.
        self.ckpt_queue.put(current_ckpt)
Ejemplo n.º 6
0
  def run(self):
    agent, envs = init_experiment(self.fns, FLAGS.eval_use_gpu, gpu_id=str(FLAGS.eval_gpu_id))
    for env in envs:
      env.punish_extra_work = False
    graph = agent.model.graph
    dev_writer = tf.summary.FileWriter(os.path.join(
      get_experiment_dir(), FLAGS.tb_log_dir, 'dev'))
    best_dev_avg_return = 0.0
    best_model_path = ''
    best_model_dir = os.path.join(get_experiment_dir(), FLAGS.best_model_dir)
    if not tf.gfile.Exists(best_model_dir):
      tf.gfile.MkDir(best_model_dir)
    i = 0
    current_ckpt = get_init_model_path()
    env_dict = dict([(env.name, env) for env in envs])
    while True:
      t1 = time.time()
      tf.logging.info('dev: iteration {}, evaluating {}.'.format(i, current_ckpt))

      dev_avg_return, dev_samples, dev_samples_in_beam = beam_search_eval(
        agent, envs, writer=dev_writer)
      
      if dev_avg_return > best_dev_avg_return:
        best_model_path = graph.save(
          os.path.join(best_model_dir, 'model'),
          agent.model.get_global_step())
        best_dev_avg_return = dev_avg_return
        tf.logging.info('New best dev avg returns is {}'.format(best_dev_avg_return))
        tf.logging.info('New best model is saved in {}'.format(best_model_path))
        with open(os.path.join(get_experiment_dir(), 'best_model_info.json'), 'w') as f:
          result = {'best_model_path': compress_home_path(best_model_path)}
          if FLAGS.eval_only:
            result['best_eval_avg_return'] = best_dev_avg_return
          else:
            result['best_dev_avg_return'] = best_dev_avg_return
          json.dump(result, f)

      if FLAGS.eval_only:
        # Save the decoding results for further. 
        dev_programs_in_beam_dict = {}
        for sample in dev_samples_in_beam:
          name = sample.traj.env_name
          program = agent_factory.traj_to_program(sample.traj, envs[0].de_vocab)
          answer = sample.traj.answer
          if name in dev_programs_in_beam_dict:
            dev_programs_in_beam_dict[name].append((program, answer, sample.prob))
          else:
            dev_programs_in_beam_dict[name] = [(program, answer, sample.prob)]

        t3 = time.time()
        with open(
            os.path.join(get_experiment_dir(), 'dev_programs_in_beam_{}.json'.format(i)),
            'w') as f:
          json.dump(dev_programs_in_beam_dict, f)
        t4 = time.time()
        tf.logging.info('{} sec used dumping programs in beam in eval iteration {}.'.format(
          t4 - t3, i))

        t3 = time.time()
        with codecs.open(
            os.path.join(
              get_experiment_dir(), 'dev_samples_{}.txt'.format(i)),
            'w', encoding='utf-8') as f:
          for sample in dev_samples:
            f.write(show_samples([sample], envs[0].de_vocab, env_dict))
        t4 = time.time()
        tf.logging.info('{} sec used logging dev samples in eval iteration {}.'.format(
          t4 - t3, i))

      t2 = time.time()
      tf.logging.info('{} sec used in eval iteration {}.'.format(
        t2 - t1, i))

      if FLAGS.eval_only or agent.model.get_global_step() >= FLAGS.n_steps:
        tf.logging.info('{} finished'.format(self.name))
        if FLAGS.eval_only:
          print('Eval average return (accuracy) of the best model is {}'.format(
            best_dev_avg_return))
        else:
          print('Best dev average return (accuracy) is {}'.format(best_dev_avg_return))
          print('Best model is saved in {}'.format(best_model_path))
        return

      # Reload on the latest model.
      new_ckpt = None
      t1 = time.time()
      while new_ckpt is None or new_ckpt == current_ckpt:
        time.sleep(1)
        new_ckpt = tf.train.latest_checkpoint(
          os.path.join(get_experiment_dir(), FLAGS.saved_model_dir))
      t2 = time.time()
      tf.logging.info('{} sec used waiting for new checkpoint in evaluator.'.format(
        t2-t1))
      
      tf.logging.info('lastest ckpt to evaluate is {}.'.format(new_ckpt))
      tf.logging.info('{} loading ckpt {}'.format(self.name, new_ckpt))
      t1 = time.time()
      graph.restore(new_ckpt)
      t2 = time.time()
      tf.logging.info('{} sec used {} loading ckpt {}'.format(
        t2-t1, self.name, new_ckpt))
      current_ckpt = new_ckpt
Ejemplo n.º 7
0
    def decode_sketch_program(envs):
        # first create the real envs from the jsons and add constraints to them
        envs = json_to_envs(envs)
        env_name_dict = dict(map(lambda env: (env.name, env), envs))

        if FLAGS.executor == 'wtq':
            oracle_envs, oracle_trajs = get_wtq_annotations(envs)
        else:
            oracle_envs, oracle_trajs = get_env_trajs(envs)

        env_sketch_dict = dict([
            (env.name, get_sketch(traj_to_program(traj, envs[0].de_vocab)))
            for env, traj in zip(oracle_envs, oracle_trajs)
        ])
        for env in envs:
            sketch = env_sketch_dict.get(env.name, None)
            if sketch is not None:
                env.set_sketch_constraint(sketch[:])

        # create the agent
        graph_config = get_saved_graph_config()
        graph_config['use_gpu'] = False
        graph_config['gpu_id'] = '0'
        init_model_path = get_init_model_path()
        agent = create_agent(graph_config, init_model_path)

        # beam search
        beam_samples = []
        env_iterator = data_utils.BatchIterator(
            dict(envs=envs), shuffle=False, batch_size=FLAGS.eval_batch_size)
        for j, batch_dict in tqdm(enumerate(env_iterator)):
            batch_envs = batch_dict['envs']
            beam_samples += agent.beam_search(batch_envs, beam_size=50)

        # group the samples into beams (because the impl is so bad)
        env_beam_dict = dict()
        for sample in beam_samples:
            env_beam_dict[sample.traj.env_name] = env_beam_dict.get(
                sample.traj.env_name, []) + [sample]

        # get the trajs with 1.0 reward for each example and re-weight the prob
        env_name_annotation_dict = dict()
        for env_name, env in env_name_dict.iteritems():
            beam = env_beam_dict.get(env_name, [])
            success_beam = filter(lambda x: x.traj.rewards[-1] == 1.0, beam)
            if len(success_beam) > 0:
                # retrieve the sketch result from previous steps
                sketch = env_sketch_dict.get(env_name, None)

                if sketch is None:
                    env_name_annotation_dict[env_name] = None
                else:
                    # re-weight the examples in the beam
                    prob_sum = sum(
                        map(lambda sample: sample.prob, success_beam))
                    success_beam = map(
                        lambda sample: agent_factory.Sample(
                            traj=sample.traj, prob=sample.prob / prob_sum),
                        success_beam)
                    if len(success_beam) > 10:
                        success_beam = sorted(success_beam,
                                              key=lambda sample: sample.prob,
                                              reverse=True)
                        success_beam = success_beam[:10]

                    annotation = SketchAnnotation(env, sketch, success_beam)
                    env_name_annotation_dict[env_name] = annotation
            else:
                env_name_annotation_dict[env_name] = None

        return env_name_annotation_dict
Ejemplo n.º 8
0
    def run(self):
        agent, envs = init_experiment(
            [get_train_shard_path(i) for i in self.shard_ids],
            use_gpu=FLAGS.actor_use_gpu,
            gpu_id=str(self.actor_id + FLAGS.actor_gpu_start_id))
        self.decode_vocab = envs[0].de_vocab

        graph = agent.model.graph
        current_ckpt = get_init_model_path()

        env_dict = dict([(env.name, env) for env in envs])
        replay_buffer = agent_factory.AllGoodReplayBuffer(
            agent, envs[0].de_vocab)

        # Load saved programs to warm start the replay buffer.
        if FLAGS.load_saved_programs:
            load_programs(envs, replay_buffer, FLAGS.saved_program_file)

        if FLAGS.save_replay_buffer_at_end:
            replay_buffer_copy = agent_factory.AllGoodReplayBuffer(
                de_vocab=envs[0].de_vocab)
            replay_buffer_copy.program_prob_dict = copy.deepcopy(
                replay_buffer.program_prob_dict)

        # shrink the annotation dict to the envs needed
        small_env_annotation_dict = dict()
        for env in envs:
            annotation = self.env_annotation_dict.get(env.name, None)
            if annotation is not None:
                small_env_annotation_dict[env.name] = annotation
        self.env_annotation_dict = small_env_annotation_dict
        print('Actor %d, total %d envs, %d has been annotated.' %
              (self.actor_id, len(envs), len(self.env_annotation_dict)))

        # get samples from the annotations and put them into the buffer
        env_name_dict = dict([(env.name, env) for env in envs])
        if len(self.env_annotation_dict) > 0:
            annotated_samples = []
            for env_name, annotation in self.env_annotation_dict.items():
                samples_from_annotation = annotation.get_samples(
                    env_name_dict[env_name])
                annotated_samples += samples_from_annotation
            self.save_to_buffer(annotated_samples, replay_buffer)

        i = 0
        while True:
            # Create the logging files.
            if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
                f_replay = codecs.open(os.path.join(
                    get_experiment_dir(),
                    'replay_samples_{}_{}.txt'.format(self.name, i)),
                                       'w',
                                       encoding='utf-8')
                f_policy = codecs.open(os.path.join(
                    get_experiment_dir(),
                    'policy_samples_{}_{}.txt'.format(self.name, i)),
                                       'w',
                                       encoding='utf-8')
                f_train = codecs.open(os.path.join(
                    get_experiment_dir(),
                    'train_samples_{}_{}.txt'.format(self.name, i)),
                                      'w',
                                      encoding='utf-8')

            n_train_samples = 0
            if FLAGS.use_replay_samples_in_train:
                n_train_samples += FLAGS.n_replay_samples

            if FLAGS.use_policy_samples_in_train and FLAGS.use_nonreplay_samples_in_train:
                raise ValueError(
                    'Cannot use both on-policy samples and nonreplay samples for training!'
                )

            if FLAGS.use_policy_samples_in_train or FLAGS.use_nonreplay_samples_in_train:
                # Note that nonreplay samples are drawn by rejection
                # sampling from on-policy samples.
                n_train_samples += FLAGS.n_policy_samples

            # Make sure that all the samples from the env batch
            # fits into one batch for training.
            if FLAGS.batch_size < n_train_samples:
                raise ValueError(
                    'One batch have to at least contain samples from one environment.'
                )

            env_batch_size = FLAGS.batch_size / n_train_samples

            env_iterator = data_utils.BatchIterator(dict(envs=envs),
                                                    shuffle=True,
                                                    batch_size=env_batch_size)

            for j, batch_dict in enumerate(env_iterator):
                batch_envs = batch_dict['envs']
                tf.logging.info('=' * 50)
                tf.logging.info('{} iteration {}, batch {}: {} envs'.format(
                    self.name, i, j, len(batch_envs)))

                t1 = time.time()
                # Generate samples with cache and save to replay buffer.
                t3 = time.time()
                n_explore = 0
                for _ in xrange(FLAGS.n_explore_samples):
                    explore_samples = agent.generate_samples(
                        batch_envs,
                        n_samples=1,
                        use_cache=FLAGS.use_cache,
                        greedy=FLAGS.greedy_exploration)
                    self.save_to_buffer(explore_samples, replay_buffer)
                    n_explore += len(explore_samples)

                if FLAGS.n_extra_explore_for_hard > 0:
                    hard_envs = [
                        env for env in batch_envs
                        if not replay_buffer.has_found_solution(env.name)
                    ]
                    if hard_envs:
                        for _ in xrange(FLAGS.n_extra_explore_for_hard):
                            explore_samples = agent.generate_samples(
                                hard_envs,
                                n_samples=1,
                                use_cache=FLAGS.use_cache,
                                greedy=FLAGS.greedy_exploration)
                            self.save_to_buffer(explore_samples, replay_buffer)
                            n_explore += len(explore_samples)

                t4 = time.time()
                tf.logging.info(
                    '{} sec used generating {} exploration samples.'.format(
                        t4 - t3, n_explore))

                tf.logging.info(
                    '{} samples saved in the replay buffer.'.format(
                        replay_buffer.size))

                t3 = time.time()
                replay_samples = replay_buffer.replay(
                    batch_envs,
                    FLAGS.n_replay_samples,
                    use_top_k=FLAGS.use_top_k_replay_samples,
                    agent=None if FLAGS.random_replay_samples else agent,
                    truncate_at_n=FLAGS.truncate_replay_buffer_at_n)
                t4 = time.time()
                tf.logging.info(
                    '{} sec used selecting {} replay samples.'.format(
                        t4 - t3, len(replay_samples)))

                t3 = time.time()
                if FLAGS.use_top_k_policy_samples:
                    if FLAGS.n_policy_samples == 1:
                        policy_samples = agent.generate_samples(
                            batch_envs,
                            n_samples=FLAGS.n_policy_samples,
                            greedy=True)
                    else:
                        policy_samples = agent.beam_search(
                            batch_envs, beam_size=FLAGS.n_policy_samples)
                else:
                    policy_samples = agent.generate_samples(
                        batch_envs,
                        n_samples=FLAGS.n_policy_samples,
                        greedy=False)
                t4 = time.time()
                tf.logging.info(
                    '{} sec used generating {} on-policy samples'.format(
                        t4 - t3, len(policy_samples)))

                t2 = time.time()
                tf.logging.info(
                    ('{} sec used generating replay and on-policy samples,'
                     ' {} iteration {}, batch {}: {} envs').format(
                         t2 - t1, self.name, i, j, len(batch_envs)))

                t1 = time.time()
                self.eval_queue.put((policy_samples, len(batch_envs)))
                self.replay_queue.put((replay_samples, len(batch_envs)))

                assert (FLAGS.fixed_replay_weight >= 0.0
                        and FLAGS.fixed_replay_weight <= 1.0)

                if FLAGS.use_replay_prob_as_weight:
                    new_samples = []
                    for sample in replay_samples:
                        name = sample.traj.env_name
                        if name in replay_buffer.prob_sum_dict:
                            replay_prob = max(
                                replay_buffer.prob_sum_dict[name],
                                FLAGS.min_replay_weight)
                        else:
                            replay_prob = 0.0
                        scale = replay_prob
                        new_samples.append(
                            agent_factory.Sample(traj=sample.traj,
                                                 prob=sample.prob * scale))
                    replay_samples = new_samples
                else:
                    replay_samples = agent_factory.scale_probs(
                        replay_samples, FLAGS.fixed_replay_weight)

                replay_samples = sorted(replay_samples,
                                        key=lambda x: x.traj.env_name)

                policy_samples = sorted(policy_samples,
                                        key=lambda x: x.traj.env_name)

                if FLAGS.use_nonreplay_samples_in_train:
                    nonreplay_samples = []
                    for sample in policy_samples:
                        if not replay_buffer.contain(sample.traj):
                            nonreplay_samples.append(sample)

                self.save_to_buffer(policy_samples, replay_buffer)

                def weight_samples(samples):
                    if FLAGS.use_replay_prob_as_weight:
                        new_samples = []
                        for sample in samples:
                            name = sample.traj.env_name
                            if name in replay_buffer.prob_sum_dict:
                                replay_prob = max(
                                    replay_buffer.prob_sum_dict[name],
                                    FLAGS.min_replay_weight)
                            else:
                                replay_prob = 0.0
                            scale = 1.0 - replay_prob
                            new_samples.append(
                                agent_factory.Sample(traj=sample.traj,
                                                     prob=sample.prob * scale))
                    else:
                        new_samples = agent_factory.scale_probs(
                            samples, 1 - FLAGS.fixed_replay_weight)
                    return new_samples

                train_samples = []
                if FLAGS.use_replay_samples_in_train:
                    if FLAGS.use_trainer_prob:
                        replay_samples = [
                            sample._replace(prob=None)
                            for sample in replay_samples
                        ]
                    train_samples += replay_samples

                if FLAGS.use_policy_samples_in_train:
                    train_samples += weight_samples(policy_samples)

                if FLAGS.use_nonreplay_samples_in_train:
                    train_samples += weight_samples(nonreplay_samples)

                train_samples = sorted(train_samples,
                                       key=lambda x: x.traj.env_name)
                tf.logging.info('{} train samples'.format(len(train_samples)))

                if FLAGS.use_importance_sampling:
                    step_logprobs = agent.compute_step_logprobs(
                        [s.traj for s in train_samples])
                else:
                    step_logprobs = None

                if FLAGS.use_replay_prob_as_weight:
                    n_clip = 0
                    for env in batch_envs:
                        name = env.name
                        if (name in replay_buffer.prob_sum_dict
                                and replay_buffer.prob_sum_dict[name] <
                                FLAGS.min_replay_weight):
                            n_clip += 1
                    clip_frac = float(n_clip) / len(batch_envs)
                else:
                    clip_frac = 0.0

                # put all weight on the annotated ones at first, then gradually increase explored examples
                al_scale_factor = min(
                    1.0,
                    (agent.model.get_global_step() - FLAGS.active_start_step) /
                    float(FLAGS.active_scale_steps))
                assert (al_scale_factor >= 0.0 and al_scale_factor <= 1.0)
                for i, sample in enumerate(train_samples):
                    if sample.traj.env_name not in self.env_annotation_dict:
                        train_samples[i] = agent_factory.Sample(
                            traj=sample.traj,
                            prob=sample.prob * al_scale_factor)
                    else:
                        annotation = self.env_annotation_dict[
                            sample.traj.env_name]
                        explored_program = agent_factory.traj_to_program(
                            sample.traj, self.decode_vocab)
                        if not annotation.verify_program(explored_program):
                            train_samples[i] = agent_factory.Sample(
                                traj=sample.traj,
                                prob=sample.prob * al_scale_factor)

                self.train_queue.put((train_samples, step_logprobs, clip_frac))
                t2 = time.time()
                tf.logging.info(
                    ('{} sec used preparing and enqueuing samples, {}'
                     ' iteration {}, batch {}: {} envs').format(
                         t2 - t1, self.name, i, j, len(batch_envs)))

                t1 = time.time()
                # Wait for a ckpt that still exist or it is the same
                # ckpt (no need to load anything).
                while True:
                    new_ckpt = self.ckpt_queue.get()
                    new_ckpt_file = new_ckpt + '.meta'
                    if new_ckpt == current_ckpt or tf.gfile.Exists(
                            new_ckpt_file):
                        break
                t2 = time.time()
                tf.logging.info(
                    '{} sec waiting {} iteration {}, batch {}'.format(
                        t2 - t1, self.name, i, j))

                if new_ckpt != current_ckpt:
                    # If the ckpt is not the same, then restore the new
                    # ckpt.
                    tf.logging.info('{} loading ckpt {}'.format(
                        self.name, new_ckpt))
                    t1 = time.time()
                    graph.restore(new_ckpt)
                    t2 = time.time()
                    tf.logging.info('{} sec used {} restoring ckpt {}'.format(
                        t2 - t1, self.name, new_ckpt))
                    current_ckpt = new_ckpt

                if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
                    f_replay.write(
                        show_samples(replay_samples, envs[0].de_vocab,
                                     env_dict))
                    f_policy.write(
                        show_samples(policy_samples, envs[0].de_vocab,
                                     env_dict))
                    f_train.write(
                        show_samples(train_samples, envs[0].de_vocab,
                                     env_dict))

            if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
                f_replay.close()
                f_policy.close()
                f_train.close()

            if agent.model.get_global_step() >= FLAGS.n_steps:
                if FLAGS.save_replay_buffer_at_end:
                    all_replay = os.path.join(
                        get_experiment_dir(),
                        'all_replay_samples_{}.txt'.format(self.name))
                with codecs.open(all_replay, 'w', encoding='utf-8') as f:
                    samples = replay_buffer.all_samples(envs, agent=None)
                    samples = [
                        s for s in samples
                        if not replay_buffer_copy.contain(s.traj)
                    ]
                    f.write(show_samples(samples, envs[0].de_vocab, None))

                tf.logging.info('{} finished'.format(self.name))
                return
            i += 1
Ejemplo n.º 9
0
    def run(self):
        agent, all_envs = init_experiment(
            [get_train_shard_path(i) for i in self.shard_ids],
            use_gpu=FLAGS.actor_use_gpu,
            gpu_id=str(self.actor_id + FLAGS.actor_gpu_start_id))
        graph = agent.model.graph
        current_ckpt = get_init_model_path()

        # obtain the oracle of the examples and delete the examples that can not obtain oracle
        envs, env_trajs = get_env_trajs(all_envs)

        # build a dict to store the oracle trajs
        env_oracle_trajs_dict = dict()
        for env, env_traj in zip(envs, env_trajs):
            env_oracle_trajs_dict[env.name] = env_traj
        tf.logging.info(
            'Found oracle for {} envs out of total of {} for actor_{}'.format(
                len(all_envs), len(envs), self.actor_id))

        i = 0
        while True:
            n_train_samples = 0

            n_train_samples += 1

            # Make sure that all the samples from the env batch
            # fits into one batch for training.
            if FLAGS.batch_size < n_train_samples:
                raise ValueError(
                    'One batch have to at least contain samples from one environment.'
                )

            env_batch_size = FLAGS.batch_size / n_train_samples

            env_iterator = data_utils.BatchIterator(dict(envs=envs),
                                                    shuffle=True,
                                                    batch_size=env_batch_size)

            for j, batch_dict in enumerate(env_iterator):
                batch_envs = batch_dict['envs']
                tf.logging.info('=' * 50)
                tf.logging.info('{} iteration {}, batch {}: {} envs'.format(
                    self.name, i, j, len(batch_envs)))

                t1 = time.time()

                # get the oracle samples
                oracle_samples = []
                for batch_env in batch_envs:
                    oracle_samples.append(
                        agent_factory.Sample(
                            traj=env_oracle_trajs_dict[batch_env.name],
                            prob=1.0))

                self.eval_queue.put((oracle_samples, len(batch_envs)))
                self.replay_queue.put((oracle_samples, len(batch_envs)))

                assert (FLAGS.fixed_replay_weight >= 0.0
                        and FLAGS.fixed_replay_weight <= 1.0)

                train_samples = []

                train_samples += oracle_samples

                train_samples = sorted(train_samples,
                                       key=lambda x: x.traj.env_name)
                tf.logging.info('{} train samples'.format(len(train_samples)))

                if FLAGS.use_importance_sampling:
                    step_logprobs = agent.compute_step_logprobs(
                        [s.traj for s in train_samples])
                else:
                    step_logprobs = None

                # TODO: the clip_factor may be wrong
                self.train_queue.put((train_samples, step_logprobs, 0.0))
                t2 = time.time()
                tf.logging.info(
                    ('{} sec used preparing and enqueuing samples, {}'
                     ' iteration {}, batch {}: {} envs').format(
                         t2 - t1, self.name, i, j, len(batch_envs)))

                t1 = time.time()
                # Wait for a ckpt that still exist or it is the same
                # ckpt (no need to load anything).
                while True:
                    new_ckpt = self.ckpt_queue.get()
                    new_ckpt_file = new_ckpt + '.meta'
                    if new_ckpt == current_ckpt or tf.gfile.Exists(
                            new_ckpt_file):
                        break
                t2 = time.time()
                tf.logging.info(
                    '{} sec waiting {} iteration {}, batch {}'.format(
                        t2 - t1, self.name, i, j))

                if new_ckpt != current_ckpt:
                    # If the ckpt is not the same, then restore the new
                    # ckpt.
                    tf.logging.info('{} loading ckpt {}'.format(
                        self.name, new_ckpt))
                    t1 = time.time()
                    graph.restore(new_ckpt)
                    t2 = time.time()
                    tf.logging.info('{} sec used {} restoring ckpt {}'.format(
                        t2 - t1, self.name, new_ckpt))
                    current_ckpt = new_ckpt

            if agent.model.get_global_step() >= FLAGS.n_steps:
                tf.logging.info('{} finished'.format(self.name))
                return
            i += 1