def create_data_stream(args): print(args) sw = StopWatch() if not args.no_copy: with sw: print('Copying data to local machine...') rsync = Rsync(args.tmpdir) rsync.sync(args.data_path) args.data_path = os.path.join(args.tmpdir, os.path.basename(args.data_path)) return fuel_utils.get_datastream(path=args.data_path, which_set=args.dataset, batch_size=args.batch_size)
# get reload path # ################### if not args.reload_model: reload_path = args.save_path + '_last_model.pkl' if os.path.exists(reload_path): print('Previously trained model detected: {}'.format(reload_path)) print('Training continues') args.reload_model = reload_path ############## # print args # ############## print(args) sw = StopWatch() if not args.no_copy: print('Loading data streams from {}'.format(args.data_path)) print('Copying data to local machine...') rsync = Rsync(args.tmpdir) rsync.sync(args.data_path) args.data_path = os.path.join(args.tmpdir, os.path.basename(args.data_path)) sw.print_elapsed() #################### # load data stream # #################### train_datastream = get_datastream(path=args.data_path, which_set=args.train_dataset,
def main(_): print(' '.join(sys.argv)) args = FLAGS print(args.__flags) if not args.start_from_ckpt: if tf.gfile.Exists(args.log_dir): tf.gfile.DeleteRecursively(args.log_dir) tf.gfile.MakeDirs(args.log_dir) tf.get_variable_scope()._reuse = None _seed = args.base_seed + args.add_seed tf.set_random_seed(_seed) np.random.seed(_seed) prefix_name = os.path.join(args.log_dir, 'model') file_name = '%s.npz' % prefix_name eval_summary = OrderedDict() # tg = build_graph(args) tg_ml_cost = tf.reduce_mean(tg.ml_cost) global_step = tf.Variable(0, trainable=False, name="global_step") tvars = tf.trainable_variables() ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate, beta1=0.9, beta2=0.99) if args.grad_clip: ml_grads, _ = tf.clip_by_global_norm(tf.gradients(tg_ml_cost, tvars), clip_norm=1.0) else: ml_grads = tf.gradients(tg_ml_cost, tvars) ml_op = ml_opt_func.apply_gradients(zip(ml_grads, tvars), global_step=global_step) sync_data(args) datasets = [args.train_dataset, args.valid_dataset, args.test_dataset] train_set, valid_set, test_set = [create_ivector_datastream(path=args.data_path, which_set=dataset, batch_size=args.batch_size, min_after_cache=args.min_after_cache, length_sort=not args.no_length_sort) for dataset in datasets] init_op = tf.global_variables_initializer() save_op = tf.train.Saver(max_to_keep=5) best_save_op = tf.train.Saver(max_to_keep=5) with tf.name_scope("per_step_eval"): tr_ce = tf.placeholder(tf.float32) tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce) with tf.name_scope("per_epoch_eval"): best_val_ce = tf.placeholder(tf.float32) val_ce = tf.placeholder(tf.float32) best_val_ce_summary = tf.summary.scalar("best_valid_ce", best_val_ce) val_ce_summary = tf.summary.scalar("valid_ce", val_ce) with tf.Session() as sess: sess.run(init_op) if args.start_from_ckpt: save_op = tf.train.import_meta_graph(os.path.join(args.log_dir, 'model.ckpt.meta')) save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt')) print("Restore from the last checkpoint. " "Restarting from %d step." % global_step.eval()) summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph, flush_secs=5.0) tr_ce_sum = 0. tr_ce_count = 0 tr_acc_sum = 0 tr_acc_count = 0 _best_score = np.iinfo(np.int32).max epoch_sw = StopWatch() disp_sw = StopWatch() eval_sw = StopWatch() per_sw = StopWatch() # For each epoch for _epoch in xrange(args.n_epoch): _n_exp = 0 epoch_sw.reset() disp_sw.reset() print('--') print('Epoch {} training'.format(_epoch+1)) # For each batch for batch in train_set.get_epoch_iterator(): orig_x, orig_x_mask, _, _, orig_y, _ = batch # Get skipped frames for sub_batch in skip_frames_fixed([orig_x, orig_x_mask, orig_y], args.n_skip+1): x, x_mask, y = sub_batch n_batch, _, _ = x.shape _feed_states = initial_states(n_batch, args.n_hidden) _tr_ml_cost, _seq_logit, _ = sess.run([tg.ml_cost, tg.seq_logit, ml_op], feed_dict={tg.seq_x_data: x, tg.seq_x_mask: x_mask, tg.seq_y_data: y, tg.init_state: _feed_states}) tr_ce_sum += _tr_ml_cost.sum() tr_ce_count += x_mask.sum() _tr_ce_summary, = sess.run([tr_ce_summary], feed_dict={tr_ce: _tr_ml_cost.sum() / x_mask.sum()}) summary_writer.add_summary(_tr_ce_summary, global_step.eval()) _, n_seq = orig_y.shape _expand_seq_logit = interpolate_feat(_seq_logit, num_skips=args.n_skip+1, axis=1, use_bidir=True) _pred_idx = _expand_seq_logit.argmax(axis=2) tr_acc_sum += ((_pred_idx == orig_y) * orig_x_mask).sum() tr_acc_count += orig_x_mask.sum() if global_step.eval() % args.display_freq == 0: avg_tr_ce = tr_ce_sum / tr_ce_count avg_tr_fer = 1. - float(tr_acc_sum) / tr_acc_count print("TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} time_taken={:.2f}".format( _epoch, global_step.eval(), avg_tr_ce, avg_tr_fer, disp_sw.elapsed())) tr_ce_sum = 0. tr_ce_count = 0 tr_acc_sum = 0 tr_acc_count = 0 disp_sw.reset() print('--') print('End of epoch {}'.format(_epoch+1)) epoch_sw.print_elapsed() print('Testing') # Evaluate the model on the validation set val_ce_sum = 0. val_ce_count = 0 val_acc_sum = 0 val_acc_count = 0 eval_sw.reset() for batch in valid_set.get_epoch_iterator(): orig_x, orig_x_mask, _, _, orig_y, _ = batch for sub_batch in skip_frames_fixed([orig_x, orig_x_mask, orig_y], args.n_skip+1, return_first=True): x, x_mask, y = sub_batch n_batch, _, _ = x.shape _feed_states = initial_states(n_batch, args.n_hidden) _val_ml_cost, _seq_logit = sess.run([tg.ml_cost, tg.seq_logit,], feed_dict={tg.seq_x_data: x, tg.seq_x_mask: x_mask, tg.seq_y_data: y, tg.init_state: _feed_states}) val_ce_sum += _val_ml_cost.sum() val_ce_count += x_mask.sum() _, n_seq = orig_y.shape _expand_seq_logit = interpolate_feat(_seq_logit, num_skips=args.n_skip+1, axis=1, use_bidir=True) _pred_idx = _expand_seq_logit.argmax(axis=2) val_acc_sum += ((_pred_idx == orig_y) * orig_x_mask).sum() val_acc_count += orig_x_mask.sum() avg_val_ce = val_ce_sum / val_ce_count avg_val_fer = 1. - float(val_acc_sum) / val_acc_count print("VALID: epoch={} ml_cost(ce/frame)={:.2f} fer={:.2f} time_taken={:.2f}".format( _epoch, avg_val_ce, avg_val_fer, eval_sw.elapsed())) _val_ce_summary, = sess.run([val_ce_summary], feed_dict={val_ce: avg_val_ce}) summary_writer.add_summary(_val_ce_summary, global_step.eval()) insert_item2dict(eval_summary, 'val_ce', avg_val_ce) insert_item2dict(eval_summary, 'time', eval_sw.elapsed()) save_npz2(file_name, eval_summary) # Save model if avg_val_ce < _best_score: _best_score = avg_val_ce best_ckpt = best_save_op.save(sess, os.path.join(args.log_dir, "best_model.ckpt"), global_step=global_step) print("Best checkpoint stored in: %s" % best_ckpt) ckpt = save_op.save(sess, os.path.join(args.log_dir, "model.ckpt"), global_step=global_step) print("Checkpoint stored in: %s" % ckpt) _best_val_ce_summary, = sess.run([best_val_ce_summary], feed_dict={best_val_ce: _best_score}) summary_writer.add_summary(_best_val_ce_summary, global_step.eval()) summary_writer.close() print("Optimization Finished.")
def main(): print(' '.join(sys.argv)) args = get_args() print(args) print('Hostname: {}'.format(socket.gethostname())) print('GPU: {}'.format(get_gpuname())) if not args.start_from_ckpt: if tf.gfile.Exists(args.log_dir): tf.gfile.DeleteRecursively(args.log_dir) tf.gfile.MakeDirs(args.log_dir) tf.get_variable_scope()._reuse = None _seed = args.base_seed + args.add_seed tf.set_random_seed(_seed) np.random.seed(_seed) prefix_name = os.path.join(args.log_dir, 'model') file_name = '%s.npz' % prefix_name eval_summary = OrderedDict() tg, sg = build_graph(args) tg_ml_cost = tf.reduce_mean(tg.ml_cost) global_step = tf.Variable(0, trainable=False, name="global_step") tvars = tf.trainable_variables() print([tvar.name for tvar in tvars]) ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate, beta1=0.9, beta2=0.99) rl_opt_func = tf.train.AdamOptimizer(learning_rate=args.rl_learning_rate, beta1=0.9, beta2=0.99) if args.grad_clip: ml_grads, _ = tf.clip_by_global_norm(tf.gradients(tg_ml_cost, tvars), clip_norm=1.0) else: ml_grads = tf.gradients(tg_ml_cost, tvars) ml_op = ml_opt_func.apply_gradients(zip(ml_grads, tvars), global_step=global_step) tg_rl_cost = tf.reduce_mean(tg.rl_cost) rl_grads = tf.gradients(tg_rl_cost, tvars) # do not increase global step -- ml op increases it rl_op = rl_opt_func.apply_gradients(zip(rl_grads, tvars)) tf.add_to_collection('n_fast_action', args.n_fast_action) sync_data(args) datasets = [args.train_dataset, args.valid_dataset, args.test_dataset] train_set, valid_set, test_set = [ create_ivector_datastream(path=args.data_path, which_set=dataset, batch_size=args.n_batch, min_after_cache=args.min_after_cache, length_sort=not args.no_length_sort) for dataset in datasets ] init_op = tf.global_variables_initializer() save_op = tf.train.Saver(max_to_keep=5) best_save_op = tf.train.Saver(max_to_keep=5) with tf.name_scope("per_step_eval"): tr_ce = tf.placeholder(tf.float32) tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce) tr_image = tf.placeholder(tf.float32) tr_image_summary = tf.summary.image("tr_image", tr_image) tr_fer = tf.placeholder(tf.float32) tr_fer_summary = tf.summary.scalar("tr_fer", tr_fer) tr_rl = tf.placeholder(tf.float32) tr_rl_summary = tf.summary.scalar("tr_rl", tr_rl) tr_rw_hist = tf.placeholder(tf.float32) tr_rw_hist_summary = tf.summary.histogram("tr_reward_hist", tr_rw_hist) with tf.name_scope("per_epoch_eval"): best_val_ce = tf.placeholder(tf.float32) val_ce = tf.placeholder(tf.float32) best_val_ce_summary = tf.summary.scalar("best_valid_ce", best_val_ce) val_ce_summary = tf.summary.scalar("valid_ce", val_ce) vf = LinearVF() with tf.Session() as sess: sess.run(init_op) if args.start_from_ckpt: save_op = tf.train.import_meta_graph( os.path.join(args.log_dir, 'model.ckpt.meta')) save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt')) print( "Restore from the last checkpoint. Restarting from %d step." % global_step.eval()) summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph, flush_secs=5.0) tr_ce_sum = 0. tr_ce_count = 0 tr_acc_sum = 0 tr_acc_count = 0 tr_rl_costs = [] tr_action_entropies = [] tr_rewards = [] _best_score = np.iinfo(np.int32).max epoch_sw = StopWatch() disp_sw = StopWatch() eval_sw = StopWatch() per_sw = StopWatch() # For each epoch for _epoch in xrange(args.n_epoch): _n_exp = 0 epoch_sw.reset() disp_sw.reset() print('Epoch {} training'.format(_epoch + 1)) # For each batch for batch in train_set.get_epoch_iterator(): x, x_mask, _, _, y, _ = batch n_batch = x.shape[0] _n_exp += n_batch new_x, new_y, actions_1hot, rewards, action_entropies, new_x_mask, new_reward_mask, output_image = \ gen_episode_with_seg_reward(x, x_mask, y, sess, sg, args) advantages, _ = compute_advantage(new_x, new_x_mask, rewards, new_reward_mask, vf, args) zero_state = gen_zero_state(n_batch, args.n_hidden) feed_dict = { tg.seq_x_data: new_x, tg.seq_x_mask: new_x_mask, tg.seq_y_data: new_y, tg.seq_action: actions_1hot, tg.seq_advantage: advantages, tg.seq_action_mask: new_reward_mask, tg.seq_y_data_for_action: new_y } feed_init_state(feed_dict, tg.init_state, zero_state) _tr_ml_cost, _tr_rl_cost, _, _, pred_idx = \ sess.run([tg.ml_cost, tg.rl_cost, ml_op, rl_op, tg.pred_idx], feed_dict=feed_dict) tr_ce_sum += _tr_ml_cost.sum() tr_ce_count += new_x_mask.sum() pred_idx = expand_output(actions_1hot, x_mask, new_x_mask, pred_idx.reshape([n_batch, -1]), args.n_fast_action) tr_acc_sum += ((pred_idx == y) * x_mask).sum() tr_acc_count += x_mask.sum() _tr_ce_summary, _tr_fer_summary, _tr_rl_summary, _tr_image_summary, _tr_rw_hist_summary = \ sess.run([tr_ce_summary, tr_fer_summary, tr_rl_summary, tr_image_summary, tr_rw_hist_summary], feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(), tr_fer: ((pred_idx == y) * x_mask).sum() / x_mask.sum(), tr_rl: _tr_rl_cost.sum() / new_reward_mask.sum(), tr_image: output_image, tr_rw_hist: rewards}) summary_writer.add_summary(_tr_ce_summary, global_step.eval()) summary_writer.add_summary(_tr_fer_summary, global_step.eval()) summary_writer.add_summary(_tr_rl_summary, global_step.eval()) summary_writer.add_summary(_tr_image_summary, global_step.eval()) summary_writer.add_summary(_tr_rw_hist_summary, global_step.eval()) tr_rl_costs.append(_tr_rl_cost.sum() / new_reward_mask.sum()) tr_action_entropies.append(action_entropies.sum() / new_reward_mask.sum()) tr_rewards.append(rewards.sum() / new_reward_mask.sum()) if global_step.eval() % args.display_freq == 0: avg_tr_ce = tr_ce_sum / tr_ce_count avg_tr_fer = 1. - tr_acc_sum / tr_acc_count avg_tr_rl_cost = np.asarray(tr_rl_costs).mean() avg_tr_action_entropy = np.asarray( tr_action_entropies).mean() avg_tr_reward = np.asarray(tr_rewards).mean() print( "TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} time_taken={:.2f}" .format(_epoch, global_step.eval(), avg_tr_ce, avg_tr_fer, avg_tr_rl_cost, avg_tr_reward, avg_tr_action_entropy, disp_sw.elapsed())) tr_ce_sum = 0. tr_ce_count = 0 tr_acc_sum = 0. tr_acc_count = 0 tr_rl_costs = [] tr_action_entropies = [] tr_rewards = [] disp_sw.reset() print('--') print('End of epoch {}'.format(_epoch + 1)) epoch_sw.print_elapsed() print('Testing') # Evaluate the model on the validation set val_ce_sum = 0. val_ce_count = 0 val_acc_sum = 0 val_acc_count = 0 val_rl_costs = [] val_action_entropies = [] val_rewards = [] eval_sw.reset() for batch in valid_set.get_epoch_iterator(): x, x_mask, _, _, y, _ = batch n_batch = x.shape[0] new_x, new_y, actions_1hot, rewards, action_entropies, new_x_mask, new_reward_mask, output_image, new_y_sample = \ gen_episode_with_seg_reward(x, x_mask, y, sess, sg, args, sample_y=True) advantages, _ = compute_advantage(new_x, new_x_mask, rewards, new_reward_mask, vf, args) zero_state = gen_zero_state(n_batch, args.n_hidden) feed_dict = { tg.seq_x_data: new_x, tg.seq_x_mask: new_x_mask, tg.seq_y_data: new_y, tg.seq_action: actions_1hot, tg.seq_advantage: advantages, tg.seq_action_mask: new_reward_mask, tg.seq_y_data_for_action: new_y_sample } feed_init_state(feed_dict, tg.init_state, zero_state) _val_ml_cost, _val_rl_cost, pred_idx = sess.run( [tg.ml_cost, tg.rl_cost, tg.pred_idx], feed_dict=feed_dict) val_ce_sum += _val_ml_cost.sum() val_ce_count += new_x_mask.sum() pred_idx = expand_output(actions_1hot, x_mask, new_x_mask, pred_idx.reshape([n_batch, -1]), args.n_fast_action) val_acc_sum += ((pred_idx == y) * x_mask).sum() val_acc_count += x_mask.sum() val_rl_costs.append(_val_rl_cost.sum() / new_reward_mask.sum()) val_action_entropies.append(action_entropies.sum() / new_reward_mask.sum()) val_rewards.append(rewards.sum() / new_reward_mask.sum()) avg_val_ce = val_ce_sum / val_ce_count avg_val_fer = 1. - val_acc_sum / val_acc_count avg_val_rl_cost = np.asarray(val_rl_costs).mean() avg_val_action_entropy = np.asarray(val_action_entropies).mean() avg_val_reward = np.asarray(val_rewards).mean() print( "VALID: epoch={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} time_taken={:.2f}" .format(_epoch, avg_val_ce, avg_val_fer, avg_val_rl_cost, avg_val_reward, avg_val_action_entropy, eval_sw.elapsed())) _val_ce_summary, = sess.run([val_ce_summary], feed_dict={val_ce: avg_val_ce}) summary_writer.add_summary(_val_ce_summary, global_step.eval()) insert_item2dict(eval_summary, 'val_ce', avg_val_ce) insert_item2dict(eval_summary, 'val_rl_cost', avg_val_rl_cost) insert_item2dict(eval_summary, 'val_reward', avg_val_reward) insert_item2dict(eval_summary, 'val_action_entropy', avg_val_action_entropy) insert_item2dict(eval_summary, 'time', eval_sw.elapsed()) save_npz2(file_name, eval_summary) # Save model if avg_val_ce < _best_score: _best_score = avg_val_ce best_ckpt = best_save_op.save(sess, os.path.join( args.log_dir, "best_model.ckpt"), global_step=global_step) print("Best checkpoint stored in: %s" % best_ckpt) ckpt = save_op.save(sess, os.path.join(args.log_dir, "model.ckpt"), global_step=global_step) print("Checkpoint stored in: %s" % ckpt) _best_val_ce_summary, = sess.run( [best_val_ce_summary], feed_dict={best_val_ce: _best_score}) summary_writer.add_summary(_best_val_ce_summary, global_step.eval()) summary_writer.close() print("Optimization Finished.")
with tf.name_scope("val_eval"): val_summary = utils.get_summary('ce cr fer image'.split()) with tf.Session() as sess: sess.run(init_op) summary_writer = tf.summary.FileWriter(args.logdir, sess.graph, flush_secs=5.0) # ce, accuracy, compression ratio accu_list = [Accumulator() for i in range(3)] ce, ac, cr = accu_list _best_score = np.iinfo(np.int32).max epoch_sw, disp_sw, eval_sw = StopWatch(), StopWatch(), StopWatch() # For each epoch for _epoch in range(1, args.n_epoch + 1): epoch_sw.reset() disp_sw.reset() print('--') print('Epoch {} training'.format(_epoch)) for accu in accu_list: accu.reset() # For each batch for batch in train_set.get_epoch_iterator(): orig_x, orig_x_mask, _, _, orig_y, _ = batch
last_hstates = [ op.outputs[0] for op in sess.graph.get_operations() if match_h(op.name) ] for c, h in zip(last_cstates, last_hstates): step_last_state.append(tf.contrib.rnn.LSTMStateTuple(c, h)) test_graph = TestGraph(step_last_state, _step_label_probs, step_action_probs, step_action_samples, step_pred_idx, step_x_data, init_state) writer = kaldi_io.BaseFloatMatrixWriter(args.wxfilename) print('Computing label probs...', file=sys.stderr) sw = StopWatch() for bidx, (batch, uttid_batch) in enumerate( zip(test_set.get_epoch_iterator(), uttid_stream.get_epoch_iterator())): orig_x, orig_x_mask, _, _ = batch uttid_batch, = uttid_batch feat_lens = orig_x_mask.sum(axis=1, dtype=np.int32) actions_1hot, label_probs, new_mask = skip_rnn_forward_supervised( orig_x, orig_x_mask, sess, test_graph, n_fast_action) seq_label_probs = expand_output(actions_1hot, orig_x_mask, new_mask, label_probs) for out_idx, (output, uttid) in enumerate(zip(seq_label_probs,
parser = get_arg_parser() args = parser.parse_args() args.save_path = os.path.join(args.log_dir, get_save_path(args)) print(args) if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) args.file_name = '{}.npz'.format(args.save_path) args.best_file_name = '{}.best.npz'.format(args.save_path) args.opt_file_name = '{}.grads.npz'.format(args.save_path) args.best_opt_file_name = '{}.best.grads.npz'.format(args.save_path) sw = StopWatch() print('Building and compiling the network') f_prop, f_update, f_log_prob, f_debug, tparams, opt_tparams, \ states, st_slope = build_graph_am(args) sw.print_elapsed() sw.reset() summary = OrderedDict() if args.start_from_ckpt and os.path.exists(args.file_name): tparams = init_tparams_with_restored_value(tparams, args.file_name) model = numpy.load(args.file_name) for k, v in model.items(): if 'summary' in k:
def train_model(): sw = StopWatch() # Fix random seeds rand_seed = FLAGS.base_seed + FLAGS.add_seed tf.set_random_seed(rand_seed) np.random.seed(rand_seed) # Get module graph model_graph = build_graph(FLAGS) # Get model parameter model_param = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # Set weight decay if FLAGS.weight_decay > 0.0: l2_cost = tf.add_n([ 0.5 * tf.nn.l2_loss(W) for W in model_param if 'W' in W.name and 'action' not in W.name and 'baseline' not in W.name ]) l2_cost *= FLAGS.weight_decay else: l2_cost = 0.0 # Set total cost model_total_cost = model_graph.ml_cost + model_graph.rl_cost + model_graph.bl_cost + l2_cost # Define global training step global_step = tf.contrib.framework.get_or_create_global_step() # Set ml optimizer (Adam optimizer, in the original paper, we use 0.99 for beta2 model_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, name='model_optimizer') model_grad = tf.gradients(ys=model_total_cost, xs=model_param, aggregation_method=2) # Set gradient clipping if FLAGS.grad_clip > 0.0: model_grad, _ = tf.clip_by_global_norm(t_list=model_grad, clip_norm=FLAGS.grad_clip) model_update = model_opt.apply_gradients(grads_and_vars=zip( model_grad, model_param), global_step=global_step) # Set dataset (sync_data(FLAGS)) datasets = [FLAGS.train_dataset, FLAGS.valid_dataset, FLAGS.test_dataset] train_set, valid_set, test_set = [ create_ivector_datastream(path=FLAGS.data_path, which_set=dataset, batch_size=FLAGS.batch_size) for dataset in datasets ] # Set variable initializer init_op = tf.global_variables_initializer() # Set last model saver last_save_op = tf.train.Saver(max_to_keep=5) # Set best model saver best_save_op = tf.train.Saver(max_to_keep=5) # Get hardware config config = tf.ConfigProto() config.gpu_options.allow_growth = True # Set session with tf.Session(config=config) as sess: # Get summary merged_summary = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) # Init model sess.run(init_op) # Load checkpoint if FLAGS.start_from_ckpt: last_save_op = tf.train.import_meta_graph( os.path.join(FLAGS.log_dir, 'last_model.ckpt.meta')) last_save_op.restore( sess, os.path.join(FLAGS.log_dir, 'last_model.ckpt')) print( "Restore from the last checkpoint. Restarting from %d step." % global_step.eval()) # For each epoch accr_history = [] loss_history = [] ml_cost_history = [] rl_cost_history = [] bl_cost_history = [] sum_cost_history = [] best_accr = 0.0 sw.reset() for e_idx in xrange(FLAGS.n_epoch): # for each batch (update) for b_idx, batch_data in enumerate(train_set.get_epoch_iterator()): # Get data x, y x_data, x_mask, _, _, y_data, _ = batch_data # Roll axis x_data = x_data.transpose((1, 0, 2)) x_mask = x_mask.transpose((1, 0)) y_data = y_data.transpose((1, 0)) # Update model mean_accr, mean_loss, ml_cost, rl_cost, bl_cost, read_ratio, summary_output \ = updater(model_graph=model_graph, model_updater=model_update, x_data=x_data, x_mask=x_mask, y_data=y_data, summary=merged_summary, session=sess) # write summary train_writer.add_summary(summary_output, global_step.eval()) accr_history.append(mean_accr) loss_history.append(mean_loss) ml_cost_history.append(ml_cost) rl_cost_history.append(rl_cost) bl_cost_history.append(bl_cost) sum_cost_history.append(ml_cost + rl_cost + bl_cost) # Display results if global_step.eval() % FLAGS.display_freq == 0: mean_accr = np.array(accr_history).mean() mean_loss = np.array(loss_history).mean() mean_ml_cost = np.array(ml_cost_history).mean() mean_rl_cost = np.array(rl_cost_history).mean() mean_bl_cost = np.array(bl_cost_history).mean() mean_sum_cost = np.array(sum_cost_history).mean() print( "====================================================") print("Epoch " + str(e_idx) + ", Total Iter " + str(global_step.eval())) print( "----------------------------------------------------") print("Average FER: {:.2f}%".format( (1.0 - mean_accr) * 100)) print("Average CCE: {:.6f}".format(mean_loss)) print("Average ML: {:.6f}".format(mean_ml_cost)) if FLAGS.use_skim: print("Average RL: {:.6f}".format(mean_rl_cost)) print("Average BL: {:.6f}".format(mean_bl_cost)) print("Average SUM: {:.6f}".format(mean_sum_cost)) print("Read ratio: ", read_ratio) sw.print_elapsed() sw.reset() last_ckpt = last_save_op.save(sess, os.path.join( FLAGS.log_dir, "last_model.ckpt"), global_step=global_step) print("Last checkpointed in: %s" % last_ckpt) accr_history = [] loss_history = [] ml_cost_history = [] rl_cost_history = [] bl_cost_history = [] sum_cost_history = [] # Evaluate model if global_step.eval() % FLAGS.evaluation_freq == 0: # Monitor validation loss, accr valid_accr, valid_cce = evaluation(model_graph=model_graph, session=sess, dataset=valid_set) # Save model if best_accr < valid_accr: best_accr = valid_accr best_ckpt = best_save_op.save(sess, os.path.join( FLAGS.log_dir, "best_model.ckpt"), global_step=global_step) print("Best checkpoint stored in: %s" % best_ckpt) print( "----------------------------------------------------") print("Validation evaluation") print( "----------------------------------------------------") print("FER: {:.2f}%".format((1.0 - valid_accr) * 100.)) print("CCE: {:.6f}".format(valid_cce)) print( "----------------------------------------------------") print("Best FER: {:.2f}%".format((1.0 - best_accr) * 100.)) print("Optimization Finished.")
def main(_): # Print settings print(' '.join(sys.argv)) args = FLAGS for k, v in args.__flags.iteritems(): print(k, v) # Load checkpoint if not args.start_from_ckpt: if tf.gfile.Exists(args.log_dir): tf.gfile.DeleteRecursively(args.log_dir) tf.gfile.MakeDirs(args.log_dir) # ??? tf.get_variable_scope()._reuse = None # Set random seed _seed = args.base_seed + args.add_seed tf.set_random_seed(_seed) np.random.seed(_seed) # Set save file name prefix_name = os.path.join(args.log_dir, 'model') file_name = '%s.npz' % prefix_name # Set evaluation summary eval_summary = OrderedDict() # Build model graph tg, sg = build_graph(args) # Set linear regressor for baseline vf = LinearVF() # Set global step global_step = tf.Variable(0, trainable=False, name="global_step") # Get ml/rl related parameters tvars = tf.trainable_variables() ml_vars = [tvar for tvar in tvars if "action" not in tvar.name] rl_vars = [tvar for tvar in tvars if "action" in tvar.name] # Set optimizer ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate) rl_opt_func = tf.train.AdamOptimizer(learning_rate=args.rl_learning_rate) # Set model ml cost (sum over all and divide it by batch_size) ml_cost = tf.reduce_sum(tg.seq_ml_cost) ml_cost /= tf.to_float(tf.shape(tg.seq_x_data)[0]) ml_cost += args.ml_l2 * 0.5 * tf.add_n( [tf.reduce_sum(tf.square(var)) for var in ml_vars]) # Set model rl cost (sum over all and divide it by batch_size, also entropy cost) rl_cost = tg.seq_rl_cost - args.ent_weight * tg.seq_a_ent rl_cost = tf.reduce_sum(rl_cost) rl_cost /= tf.to_float(tf.shape(tg.seq_x_data)[0]) # Set model rl cost (sum over all and divide it by batch_size, also entropy cost) real_rl_cost = tf.reduce_sum(tg.seq_real_rl_cost) real_rl_cost /= tf.reduce_sum(tg.seq_a_mask) # Gradient clipping for ML ml_grads = tf.gradients(ml_cost, ml_vars) if args.grad_clip > 0.0: ml_grads, _ = tf.clip_by_global_norm(t_list=ml_grads, clip_norm=args.grad_clip) # Gradient for RL rl_grads = tf.gradients(rl_cost, rl_vars) # ML optimization ml_op = ml_opt_func.apply_gradients(grads_and_vars=zip(ml_grads, ml_vars), global_step=global_step, name='ml_op') # RL optimization rl_op = rl_opt_func.apply_gradients(grads_and_vars=zip(rl_grads, rl_vars), global_step=global_step, name='rl_op') # Sync dataset sync_data(args) # Get dataset train_set = create_ivector_datastream(path=args.data_path, which_set=args.train_dataset, batch_size=args.batch_size, min_after_cache=args.min_after_cache, length_sort=not args.no_length_sort) valid_set = create_ivector_datastream(path=args.data_path, which_set=args.valid_dataset, batch_size=args.batch_size, min_after_cache=args.min_after_cache, length_sort=not args.no_length_sort) # Set param init op init_op = tf.global_variables_initializer() # Set save op save_op = tf.train.Saver(max_to_keep=5) best_save_op = tf.train.Saver(max_to_keep=5) # Set per-step logging with tf.name_scope("per_step_eval"): # For ML cost (ce) tr_ce = tf.placeholder(tf.float32) tr_ce_summary = tf.summary.scalar("train_ml_cost", tr_ce) # For output visualization tr_image = tf.placeholder(tf.float32) tr_image_summary = tf.summary.image("train_image", tr_image) # For ML FER tr_fer = tf.placeholder(tf.float32) tr_fer_summary = tf.summary.scalar("train_fer", tr_fer) # For RL cost tr_rl = tf.placeholder(tf.float32) tr_rl_summary = tf.summary.scalar("train_rl", tr_rl) # For RL reward tr_reward = tf.placeholder(tf.float32) tr_reward_summary = tf.summary.scalar("train_reward", tr_reward) # For RL entropy tr_ent = tf.placeholder(tf.float32) tr_ent_summary = tf.summary.scalar("train_entropy", tr_ent) # For RL reward histogram tr_rw_hist = tf.placeholder(tf.float32) tr_rw_hist_summary = tf.summary.histogram("train_reward_hist", tr_rw_hist) # For RL skip count tr_skip_cnt = tf.placeholder(tf.float32) tr_skip_cnt_summary = tf.summary.scalar("train_skip_cnt", tr_skip_cnt) # Set per-epoch logging with tf.name_scope("per_epoch_eval"): # For best valid ML cost (full) best_val_ce = tf.placeholder(tf.float32) best_val_ce_summary = tf.summary.scalar("best_valid_ce", best_val_ce) # For best valid FER best_val_fer = tf.placeholder(tf.float32) best_val_fer_summary = tf.summary.scalar("best_valid_fer", best_val_fer) # For valid ML cost (full) val_ce = tf.placeholder(tf.float32) val_ce_summary = tf.summary.scalar("valid_ce", val_ce) # For valid FER val_fer = tf.placeholder(tf.float32) val_fer_summary = tf.summary.scalar("valid_fer", val_fer) # For output visualization val_image = tf.placeholder(tf.float32) val_image_summary = tf.summary.image("valid_image", val_image) # For RL skip count val_skip_cnt = tf.placeholder(tf.float32) val_skip_cnt_summary = tf.summary.scalar("valid_skip_cnt", val_skip_cnt) # Set module gen_episodes = improve_skip_rnn_act_parallel # Init session with tf.Session() as sess: # Init model sess.run(init_op) # Load from checkpoint if args.start_from_ckpt: save_op = tf.train.import_meta_graph( os.path.join(args.log_dir, 'model.ckpt.meta')) save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt')) print( "Restore from the last checkpoint. Restarting from %d step." % global_step.eval()) # Summary writer summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph, flush_secs=5.0) # For train tracking tr_ce_sum = 0. tr_ce_count = 0 tr_acc_sum = 0. tr_acc_count = 0 tr_rl_sum = 0. tr_rl_count = 0 tr_ent_sum = 0. tr_ent_count = 0 tr_reward_sum = 0. tr_reward_count = 0 tr_skip_sum = 0. tr_skip_count = 0 _best_ce = np.iinfo(np.int32).max _best_fer = 1.00 # For time measure epoch_sw = StopWatch() disp_sw = StopWatch() eval_sw = StopWatch() # For each epoch for _epoch in xrange(args.n_epoch): # Reset timer epoch_sw.reset() disp_sw.reset() print('Epoch {} training'.format(_epoch + 1)) # Set rl skipping flag use_rl_skipping = True # For each batch (update) for batch_data in train_set.get_epoch_iterator(): ################## # Sampling Phase # ################## # Get batch data seq_x_data, seq_x_mask, _, _, seq_y_data, _ = batch_data # Use skipping if use_rl_skipping: # Transpose axis seq_x_data = np.transpose(seq_x_data, (1, 0, 2)) seq_x_mask = np.transpose(seq_x_mask, (1, 0)) seq_y_data = np.transpose(seq_y_data, (1, 0)) # Number of samples batch_size = seq_x_data.shape[1] # Sample actions (episode generation) [ skip_x_data, skip_h_data, skip_x_mask, skip_y_data, skip_a_data, skip_a_mask, skip_rewards, result_image ] = gen_episodes(seq_x_data=seq_x_data, seq_x_mask=seq_x_mask, seq_y_data=seq_y_data, sess=sess, sample_graph=sg, args=args, use_sampling=True) # Compute skip ratio tr_skip_sum += skip_x_mask.sum() / seq_x_mask.sum() tr_skip_count += 1.0 # Compute baseline and refine reward skip_advantage, skip_disc_rewards = compute_advantage( seq_h_data=skip_h_data, seq_r_data=skip_rewards, seq_r_mask=skip_a_mask, vf=vf, args=args, final_cost=args.use_final_reward) if args.use_baseline is False: skip_advantage = skip_disc_rewards ################## # Training Phase # ################## # Update model [ _tr_ml_cost, _tr_rl_cost, _, _, _tr_act_ent, _tr_pred_logit ] = sess.run( [ ml_cost, real_rl_cost, ml_op, rl_op, tg.seq_a_ent, tg.seq_label_logits ], feed_dict={ tg.seq_x_data: skip_x_data, tg.seq_x_mask: skip_x_mask, tg.seq_y_data: skip_y_data, tg.seq_a_data: skip_a_data, tg.seq_a_mask: skip_a_mask, tg.seq_advantage: skip_advantage, tg.seq_reward: skip_disc_rewards }) seq_x_mask = np.transpose(seq_x_mask, (1, 0)) seq_y_data = np.transpose(seq_y_data, (1, 0)) # Get full sequence prediction _tr_pred_full = expand_pred_idx( seq_skip_1hot=skip_a_data, seq_skip_mask=skip_a_mask, seq_prd_idx=_tr_pred_logit.argmax(axis=-1).reshape( [batch_size, -1]), seq_x_mask=seq_y_data) # Update history tr_ce_sum += _tr_ml_cost.sum() * batch_size tr_ce_count += skip_x_mask.sum() tr_acc_sum += ((_tr_pred_full == seq_y_data) * seq_x_mask).sum() tr_acc_count += seq_x_mask.sum() tr_rl_sum += _tr_rl_cost.sum() tr_rl_count += 1.0 tr_ent_sum += _tr_act_ent.sum() tr_ent_count += skip_a_mask.sum() tr_reward_sum += (skip_rewards * skip_a_mask).sum() tr_reward_count += skip_a_mask.sum() ################ # Write result # ################ [ _tr_rl_summary, _tr_image_summary, _tr_ent_summary, _tr_reward_summary, _tr_rw_hist_summary, _tr_skip_cnt_summary ] = sess.run( [ tr_rl_summary, tr_image_summary, tr_ent_summary, tr_reward_summary, tr_rw_hist_summary, tr_skip_cnt_summary ], feed_dict={ tr_rl: _tr_rl_cost.sum(), tr_image: result_image, tr_ent: (_tr_act_ent.sum() / skip_a_mask.sum()), tr_reward: ((skip_rewards * skip_a_mask).sum() / skip_a_mask.sum()), tr_rw_hist: skip_rewards, tr_skip_cnt: skip_x_mask.sum() / seq_x_mask.sum() }) summary_writer.add_summary(_tr_rl_summary, global_step.eval()) summary_writer.add_summary(_tr_image_summary, global_step.eval()) summary_writer.add_summary(_tr_ent_summary, global_step.eval()) summary_writer.add_summary(_tr_reward_summary, global_step.eval()) summary_writer.add_summary(_tr_rw_hist_summary, global_step.eval()) summary_writer.add_summary(_tr_skip_cnt_summary, global_step.eval()) else: # Number of samples batch_size = seq_x_data.shape[0] ################## # Training Phase # ################## [_tr_ml_cost, _, _tr_pred_full] = sess.run( [ml_cost, ml_op, tg.seq_label_logits], feed_dict={ tg.seq_x_data: seq_x_data, tg.seq_x_mask: seq_x_mask, tg.seq_y_data: seq_y_data }) _tr_pred_full = np.reshape(_tr_pred_full.argmax(axis=1), seq_y_data.shape) # Update history tr_ce_sum += _tr_ml_cost.sum() * batch_size tr_ce_count += seq_x_mask.sum() tr_acc_sum += ((_tr_pred_full == seq_y_data) * seq_x_mask).sum() tr_acc_count += seq_x_mask.sum() skip_x_mask = seq_x_mask ################ # Write result # ################ [_tr_ce_summary, _tr_fer_summary] = sess.run( [tr_ce_summary, tr_fer_summary], feed_dict={ tr_ce: (_tr_ml_cost.sum() * batch_size) / skip_x_mask.sum(), tr_fer: ((_tr_pred_full == seq_y_data) * seq_x_mask).sum() / seq_x_mask.sum() }) summary_writer.add_summary(_tr_ce_summary, global_step.eval()) summary_writer.add_summary(_tr_fer_summary, global_step.eval()) # Display results if global_step.eval() % args.display_freq == 0: # Get average results avg_tr_ce = tr_ce_sum / tr_ce_count avg_tr_fer = 1. - tr_acc_sum / tr_acc_count if use_rl_skipping: avg_tr_rl = tr_rl_sum / tr_rl_count avg_tr_ent = tr_ent_sum / tr_ent_count avg_tr_reward = tr_reward_sum / tr_reward_count avg_tr_skip = tr_skip_sum / tr_skip_count print( "TRAIN: epoch={} iter={} " "ml_cost(ce/frame)={:.2f} fer={:.2f} " "rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} " "skip_ratio={:.2f} " "time_taken={:.2f}".format(_epoch, global_step.eval(), avg_tr_ce, avg_tr_fer, avg_tr_rl, avg_tr_reward, avg_tr_ent, avg_tr_skip, disp_sw.elapsed())) else: print("TRAIN: epoch={} iter={} " "ml_cost(ce/frame)={:.2f} fer={:.2f} " "time_taken={:.2f}".format( _epoch, global_step.eval(), avg_tr_ce, avg_tr_fer, disp_sw.elapsed())) # Reset average results tr_ce_sum = 0. tr_ce_count = 0 tr_acc_sum = 0. tr_acc_count = 0 tr_rl_sum = 0. tr_rl_count = 0 tr_ent_sum = 0. tr_ent_count = 0 tr_reward_sum = 0. tr_reward_count = 0 tr_skip_sum = 0. tr_skip_count = 0 disp_sw.reset() # End of epoch print('--') print('End of epoch {}'.format(_epoch + 1)) epoch_sw.print_elapsed() # Evaluation print('Testing') # Evaluate the model on the validation set val_ce_sum = 0. val_ce_count = 0 val_acc_sum = 0. val_acc_count = 0 val_rl_sum = 0. val_rl_count = 0 val_ent_sum = 0. val_ent_count = 0 val_reward_sum = 0. val_reward_count = 0 val_skip_sum = 0. val_skip_count = 0 eval_sw.reset() # For each batch in Valid for batch_data in valid_set.get_epoch_iterator(): ################## # Sampling Phase # ################## # Get batch data seq_x_data, seq_x_mask, _, _, seq_y_data, _ = batch_data if use_rl_skipping: # Transpose axis seq_x_data = np.transpose(seq_x_data, (1, 0, 2)) seq_x_mask = np.transpose(seq_x_mask, (1, 0)) seq_y_data = np.transpose(seq_y_data, (1, 0)) # Number of samples batch_size = seq_x_data.shape[1] # Sample actions (episode generation) [ skip_x_data, skip_h_data, skip_x_mask, skip_y_data, skip_a_data, skip_a_mask, skip_rewards, result_image ] = gen_episodes(seq_x_data=seq_x_data, seq_x_mask=seq_x_mask, seq_y_data=seq_y_data, sess=sess, sample_graph=sg, args=args, use_sampling=False) # Compute skip ratio val_skip_sum += skip_x_mask.sum() / seq_x_mask.sum() val_skip_count += 1.0 # Compute baseline and refine reward skip_advantage, skip_disc_rewards = compute_advantage( seq_h_data=skip_h_data, seq_r_data=skip_rewards, seq_r_mask=skip_a_mask, vf=vf, args=args, final_cost=args.use_final_reward) if args.use_baseline is False: skip_advantage = skip_disc_rewards ################# # Forward Phase # ################# [ _val_ml_cost, _val_rl_cost, _val_pred_logit, _val_action_ent ] = sess.run( [ ml_cost, real_rl_cost, tg.seq_label_logits, tg.seq_a_ent ], feed_dict={ tg.seq_x_data: skip_x_data, tg.seq_x_mask: skip_x_mask, tg.seq_y_data: skip_y_data, tg.seq_a_data: skip_a_data, tg.seq_a_mask: skip_a_mask, tg.seq_advantage: skip_advantage, tg.seq_reward: skip_disc_rewards }) seq_x_mask = np.transpose(seq_x_mask, (1, 0)) seq_y_data = np.transpose(seq_y_data, (1, 0)) # Get full sequence prediction _val_pred_full = expand_pred_idx( seq_skip_1hot=skip_a_data, seq_skip_mask=skip_a_mask, seq_prd_idx=_val_pred_logit.argmax(axis=-1).reshape( [batch_size, -1]), seq_x_mask=seq_y_data) # Update history val_ce_sum += _val_ml_cost.sum() * batch_size val_ce_count += skip_x_mask.sum() val_acc_sum += ((_val_pred_full == seq_y_data) * seq_x_mask).sum() val_acc_count += seq_x_mask.sum() val_rl_sum += _val_rl_cost.sum() val_rl_count += 1.0 val_ent_sum += _val_action_ent.sum() val_ent_count += skip_a_mask.sum() val_reward_sum += (skip_rewards * skip_a_mask).sum() val_reward_count += skip_a_mask.sum() else: # Number of samples batch_size = seq_x_data.shape[0] ################# # Forward Phase # ################# # Update model [_val_ml_cost, _val_pred_full] = sess.run( [ml_cost, tg.seq_label_logits], feed_dict={ tg.seq_x_data: seq_x_data, tg.seq_x_mask: seq_x_mask, tg.seq_y_data: seq_y_data }) _val_pred_full = np.reshape(_val_pred_full.argmax(axis=1), seq_y_data.shape) # Update history val_ce_sum += _val_ml_cost.sum() * batch_size val_ce_count += seq_x_mask.sum() val_acc_sum += ((_val_pred_full == seq_y_data) * seq_x_mask).sum() val_acc_count += seq_x_mask.sum() # Aggregate over all valid data avg_val_ce = val_ce_sum / val_ce_count avg_val_fer = 1. - val_acc_sum / val_acc_count if use_rl_skipping: avg_val_rl = val_rl_sum / val_rl_count avg_val_ent = val_ent_sum / val_ent_count avg_val_reward = val_reward_sum / val_reward_count avg_val_skip = val_skip_sum / val_skip_count print("VALID: epoch={} " "ml_cost(ce/frame)={:.2f} fer={:.2f} " "rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} " "skip_ratio={:.2f} " "time_taken={:.2f}".format(_epoch, avg_val_ce, avg_val_fer, avg_val_rl, avg_val_reward, avg_val_ent, avg_val_skip, eval_sw.elapsed())) else: print("VALID: epoch={} " "ml_cost(ce/frame)={:.2f} fer={:.2f} " "time_taken={:.2f}".format(_epoch, avg_val_ce, avg_val_fer, eval_sw.elapsed())) ################ # Write result # ################ [ _val_ce_summary, _val_fer_summary, _val_skip_cnt_summary, _val_img_summary ] = sess.run( [ val_ce_summary, val_fer_summary, val_skip_cnt_summary, val_image_summary ], feed_dict={ val_ce: avg_val_ce, val_fer: avg_val_fer, val_skip_cnt: avg_val_skip, val_image: result_image }) summary_writer.add_summary(_val_skip_cnt_summary, global_step.eval()) summary_writer.add_summary(_val_img_summary, global_step.eval()) summary_writer.add_summary(_val_ce_summary, global_step.eval()) summary_writer.add_summary(_val_fer_summary, global_step.eval()) insert_item2dict(eval_summary, 'val_ce', avg_val_ce) insert_item2dict(eval_summary, 'val_fer', avg_val_fer) # insert_item2dict(eval_summary, 'val_rl', avg_val_rl) # insert_item2dict(eval_summary, 'val_reward', avg_val_reward) # insert_item2dict(eval_summary, 'val_ent', avg_val_ent) insert_item2dict(eval_summary, 'time', eval_sw.elapsed()) save_npz2(file_name, eval_summary) # Save best model if avg_val_ce < _best_ce: _best_ce = avg_val_ce best_ckpt = best_save_op.save(sess=sess, save_path=os.path.join( args.log_dir, "best_model(ce).ckpt"), global_step=global_step) print("Best checkpoint based on CE stored in: %s" % best_ckpt) if avg_val_fer < _best_fer: _best_fer = avg_val_fer best_ckpt = best_save_op.save(sess=sess, save_path=os.path.join( args.log_dir, "best_model(fer).ckpt"), global_step=global_step) print("Best checkpoint based on FER stored in: %s" % best_ckpt) # Save model ckpt = save_op.save(sess=sess, save_path=os.path.join(args.log_dir, "model.ckpt"), global_step=global_step) print("Checkpoint stored in: %s" % ckpt) # Write result [_best_val_ce_summary, _best_val_fer_summary ] = sess.run([best_val_ce_summary, best_val_fer_summary], feed_dict={ best_val_ce: _best_ce, best_val_fer: _best_fer }) summary_writer.add_summary(_best_val_ce_summary, global_step.eval()) summary_writer.add_summary(_best_val_fer_summary, global_step.eval()) # Done of training summary_writer.close() print("Optimization Finished.")
def main(_): print(' '.join(sys.argv)) args = FLAGS print(args.__flags) print('Hostname: {}'.format(socket.gethostname())) print('GPU: {}'.format(get_gpuname())) if not args.start_from_ckpt: if tf.gfile.Exists(args.log_dir): tf.gfile.DeleteRecursively(args.log_dir) tf.gfile.MakeDirs(args.log_dir) tf.get_variable_scope()._reuse = None _seed = args.base_seed + args.add_seed tf.set_random_seed(_seed) np.random.seed(_seed) prefix_name = os.path.join(args.log_dir, 'model') file_name = '%s.npz' % prefix_name eval_summary = OrderedDict() tg, test_graph = build_graph(args) tg_ml_cost = tf.reduce_mean(tg.ml_cost) global_step = tf.Variable(0, trainable=False, name="global_step") tvars = tf.trainable_variables() print([tvar.name for tvar in tvars]) ml_tvars = [tvar for tvar in tvars if "action_logit" not in tvar.name] rl_tvars = [tvar for tvar in tvars if "action_logit" in tvar.name] ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate, beta1=0.9, beta2=0.99) rl_opt_func = tf.train.AdamOptimizer(learning_rate=args.rl_learning_rate, beta1=0.9, beta2=0.99) if args.grad_clip: ml_grads, _ = tf.clip_by_global_norm(tf.gradients( tg_ml_cost, ml_tvars), clip_norm=1.0) else: ml_grads = tf.gradients(tg_ml_cost, ml_tvars) ml_op = ml_opt_func.apply_gradients(zip(ml_grads, ml_tvars), global_step=global_step) tg_rl_cost = tf.reduce_mean(tg.rl_cost) rl_grads = tf.gradients(tg_rl_cost, rl_tvars) rl_op = rl_opt_func.apply_gradients(zip(rl_grads, rl_tvars)) tf.add_to_collection('fast_action', args.fast_action) tf.add_to_collection('fast_action', args.n_fast_action) sync_data(args) datasets = [args.train_dataset, args.valid_dataset, args.test_dataset] train_set, valid_set, test_set = [ create_ivector_datastream(path=args.data_path, which_set=dataset, batch_size=args.n_batch, min_after_cache=args.min_after_cache, length_sort=not args.no_length_sort) for dataset in datasets ] init_op = tf.global_variables_initializer() save_op = tf.train.Saver(max_to_keep=5) best_save_op = tf.train.Saver(max_to_keep=5) with tf.name_scope("per_step_eval"): tr_ce = tf.placeholder(tf.float32) tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce) tr_fer = tf.placeholder(tf.float32) tr_fer_summary = tf.summary.scalar("tr_fer", tr_fer) tr_ce2 = tf.placeholder(tf.float32) tr_ce2_summary = tf.summary.scalar("tr_rl", tr_ce2) tr_image = tf.placeholder(tf.float32) tr_image_summary = tf.summary.image("tr_image", tr_image) with tf.name_scope("per_epoch_eval"): val_fer = tf.placeholder(tf.float32) val_fer_summary = tf.summary.scalar("val_fer", val_fer) best_val_fer = tf.placeholder(tf.float32) best_val_fer_summary = tf.summary.scalar("best_valid_fer", best_val_fer) val_image = tf.placeholder(tf.float32) val_image_summary = tf.summary.image("val_image", val_image) vf = LinearVF() with tf.Session() as sess: sess.run(init_op) if args.start_from_ckpt: save_op = tf.train.import_meta_graph( os.path.join(args.log_dir, 'model.ckpt.meta')) save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt')) print( "Restore from the last checkpoint. Restarting from %d step." % global_step.eval()) summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph, flush_secs=5.0) tr_ce_sum = 0. tr_ce_count = 0 tr_acc_sum = 0 tr_acc_count = 0 tr_ce2_sum = 0. tr_ce2_count = 0 _best_score = np.iinfo(np.int32).max epoch_sw = StopWatch() disp_sw = StopWatch() eval_sw = StopWatch() per_sw = StopWatch() # For each epoch for _epoch in xrange(args.n_epoch): _n_exp = 0 epoch_sw.reset() disp_sw.reset() print('Epoch {} training'.format(_epoch + 1)) # For each batch for batch in train_set.get_epoch_iterator(): x, x_mask, _, _, y, _ = batch n_batch = x.shape[0] _n_exp += n_batch if args.no_sampling: new_x, new_y, actions, actions_1hot, new_x_mask = gen_supervision( x, x_mask, y, args) zero_state = gen_zero_state(n_batch, args.n_hidden) feed_dict = { tg.seq_x_data: new_x, tg.seq_x_mask: new_x_mask, tg.seq_y_data: new_y, tg.seq_jump_data: actions } feed_init_state(feed_dict, tg.init_state, zero_state) _tr_ml_cost, _tr_rl_cost, _, _ = \ sess.run([tg.ml_cost, tg.rl_cost, ml_op, rl_op], feed_dict=feed_dict) tr_ce_sum += _tr_ml_cost.sum() tr_ce_count += new_x_mask.sum() tr_ce2_sum += _tr_rl_cost.sum() tr_ce2_count += new_x_mask[:, :-1].sum() actions_1hot, label_probs, new_mask, output_image = \ skip_rnn_forward_supervised(x, x_mask, sess, test_graph, args.fast_action, args.n_fast_action, y) pred_idx = expand_output(actions_1hot, x_mask, new_mask, label_probs.argmax(axis=-1)) tr_acc_sum += ((pred_idx == y) * x_mask).sum() tr_acc_count += x_mask.sum() _tr_ce_summary, _tr_fer_summary, _tr_ce2_summary, _tr_image_summary = \ sess.run([tr_ce_summary, tr_fer_summary, tr_ce2_summary, tr_image_summary], feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(), tr_fer: 1 - ((pred_idx == y) * x_mask).sum() / x_mask.sum(), tr_ce2: _tr_rl_cost.sum() / new_x_mask[:,:-1].sum(), tr_image: output_image}) summary_writer.add_summary(_tr_ce_summary, global_step.eval()) summary_writer.add_summary(_tr_fer_summary, global_step.eval()) summary_writer.add_summary(_tr_ce2_summary, global_step.eval()) summary_writer.add_summary(_tr_image_summary, global_step.eval()) else: # train jump prediction part new_x, _, actions, _, new_x_mask = gen_supervision( x, x_mask, y, args) zero_state = gen_zero_state(n_batch, args.n_hidden) feed_dict = { tg.seq_x_data: new_x, tg.seq_x_mask: new_x_mask, tg.seq_jump_data: actions } feed_init_state(feed_dict, tg.init_state, zero_state) _tr_rl_cost, _ = sess.run([tg.rl_cost, rl_op], feed_dict=feed_dict) tr_ce2_sum += _tr_rl_cost.sum() tr_ce2_count += new_x_mask[:, :-1].sum() _tr_ce2_summary, = sess.run([tr_ce2_summary], feed_dict={ tr_ce2: _tr_rl_cost.sum() / new_x_mask[:, :-1].sum() }) # generate jumps from the model new_x, new_y, actions_1hot, label_probs, new_x_mask, output_image = gen_episode_supervised( x, y, x_mask, sess, test_graph, args.fast_action, args.n_fast_action) feed_dict = { tg.seq_x_data: new_x, tg.seq_x_mask: new_x_mask, tg.seq_y_data: new_y } feed_init_state(feed_dict, tg.init_state, zero_state) # train label prediction part _tr_ml_cost, _ = sess.run([tg.ml_cost, ml_op], feed_dict=feed_dict) tr_ce_sum += _tr_ml_cost.sum() tr_ce_count += new_x_mask.sum() actions_1hot, label_probs, new_mask, output_image = \ skip_rnn_forward_supervised(x, x_mask, sess, test_graph, args.fast_action, args.n_fast_action, y) pred_idx = expand_output(actions_1hot, x_mask, new_mask, label_probs.argmax(axis=-1)) tr_acc_sum += ((pred_idx == y) * x_mask).sum() tr_acc_count += x_mask.sum() _tr_ce_summary, _tr_fer_summary, _tr_image_summary = \ sess.run([tr_ce_summary, tr_fer_summary, tr_image_summary], feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(), tr_fer: 1 - ((pred_idx == y) * x_mask).sum() / x_mask.sum(), tr_image: output_image}) summary_writer.add_summary(_tr_ce_summary, global_step.eval()) summary_writer.add_summary(_tr_fer_summary, global_step.eval()) summary_writer.add_summary(_tr_ce2_summary, global_step.eval()) summary_writer.add_summary(_tr_image_summary, global_step.eval()) if global_step.eval() % args.display_freq == 0: avg_tr_ce = tr_ce_sum / tr_ce_count avg_tr_fer = 1. - tr_acc_sum / tr_acc_count avg_tr_ce2 = tr_ce2_sum / tr_ce2_count print( "TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} time_taken={:.2f}" .format(_epoch, global_step.eval(), avg_tr_ce, avg_tr_fer, avg_tr_ce2, disp_sw.elapsed())) tr_ce_sum = 0. tr_ce_count = 0 tr_acc_sum = 0. tr_acc_count = 0 tr_ce2_sum = 0. tr_ce2_count = 0 disp_sw.reset() print('--') print('End of epoch {}'.format(_epoch + 1)) epoch_sw.print_elapsed() print('Testing') # Evaluate the model on the validation set val_acc_sum = 0 val_acc_count = 0 eval_sw.reset() for batch in valid_set.get_epoch_iterator(): x, x_mask, _, _, y, _ = batch n_batch = x.shape[0] actions_1hot, label_probs, new_mask, output_image = \ skip_rnn_forward_supervised(x, x_mask, sess, test_graph, args.fast_action, args.n_fast_action, y) pred_idx = expand_output(actions_1hot, x_mask, new_mask, label_probs.argmax(axis=-1)) val_acc_sum += ((pred_idx == y) * x_mask).sum() val_acc_count += x_mask.sum() avg_val_fer = 1. - val_acc_sum / val_acc_count print("VALID: epoch={} fer={:.2f} time_taken={:.2f}".format( _epoch, avg_val_fer, eval_sw.elapsed())) _val_fer_summary, _val_image_summary = sess.run( [val_fer_summary, val_image_summary], feed_dict={ val_fer: avg_val_fer, val_image: output_image }) summary_writer.add_summary(_val_fer_summary, global_step.eval()) summary_writer.add_summary(_val_image_summary, global_step.eval()) insert_item2dict(eval_summary, 'val_fer', avg_val_fer) insert_item2dict(eval_summary, 'time', eval_sw.elapsed()) save_npz2(file_name, eval_summary) # Save model if avg_val_fer < _best_score: _best_score = avg_val_fer best_ckpt = best_save_op.save(sess, os.path.join( args.log_dir, "best_model.ckpt"), global_step=global_step) print("Best checkpoint stored in: %s" % best_ckpt) ckpt = save_op.save(sess, os.path.join(args.log_dir, "model.ckpt"), global_step=global_step) print("Checkpoint stored in: %s" % ckpt) _best_val_fer_summary, = sess.run( [best_val_fer_summary], feed_dict={best_val_fer: _best_score}) summary_writer.add_summary(_best_val_fer_summary, global_step.eval()) summary_writer.close() print("Optimization Finished.")