Example #1
0
    def train(self):
        log.infov("Training Starts!")
        pprint(self.batch_train)

        max_steps = 1000000

        ckpt_save_step = 1000
        log_step = self.log_step
        test_sample_step = self.test_sample_step
        write_summary_step = self.write_summary_step

        for s in xrange(max_steps):
            # train a single step
            step, train_summary, loss, output, step_time = \
                self.run_single_step(
                    self.batch_train, step=s, is_train=True)
            if s % log_step == 0:
                self.log_step_message(step, loss, step_time)

            # periodic inference
            if s % test_sample_step == 0:
                test_step, test_summary, test_loss, output, test_step_time = \
                    self.run_test(self.batch_test)
                self.summary_writer.add_summary(test_summary,
                                                global_step=test_step)
                self.log_step_message(step,
                                      test_loss,
                                      test_step_time,
                                      is_train=False)

            if s % write_summary_step == 0:
                self.summary_writer.add_summary(train_summary,
                                                global_step=step)

            if s % ckpt_save_step == 0:
                log.infov("Saved checkpoint at %d", s)
                self.saver.save(self.session,
                                os.path.join(self.train_dir, 'model'),
                                global_step=step)
Example #2
0
 def log_final_message(self,
                       loss,
                       loss_key,
                       acc,
                       acc_key,
                       hist,
                       hist_key,
                       time,
                       write_summary=False,
                       summary_file=None,
                       is_train=False):
     loss_str = ""
     for key, i in sorted(zip(loss_key, range(len(loss_key)))):
         loss_str += "{}:{loss: .3f} ".format(loss_key[i], loss=loss[i])
     acc_str = ""
     for key, i in sorted(zip(acc_key, range(len(acc_key)))):
         acc_str += "{}:{acc: .3f}\n".format(acc_key[i], acc=acc[i])
     hist_str = ""
     for key in sorted(hist_key):
         hist_str += "{}: [".format(key)
         for h in hist[key]:
             hist_str += "{acc: .3f}, ".format(acc=h)
         hist_str += "]\n"
     log_fn = (is_train and log.info or log.infov)
     msg = ("[Final Avg Report] \n" + "[Loss] {loss_str}\n" +
            "[Acc]  {acc_str}\n" + "[Hist] {hist_str}\n" +
            "[Time] ({time:.3f} sec)").format(
                split_mode=('Report'),
                loss_str=loss_str,
                acc_str=acc_str[:-1],
                hist_str=hist_str[:-1],
                time=time,
            )
     log_fn(msg)
     log.infov("Model class: %s", self.config.model)
     log.infov("Checkpoint: %s", self.checkpoint)
     log.infov("Dataset: %s", self.config.dataset_path)
     if write_summary:
         final_msg = 'Model class: {}\nCheckpoint: {}\nDataset: %s {}\n{}'.format(
             self.config.model, self.checkpoint, self.config.dataset_path,
             msg)
         with open(summary_file, 'w') as f:
             f.write(final_msg)
     return msg
Example #3
0
    def __init__(self, config, dataset, dataset_test):
        self.config = config
        hyper_parameter_str = 'bs_{}_lr_{}_{}_cell_{}'.format(
            config.batch_size, config.learning_rate, config.encoder_rnn_type,
            config.num_lstm_cell_units)
        if config.scheduled_sampling:
            hyper_parameter_str += '_sd_{}'.format(
                config.scheduled_sampling_decay_steps)
        hyper_parameter_str += '_k_{}'.format(self.config.num_k)

        self.train_dir = './train_dir/%s-%s-%s-%s-%s-%s' % (
            config.dataset_type, '_'.join(
                config.dataset_path.split('/')), config.model, config.prefix,
            hyper_parameter_str, time.strftime("%Y%m%d-%H%M%S"))

        if not os.path.exists(self.train_dir): os.makedirs(self.train_dir)
        log.infov("Train Dir: %s", self.train_dir)

        # --- input ops ---
        self.batch_size = config.batch_size

        if config.dataset_type == 'karel':
            from karel_env.input_ops_karel import create_input_ops
        elif config.dataset_type == 'vizdoom':
            from vizdoom_env.input_ops_vizdoom import create_input_ops
        else:
            raise ValueError(config.dataset)

        _, self.batch_train = create_input_ops(dataset,
                                               self.batch_size,
                                               is_training=True)
        _, self.batch_test = create_input_ops(dataset_test,
                                              self.batch_size,
                                              is_training=False)
        # --- optimizer ---
        self.global_step = tf.contrib.framework.get_or_create_global_step(
            graph=None)

        # --- create model ---
        Model = self.get_model_class(config.model)
        log.infov("Using Model class: %s", Model)
        self.model = Model(config,
                           debug_information=config.debug,
                           global_step=self.global_step)

        if config.lr_weight_decay:
            self.init_learning_rate = config.learning_rate
            self.learning_rate = tf.train.exponential_decay(
                self.init_learning_rate,
                global_step=self.global_step,
                decay_steps=10000,
                decay_rate=0.5,
                staircase=True,
                name='decaying_learning_rate')
        else:
            self.learning_rate = config.learning_rate

        self.check_op = tf.no_op()

        # --- checkpoint and monitoring ---
        all_vars = tf.trainable_variables()
        log.warn("********* var ********** ")
        slim.model_analyzer.analyze_vars(all_vars, print_info=True)

        self.optimizer = tf.contrib.layers.optimize_loss(
            loss=self.model.loss,
            global_step=self.global_step,
            learning_rate=self.learning_rate,
            optimizer=tf.train.AdamOptimizer,
            clip_gradients=20.0,
            name='optimizer_pixel_loss')

        self.train_summary_op = tf.summary.merge_all(key='train')
        self.test_summary_op = tf.summary.merge_all(key='test')

        self.saver = tf.train.Saver(max_to_keep=100)
        self.pretrain_saver = tf.train.Saver(var_list=all_vars, max_to_keep=1)
        self.summary_writer = tf.summary.FileWriter(self.train_dir)
        self.log_step = self.config.log_step
        self.test_sample_step = self.config.test_sample_step
        self.write_summary_step = self.config.write_summary_step

        self.checkpoint_secs = 600  # 10 min

        self.supervisor = tf.train.Supervisor(
            logdir=self.train_dir,
            is_chief=True,
            saver=None,
            summary_op=None,
            summary_writer=self.summary_writer,
            save_summaries_secs=300,
            save_model_secs=self.checkpoint_secs,
            global_step=self.global_step,
        )

        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(allow_growth=True),
            device_count={'GPU': 1},
        )
        self.session = self.supervisor.prepare_or_wait_for_session(
            config=session_config)

        self.ckpt_path = config.checkpoint
        if self.ckpt_path is not None:
            log.info("Checkpoint path: %s", self.ckpt_path)
            self.pretrain_saver.restore(self.session, self.ckpt_path)
            log.info("Loaded the pretrain parameters from the provided" +
                     "checkpoint path")
