def __init__(self,
                 config,
                 debug_information=False,
                 is_train=True,
                 global_step=None):
        self.debug = debug_information
        self.global_step = global_step

        self.config = config
        self.pixel_input = self.config.pixel_input or self.config.dataset_type == 'vizdoom'
        self.attn_type = self.config.attn_type
        self.scheduled_sampling = \
            getattr(self.config, 'scheduled_sampling', False) or False
        self.scheduled_sampling_decay_steps = \
            getattr(self.config, 'scheduled_sampling_decay_steps', 5000) or 5000
        self.batch_size = self.config.batch_size
        self.state_encoder_fc = self.config.state_encoder_fc
        self.concat_state_feature_direct_prediction = \
            self.config.concat_state_feature_direct_prediction
        self.encoder_rnn_type = self.config.encoder_rnn_type
        self.dataset_type = self.config.dataset_type
        self.dsl_type = self.config.dsl_type
        self.env_type = self.config.env_type
        self.vizdoom_pos_keys = self.config.vizdoom_pos_keys
        self.vizdoom_max_init_pos_len = self.config.vizdoom_max_init_pos_len
        self.perception_type = self.config.perception_type
        self.level = self.config.level
        self.stack_subsequent_state = self.config.stack_subsequent_state
        self.num_lstm_cell_units = self.config.num_lstm_cell_units
        self.demo_aggregation = self.config.demo_aggregation
        self.dim_program_token = config.dim_program_token
        self.max_program_len = config.max_program_len
        self.max_demo_len = config.max_demo_len
        self.max_action_len = self.max_demo_len
        self.k = config.k
        self.test_k = config.test_k
        self.h = config.h
        self.w = config.w
        self.depth = config.depth
        self.action_space = config.action_space
        self.per_dim = config.per_dim

        if self.scheduled_sampling:
            if global_step is None:
                raise ValueError('scheduled sampling requires global_step')
            # linearly decaying sampling probability
            final_teacher_forcing_prob = 0.1
            self.sample_prob = tf.train.polynomial_decay(
                1.0,
                global_step,
                self.scheduled_sampling_decay_steps,
                end_learning_rate=final_teacher_forcing_prob,
                power=1.0,
                name='scheduled_sampling')
        # Text
        if self.dataset_type == 'karel':
            from karel_env.dsl import get_KarelDSL
            self.vocab = get_KarelDSL(dsl_type=self.dsl_type, seed=123)
        else:
            from vizdoom_env.dsl.vocab import VizDoomDSLVocab
            self.vocab = VizDoomDSLVocab(perception_type=self.perception_type,
                                         level=self.level)

        # create placeholders for the input
        self.program_id = tf.placeholder(
            name='program_id',
            dtype=tf.string,
            shape=[self.batch_size],
        )

        self.program = tf.placeholder(
            name='program',
            dtype=tf.float32,
            shape=[
                self.batch_size, self.dim_program_token, self.max_program_len
            ],
        )

        self.program_tokens = tf.placeholder(
            name='program_tokens',
            dtype=tf.int32,
            shape=[self.batch_size, self.max_program_len])

        self.s_h = tf.placeholder(
            name='s_h',
            dtype=tf.float32,
            shape=[
                self.batch_size, self.k, self.max_demo_len, self.h, self.w,
                self.depth
            ],
        )

        self.test_s_h = tf.placeholder(
            name='test_s_h',
            dtype=tf.float32,
            shape=[
                self.batch_size, self.test_k, self.max_demo_len, self.h,
                self.w, self.depth
            ],
        )

        self.a_h = tf.placeholder(
            name='a_h',
            dtype=tf.float32,
            shape=[
                self.batch_size, self.k, self.max_action_len, self.action_space
            ],
        )

        self.a_h_tokens = tf.placeholder(
            name='a_h_tokens',
            dtype=tf.int32,
            shape=[self.batch_size, self.k, self.max_action_len],
        )

        self.test_a_h = tf.placeholder(
            name='test_a_h',
            dtype=tf.float32,
            shape=[
                self.batch_size, self.test_k, self.max_demo_len,
                self.action_space
            ],
        )

        self.test_a_h_tokens = tf.placeholder(
            name='test_a_h_tokens',
            dtype=tf.int32,
            shape=[self.batch_size, self.test_k, self.max_action_len],
        )

        self.per = tf.placeholder(
            name='per',
            dtype=tf.float32,
            shape=[self.batch_size, self.k, self.max_demo_len, self.per_dim],
        )

        self.test_per = tf.placeholder(
            name='test_per',
            dtype=tf.float32,
            shape=[
                self.batch_size, self.test_k, self.max_demo_len, self.per_dim
            ],
        )

        if self.config.dataset_type == 'vizdoom':
            self.init_pos = tf.placeholder(
                name='init_pos',
                dtype=tf.int32,
                shape=[
                    self.batch_size, self.k,
                    len(self.vizdoom_pos_keys), self.vizdoom_max_init_pos_len,
                    2
                ],
            )

            self.init_pos_len = tf.placeholder(
                name='init_pos_len',
                dtype=tf.int32,
                shape=[self.batch_size, self.k,
                       len(self.vizdoom_pos_keys)],
            )

            self.test_init_pos = tf.placeholder(
                name='test_init_pos',
                dtype=tf.int32,
                shape=[
                    self.batch_size, self.test_k,
                    len(self.vizdoom_pos_keys), self.vizdoom_max_init_pos_len,
                    2
                ],
            )

            self.test_init_pos_len = tf.placeholder(
                name='test_init_pos_len',
                dtype=tf.int32,
                shape=[
                    self.batch_size, self.test_k,
                    len(self.vizdoom_pos_keys)
                ],
            )

        self.program_len = tf.placeholder(
            name='program_len',
            dtype=tf.float32,
            shape=[self.batch_size, 1],
        )
        self.program_len = tf.cast(self.program_len, dtype=tf.int32)

        self.demo_len = tf.placeholder(
            name='demo_len',
            dtype=tf.float32,
            shape=[self.batch_size, self.k],
        )
        self.demo_len = tf.cast(self.demo_len, dtype=tf.int32)
        self.action_len = self.demo_len

        self.test_demo_len = tf.placeholder(
            name='test_demo_len',
            dtype=tf.float32,
            shape=[self.batch_size, self.test_k],
        )
        self.test_demo_len = tf.cast(self.test_demo_len, dtype=tf.int32)
        self.test_action_len = self.test_demo_len

        self.is_train = tf.placeholder(
            name='is_train',
            dtype=tf.bool,
            shape=[],
        )

        self.is_training = tf.placeholder_with_default(bool(is_train), [],
                                                       name='is_training')

        self.build(is_train=is_train)
