示例#1
0
 def State_Encoder(s, per, batch_size, scope='State_Encoder', reuse=False):
     with tf.variable_scope(scope, reuse=reuse) as scope:
         if not reuse: log.warning(scope.name)
         _ = conv2d(s, 16, is_train, k_h=3, k_w=3,
                    info=not reuse, batch_norm=True, name='conv1')
         _ = conv2d(_, 32, is_train, k_h=3, k_w=3,
                    info=not reuse, batch_norm=True, name='conv2')
         _ = conv2d(_, 48, is_train, k_h=3, k_w=3,
                    info=not reuse, batch_norm=True, name='conv3')
         if self.pixel_input:
             _ = conv2d(_, 48, is_train, k_h=3, k_w=3,
                        info=not reuse, batch_norm=True, name='conv4')
             _ = conv2d(_, 48, is_train, k_h=3, k_w=3,
                        info=not reuse, batch_norm=True, name='conv5')
         state_feature = tf.reshape(_, [batch_size, -1])
         if self.state_encoder_fc:
             state_feature = fc(state_feature, 512, is_train,
                                info=not reuse, name='fc1')
             state_feature = fc(state_feature, 512, is_train,
                                info=not reuse, name='fc2')
         state_feature = tf.concat([state_feature, per], axis=-1)
         if not reuse: log.info(
             'concat feature {}'.format(state_feature))
         return state_feature
示例#2
0
    def __init__(self, config, dataset, dataset_test):
        self.config = config
        hyper_parameter_str = 'bs_{}_lr_{}_{}_cell_{}'.format(
            config.batch_size, config.learning_rate, config.encoder_rnn_type,
            config.num_lstm_cell_units)
        if config.scheduled_sampling:
            hyper_parameter_str += '_sd_{}'.format(
                config.scheduled_sampling_decay_steps)
        hyper_parameter_str += '_k_{}'.format(self.config.num_k)

        self.train_dir = './train_dir/%s-%s-%s-%s-%s-%s' % (
            config.dataset_type, '_'.join(
                config.dataset_path.split('/')), config.model, config.prefix,
            hyper_parameter_str, time.strftime("%Y%m%d-%H%M%S"))

        if not os.path.exists(self.train_dir): os.makedirs(self.train_dir)
        log.infov("Train Dir: %s", self.train_dir)

        # --- input ops ---
        self.batch_size = config.batch_size

        if config.dataset_type == 'karel':
            from karel_env.input_ops_karel import create_input_ops
        elif config.dataset_type == 'vizdoom':
            from vizdoom_env.input_ops_vizdoom import create_input_ops
        else:
            raise ValueError(config.dataset)

        _, self.batch_train = create_input_ops(dataset,
                                               self.batch_size,
                                               is_training=True)
        _, self.batch_test = create_input_ops(dataset_test,
                                              self.batch_size,
                                              is_training=False)
        # --- optimizer ---
        self.global_step = tf.contrib.framework.get_or_create_global_step(
            graph=None)

        # --- create model ---
        Model = self.get_model_class(config.model)
        log.infov("Using Model class: %s", Model)
        self.model = Model(config,
                           debug_information=config.debug,
                           global_step=self.global_step)

        if config.lr_weight_decay:
            self.init_learning_rate = config.learning_rate
            self.learning_rate = tf.train.exponential_decay(
                self.init_learning_rate,
                global_step=self.global_step,
                decay_steps=10000,
                decay_rate=0.5,
                staircase=True,
                name='decaying_learning_rate')
        else:
            self.learning_rate = config.learning_rate

        self.check_op = tf.no_op()

        # --- checkpoint and monitoring ---
        all_vars = tf.trainable_variables()
        log.warn("********* var ********** ")
        slim.model_analyzer.analyze_vars(all_vars, print_info=True)

        self.optimizer = tf.contrib.layers.optimize_loss(
            loss=self.model.loss,
            global_step=self.global_step,
            learning_rate=self.learning_rate,
            optimizer=tf.train.AdamOptimizer,
            clip_gradients=20.0,
            name='optimizer_pixel_loss')

        self.train_summary_op = tf.summary.merge_all(key='train')
        self.test_summary_op = tf.summary.merge_all(key='test')

        self.saver = tf.train.Saver(max_to_keep=100)
        self.pretrain_saver = tf.train.Saver(var_list=all_vars, max_to_keep=1)
        self.summary_writer = tf.summary.FileWriter(self.train_dir)
        self.log_step = self.config.log_step
        self.test_sample_step = self.config.test_sample_step
        self.write_summary_step = self.config.write_summary_step

        self.checkpoint_secs = 600  # 10 min

        self.supervisor = tf.train.Supervisor(
            logdir=self.train_dir,
            is_chief=True,
            saver=None,
            summary_op=None,
            summary_writer=self.summary_writer,
            save_summaries_secs=300,
            save_model_secs=self.checkpoint_secs,
            global_step=self.global_step,
        )

        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(allow_growth=True),
            device_count={'GPU': 1},
        )
        self.session = self.supervisor.prepare_or_wait_for_session(
            config=session_config)

        self.ckpt_path = config.checkpoint
        if self.ckpt_path is not None:
            log.info("Checkpoint path: %s", self.ckpt_path)
            self.pretrain_saver.restore(self.session, self.ckpt_path)
            log.info("Loaded the pretrain parameters from the provided" +
                     "checkpoint path")
