def decode_examples(self, environments: List[QAProgrammingEnv], beam_size, batch_size=32): decode_results = [] use_sketch_constrained_decoding = self.config.get('use_sketch_constrained_decoding', False) if use_sketch_constrained_decoding: assert self.sketch_manager is not None print('[Model] use sketch-constrained decoding...', file=sys.stderr) num_sketch = self.config.get('sketch_constrained_decoding_num_sketch', 5) with torch.no_grad(): batch_iter = nn_util.batch_iter(environments, batch_size, shuffle=False) for batched_envs in tqdm(batch_iter, total=len(environments) // batch_size, file=sys.stdout): if use_sketch_constrained_decoding: batched_hyp_sketches = self.sketch_manager.get_sketches( batched_envs, K=num_sketch ) constraint_sketches = { env.name: sketches for env, sketches in zip(batched_envs, batched_hyp_sketches) } else: constraint_sketches = None batch_decode_result = self.new_beam_search( batched_envs, beam_size=beam_size, constraint_sketches=constraint_sketches, strict_constraint_on_sketches=use_sketch_constrained_decoding ) batch_decode_result = list(batch_decode_result.values()) decode_results.extend(batch_decode_result) return decode_results # [list of x type which should be same as train_examples.]
def train(self): config = self.config epoch_id = 0 env_dict = {env.name: env for env in self.environments} sample_method = self.config['sample_method'] method = self.config['method'] assert sample_method in ('sample', 'beam_search') assert method in ('sample', 'mapo', 'mml') work_dir = Path(self.config['work_dir']) log_dir = work_dir / 'log' log_dir.mkdir(exist_ok=True, parents=True) debug_file = None if self.config.get('save_actor_log', False): debug_file = (log_dir / f'debug.actor{self.actor_id}.log').open('w') # self.agent.log = debug_file with torch.no_grad(): while True: epoch_id += 1 epoch_start = time.time() batch_iter = nn_util.batch_iter( self.environments, batch_size=self.config['batch_size'], shuffle=True) for batch_id, batched_envs in enumerate(batch_iter): print('batched envs from batch_iter: ', batched_envs) try: # print(f'[Actor {self.actor_id}] epoch {epoch_id} batch {batch_id}', file=sys.stderr) # perform sampling strict_constraint_on_sketches = config.get( 'sketch_explore_strict_constraint_on_sketch', True) force_sketch_coverage = config.get( 'sketch_explore_force_coverage', False) constraint_sketches = None if isinstance(self.agent, PGAgent) and self.use_sketch_exploration: constraint_sketches = dict() explore_beam_size = config.get( 'sketch_explore_beam_size', 5) num_sketches_per_example = config.get( 'num_candidate_sketches', 5) remove_explored_sketch = config.get( 'remove_explored_sketch', True) use_sketch_exploration_for_nepoch = config.get( 'use_sketch_exploration_for_nepoch', 10000) use_trainable_sketch_predictor = self.config.get( 'use_trainable_sketch_predictor', False) if epoch_id <= use_sketch_exploration_for_nepoch: t1 = time.time() if use_trainable_sketch_predictor: candidate_sketches = self.sketch_predictor.get_sketches( batched_envs, K=num_sketches_per_example) for env, sketches in zip( batched_envs, candidate_sketches): constraint_sketches[ env.name] = sketches else: for env in batched_envs: env_candidate_sketches = self.sketch_predictor.get_sketches_from_similar_questions( env.name, remove_explored= remove_explored_sketch, log_file=None) if debug_file: print( f"Question {env.name} Candidate sketches in the cache:\n" f"{json.dumps({str(k): v for k, v in env_candidate_sketches.items()}, indent=2, default=str)}", file=debug_file) env_candidate_sketches = sorted( env_candidate_sketches, key=lambda s: env_candidate_sketches[s]['score'], reverse=True )[:num_sketches_per_example] constraint_sketches[ env.name] = env_candidate_sketches # logging # print('[Actor] Sampled sketches', file=sys.stderr) # print(constraint_sketches, file=sys.stderr) if debug_file: print( f'Found candidate sketches took {time.time() - t1}s', file=debug_file) for env in batched_envs: print("======", file=debug_file) print( f"Question [{env.name}] " f"{env.question_annotation['question']}", file=debug_file) print( f"Selected sketches for [{env.name}]:\n" f"{json.dumps(constraint_sketches[env.name], indent=2, default=str)}", file=debug_file) t1 = time.time() if sample_method == 'sample': explore_samples = self.agent.sample( batched_envs, sample_num=config['n_explore_samples'], use_cache=config['use_cache'], constraint_sketches=constraint_sketches) else: explore_samples = self.agent.new_beam_search( batched_envs, beam_size=config['n_explore_samples'], use_cache=config['use_cache'], return_list=True, constraint_sketches=constraint_sketches, strict_constraint_on_sketches= strict_constraint_on_sketches, force_sketch_coverage=force_sketch_coverage) t2 = time.time() if debug_file: print('Explored programs:', file=debug_file) for sample in explore_samples: print( f"[{sample.trajectory.environment_name}] " f"{' '.join(sample.trajectory.program)} " f"(prob={sample.prob:.4f}, correct={sample.trajectory.reward == 1.})", file=debug_file) print( f'[Actor {self.actor_id}] ' f'epoch {epoch_id} batch {batch_id}, ' f'sampled {len(explore_samples)} trajectories (took {t2 - t1}s)', file=sys.stderr) # retain samples with high reward good_explore_samples = [ sample for sample in explore_samples if sample.trajectory.reward == 1. ] # for sample in good_explore_samples: # print(f'[Actor {self.actor_id}] epoch {epoch_id} batch {batch_id}, ' # f'add 1 traj [{sample.trajectory}] for env [{sample.trajectory.environment_name}] to buffer', # file=sys.stderr) self.replay_buffer.save_samples(good_explore_samples) # sample replay examples from the replay buffer t1 = time.time() replay_constraint_sketches = None if self.use_sketch_guided_replay: replay_constraint_sketches = dict() num_sketches_per_example = config.get( 'num_candidate_sketches', 5) env_candidate_sketches = self.sketch_predictor.get_sketches( batched_envs) env_selected_candidate_sketches = sorted( env_candidate_sketches, key=lambda s: env_candidate_sketches[s]['score' ], reverse=True)[:num_sketches_per_example] replay_constraint_sketches[ env.name] = env_selected_candidate_sketches if debug_file: for env in batched_envs: print( "======begin sketch guided reply======", file=debug_file) print( f"Question [{env.name}] " f"{env.question_annotation['question']}", file=debug_file) print( f"Candidate sketches in the cache:\n" f"{json.dumps({str(k): v for k, v in env_candidate_sketches.items()}, indent=2, default=str)}", file=debug_file) print( "======end sketch guided reply======", file=debug_file) replay_samples = self.replay_buffer.replay( batched_envs, n_samples=config['n_replay_samples'], use_top_k=config['use_top_k_replay_samples'], replace=config['replay_sample_with_replacement'], truncate_at_n=config.get('sample_replay_from_topk', 0), consistency_model=self.consistency_model, constraint_sketches=replay_constraint_sketches, debug_file=debug_file) t2 = time.time() print( f'[Actor {self.actor_id}] epoch {epoch_id} batch {batch_id}, got {len(replay_samples)} replay samples (took {t2 - t1}s)', file=sys.stderr) samples_info = dict() if method == 'mapo': train_examples = [] for sample in replay_samples: sample_weight = self.replay_buffer.env_program_prob_sum_dict.get( sample.trajectory.environment_name, 0.) sample_weight = max( sample_weight, self.config['min_replay_samples_weight']) sample.weight = sample_weight * 1. / config[ 'n_replay_samples'] train_examples.append(sample) on_policy_samples = self.agent.sample( batched_envs, sample_num=config['n_policy_samples'], use_cache=False) non_replay_samples = [ sample for sample in on_policy_samples if sample.trajectory.reward == 1. and not self. replay_buffer.contains(sample.trajectory) ] self.replay_buffer.save_samples(non_replay_samples) for sample in non_replay_samples: if self.use_consistency_model and self.consistency_model.debug: print( f'>>>>>>>>>> non replay samples for {sample.trajectory.environment_name}', file=self.consistency_model.log_file) self.consistency_model.compute_consistency_score( sample.trajectory.environment_name, [sample]) print( f'<<<<<<<<<<< non replay samples for {sample.trajectory.environment_name}', file=self.consistency_model.log_file) replay_samples_prob = self.replay_buffer.env_program_prob_sum_dict.get( sample.trajectory.environment_name, 0.) if replay_samples_prob > 0.: # clip the sum of probabilities for replay samples if the replay buffer is not empty replay_samples_prob = max( replay_samples_prob, self. config['min_replay_samples_weight']) sample_weight = 1. - replay_samples_prob sample.weight = sample_weight * 1. / config[ 'n_policy_samples'] train_examples.append(sample) n_clip = 0 for env in batched_envs: name = env.name if (name in self.replay_buffer. env_program_prob_dict and self.replay_buffer. env_program_prob_sum_dict.get( name, 0.) < self. config['min_replay_samples_weight']): n_clip += 1 clip_frac = n_clip / len(batched_envs) train_examples = train_examples samples_info['clip_frac'] = clip_frac elif method == 'mml': for sample in replay_samples: sample.weight = sample.prob / self.replay_buffer.env_program_prob_sum_dict[ sample.trajectory.environment_name] train_examples = replay_samples elif method == 'sample': train_examples = replay_samples for sample in train_examples: sample.weight = max( sample.prob, config['min_replay_samples_weight']) except RuntimeError as e: if 'out of memory' in str(e): msg = ( f'[Actor {self.actor_id}] WARNING: ran out of memory with exception: ' + '{};'.format(e) + '\n Skipping batch') print(msg, file=sys.stderr) sys.stderr.flush() continue else: raise e print("len of train examples put in queue: ", len(train_examples)) sys.stdout.flush() if train_examples: self.train_queue.put((train_examples, samples_info)) else: continue self.check_and_load_new_model() if debug_file: debug_file.flush() if self.device.type == 'cuda': mem_cached_mb = torch.cuda.memory_cached() / 1000000 if mem_cached_mb > 8000: print( f'Actor {self.actor_id} empty cached memory [{mem_cached_mb} MB]', file=sys.stderr) torch.cuda.empty_cache() epoch_end = time.time() print( f"[Actor {self.actor_id}] epoch {epoch_id} finished, took {epoch_end - epoch_start}s", file=sys.stderr) # buffer_content = dict() # for env_name, samples in self.replay_buffer.all_samples().items(): # buffer_content[env_name] = [dict(program=' '.join(sample.trajectory.program), prob=sample.prob) for sample in samples] # buffer_save_path = os.path.join(config['work_dir'], f'replay_buffer_actor{self.actor_id}_epoch{epoch_id}.json') # with open(buffer_save_path, 'w') as f: # json.dump(buffer_content, f, indent=2) # dump program cache for the current actor # cur_program_cache = self.replay_buffer.all_samples() # with multiprocessing.Lock(): # program_cache_save_file = log_dir / f'program_cache.epoch{epoch_id}.jsonl' # # with program_cache_save_file.open('a') as f: # for env_name, samples in cur_program_cache.items(): # entry = { # 'question_id': env_name, # 'hypotheses': [ # { # 'program': ' '.join(sample.trajectory.human_readable_program), # 'prob': sample.prob # } # for sample in samples # ] # } # line = json.dumps(entry) # f.write(line + os.linesep) if self.consistency_model: self.consistency_model.log_file.flush() sys.stderr.flush()
def fine_tune(self): beam_size = self.config['beam_size'] decoding_results = self.agent.decode_examples(self.train_set, beam_size=beam_size) decoding_results_dict = to_decode_results_dict(decoding_results, self.train_set) train_examples = [] for env, hyp_list in zip(self.train_set, decoding_results): # hyp_list = [hyp for hyp in hyp_list if hyp.trajectory.reward == 1.] if not hyp_list: continue is_best_hyp_correct = hyp_list[0].trajectory.reward == 1. if not is_best_hyp_correct: # if True: correct_hyps = [ hyp for hyp in hyp_list if hyp.trajectory.reward == 1. ] if not correct_hyps: continue hyp_supports = [ _compute_consistency_score( env_name=env.name, hyp_program=hyp.trajectory.program, nearest_neighbors=self.nearest_neighbors, decode_results_dict=decoding_results, K=3) for hyp in correct_hyps ] best_hyp_idx = np.argmax(hyp_supports) best_hyp = hyp_list[best_hyp_idx] train_examples.append(best_hyp.trajectory) else: train_examples.append(hyp_list[0].trajectory) print(f'Num. fine tune examples: {len(train_examples)}', file=sys.stderr) max_epoch = 1 model = self.agent.train() params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.Adam(params, lr=0.001) for epoch in range(max_epoch): batch_iter = nn_util.batch_iter(train_examples, batch_size=32, shuffle=True) for batch_id, train_trajectories in enumerate(batch_iter): optimizer.zero_grad() # (batch_size) batch_log_prob = self.agent(train_trajectories) loss = -batch_log_prob.mean() loss.backward() loss_val = loss.item() # clip gradient grad_norm = torch.nn.utils.clip_grad_norm_(params, 5.) optimizer.step()