def State_Encoder(s, per, batch_size, scope='State_Encoder', reuse=False): with tf.variable_scope(scope, reuse=reuse) as scope: if not reuse: log.warning(scope.name) _ = conv2d(s, 16, is_train, k_h=3, k_w=3, info=not reuse, batch_norm=True, name='conv1') _ = conv2d(_, 32, is_train, k_h=3, k_w=3, info=not reuse, batch_norm=True, name='conv2') _ = conv2d(_, 48, is_train, k_h=3, k_w=3, info=not reuse, batch_norm=True, name='conv3') if self.pixel_input: _ = conv2d(_, 48, is_train, k_h=3, k_w=3, info=not reuse, batch_norm=True, name='conv4') _ = conv2d(_, 48, is_train, k_h=3, k_w=3, info=not reuse, batch_norm=True, name='conv5') state_feature = tf.reshape(_, [batch_size, -1]) if self.state_encoder_fc: state_feature = fc(state_feature, 512, is_train, info=not reuse, name='fc1') state_feature = fc(state_feature, 512, is_train, info=not reuse, name='fc2') state_feature = tf.concat([state_feature, per], axis=-1) if not reuse: log.info( 'concat feature {}'.format(state_feature)) return state_feature
def __init__(self, config, dataset, dataset_test): self.config = config hyper_parameter_str = 'bs_{}_lr_{}_{}_cell_{}'.format( config.batch_size, config.learning_rate, config.encoder_rnn_type, config.num_lstm_cell_units) if config.scheduled_sampling: hyper_parameter_str += '_sd_{}'.format( config.scheduled_sampling_decay_steps) hyper_parameter_str += '_k_{}'.format(self.config.num_k) self.train_dir = './train_dir/%s-%s-%s-%s-%s-%s' % ( config.dataset_type, '_'.join( config.dataset_path.split('/')), config.model, config.prefix, hyper_parameter_str, time.strftime("%Y%m%d-%H%M%S")) if not os.path.exists(self.train_dir): os.makedirs(self.train_dir) log.infov("Train Dir: %s", self.train_dir) # --- input ops --- self.batch_size = config.batch_size if config.dataset_type == 'karel': from karel_env.input_ops_karel import create_input_ops elif config.dataset_type == 'vizdoom': from vizdoom_env.input_ops_vizdoom import create_input_ops else: raise ValueError(config.dataset) _, self.batch_train = create_input_ops(dataset, self.batch_size, is_training=True) _, self.batch_test = create_input_ops(dataset_test, self.batch_size, is_training=False) # --- optimizer --- self.global_step = tf.contrib.framework.get_or_create_global_step( graph=None) # --- create model --- Model = self.get_model_class(config.model) log.infov("Using Model class: %s", Model) self.model = Model(config, debug_information=config.debug, global_step=self.global_step) if config.lr_weight_decay: self.init_learning_rate = config.learning_rate self.learning_rate = tf.train.exponential_decay( self.init_learning_rate, global_step=self.global_step, decay_steps=10000, decay_rate=0.5, staircase=True, name='decaying_learning_rate') else: self.learning_rate = config.learning_rate self.check_op = tf.no_op() # --- checkpoint and monitoring --- all_vars = tf.trainable_variables() log.warn("********* var ********** ") slim.model_analyzer.analyze_vars(all_vars, print_info=True) self.optimizer = tf.contrib.layers.optimize_loss( loss=self.model.loss, global_step=self.global_step, learning_rate=self.learning_rate, optimizer=tf.train.AdamOptimizer, clip_gradients=20.0, name='optimizer_pixel_loss') self.train_summary_op = tf.summary.merge_all(key='train') self.test_summary_op = tf.summary.merge_all(key='test') self.saver = tf.train.Saver(max_to_keep=100) self.pretrain_saver = tf.train.Saver(var_list=all_vars, max_to_keep=1) self.summary_writer = tf.summary.FileWriter(self.train_dir) self.log_step = self.config.log_step self.test_sample_step = self.config.test_sample_step self.write_summary_step = self.config.write_summary_step self.checkpoint_secs = 600 # 10 min self.supervisor = tf.train.Supervisor( logdir=self.train_dir, is_chief=True, saver=None, summary_op=None, summary_writer=self.summary_writer, save_summaries_secs=300, save_model_secs=self.checkpoint_secs, global_step=self.global_step, ) session_config = tf.ConfigProto( allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True), device_count={'GPU': 1}, ) self.session = self.supervisor.prepare_or_wait_for_session( config=session_config) self.ckpt_path = config.checkpoint if self.ckpt_path is not None: log.info("Checkpoint path: %s", self.ckpt_path) self.pretrain_saver.restore(self.session, self.ckpt_path) log.info("Loaded the pretrain parameters from the provided" + "checkpoint path")
def eval_run(self): # load checkpoint if self.checkpoint: self.saver.restore(self.session, self.checkpoint) log.info("Loaded from checkpoint!") log.infov("Start Inference and Evaluation") coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(self.session, coord=coord, start=True) try: if self.config.pred_program: if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) log.infov("Output Dir: %s", self.output_dir) base_name = os.path.join( self.output_dir, 'out_{}_{}'.format(self.checkpoint_name, self.dataset_split)) text_file = open('{}.txt'.format(base_name), 'w') from karel_env.dsl import get_KarelDSL dsl = get_KarelDSL(dsl_type=self.dataset.dsl_type, seed=123) hdf5_file = h5py.File('{}.hdf5'.format(base_name), 'w') log_file = open('{}.log'.format(base_name), 'w') else: log_file = None if self.config.result_data: result_file = h5py.File(self.config.result_data_path, 'w') data_file = h5py.File( os.path.join(self.config.dataset_path, 'data.hdf5'), 'r') if not self.config.no_loss: loss_all = [] acc_all = [] hist_all = {} time_all = [] for s in xrange(self.config.max_steps): step, loss, acc, hist, \ pred_program, pred_program_len, pred_is_correct_syntax, \ greedy_pred_program, greedy_program_len, greedy_is_correct_syntax, \ gt_program, gt_program_len, output, program_id, \ program_num_execution_correct, program_is_correct_execution, \ greedy_num_execution_correct, greedy_is_correct_execution, \ step_time = self.run_single_step(self.batch) if not self.config.quiet: step_msg = self.log_step_message( s, loss, acc, hist, step_time) if self.config.result_data: for i in range(len(program_id)): try: grp = result_file.create_group(program_id[i]) grp['program'] = gt_program[i] grp['pred_program'] = greedy_pred_program[i] grp['pred_program_len'] = greedy_program_len[ i][0] grp['s_h'] = data_file[ program_id[i]]['s_h'].value grp['test_s_h'] = data_file[ program_id[i]]['test_s_h'].value except: print('Duplicates: {}'.format(program_id[i])) pass # write pred/gt program if self.config.pred_program: log_file.write('{}\n'.format(step_msg)) for i in range(self.batch_size): pred_program_token = np.argmax( pred_program[i, :, :pred_program_len[i, 0]], axis=0) pred_program_str = dsl.intseq2str( pred_program_token) greedy_program_token = np.argmax( greedy_pred_program[i, :, :greedy_program_len[ i, 0]], axis=0) greedy_program_str = dsl.intseq2str( greedy_program_token) try: grp = hdf5_file.create_group(program_id[i]) except: pass else: correctness = ['wrong', 'correct'] grp['program_prediction'] = pred_program_str grp['program_syntax'] = \ correctness[int(pred_is_correct_syntax[i])] grp['program_num_execution_correct'] = \ int(program_num_execution_correct[i]) grp['program_is_correct_execution'] = \ program_is_correct_execution[i] grp['greedy_prediction'] = \ greedy_program_str grp['greedy_syntax'] = \ correctness[int(greedy_is_correct_syntax[i])] grp['greedy_num_execution_correct'] = \ int(greedy_num_execution_correct[i]) grp['greedy_is_correct_execution'] = \ greedy_is_correct_execution[i] text_file.write( '[id: {}]\ngt: {}\npred{}: {}\ngreedy{}: {}\n'. format( program_id[i], dsl.intseq2str( np.argmax(gt_program[ i, :, :gt_program_len[i, 0]], axis=0)), '(error)' if pred_is_correct_syntax[i] == 0 else '', pred_program_str, '(error)' if greedy_is_correct_syntax[i] == 0 else '', greedy_program_str, )) loss_all.append(np.array(loss.values())) acc_all.append(np.array(acc.values())) time_all.append(step_time) for hist_key, hist_value in hist.items(): if hist_key not in hist_all: hist_all[hist_key] = [] hist_all[hist_key].append(hist_value) loss_avg = np.average(np.stack(loss_all), axis=0) acc_avg = np.average(np.stack(acc_all), axis=0) hist_avg = {} for hist_key, hist_values in hist_all.items(): hist_avg[hist_key] = np.average(np.stack(hist_values), axis=0) final_msg = self.log_final_message( loss_avg, loss.keys(), acc_avg, acc.keys(), hist_avg, hist_avg.keys(), np.sum(time_all), write_summary=self.config.write_summary, summary_file=self.config.summary_file) if self.config.result_data: result_file.close() data_file.close() if self.config.pred_program: log_file.write('{}\n'.format(final_msg)) log_file.write("Model class: {}\n".format(self.config.model)) log_file.write("Checkpoint: {}\n".format(self.checkpoint)) log_file.write("Dataset: {}\n".format( self.config.dataset_path)) log_file.close() text_file.close() hdf5_file.close() except Exception as e: coord.request_stop(e) log.warning('Completed Evaluation.') coord.request_stop() try: coord.join(threads, stop_grace_period_secs=3) except RuntimeError as e: log.warn(str(e))
def __init__(self, config, dataset): self.config = config self.dataset_split = config.dataset_split self.train_dir = config.train_dir self.output_dir = getattr(config, 'output_dir', config.train_dir) or self.train_dir log.info("self.train_dir = %s", self.train_dir) # --- input ops --- self.batch_size = config.batch_size if config.dataset_type == 'karel': from karel_env.input_ops_karel import create_input_ops elif config.dataset_type == 'vizdoom': from vizdoom_env.input_ops_vizdoom import create_input_ops else: raise NotImplementedError( "The dataset related code is not implemented.") self.dataset = dataset _, self.batch = create_input_ops(dataset, self.batch_size, is_training=False, shuffle=False) # --- create model --- Model = self.get_model_class(config.model) log.infov("Using Model class: %s", Model) self.model = Model(config, is_train=False) self.global_step = tf.contrib.framework.get_or_create_global_step( graph=None) self.step_op = tf.no_op(name='step_no_op') # --- vars --- all_vars = tf.trainable_variables() log.warn("********* var ********** ") slim.model_analyzer.analyze_vars(all_vars, print_info=True) tf.set_random_seed(123) session_config = tf.ConfigProto( allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True), device_count={'GPU': 1}, ) self.session = tf.Session(config=session_config) # --- checkpoint and monitoring --- self.saver = tf.train.Saver(max_to_keep=100) self.checkpoint = config.checkpoint if self.checkpoint is '' and self.train_dir: self.checkpoint = tf.train.latest_checkpoint(self.train_dir) if self.checkpoint is '': log.warn("No checkpoint is given. Just random initialization :-)") self.session.run(tf.global_variables_initializer()) else: self.checkpoint_name = os.path.basename(self.checkpoint) log.info("Checkpoint path : %s", self.checkpoint) self.config.summary_file = self.checkpoint + '_report_testdata{}_num_k{}.txt'.format( self.config.max_steps * self.config.batch_size, self.config.num_k)