def save_to_buffer(self, examples, replay_buffer): filtered_examples = [] # for the examples that already obtained annotation, screen out the ones that do not fit for example in examples: annotation = self.env_annotation_dict.get(example.traj.env_name, None) if annotation is not None and len(annotation) > 0: explored_program = agent_factory.traj_to_program( example.traj, self.decode_vocab) if not annotation.verify_program(explored_program): continue filtered_examples.append(example) replay_buffer.save(filtered_examples)
def annotate_example(self, envs): # first create envs from jsons envs = json_to_envs(envs) if FLAGS.executor == 'wtq': oracle_envs, oracle_trajs = get_wtq_annotations(envs) else: oracle_envs, oracle_trajs = get_env_trajs(envs) oracle_env_programs = [(env.name, traj_to_program(traj, envs[0].de_vocab)) for env, traj in zip(oracle_envs, oracle_trajs)] env_name_program_dict = dict(oracle_env_programs) env_name_annotation_dict = dict() for env in envs: program = env_name_program_dict.get(env.name, None) if program is not None: annotation = OracleAnnotation(env, program) env_name_annotation_dict[env.name] = annotation else: env_name_annotation_dict[env.name] = None return env_name_annotation_dict
def annotate_example_exploration(self, envs): # first create envs from jsons envs = json_to_envs(envs) if FLAGS.executor == 'wtq': oracle_envs, oracle_trajs = get_wtq_annotations(envs) else: oracle_envs, oracle_trajs = get_env_trajs(envs) oracle_env_programs = [(env.name, traj_to_program(traj, envs[0].de_vocab)) for env, traj in zip(oracle_envs, oracle_trajs)] env_name_program_dict = dict(oracle_env_programs) env_name_annotation_dict = dict() for env in envs: program = env_name_program_dict.get(env.name, None) if program is None: env_name_annotation_dict[env.name] = None else: sketch = get_sketch(program) explored_programs = SketchAnnotator.explore_sketch_programs( env, sketch) oracle_trajs = [ collect_traj_for_program(env, explored_program) for explored_program in explored_programs ] samples = [ agent_factory.Sample(oracle_traj, prob=1.0) for oracle_traj in oracle_trajs ] annotation = SketchAnnotation(env, sketch, samples) env_name_annotation_dict[env.name] = annotation return env_name_annotation_dict
def run(self): agent, envs = init_experiment(self.fns, FLAGS.eval_use_gpu, gpu_id=str(FLAGS.eval_gpu_id)) for env in envs: env.punish_extra_work = False graph = agent.model.graph dev_writer = tf.summary.FileWriter(os.path.join( get_experiment_dir(), FLAGS.tb_log_dir, 'dev')) best_dev_avg_return = 0.0 best_model_path = '' best_model_dir = os.path.join(get_experiment_dir(), FLAGS.best_model_dir) if not tf.gfile.Exists(best_model_dir): tf.gfile.MkDir(best_model_dir) i = 0 current_ckpt = get_init_model_path() env_dict = dict([(env.name, env) for env in envs]) while True: t1 = time.time() tf.logging.info('dev: iteration {}, evaluating {}.'.format(i, current_ckpt)) env_batch_size = FLAGS.eval_batch_size env_iterator = data_utils.BatchIterator( dict(envs=envs), shuffle=False, batch_size=env_batch_size) dev_samples = [] dev_samples_in_beam = [] for j, batch_dict in enumerate(env_iterator): t3 = time.time() batch_envs = batch_dict['envs'] tf.logging.info('=' * 50) tf.logging.info('{} iteration {}, batch {}: {} envs'.format( self.name, i, j, len(batch_envs))) new_samples_in_beam = agent.beam_search( batch_envs, beam_size=FLAGS.eval_beam_size) dev_samples_in_beam += new_samples_in_beam tf.logging.info('{} samples in beam, batch {}.'.format( len(new_samples_in_beam), j)) t4 = time.time() tf.logging.info('{} sec used in evaluator batch {}.'.format(t4 - t3, j)) # Account for beam search where the beam doesn't # contain any examples without error, which will make # len(dev_samples) smaller than len(envs). dev_samples = self.select_top(dev_samples_in_beam) dev_avg_return, dev_avg_len = agent.evaluate( dev_samples, writer=dev_writer, true_n=len(envs)) tf.logging.info('{} samples in non-empty beam.'.format(len(dev_samples))) tf.logging.info('true n is {}'.format(len(envs))) tf.logging.info('{} questions in dev set.'.format(len(envs))) tf.logging.info('{} dev avg return.'.format(dev_avg_return)) tf.logging.info('dev: avg return: {}, avg length: {}.'.format( dev_avg_return, dev_avg_len)) if dev_avg_return > best_dev_avg_return: best_model_path = graph.save( os.path.join(best_model_dir, 'model'), agent.model.get_global_step()) best_dev_avg_return = dev_avg_return tf.logging.info('New best dev avg returns is {}'.format(best_dev_avg_return)) tf.logging.info('New best model is saved in {}'.format(best_model_path)) with open(os.path.join(get_experiment_dir(), 'best_model_info.json'), 'w') as f: result = {'best_model_path': best_model_path} if FLAGS.eval_only: result['best_eval_avg_return'] = best_dev_avg_return else: result['best_dev_avg_return'] = best_dev_avg_return json.dump(result, f) if FLAGS.eval_only: # Save the decoding results for further. dev_programs_in_beam_dict = {} for sample in dev_samples_in_beam: name = sample.traj.env_name program = agent_factory.traj_to_program(sample.traj, batch_envs[0].de_vocab) answer = sample.traj.answer if name in dev_programs_in_beam_dict: dev_programs_in_beam_dict[name].append((program, answer, sample.prob)) else: dev_programs_in_beam_dict[name] = [(program, answer, sample.prob)] t3 = time.time() with open( os.path.join(get_experiment_dir(), 'dev_programs_in_beam_{}.json'.format(i)), 'w') as f: json.dump(dev_programs_in_beam_dict, f) t4 = time.time() tf.logging.info('{} sec used dumping programs in beam in eval iteration {}.'.format( t4 - t3, i)) t3 = time.time() with codecs.open( os.path.join( get_experiment_dir(), 'dev_samples_{}.txt'.format(i)), 'w', encoding='utf-8') as f: for sample in dev_samples: f.write(show_samples([sample], envs[0].de_vocab, env_dict)) t4 = time.time() tf.logging.info('{} sec used logging dev samples in eval iteration {}.'.format( t4 - t3, i)) t2 = time.time() tf.logging.info('{} sec used in eval iteration {}.'.format( t2 - t1, i)) if FLAGS.eval_only or agent.model.get_global_step() >= FLAGS.n_steps: tf.logging.info('{} finished'.format(self.name)) if FLAGS.eval_only: print('Eval average return (accuracy) of the best model is {}'.format( best_dev_avg_return)) else: print('Best dev average return (accuracy) is {}'.format(best_dev_avg_return)) print('Best model is saved in {}'.format(best_model_path)) return # Reload on the latest model. new_ckpt = None t1 = time.time() while new_ckpt is None or new_ckpt == current_ckpt: time.sleep(1) new_ckpt = tf.train.latest_checkpoint( os.path.join(get_experiment_dir(), FLAGS.saved_model_dir)) t2 = time.time() tf.logging.info('{} sec used waiting for new checkpoint in evaluator.'.format( t2-t1)) tf.logging.info('lastest ckpt to evaluate is {}.'.format(new_ckpt)) tf.logging.info('{} loading ckpt {}'.format(self.name, new_ckpt)) t1 = time.time() graph.restore(new_ckpt) t2 = time.time() tf.logging.info('{} sec used {} loading ckpt {}'.format( t2-t1, self.name, new_ckpt)) current_ckpt = new_ckpt
def run(self): agent, envs = init_experiment(self.fns, FLAGS.eval_use_gpu, gpu_id=str(FLAGS.eval_gpu_id)) for env in envs: env.punish_extra_work = False graph = agent.model.graph dev_writer = tf.summary.FileWriter(os.path.join( get_experiment_dir(), FLAGS.tb_log_dir, 'dev')) best_dev_avg_return = 0.0 best_model_path = '' best_model_dir = os.path.join(get_experiment_dir(), FLAGS.best_model_dir) if not tf.gfile.Exists(best_model_dir): tf.gfile.MkDir(best_model_dir) i = 0 current_ckpt = get_init_model_path() env_dict = dict([(env.name, env) for env in envs]) while True: t1 = time.time() tf.logging.info('dev: iteration {}, evaluating {}.'.format(i, current_ckpt)) dev_avg_return, dev_samples, dev_samples_in_beam = beam_search_eval( agent, envs, writer=dev_writer) if dev_avg_return > best_dev_avg_return: best_model_path = graph.save( os.path.join(best_model_dir, 'model'), agent.model.get_global_step()) best_dev_avg_return = dev_avg_return tf.logging.info('New best dev avg returns is {}'.format(best_dev_avg_return)) tf.logging.info('New best model is saved in {}'.format(best_model_path)) with open(os.path.join(get_experiment_dir(), 'best_model_info.json'), 'w') as f: result = {'best_model_path': compress_home_path(best_model_path)} if FLAGS.eval_only: result['best_eval_avg_return'] = best_dev_avg_return else: result['best_dev_avg_return'] = best_dev_avg_return json.dump(result, f) if FLAGS.eval_only: # Save the decoding results for further. dev_programs_in_beam_dict = {} for sample in dev_samples_in_beam: name = sample.traj.env_name program = agent_factory.traj_to_program(sample.traj, envs[0].de_vocab) answer = sample.traj.answer if name in dev_programs_in_beam_dict: dev_programs_in_beam_dict[name].append((program, answer, sample.prob)) else: dev_programs_in_beam_dict[name] = [(program, answer, sample.prob)] t3 = time.time() with open( os.path.join(get_experiment_dir(), 'dev_programs_in_beam_{}.json'.format(i)), 'w') as f: json.dump(dev_programs_in_beam_dict, f) t4 = time.time() tf.logging.info('{} sec used dumping programs in beam in eval iteration {}.'.format( t4 - t3, i)) t3 = time.time() with codecs.open( os.path.join( get_experiment_dir(), 'dev_samples_{}.txt'.format(i)), 'w', encoding='utf-8') as f: for sample in dev_samples: f.write(show_samples([sample], envs[0].de_vocab, env_dict)) t4 = time.time() tf.logging.info('{} sec used logging dev samples in eval iteration {}.'.format( t4 - t3, i)) t2 = time.time() tf.logging.info('{} sec used in eval iteration {}.'.format( t2 - t1, i)) if FLAGS.eval_only or agent.model.get_global_step() >= FLAGS.n_steps: tf.logging.info('{} finished'.format(self.name)) if FLAGS.eval_only: print('Eval average return (accuracy) of the best model is {}'.format( best_dev_avg_return)) else: print('Best dev average return (accuracy) is {}'.format(best_dev_avg_return)) print('Best model is saved in {}'.format(best_model_path)) return # Reload on the latest model. new_ckpt = None t1 = time.time() while new_ckpt is None or new_ckpt == current_ckpt: time.sleep(1) new_ckpt = tf.train.latest_checkpoint( os.path.join(get_experiment_dir(), FLAGS.saved_model_dir)) t2 = time.time() tf.logging.info('{} sec used waiting for new checkpoint in evaluator.'.format( t2-t1)) tf.logging.info('lastest ckpt to evaluate is {}.'.format(new_ckpt)) tf.logging.info('{} loading ckpt {}'.format(self.name, new_ckpt)) t1 = time.time() graph.restore(new_ckpt) t2 = time.time() tf.logging.info('{} sec used {} loading ckpt {}'.format( t2-t1, self.name, new_ckpt)) current_ckpt = new_ckpt
def wikisql_error_analysis(): env_file = '/Users/ansongni/projects/data/wikisql/processed_input/preprocess_4/test_split.jsonl' #train_env_file_prefix = '/Users/ansongni/projects/data/wikisql/processed_input/preprocess_4/train_split_shard_30-' decoded_beam_file = '/Users/ansongni/projects/data/wikisql/output/eval_imp_baseline/dev_programs_in_beam_0.json' #decoded_beam_file = '/Users/ansongni/projects/data/wikisql/output/train_eval_imp_baseline/dev_programs_in_beam_0.json' # first load the test environments and get oracle programs #test_envs = get_envs([(train_env_file_prefix+str(i)+'.jsonl') for i in range(0, 30)]) test_envs = get_envs([env_file]) envs, trajs = get_env_trajs(test_envs) oracle_env_programs = [(env, traj_to_program(traj, envs[0].de_vocab)) for env, traj in zip(envs, trajs)] # then load decoded results in the beam with open(decoded_beam_file) as f: decoded_beam = json.load(f) example_results = [] # generate the detailed example result for each env that got oracle for env, oracle_program in oracle_env_programs: id = env.name table = env.question_annotation['context'] question = env.question_annotation['question'] oracle_answer = env.question_annotation['answer'] # take care of missing example hyps = decoded_beam.get(id, None) if hyps is None: continue beam = [] for hyp in hyps: prob = hyp[2] predicted_program = hyp[0] predicted_answer = hyp[1] program_hyp = program_result(prob, predicted_program, predicted_answer) beam.append(program_hyp) result = detailed_example_result(id, table, question, oracle_program, oracle_answer, beam) example_results.append(result) # 1. now we do overall error analysis print('%d test examples, %d have oracle program, %d gets evaluated' % (len(test_envs), len(envs), len(example_results))) overall_analysis(example_results) confidence_analysis(example_results) # 2. now we analyze the failed cases failed_example_results = filter( lambda result: 1.0 - wikisql_score( result.hyps_in_beam[0].predicted_answer, result.oracle_answer), example_results) print('%d failed examples' % len(failed_example_results)) overall_analysis(failed_example_results) confidence_analysis(failed_example_results) failed_words_analysis(example_results, failed_example_results)
def decode_sketch_program(envs): # first create the real envs from the jsons and add constraints to them envs = json_to_envs(envs) env_name_dict = dict(map(lambda env: (env.name, env), envs)) if FLAGS.executor == 'wtq': oracle_envs, oracle_trajs = get_wtq_annotations(envs) else: oracle_envs, oracle_trajs = get_env_trajs(envs) env_sketch_dict = dict([ (env.name, get_sketch(traj_to_program(traj, envs[0].de_vocab))) for env, traj in zip(oracle_envs, oracle_trajs) ]) for env in envs: sketch = env_sketch_dict.get(env.name, None) if sketch is not None: env.set_sketch_constraint(sketch[:]) # create the agent graph_config = get_saved_graph_config() graph_config['use_gpu'] = False graph_config['gpu_id'] = '0' init_model_path = get_init_model_path() agent = create_agent(graph_config, init_model_path) # beam search beam_samples = [] env_iterator = data_utils.BatchIterator( dict(envs=envs), shuffle=False, batch_size=FLAGS.eval_batch_size) for j, batch_dict in tqdm(enumerate(env_iterator)): batch_envs = batch_dict['envs'] beam_samples += agent.beam_search(batch_envs, beam_size=50) # group the samples into beams (because the impl is so bad) env_beam_dict = dict() for sample in beam_samples: env_beam_dict[sample.traj.env_name] = env_beam_dict.get( sample.traj.env_name, []) + [sample] # get the trajs with 1.0 reward for each example and re-weight the prob env_name_annotation_dict = dict() for env_name, env in env_name_dict.iteritems(): beam = env_beam_dict.get(env_name, []) success_beam = filter(lambda x: x.traj.rewards[-1] == 1.0, beam) if len(success_beam) > 0: # retrieve the sketch result from previous steps sketch = env_sketch_dict.get(env_name, None) if sketch is None: env_name_annotation_dict[env_name] = None else: # re-weight the examples in the beam prob_sum = sum( map(lambda sample: sample.prob, success_beam)) success_beam = map( lambda sample: agent_factory.Sample( traj=sample.traj, prob=sample.prob / prob_sum), success_beam) if len(success_beam) > 10: success_beam = sorted(success_beam, key=lambda sample: sample.prob, reverse=True) success_beam = success_beam[:10] annotation = SketchAnnotation(env, sketch, success_beam) env_name_annotation_dict[env_name] = annotation else: env_name_annotation_dict[env_name] = None return env_name_annotation_dict
def explore_sketch_programs(env, sketch, max_hyp=1000): '''provide a sketch, find all the executable programs fit that sketch''' env = env.clone() env.use_cache = False all_hyps = [(env, env.start_ob)] for cmd in sketch: # first add the left bracket and the cmd head new_all_hyps = [] for env, ob in all_hyps: try: action_index = list(ob[0].valid_indices).index( env.de_vocab.lookup('(')) ob, _, _, _ = env.step(action_index) action_index = list(ob[0].valid_indices).index( env.de_vocab.lookup(cmd)) ob, _, _, _ = env.step(action_index) new_all_hyps.append((env, ob)) except ValueError: raise ValueError('This should not happen') all_hyps = new_all_hyps # then for every possible action, we use a cloned env and step that until the closure of this stmt stmt_done = False while not stmt_done: new_all_hyps = [] for env, ob in all_hyps: valid_indices = list(ob[0].valid_indices) if len(valid_indices ) == 0: # this is a dead end for current env continue for action_index in range(len(valid_indices)): new_env = env.clone() new_env.use_cache = False ob, _, _, _ = new_env.step(action_index) new_all_hyps.append((new_env, ob)) if valid_indices == [env.de_vocab.lookup(')') ]: # maximum stmt length is reached stmt_done = True all_hyps = list(np.random.permutation(new_all_hyps))[:max_hyp] # add then end token new_all_hyps = [] for env, ob in all_hyps: try: action_index = list(ob[0].valid_indices).index( env.de_vocab.lookup('<END>')) ob, _, _, _ = env.step(action_index) new_all_hyps.append((env, ob)) except ValueError: raise ValueError('This should not happen') all_hyps = new_all_hyps # pruning based on the result explored_programs = [] for env, _ in all_hyps: if env.rewards[-1] == 1.0: traj = agent_factory.Traj(obs=env.obs, actions=env.actions, rewards=env.rewards, context=env.get_context(), env_name=env.name, answer=env.interpreter.result) explored_programs.append(traj_to_program(traj, env.de_vocab)) explored_programs = list(np.random.permutation(explored_programs)) return explored_programs
def __init__(self, env, sketch, samples): Annotation.__init__(self, env.name) self.sketch = sketch self.sketch_programs = [(traj_to_program(sample.traj, env.de_vocab), sample.prob) for sample in samples]
def run(self): agent, envs = init_experiment( [get_train_shard_path(i) for i in self.shard_ids], use_gpu=FLAGS.actor_use_gpu, gpu_id=str(self.actor_id + FLAGS.actor_gpu_start_id)) self.decode_vocab = envs[0].de_vocab graph = agent.model.graph current_ckpt = get_init_model_path() env_dict = dict([(env.name, env) for env in envs]) replay_buffer = agent_factory.AllGoodReplayBuffer( agent, envs[0].de_vocab) # Load saved programs to warm start the replay buffer. if FLAGS.load_saved_programs: load_programs(envs, replay_buffer, FLAGS.saved_program_file) if FLAGS.save_replay_buffer_at_end: replay_buffer_copy = agent_factory.AllGoodReplayBuffer( de_vocab=envs[0].de_vocab) replay_buffer_copy.program_prob_dict = copy.deepcopy( replay_buffer.program_prob_dict) # shrink the annotation dict to the envs needed small_env_annotation_dict = dict() for env in envs: annotation = self.env_annotation_dict.get(env.name, None) if annotation is not None: small_env_annotation_dict[env.name] = annotation self.env_annotation_dict = small_env_annotation_dict print('Actor %d, total %d envs, %d has been annotated.' % (self.actor_id, len(envs), len(self.env_annotation_dict))) # get samples from the annotations and put them into the buffer env_name_dict = dict([(env.name, env) for env in envs]) if len(self.env_annotation_dict) > 0: annotated_samples = [] for env_name, annotation in self.env_annotation_dict.items(): samples_from_annotation = annotation.get_samples( env_name_dict[env_name]) annotated_samples += samples_from_annotation self.save_to_buffer(annotated_samples, replay_buffer) i = 0 while True: # Create the logging files. if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0: f_replay = codecs.open(os.path.join( get_experiment_dir(), 'replay_samples_{}_{}.txt'.format(self.name, i)), 'w', encoding='utf-8') f_policy = codecs.open(os.path.join( get_experiment_dir(), 'policy_samples_{}_{}.txt'.format(self.name, i)), 'w', encoding='utf-8') f_train = codecs.open(os.path.join( get_experiment_dir(), 'train_samples_{}_{}.txt'.format(self.name, i)), 'w', encoding='utf-8') n_train_samples = 0 if FLAGS.use_replay_samples_in_train: n_train_samples += FLAGS.n_replay_samples if FLAGS.use_policy_samples_in_train and FLAGS.use_nonreplay_samples_in_train: raise ValueError( 'Cannot use both on-policy samples and nonreplay samples for training!' ) if FLAGS.use_policy_samples_in_train or FLAGS.use_nonreplay_samples_in_train: # Note that nonreplay samples are drawn by rejection # sampling from on-policy samples. n_train_samples += FLAGS.n_policy_samples # Make sure that all the samples from the env batch # fits into one batch for training. if FLAGS.batch_size < n_train_samples: raise ValueError( 'One batch have to at least contain samples from one environment.' ) env_batch_size = FLAGS.batch_size / n_train_samples env_iterator = data_utils.BatchIterator(dict(envs=envs), shuffle=True, batch_size=env_batch_size) for j, batch_dict in enumerate(env_iterator): batch_envs = batch_dict['envs'] tf.logging.info('=' * 50) tf.logging.info('{} iteration {}, batch {}: {} envs'.format( self.name, i, j, len(batch_envs))) t1 = time.time() # Generate samples with cache and save to replay buffer. t3 = time.time() n_explore = 0 for _ in xrange(FLAGS.n_explore_samples): explore_samples = agent.generate_samples( batch_envs, n_samples=1, use_cache=FLAGS.use_cache, greedy=FLAGS.greedy_exploration) self.save_to_buffer(explore_samples, replay_buffer) n_explore += len(explore_samples) if FLAGS.n_extra_explore_for_hard > 0: hard_envs = [ env for env in batch_envs if not replay_buffer.has_found_solution(env.name) ] if hard_envs: for _ in xrange(FLAGS.n_extra_explore_for_hard): explore_samples = agent.generate_samples( hard_envs, n_samples=1, use_cache=FLAGS.use_cache, greedy=FLAGS.greedy_exploration) self.save_to_buffer(explore_samples, replay_buffer) n_explore += len(explore_samples) t4 = time.time() tf.logging.info( '{} sec used generating {} exploration samples.'.format( t4 - t3, n_explore)) tf.logging.info( '{} samples saved in the replay buffer.'.format( replay_buffer.size)) t3 = time.time() replay_samples = replay_buffer.replay( batch_envs, FLAGS.n_replay_samples, use_top_k=FLAGS.use_top_k_replay_samples, agent=None if FLAGS.random_replay_samples else agent, truncate_at_n=FLAGS.truncate_replay_buffer_at_n) t4 = time.time() tf.logging.info( '{} sec used selecting {} replay samples.'.format( t4 - t3, len(replay_samples))) t3 = time.time() if FLAGS.use_top_k_policy_samples: if FLAGS.n_policy_samples == 1: policy_samples = agent.generate_samples( batch_envs, n_samples=FLAGS.n_policy_samples, greedy=True) else: policy_samples = agent.beam_search( batch_envs, beam_size=FLAGS.n_policy_samples) else: policy_samples = agent.generate_samples( batch_envs, n_samples=FLAGS.n_policy_samples, greedy=False) t4 = time.time() tf.logging.info( '{} sec used generating {} on-policy samples'.format( t4 - t3, len(policy_samples))) t2 = time.time() tf.logging.info( ('{} sec used generating replay and on-policy samples,' ' {} iteration {}, batch {}: {} envs').format( t2 - t1, self.name, i, j, len(batch_envs))) t1 = time.time() self.eval_queue.put((policy_samples, len(batch_envs))) self.replay_queue.put((replay_samples, len(batch_envs))) assert (FLAGS.fixed_replay_weight >= 0.0 and FLAGS.fixed_replay_weight <= 1.0) if FLAGS.use_replay_prob_as_weight: new_samples = [] for sample in replay_samples: name = sample.traj.env_name if name in replay_buffer.prob_sum_dict: replay_prob = max( replay_buffer.prob_sum_dict[name], FLAGS.min_replay_weight) else: replay_prob = 0.0 scale = replay_prob new_samples.append( agent_factory.Sample(traj=sample.traj, prob=sample.prob * scale)) replay_samples = new_samples else: replay_samples = agent_factory.scale_probs( replay_samples, FLAGS.fixed_replay_weight) replay_samples = sorted(replay_samples, key=lambda x: x.traj.env_name) policy_samples = sorted(policy_samples, key=lambda x: x.traj.env_name) if FLAGS.use_nonreplay_samples_in_train: nonreplay_samples = [] for sample in policy_samples: if not replay_buffer.contain(sample.traj): nonreplay_samples.append(sample) self.save_to_buffer(policy_samples, replay_buffer) def weight_samples(samples): if FLAGS.use_replay_prob_as_weight: new_samples = [] for sample in samples: name = sample.traj.env_name if name in replay_buffer.prob_sum_dict: replay_prob = max( replay_buffer.prob_sum_dict[name], FLAGS.min_replay_weight) else: replay_prob = 0.0 scale = 1.0 - replay_prob new_samples.append( agent_factory.Sample(traj=sample.traj, prob=sample.prob * scale)) else: new_samples = agent_factory.scale_probs( samples, 1 - FLAGS.fixed_replay_weight) return new_samples train_samples = [] if FLAGS.use_replay_samples_in_train: if FLAGS.use_trainer_prob: replay_samples = [ sample._replace(prob=None) for sample in replay_samples ] train_samples += replay_samples if FLAGS.use_policy_samples_in_train: train_samples += weight_samples(policy_samples) if FLAGS.use_nonreplay_samples_in_train: train_samples += weight_samples(nonreplay_samples) train_samples = sorted(train_samples, key=lambda x: x.traj.env_name) tf.logging.info('{} train samples'.format(len(train_samples))) if FLAGS.use_importance_sampling: step_logprobs = agent.compute_step_logprobs( [s.traj for s in train_samples]) else: step_logprobs = None if FLAGS.use_replay_prob_as_weight: n_clip = 0 for env in batch_envs: name = env.name if (name in replay_buffer.prob_sum_dict and replay_buffer.prob_sum_dict[name] < FLAGS.min_replay_weight): n_clip += 1 clip_frac = float(n_clip) / len(batch_envs) else: clip_frac = 0.0 # put all weight on the annotated ones at first, then gradually increase explored examples al_scale_factor = min( 1.0, (agent.model.get_global_step() - FLAGS.active_start_step) / float(FLAGS.active_scale_steps)) assert (al_scale_factor >= 0.0 and al_scale_factor <= 1.0) for i, sample in enumerate(train_samples): if sample.traj.env_name not in self.env_annotation_dict: train_samples[i] = agent_factory.Sample( traj=sample.traj, prob=sample.prob * al_scale_factor) else: annotation = self.env_annotation_dict[ sample.traj.env_name] explored_program = agent_factory.traj_to_program( sample.traj, self.decode_vocab) if not annotation.verify_program(explored_program): train_samples[i] = agent_factory.Sample( traj=sample.traj, prob=sample.prob * al_scale_factor) self.train_queue.put((train_samples, step_logprobs, clip_frac)) t2 = time.time() tf.logging.info( ('{} sec used preparing and enqueuing samples, {}' ' iteration {}, batch {}: {} envs').format( t2 - t1, self.name, i, j, len(batch_envs))) t1 = time.time() # Wait for a ckpt that still exist or it is the same # ckpt (no need to load anything). while True: new_ckpt = self.ckpt_queue.get() new_ckpt_file = new_ckpt + '.meta' if new_ckpt == current_ckpt or tf.gfile.Exists( new_ckpt_file): break t2 = time.time() tf.logging.info( '{} sec waiting {} iteration {}, batch {}'.format( t2 - t1, self.name, i, j)) if new_ckpt != current_ckpt: # If the ckpt is not the same, then restore the new # ckpt. tf.logging.info('{} loading ckpt {}'.format( self.name, new_ckpt)) t1 = time.time() graph.restore(new_ckpt) t2 = time.time() tf.logging.info('{} sec used {} restoring ckpt {}'.format( t2 - t1, self.name, new_ckpt)) current_ckpt = new_ckpt if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0: f_replay.write( show_samples(replay_samples, envs[0].de_vocab, env_dict)) f_policy.write( show_samples(policy_samples, envs[0].de_vocab, env_dict)) f_train.write( show_samples(train_samples, envs[0].de_vocab, env_dict)) if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0: f_replay.close() f_policy.close() f_train.close() if agent.model.get_global_step() >= FLAGS.n_steps: if FLAGS.save_replay_buffer_at_end: all_replay = os.path.join( get_experiment_dir(), 'all_replay_samples_{}.txt'.format(self.name)) with codecs.open(all_replay, 'w', encoding='utf-8') as f: samples = replay_buffer.all_samples(envs, agent=None) samples = [ s for s in samples if not replay_buffer_copy.contain(s.traj) ] f.write(show_samples(samples, envs[0].de_vocab, None)) tf.logging.info('{} finished'.format(self.name)) return i += 1