Example #4
0
    def eval_run(self):
        # load checkpoint
        if self.checkpoint:
            self.saver.restore(self.session, self.checkpoint)
            log.info("Loaded from checkpoint!")

        log.infov("Start Inference and Evaluation")

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(self.session,
                                               coord=coord,
                                               start=True)
        try:
            if self.config.pred_program:
                if not os.path.exists(self.output_dir):
                    os.makedirs(self.output_dir)
                log.infov("Output Dir: %s", self.output_dir)
                base_name = os.path.join(
                    self.output_dir,
                    'out_{}_{}'.format(self.checkpoint_name,
                                       self.dataset_split))
                text_file = open('{}.txt'.format(base_name), 'w')
                from karel_env.dsl import get_KarelDSL
                dsl = get_KarelDSL(dsl_type=self.dataset.dsl_type, seed=123)

                hdf5_file = h5py.File('{}.hdf5'.format(base_name), 'w')
                log_file = open('{}.log'.format(base_name), 'w')
            else:
                log_file = None

            if self.config.result_data:
                result_file = h5py.File(self.config.result_data_path, 'w')
                data_file = h5py.File(
                    os.path.join(self.config.dataset_path, 'data.hdf5'), 'r')

            if not self.config.no_loss:
                loss_all = []
                acc_all = []
                hist_all = {}
                time_all = []
                for s in xrange(self.config.max_steps):
                    step, loss, acc, hist, \
                        pred_program, pred_program_len, pred_is_correct_syntax, \
                        greedy_pred_program, greedy_program_len, greedy_is_correct_syntax, \
                        gt_program, gt_program_len, output, program_id, \
                        program_num_execution_correct, program_is_correct_execution, \
                        greedy_num_execution_correct, greedy_is_correct_execution, \
                        step_time = self.run_single_step(self.batch)
                    if not self.config.quiet:
                        step_msg = self.log_step_message(
                            s, loss, acc, hist, step_time)
                    if self.config.result_data:
                        for i in range(len(program_id)):
                            try:
                                grp = result_file.create_group(program_id[i])
                                grp['program'] = gt_program[i]
                                grp['pred_program'] = greedy_pred_program[i]
                                grp['pred_program_len'] = greedy_program_len[
                                    i][0]
                                grp['s_h'] = data_file[
                                    program_id[i]]['s_h'].value
                                grp['test_s_h'] = data_file[
                                    program_id[i]]['test_s_h'].value
                            except:
                                print('Duplicates: {}'.format(program_id[i]))
                                pass

                    # write pred/gt program
                    if self.config.pred_program:
                        log_file.write('{}\n'.format(step_msg))
                        for i in range(self.batch_size):
                            pred_program_token = np.argmax(
                                pred_program[i, :, :pred_program_len[i, 0]],
                                axis=0)
                            pred_program_str = dsl.intseq2str(
                                pred_program_token)
                            greedy_program_token = np.argmax(
                                greedy_pred_program[i, :, :greedy_program_len[
                                    i, 0]],
                                axis=0)
                            greedy_program_str = dsl.intseq2str(
                                greedy_program_token)
                            try:
                                grp = hdf5_file.create_group(program_id[i])
                            except:
                                pass
                            else:
                                correctness = ['wrong', 'correct']
                                grp['program_prediction'] = pred_program_str
                                grp['program_syntax'] = \
                                    correctness[int(pred_is_correct_syntax[i])]
                                grp['program_num_execution_correct'] = \
                                    int(program_num_execution_correct[i])
                                grp['program_is_correct_execution'] = \
                                    program_is_correct_execution[i]
                                grp['greedy_prediction'] = \
                                    greedy_program_str
                                grp['greedy_syntax'] = \
                                    correctness[int(greedy_is_correct_syntax[i])]
                                grp['greedy_num_execution_correct'] = \
                                    int(greedy_num_execution_correct[i])
                                grp['greedy_is_correct_execution'] = \
                                    greedy_is_correct_execution[i]

                            text_file.write(
                                '[id: {}]\ngt: {}\npred{}: {}\ngreedy{}: {}\n'.
                                format(
                                    program_id[i],
                                    dsl.intseq2str(
                                        np.argmax(gt_program[
                                            i, :, :gt_program_len[i, 0]],
                                                  axis=0)),
                                    '(error)'
                                    if pred_is_correct_syntax[i] == 0 else '',
                                    pred_program_str,
                                    '(error)' if greedy_is_correct_syntax[i]
                                    == 0 else '',
                                    greedy_program_str,
                                ))
                    loss_all.append(np.array(loss.values()))
                    acc_all.append(np.array(acc.values()))
                    time_all.append(step_time)
                    for hist_key, hist_value in hist.items():
                        if hist_key not in hist_all:
                            hist_all[hist_key] = []
                        hist_all[hist_key].append(hist_value)

                loss_avg = np.average(np.stack(loss_all), axis=0)
                acc_avg = np.average(np.stack(acc_all), axis=0)
                hist_avg = {}
                for hist_key, hist_values in hist_all.items():
                    hist_avg[hist_key] = np.average(np.stack(hist_values),
                                                    axis=0)
                final_msg = self.log_final_message(
                    loss_avg,
                    loss.keys(),
                    acc_avg,
                    acc.keys(),
                    hist_avg,
                    hist_avg.keys(),
                    np.sum(time_all),
                    write_summary=self.config.write_summary,
                    summary_file=self.config.summary_file)

            if self.config.result_data:
                result_file.close()
                data_file.close()

            if self.config.pred_program:
                log_file.write('{}\n'.format(final_msg))
                log_file.write("Model class: {}\n".format(self.config.model))
                log_file.write("Checkpoint: {}\n".format(self.checkpoint))
                log_file.write("Dataset: {}\n".format(
                    self.config.dataset_path))
                log_file.close()
                text_file.close()
                hdf5_file.close()

        except Exception as e:
            coord.request_stop(e)

        log.warning('Completed Evaluation.')

        coord.request_stop()
        try:
            coord.join(threads, stop_grace_period_secs=3)
        except RuntimeError as e:
            log.warn(str(e))