示例#3
0
    def eval_run(self):
        # load checkpoint
        if self.checkpoint:
            self.saver.restore(self.session, self.checkpoint)
            log.info("Loaded from checkpoint!")

        log.infov("Start Inference and Evaluation")

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(self.session,
                                               coord=coord,
                                               start=True)
        try:
            if self.config.pred_program:
                if not os.path.exists(self.output_dir):
                    os.makedirs(self.output_dir)
                log.infov("Output Dir: %s", self.output_dir)
                base_name = os.path.join(
                    self.output_dir,
                    'out_{}_{}'.format(self.checkpoint_name,
                                       self.dataset_split))
                text_file = open('{}.txt'.format(base_name), 'w')
                from karel_env.dsl import get_KarelDSL
                dsl = get_KarelDSL(dsl_type=self.dataset.dsl_type, seed=123)

                hdf5_file = h5py.File('{}.hdf5'.format(base_name), 'w')
                log_file = open('{}.log'.format(base_name), 'w')
            else:
                log_file = None

            if self.config.result_data:
                result_file = h5py.File(self.config.result_data_path, 'w')
                data_file = h5py.File(
                    os.path.join(self.config.dataset_path, 'data.hdf5'), 'r')

            if not self.config.no_loss:
                loss_all = []
                acc_all = []
                hist_all = {}
                time_all = []
                for s in xrange(self.config.max_steps):
                    step, loss, acc, hist, \
                        pred_program, pred_program_len, pred_is_correct_syntax, \
                        greedy_pred_program, greedy_program_len, greedy_is_correct_syntax, \
                        gt_program, gt_program_len, output, program_id, \
                        program_num_execution_correct, program_is_correct_execution, \
                        greedy_num_execution_correct, greedy_is_correct_execution, \
                        step_time = self.run_single_step(self.batch)
                    if not self.config.quiet:
                        step_msg = self.log_step_message(
                            s, loss, acc, hist, step_time)
                    if self.config.result_data:
                        for i in range(len(program_id)):
                            try:
                                grp = result_file.create_group(program_id[i])
                                grp['program'] = gt_program[i]
                                grp['pred_program'] = greedy_pred_program[i]
                                grp['pred_program_len'] = greedy_program_len[
                                    i][0]
                                grp['s_h'] = data_file[
                                    program_id[i]]['s_h'].value
                                grp['test_s_h'] = data_file[
                                    program_id[i]]['test_s_h'].value
                            except:
                                print('Duplicates: {}'.format(program_id[i]))
                                pass

                    # write pred/gt program
                    if self.config.pred_program:
                        log_file.write('{}\n'.format(step_msg))
                        for i in range(self.batch_size):
                            pred_program_token = np.argmax(
                                pred_program[i, :, :pred_program_len[i, 0]],
                                axis=0)
                            pred_program_str = dsl.intseq2str(
                                pred_program_token)
                            greedy_program_token = np.argmax(
                                greedy_pred_program[i, :, :greedy_program_len[
                                    i, 0]],
                                axis=0)
                            greedy_program_str = dsl.intseq2str(
                                greedy_program_token)
                            try:
                                grp = hdf5_file.create_group(program_id[i])
                            except:
                                pass
                            else:
                                correctness = ['wrong', 'correct']
                                grp['program_prediction'] = pred_program_str
                                grp['program_syntax'] = \
                                    correctness[int(pred_is_correct_syntax[i])]
                                grp['program_num_execution_correct'] = \
                                    int(program_num_execution_correct[i])
                                grp['program_is_correct_execution'] = \
                                    program_is_correct_execution[i]
                                grp['greedy_prediction'] = \
                                    greedy_program_str
                                grp['greedy_syntax'] = \
                                    correctness[int(greedy_is_correct_syntax[i])]
                                grp['greedy_num_execution_correct'] = \
                                    int(greedy_num_execution_correct[i])
                                grp['greedy_is_correct_execution'] = \
                                    greedy_is_correct_execution[i]

                            text_file.write(
                                '[id: {}]\ngt: {}\npred{}: {}\ngreedy{}: {}\n'.
                                format(
                                    program_id[i],
                                    dsl.intseq2str(
                                        np.argmax(gt_program[
                                            i, :, :gt_program_len[i, 0]],
                                                  axis=0)),
                                    '(error)'
                                    if pred_is_correct_syntax[i] == 0 else '',
                                    pred_program_str,
                                    '(error)' if greedy_is_correct_syntax[i]
                                    == 0 else '',
                                    greedy_program_str,
                                ))
                    loss_all.append(np.array(loss.values()))
                    acc_all.append(np.array(acc.values()))
                    time_all.append(step_time)
                    for hist_key, hist_value in hist.items():
                        if hist_key not in hist_all:
                            hist_all[hist_key] = []
                        hist_all[hist_key].append(hist_value)

                loss_avg = np.average(np.stack(loss_all), axis=0)
                acc_avg = np.average(np.stack(acc_all), axis=0)
                hist_avg = {}
                for hist_key, hist_values in hist_all.items():
                    hist_avg[hist_key] = np.average(np.stack(hist_values),
                                                    axis=0)
                final_msg = self.log_final_message(
                    loss_avg,
                    loss.keys(),
                    acc_avg,
                    acc.keys(),
                    hist_avg,
                    hist_avg.keys(),
                    np.sum(time_all),
                    write_summary=self.config.write_summary,
                    summary_file=self.config.summary_file)

            if self.config.result_data:
                result_file.close()
                data_file.close()

            if self.config.pred_program:
                log_file.write('{}\n'.format(final_msg))
                log_file.write("Model class: {}\n".format(self.config.model))
                log_file.write("Checkpoint: {}\n".format(self.checkpoint))
                log_file.write("Dataset: {}\n".format(
                    self.config.dataset_path))
                log_file.close()
                text_file.close()
                hdf5_file.close()

        except Exception as e:
            coord.request_stop(e)

        log.warning('Completed Evaluation.')

        coord.request_stop()
        try:
            coord.join(threads, stop_grace_period_secs=3)
        except RuntimeError as e:
            log.warn(str(e))
