def decode_examples(self, environments: List[QAProgrammingEnv], beam_size, batch_size=32):
        decode_results = []
        use_sketch_constrained_decoding = self.config.get('use_sketch_constrained_decoding', False)

        if use_sketch_constrained_decoding:
            assert self.sketch_manager is not None
            print('[Model] use sketch-constrained decoding...', file=sys.stderr)
            num_sketch = self.config.get('sketch_constrained_decoding_num_sketch', 5)

        with torch.no_grad():
            batch_iter = nn_util.batch_iter(environments, batch_size, shuffle=False)
            for batched_envs in tqdm(batch_iter, total=len(environments) // batch_size, file=sys.stdout):
                if use_sketch_constrained_decoding:
                    batched_hyp_sketches = self.sketch_manager.get_sketches(
                        batched_envs, K=num_sketch
                    )
                    constraint_sketches = {
                        env.name: sketches
                        for env, sketches
                        in zip(batched_envs, batched_hyp_sketches)
                    }
                else:
                    constraint_sketches = None

                batch_decode_result = self.new_beam_search(
                    batched_envs,
                    beam_size=beam_size,
                    constraint_sketches=constraint_sketches,
                    strict_constraint_on_sketches=use_sketch_constrained_decoding
                )

                batch_decode_result = list(batch_decode_result.values())
                decode_results.extend(batch_decode_result)

        return decode_results # [list of x type which should be same as train_examples.]
    def train(self):
        config = self.config
        epoch_id = 0
        env_dict = {env.name: env for env in self.environments}
        sample_method = self.config['sample_method']
        method = self.config['method']
        assert sample_method in ('sample', 'beam_search')
        assert method in ('sample', 'mapo', 'mml')

        work_dir = Path(self.config['work_dir'])
        log_dir = work_dir / 'log'
        log_dir.mkdir(exist_ok=True, parents=True)

        debug_file = None
        if self.config.get('save_actor_log', False):
            debug_file = (log_dir /
                          f'debug.actor{self.actor_id}.log').open('w')
        # self.agent.log = debug_file

        with torch.no_grad():
            while True:
                epoch_id += 1
                epoch_start = time.time()
                batch_iter = nn_util.batch_iter(
                    self.environments,
                    batch_size=self.config['batch_size'],
                    shuffle=True)
                for batch_id, batched_envs in enumerate(batch_iter):
                    print('batched envs from batch_iter: ', batched_envs)
                    try:
                        # print(f'[Actor {self.actor_id}] epoch {epoch_id} batch {batch_id}', file=sys.stderr)
                        # perform sampling

                        strict_constraint_on_sketches = config.get(
                            'sketch_explore_strict_constraint_on_sketch', True)
                        force_sketch_coverage = config.get(
                            'sketch_explore_force_coverage', False)
                        constraint_sketches = None

                        if isinstance(self.agent,
                                      PGAgent) and self.use_sketch_exploration:
                            constraint_sketches = dict()
                            explore_beam_size = config.get(
                                'sketch_explore_beam_size', 5)
                            num_sketches_per_example = config.get(
                                'num_candidate_sketches', 5)
                            remove_explored_sketch = config.get(
                                'remove_explored_sketch', True)
                            use_sketch_exploration_for_nepoch = config.get(
                                'use_sketch_exploration_for_nepoch', 10000)
                            use_trainable_sketch_predictor = self.config.get(
                                'use_trainable_sketch_predictor', False)

                            if epoch_id <= use_sketch_exploration_for_nepoch:
                                t1 = time.time()
                                if use_trainable_sketch_predictor:
                                    candidate_sketches = self.sketch_predictor.get_sketches(
                                        batched_envs,
                                        K=num_sketches_per_example)
                                    for env, sketches in zip(
                                            batched_envs, candidate_sketches):
                                        constraint_sketches[
                                            env.name] = sketches
                                else:
                                    for env in batched_envs:
                                        env_candidate_sketches = self.sketch_predictor.get_sketches_from_similar_questions(
                                            env.name,
                                            remove_explored=
                                            remove_explored_sketch,
                                            log_file=None)

                                        if debug_file:
                                            print(
                                                f"Question {env.name} Candidate sketches in the cache:\n"
                                                f"{json.dumps({str(k): v for k, v in env_candidate_sketches.items()}, indent=2, default=str)}",
                                                file=debug_file)

                                        env_candidate_sketches = sorted(
                                            env_candidate_sketches,
                                            key=lambda s:
                                            env_candidate_sketches[s]['score'],
                                            reverse=True
                                        )[:num_sketches_per_example]

                                    constraint_sketches[
                                        env.name] = env_candidate_sketches

                                # logging
                                # print('[Actor] Sampled sketches', file=sys.stderr)
                                # print(constraint_sketches, file=sys.stderr)
                                if debug_file:
                                    print(
                                        f'Found candidate sketches took {time.time() - t1}s',
                                        file=debug_file)
                                    for env in batched_envs:
                                        print("======", file=debug_file)
                                        print(
                                            f"Question [{env.name}] "
                                            f"{env.question_annotation['question']}",
                                            file=debug_file)

                                        print(
                                            f"Selected sketches for [{env.name}]:\n"
                                            f"{json.dumps(constraint_sketches[env.name], indent=2, default=str)}",
                                            file=debug_file)

                        t1 = time.time()
                        if sample_method == 'sample':
                            explore_samples = self.agent.sample(
                                batched_envs,
                                sample_num=config['n_explore_samples'],
                                use_cache=config['use_cache'],
                                constraint_sketches=constraint_sketches)
                        else:
                            explore_samples = self.agent.new_beam_search(
                                batched_envs,
                                beam_size=config['n_explore_samples'],
                                use_cache=config['use_cache'],
                                return_list=True,
                                constraint_sketches=constraint_sketches,
                                strict_constraint_on_sketches=
                                strict_constraint_on_sketches,
                                force_sketch_coverage=force_sketch_coverage)
                        t2 = time.time()

                        if debug_file:
                            print('Explored programs:', file=debug_file)
                            for sample in explore_samples:
                                print(
                                    f"[{sample.trajectory.environment_name}] "
                                    f"{' '.join(sample.trajectory.program)} "
                                    f"(prob={sample.prob:.4f}, correct={sample.trajectory.reward == 1.})",
                                    file=debug_file)

                        print(
                            f'[Actor {self.actor_id}] '
                            f'epoch {epoch_id} batch {batch_id}, '
                            f'sampled {len(explore_samples)} trajectories (took {t2 - t1}s)',
                            file=sys.stderr)

                        # retain samples with high reward
                        good_explore_samples = [
                            sample for sample in explore_samples
                            if sample.trajectory.reward == 1.
                        ]
                        # for sample in good_explore_samples:
                        #     print(f'[Actor {self.actor_id}] epoch {epoch_id} batch {batch_id}, '
                        #           f'add 1 traj [{sample.trajectory}] for env [{sample.trajectory.environment_name}] to buffer',
                        #           file=sys.stderr)
                        self.replay_buffer.save_samples(good_explore_samples)

                        # sample replay examples from the replay buffer
                        t1 = time.time()
                        replay_constraint_sketches = None
                        if self.use_sketch_guided_replay:
                            replay_constraint_sketches = dict()
                            num_sketches_per_example = config.get(
                                'num_candidate_sketches', 5)

                            env_candidate_sketches = self.sketch_predictor.get_sketches(
                                batched_envs)
                            env_selected_candidate_sketches = sorted(
                                env_candidate_sketches,
                                key=lambda s: env_candidate_sketches[s]['score'
                                                                        ],
                                reverse=True)[:num_sketches_per_example]

                            replay_constraint_sketches[
                                env.name] = env_selected_candidate_sketches

                            if debug_file:
                                for env in batched_envs:
                                    print(
                                        "======begin sketch guided reply======",
                                        file=debug_file)
                                    print(
                                        f"Question [{env.name}] "
                                        f"{env.question_annotation['question']}",
                                        file=debug_file)

                                    print(
                                        f"Candidate sketches in the cache:\n"
                                        f"{json.dumps({str(k): v for k, v in env_candidate_sketches.items()}, indent=2, default=str)}",
                                        file=debug_file)

                                    print(
                                        "======end sketch guided reply======",
                                        file=debug_file)

                        replay_samples = self.replay_buffer.replay(
                            batched_envs,
                            n_samples=config['n_replay_samples'],
                            use_top_k=config['use_top_k_replay_samples'],
                            replace=config['replay_sample_with_replacement'],
                            truncate_at_n=config.get('sample_replay_from_topk',
                                                     0),
                            consistency_model=self.consistency_model,
                            constraint_sketches=replay_constraint_sketches,
                            debug_file=debug_file)
                        t2 = time.time()
                        print(
                            f'[Actor {self.actor_id}] epoch {epoch_id} batch {batch_id}, got {len(replay_samples)} replay samples (took {t2 - t1}s)',
                            file=sys.stderr)

                        samples_info = dict()

                        if method == 'mapo':
                            train_examples = []
                            for sample in replay_samples:
                                sample_weight = self.replay_buffer.env_program_prob_sum_dict.get(
                                    sample.trajectory.environment_name, 0.)
                                sample_weight = max(
                                    sample_weight,
                                    self.config['min_replay_samples_weight'])

                                sample.weight = sample_weight * 1. / config[
                                    'n_replay_samples']
                                train_examples.append(sample)

                            on_policy_samples = self.agent.sample(
                                batched_envs,
                                sample_num=config['n_policy_samples'],
                                use_cache=False)
                            non_replay_samples = [
                                sample for sample in on_policy_samples
                                if sample.trajectory.reward == 1. and not self.
                                replay_buffer.contains(sample.trajectory)
                            ]
                            self.replay_buffer.save_samples(non_replay_samples)

                            for sample in non_replay_samples:
                                if self.use_consistency_model and self.consistency_model.debug:
                                    print(
                                        f'>>>>>>>>>> non replay samples for {sample.trajectory.environment_name}',
                                        file=self.consistency_model.log_file)
                                    self.consistency_model.compute_consistency_score(
                                        sample.trajectory.environment_name,
                                        [sample])
                                    print(
                                        f'<<<<<<<<<<< non replay samples for {sample.trajectory.environment_name}',
                                        file=self.consistency_model.log_file)

                                replay_samples_prob = self.replay_buffer.env_program_prob_sum_dict.get(
                                    sample.trajectory.environment_name, 0.)
                                if replay_samples_prob > 0.:
                                    # clip the sum of probabilities for replay samples if the replay buffer is not empty
                                    replay_samples_prob = max(
                                        replay_samples_prob, self.
                                        config['min_replay_samples_weight'])

                                sample_weight = 1. - replay_samples_prob

                                sample.weight = sample_weight * 1. / config[
                                    'n_policy_samples']
                                train_examples.append(sample)

                            n_clip = 0
                            for env in batched_envs:
                                name = env.name
                                if (name in self.replay_buffer.
                                        env_program_prob_dict
                                        and self.replay_buffer.
                                        env_program_prob_sum_dict.get(
                                            name, 0.) < self.
                                        config['min_replay_samples_weight']):
                                    n_clip += 1
                            clip_frac = n_clip / len(batched_envs)

                            train_examples = train_examples
                            samples_info['clip_frac'] = clip_frac
                        elif method == 'mml':
                            for sample in replay_samples:
                                sample.weight = sample.prob / self.replay_buffer.env_program_prob_sum_dict[
                                    sample.trajectory.environment_name]
                            train_examples = replay_samples
                        elif method == 'sample':
                            train_examples = replay_samples
                            for sample in train_examples:
                                sample.weight = max(
                                    sample.prob,
                                    config['min_replay_samples_weight'])
                    except RuntimeError as e:
                        if 'out of memory' in str(e):
                            msg = (
                                f'[Actor {self.actor_id}] WARNING: ran out of memory with exception: '
                                + '{};'.format(e) + '\n Skipping batch')
                            print(msg, file=sys.stderr)
                            sys.stderr.flush()

                            continue
                        else:
                            raise e

                    print("len of train examples put in queue: ",
                          len(train_examples))
                    sys.stdout.flush()
                    if train_examples:
                        self.train_queue.put((train_examples, samples_info))
                    else:
                        continue

                    self.check_and_load_new_model()
                    if debug_file:
                        debug_file.flush()

                    if self.device.type == 'cuda':
                        mem_cached_mb = torch.cuda.memory_cached() / 1000000
                        if mem_cached_mb > 8000:
                            print(
                                f'Actor {self.actor_id} empty cached memory [{mem_cached_mb} MB]',
                                file=sys.stderr)
                            torch.cuda.empty_cache()

                epoch_end = time.time()
                print(
                    f"[Actor {self.actor_id}] epoch {epoch_id} finished, took {epoch_end - epoch_start}s",
                    file=sys.stderr)

                # buffer_content = dict()
                # for env_name, samples in self.replay_buffer.all_samples().items():
                #     buffer_content[env_name] = [dict(program=' '.join(sample.trajectory.program), prob=sample.prob) for sample in samples]
                # buffer_save_path = os.path.join(config['work_dir'], f'replay_buffer_actor{self.actor_id}_epoch{epoch_id}.json')
                # with open(buffer_save_path, 'w') as f:
                #     json.dump(buffer_content, f, indent=2)

                # dump program cache for the current actor
                # cur_program_cache = self.replay_buffer.all_samples()
                # with multiprocessing.Lock():
                #     program_cache_save_file = log_dir / f'program_cache.epoch{epoch_id}.jsonl'
                #
                #     with program_cache_save_file.open('a') as f:
                #         for env_name, samples in cur_program_cache.items():
                #             entry = {
                #                 'question_id': env_name,
                #                 'hypotheses': [
                #                     {
                #                         'program': ' '.join(sample.trajectory.human_readable_program),
                #                         'prob': sample.prob
                #                     }
                #                     for sample in samples
                #                 ]
                #             }
                #             line = json.dumps(entry)
                #             f.write(line + os.linesep)

                if self.consistency_model:
                    self.consistency_model.log_file.flush()
                    sys.stderr.flush()
Beispiel #3
0
    def fine_tune(self):
        beam_size = self.config['beam_size']
        decoding_results = self.agent.decode_examples(self.train_set,
                                                      beam_size=beam_size)
        decoding_results_dict = to_decode_results_dict(decoding_results,
                                                       self.train_set)

        train_examples = []
        for env, hyp_list in zip(self.train_set, decoding_results):
            # hyp_list = [hyp for hyp in hyp_list if hyp.trajectory.reward == 1.]
            if not hyp_list:
                continue

            is_best_hyp_correct = hyp_list[0].trajectory.reward == 1.
            if not is_best_hyp_correct:
                # if True:
                correct_hyps = [
                    hyp for hyp in hyp_list if hyp.trajectory.reward == 1.
                ]
                if not correct_hyps: continue

                hyp_supports = [
                    _compute_consistency_score(
                        env_name=env.name,
                        hyp_program=hyp.trajectory.program,
                        nearest_neighbors=self.nearest_neighbors,
                        decode_results_dict=decoding_results,
                        K=3) for hyp in correct_hyps
                ]
                best_hyp_idx = np.argmax(hyp_supports)
                best_hyp = hyp_list[best_hyp_idx]
                train_examples.append(best_hyp.trajectory)
            else:
                train_examples.append(hyp_list[0].trajectory)

        print(f'Num. fine tune examples: {len(train_examples)}',
              file=sys.stderr)
        max_epoch = 1

        model = self.agent.train()
        params = [p for p in model.parameters() if p.requires_grad]
        optimizer = torch.optim.Adam(params, lr=0.001)

        for epoch in range(max_epoch):
            batch_iter = nn_util.batch_iter(train_examples,
                                            batch_size=32,
                                            shuffle=True)

            for batch_id, train_trajectories in enumerate(batch_iter):
                optimizer.zero_grad()

                # (batch_size)
                batch_log_prob = self.agent(train_trajectories)
                loss = -batch_log_prob.mean()

                loss.backward()
                loss_val = loss.item()

                # clip gradient
                grad_norm = torch.nn.utils.clip_grad_norm_(params, 5.)

                optimizer.step()