Example #5
0
    def __init__(self, config, dataset):
        self.config = config
        self.dataset_split = config.dataset_split
        self.train_dir = config.train_dir
        self.output_dir = getattr(config, 'output_dir',
                                  config.train_dir) or self.train_dir
        log.info("self.train_dir = %s", self.train_dir)

        # --- input ops ---
        self.batch_size = config.batch_size

        if config.dataset_type == 'karel':
            from karel_env.input_ops_karel import create_input_ops
        elif config.dataset_type == 'vizdoom':
            from vizdoom_env.input_ops_vizdoom import create_input_ops
        else:
            raise NotImplementedError(
                "The dataset related code is not implemented.")

        self.dataset = dataset

        _, self.batch = create_input_ops(dataset,
                                         self.batch_size,
                                         is_training=False,
                                         shuffle=False)

        # --- create model ---
        Model = self.get_model_class(config.model)
        log.infov("Using Model class: %s", Model)
        self.model = Model(config, is_train=False)

        self.global_step = tf.contrib.framework.get_or_create_global_step(
            graph=None)
        self.step_op = tf.no_op(name='step_no_op')

        # --- vars ---
        all_vars = tf.trainable_variables()
        log.warn("********* var ********** ")
        slim.model_analyzer.analyze_vars(all_vars, print_info=True)

        tf.set_random_seed(123)

        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(allow_growth=True),
            device_count={'GPU': 1},
        )
        self.session = tf.Session(config=session_config)

        # --- checkpoint and monitoring ---
        self.saver = tf.train.Saver(max_to_keep=100)

        self.checkpoint = config.checkpoint
        if self.checkpoint is '' and self.train_dir:
            self.checkpoint = tf.train.latest_checkpoint(self.train_dir)
        if self.checkpoint is '':
            log.warn("No checkpoint is given. Just random initialization :-)")
            self.session.run(tf.global_variables_initializer())
        else:
            self.checkpoint_name = os.path.basename(self.checkpoint)
            log.info("Checkpoint path : %s", self.checkpoint)
        self.config.summary_file = self.checkpoint + '_report_testdata{}_num_k{}.txt'.format(
            self.config.max_steps * self.config.batch_size, self.config.num_k)