示例#4
0
    def __init__(self, config, dataset):
        self.config = config
        self.dataset_split = config.dataset_split
        self.train_dir = config.train_dir
        self.output_dir = getattr(config, 'output_dir',
                                  config.train_dir) or self.train_dir
        log.info("self.train_dir = %s", self.train_dir)

        # --- input ops ---
        self.batch_size = config.batch_size

        if config.dataset_type == 'karel':
            from karel_env.input_ops_karel import create_input_ops
        elif config.dataset_type == 'vizdoom':
            from vizdoom_env.input_ops_vizdoom import create_input_ops
        else:
            raise NotImplementedError(
                "The dataset related code is not implemented.")

        self.dataset = dataset

        _, self.batch = create_input_ops(dataset,
                                         self.batch_size,
                                         is_training=False,
                                         shuffle=False)

        # --- create model ---
        Model = self.get_model_class(config.model)
        log.infov("Using Model class: %s", Model)
        self.model = Model(config, is_train=False)

        self.global_step = tf.contrib.framework.get_or_create_global_step(
            graph=None)
        self.step_op = tf.no_op(name='step_no_op')

        # --- vars ---
        all_vars = tf.trainable_variables()
        log.warn("********* var ********** ")
        slim.model_analyzer.analyze_vars(all_vars, print_info=True)

        tf.set_random_seed(123)

        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(allow_growth=True),
            device_count={'GPU': 1},
        )
        self.session = tf.Session(config=session_config)

        # --- checkpoint and monitoring ---
        self.saver = tf.train.Saver(max_to_keep=100)

        self.checkpoint = config.checkpoint
        if self.checkpoint is '' and self.train_dir:
            self.checkpoint = tf.train.latest_checkpoint(self.train_dir)
        if self.checkpoint is '':
            log.warn("No checkpoint is given. Just random initialization :-)")
            self.session.run(tf.global_variables_initializer())
        else:
            self.checkpoint_name = os.path.basename(self.checkpoint)
            log.info("Checkpoint path : %s", self.checkpoint)
        self.config.summary_file = self.checkpoint + '_report_testdata{}_num_k{}.txt'.format(
            self.config.max_steps * self.config.batch_size, self.config.num_k)