Beispiel #2
0
    def eval_run(self):
        # load checkpoint
        if self.checkpoint:
            self.saver.restore(self.session, self.checkpoint)
            log.info("Loaded from checkpoint!")

        log.infov("Start Inference and Evaluation")

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(self.session,
                                               coord=coord,
                                               start=True)
        try:
            if self.config.pred_program:
                if not os.path.exists(self.output_dir):
                    os.makedirs(self.output_dir)
                log.infov("Output Dir: %s", self.output_dir)
                base_name = os.path.join(
                    self.output_dir,
                    'out_{}_{}'.format(self.checkpoint_name,
                                       self.dataset_split))
                text_file = open('{}.txt'.format(base_name), 'w')
                from karel_env.dsl import get_KarelDSL
                dsl = get_KarelDSL(dsl_type=self.dataset.dsl_type, seed=123)

                hdf5_file = h5py.File('{}.hdf5'.format(base_name), 'w')
                log_file = open('{}.log'.format(base_name), 'w')
            else:
                log_file = None

            if self.config.result_data:
                result_file = h5py.File(self.config.result_data_path, 'w')
                data_file = h5py.File(
                    os.path.join(self.config.dataset_path, 'data.hdf5'), 'r')

            if not self.config.no_loss:
                loss_all = []
                acc_all = []
                hist_all = {}
                time_all = []
                for s in xrange(self.config.max_steps):
                    step, loss, acc, hist, \
                        pred_program, pred_program_len, pred_is_correct_syntax, \
                        greedy_pred_program, greedy_program_len, greedy_is_correct_syntax, \
                        gt_program, gt_program_len, output, program_id, \
                        program_num_execution_correct, program_is_correct_execution, \
                        greedy_num_execution_correct, greedy_is_correct_execution, \
                        step_time = self.run_single_step(self.batch)
                    if not self.config.quiet:
                        step_msg = self.log_step_message(
                            s, loss, acc, hist, step_time)
                    if self.config.result_data:
                        for i in range(len(program_id)):
                            try:
                                grp = result_file.create_group(program_id[i])
                                grp['program'] = gt_program[i]
                                grp['pred_program'] = greedy_pred_program[i]
                                grp['pred_program_len'] = greedy_program_len[
                                    i][0]
                                grp['s_h'] = data_file[
                                    program_id[i]]['s_h'].value
                                grp['test_s_h'] = data_file[
                                    program_id[i]]['test_s_h'].value
                            except:
                                print('Duplicates: {}'.format(program_id[i]))
                                pass

                    # write pred/gt program
                    if self.config.pred_program:
                        log_file.write('{}\n'.format(step_msg))
                        for i in range(self.batch_size):
                            pred_program_token = np.argmax(
                                pred_program[i, :, :pred_program_len[i, 0]],
                                axis=0)
                            pred_program_str = dsl.intseq2str(
                                pred_program_token)
                            greedy_program_token = np.argmax(
                                greedy_pred_program[i, :, :greedy_program_len[
                                    i, 0]],
                                axis=0)
                            greedy_program_str = dsl.intseq2str(
                                greedy_program_token)
                            try:
                                grp = hdf5_file.create_group(program_id[i])
                            except:
                                pass
                            else:
                                correctness = ['wrong', 'correct']
                                grp['program_prediction'] = pred_program_str
                                grp['program_syntax'] = \
                                    correctness[int(pred_is_correct_syntax[i])]
                                grp['program_num_execution_correct'] = \
                                    int(program_num_execution_correct[i])
                                grp['program_is_correct_execution'] = \
                                    program_is_correct_execution[i]
                                grp['greedy_prediction'] = \
                                    greedy_program_str
                                grp['greedy_syntax'] = \
                                    correctness[int(greedy_is_correct_syntax[i])]
                                grp['greedy_num_execution_correct'] = \
                                    int(greedy_num_execution_correct[i])
                                grp['greedy_is_correct_execution'] = \
                                    greedy_is_correct_execution[i]

                            text_file.write(
                                '[id: {}]\ngt: {}\npred{}: {}\ngreedy{}: {}\n'.
                                format(
                                    program_id[i],
                                    dsl.intseq2str(
                                        np.argmax(gt_program[
                                            i, :, :gt_program_len[i, 0]],
                                                  axis=0)),
                                    '(error)'
                                    if pred_is_correct_syntax[i] == 0 else '',
                                    pred_program_str,
                                    '(error)' if greedy_is_correct_syntax[i]
                                    == 0 else '',
                                    greedy_program_str,
                                ))
                    loss_all.append(np.array(loss.values()))
                    acc_all.append(np.array(acc.values()))
                    time_all.append(step_time)
                    for hist_key, hist_value in hist.items():
                        if hist_key not in hist_all:
                            hist_all[hist_key] = []
                        hist_all[hist_key].append(hist_value)

                loss_avg = np.average(np.stack(loss_all), axis=0)
                acc_avg = np.average(np.stack(acc_all), axis=0)
                hist_avg = {}
                for hist_key, hist_values in hist_all.items():
                    hist_avg[hist_key] = np.average(np.stack(hist_values),
                                                    axis=0)
                final_msg = self.log_final_message(
                    loss_avg,
                    loss.keys(),
                    acc_avg,
                    acc.keys(),
                    hist_avg,
                    hist_avg.keys(),
                    np.sum(time_all),
                    write_summary=self.config.write_summary,
                    summary_file=self.config.summary_file)

            if self.config.result_data:
                result_file.close()
                data_file.close()

            if self.config.pred_program:
                log_file.write('{}\n'.format(final_msg))
                log_file.write("Model class: {}\n".format(self.config.model))
                log_file.write("Checkpoint: {}\n".format(self.checkpoint))
                log_file.write("Dataset: {}\n".format(
                    self.config.dataset_path))
                log_file.close()
                text_file.close()
                hdf5_file.close()

        except Exception as e:
            coord.request_stop(e)

        log.warning('Completed Evaluation.')

        coord.request_stop()
        try:
            coord.join(threads, stop_grace_period_secs=3)
        except RuntimeError as e:
            log.warn(str(e))