def main(FLAGS): # set seed np.random.seed(FLAGS.seed) tf.set_random_seed(FLAGS.seed) with tf.device('/cpu:0'), tf.name_scope('input'): # load data data, meta = load_data(FLAGS.dataset_root, FLAGS.dataset, is_training=True) train_data, val_data = split_data(data, FLAGS.validate_rate) batch_size = FLAGS.n_class_per_iter * FLAGS.n_img_per_class img_shape = train_data[0].shape[1:] # build DataSampler train_data_sampler = DataSampler(train_data, meta['n_class'], FLAGS.n_class_per_iter, FLAGS.n_img_per_class) val_data_sampler = DataSampler(val_data, meta['n_class'], FLAGS.n_class_per_iter, FLAGS.n_img_per_class) # build tf_dataset for training train_dataset = (tf.data.Dataset.from_generator( lambda: train_data_sampler, (tf.float32, tf.int32), ([batch_size, *img_shape ], [batch_size])).take(FLAGS.n_iter_per_epoch).flat_map( lambda x, y: tf.data.Dataset.from_tensor_slices((x, y))).map( preprocess_for_train, 8).batch(batch_size).prefetch(1)) # build tf_dataset for val val_dataset = (tf.data.Dataset.from_generator( lambda: val_data_sampler, (tf.float32, tf.int32), ([batch_size, *img_shape], [batch_size])).take(100).flat_map( lambda x, y: tf.data.Dataset.from_tensor_slices((x, y))).map( preprocess_for_eval, 8).batch(batch_size).prefetch(1)) # clean up del data, train_data, val_data # construct data iterator data_iterator = tf.data.Iterator.from_structure( train_dataset.output_types, train_dataset.output_shapes) # construct iterator initializer for training and validation train_data_init = data_iterator.make_initializer(train_dataset) val_data_init = data_iterator.make_initializer(val_dataset) # get data from data iterator images, labels = data_iterator.get_next() tf.summary.image('images', images) # define useful scalars learning_rate = tf.placeholder(tf.float32, shape=(), name='learning_rate') tf.summary.scalar('lr', learning_rate) is_training = tf.placeholder(tf.bool, [], name='is_training') global_step = tf.train.create_global_step() # define optimizer optimizer = tf.train.AdamOptimizer(learning_rate) # optimizer = tf.train.GradientDescentOptimizer(learning_rate) # build the net model = importlib.import_module('models.{}'.format(FLAGS.model)) net = model.Net(n_feats=FLAGS.n_feats, weight_decay=FLAGS.weight_decay) if net.data_format == 'channels_first' or net.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # get features features = net(images, is_training) tf.summary.histogram('features', features) # summary variable defined in net for w in net.global_variables: tf.summary.histogram(w.name, w) with tf.name_scope('losses'): # compute loss, if features is l2 normed, then 2 * cosine_distance will # equal squared l2 distance. distance = 2 * custom_ops.cosine_distance(features) # hard mining arch_idx, pos_idx, neg_idx = custom_ops.semi_hard_mining( distance, FLAGS.n_class_per_iter, FLAGS.n_img_per_class, FLAGS.threshold) # triplet loss N_pair_lefted = tf.shape(arch_idx)[0] def true_fn(): pos_distance = tf.gather_nd(distance, tf.stack([arch_idx, pos_idx], 1)) neg_distance = tf.gather_nd(distance, tf.stack([arch_idx, neg_idx], 1)) return custom_ops.triplet_distance(pos_distance, neg_distance, FLAGS.threshold) loss = tf.cond(N_pair_lefted > 0, true_fn, lambda: 0.) pair_rate = N_pair_lefted / (FLAGS.n_class_per_iter * FLAGS.n_img_per_class**2) # compute l2 regularization l2_reg = tf.losses.get_regularization_loss() with tf.name_scope('metrics') as scope: mean_loss, mean_loss_update_op = tf.metrics.mean(loss, name='mean_loss') mean_pair_rate, mean_pair_rate_update_op = tf.metrics.mean( pair_rate, name='mean_pair_rate') reset_metrics = tf.variables_initializer( tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope)) metrics_update_op = tf.group(mean_loss_update_op, mean_pair_rate_update_op) # collect metric summary alone, because it need to # summary after metrics update metric_summary = [ tf.summary.scalar('loss', mean_loss, collections=[]), tf.summary.scalar('pair_rate', mean_pair_rate, collections=[]) ] # compute grad grads_and_vars = optimizer.compute_gradients(loss + l2_reg) # summary grads for g, v in grads_and_vars: tf.summary.histogram(v.name + '/grad', g) # run train_op and update_op together train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) train_op = tf.group(train_op, *update_ops) # build summary jpg_img_str = tf.placeholder(tf.string, shape=[], name='jpg_img_str') emb_summary_str = tf.summary.image( 'emb', tf.expand_dims(tf.image.decode_image(jpg_img_str, 3), 0), collections=[]) train_summary_str = tf.summary.merge_all() metric_summary_str = tf.summary.merge(metric_summary) # init op init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # prepare for the logdir if not tf.gfile.Exists(FLAGS.logdir): tf.gfile.MakeDirs(FLAGS.logdir) # saver saver = tf.train.Saver(max_to_keep=FLAGS.n_epoch) # summary writer train_writer = tf.summary.FileWriter(os.path.join(FLAGS.logdir, 'train'), tf.get_default_graph()) val_writer = tf.summary.FileWriter(os.path.join(FLAGS.logdir, 'val'), tf.get_default_graph()) # session config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, intra_op_parallelism_threads=8, inter_op_parallelism_threads=0) config.gpu_options.allow_growth = True sess = tf.Session(config=config) # do initialization sess.run(init_op) # restore if FLAGS.restore: saver.restore(sess, FLAGS.restore) lr_boundaries = list(map(int, FLAGS.boundaries.split(','))) lr_values = list(map(float, FLAGS.values.split(','))) lr_manager = LRManager(lr_boundaries, lr_values) time_meter = TimeMeter() # start to train for e in range(FLAGS.n_epoch): print('-' * 40) print('Epoch: {:d}'.format(e)) # training loop try: i = 0 sess.run([train_data_init, reset_metrics]) while True: lr = lr_manager.get(e) fetch = [train_summary_str] if i % FLAGS.log_every == 0 else [] time_meter.start() result = sess.run([train_op, metrics_update_op] + fetch, { learning_rate: lr, is_training: True }) time_meter.stop() if i % FLAGS.log_every == 0: # fetch summary str t_summary = result[-1] t_metric_summary = sess.run(metric_summary_str) t_loss, t_pr = sess.run([mean_loss, mean_pair_rate]) sess.run(reset_metrics) spd = batch_size / time_meter.get_and_reset() print( 'Iter: {:d}, LR: {:g}, Loss: {:.4f}, PR: {:.2f}, Spd: {:.2f} i/s' .format(i, lr, t_loss, t_pr, spd)) train_writer.add_summary(t_summary, global_step=sess.run(global_step)) train_writer.add_summary(t_metric_summary, global_step=sess.run(global_step)) i += 1 except tf.errors.OutOfRangeError: pass # save checkpoint saver.save(sess, '{}/{}'.format(FLAGS.logdir, FLAGS.model), global_step=sess.run(global_step), write_meta_graph=False) # val loop try: sess.run([val_data_init, reset_metrics]) v_flist, v_llist = [], [] v_iter = 0 while True: v_feats, v_labels, _ = sess.run( [features, labels, metrics_update_op], {is_training: False}) if v_iter < FLAGS.n_iter_for_emb: v_flist.append(v_feats) v_llist.append(v_labels) v_iter += 1 except tf.errors.OutOfRangeError: pass v_loss, v_pr = sess.run([mean_loss, mean_pair_rate]) print('[VAL]Loss: {:.4f}, PR: {:.2f}'.format(v_loss, v_pr)) v_jpg_str = feat2emb( np.concatenate(v_flist, axis=0), np.concatenate(v_llist, axis=0), TSNE_transform if int(FLAGS.n_feats) > 2 else None) val_writer.add_summary(sess.run(metric_summary_str), global_step=sess.run(global_step)) val_writer.add_summary(sess.run(emb_summary_str, {jpg_img_str: v_jpg_str}), global_step=sess.run(global_step)) print('-' * 40)
def main(args): # create env env = Env(args.game) with tf.name_scope('data'), tf.device('/cpu:0'): global_step = tf.train.create_global_step() epsilon = tf.placeholder(dtype=tf.float32, name='epsilon') if not args.use_priority: memory_pool = MemoryPool(args.pool_size) output_dtypes = (tf.uint8, tf.int64, tf.float32, tf.bool, tf.uint8) output_shapes = ([None, 4, 84, 84], [ None, ], [ None, ], [ None, ], [None, 4, 84, 84]) else: memory_pool = PriorityMemoryPool(args.pool_size) output_dtypes = (tf.uint8, tf.int64, tf.float32, tf.bool, tf.uint8, tf.int64, tf.float32) output_shapes = ([None, 4, 84, 84], [ None, ], [ None, ], [ None, ], [None, 4, 84, 84], [ None, ], [ None, ]) train_dataset = (tf.data.Dataset.from_generator( lambda: memory_iter(memory_pool, args.batch_size), output_dtypes, output_shapes=output_shapes).prefetch(1)) iterator = train_dataset.make_one_shot_iterator() if not args.use_priority: state, action, reward, done, next_state = iterator.get_next() else: state, action, reward, done, next_state, indices, priorities = iterator.get_next( ) with tf.name_scope('net'), tf.device('/device:GPU:0'): Net = importlib.import_module('models.{}'.format(args.model)).Net online_net = Net(env.n_action, name='online_net') target_net = Net(env.n_action, name='target_net') if args.model != 'c51_cnn': online_q = online_net(state_preprocess(state), True) target_q = target_net(state_preprocess(next_state), False) else: support = np.linspace(-args.vmax, args.vmax, args.n_atom, dtype='float32') online_logits = online_net(state_preprocess(state), True) online_q_distribution = softmax(online_logits, axis=2) online_q = tf.reduce_sum(online_q_distribution * support, axis=2) target_logits = target_net(state_preprocess(next_state), False) target_q_distribution = softmax(target_logits, axis=2) target_q = tf.reduce_sum(target_q_distribution * support, axis=2) # choice action max_action = tf.cond( tf.random_uniform(shape=[], dtype=tf.float32) < epsilon, lambda: sample_action(env.n_action), lambda: tf.argmax(online_q, axis=1)[0]) sync_op = [] for w_target, w_online in zip(target_net.global_variables, online_net.global_variables): sync_op.append(tf.assign(w_target, w_online, use_locking=True)) sync_op = tf.group(*sync_op) with tf.name_scope('losses'), tf.device('/device:GPU:0'): if args.model != 'c51_cnn': # compute online net q value online_q_val = tf.reduce_sum( tf.one_hot(action, env.n_action, 1., 0.) * online_q, axis=1) # compute target value if args.double: online_action = tf.argmax(online_net( state_preprocess(next_state), False), axis=1) Y = tf.reduce_sum( tf.one_hot(online_action, env.n_action, 1., 0.) * target_q, axis=1) else: Y = tf.reduce_max(target_q, axis=1) target_q_max = reward + args.gamma * \ Y * (1. - tf.cast(done, 'float32')) target_q_max = tf.stop_gradient(target_q_max) loss = tf.losses.huber_loss(labels=target_q_max, predictions=online_q_val, reduction=tf.losses.Reduction.NONE) else: batch_indices = tf.range(args.batch_size, dtype=tf.int64) indices = tf.stack([batch_indices, action], axis=1) online_q_val = tf.gather_nd(online_q, indices) online_a_logits = tf.gather_nd(online_logits, indices) projected_distribution = project_distribution( target_q_distribution, support, reward, done, args.gamma, args.batch_size) projected_distribution = tf.stop_gradient(projected_distribution) loss = tf.nn.softmax_cross_entropy_with_logits( logits=online_a_logits, labels=projected_distribution) if args.use_priority: print('using priority memory...') new_priority = tf.sqrt(loss + 1e-10) priorities_update_op = tf.py_func(memory_pool.set_priority, [indices, new_priority], [], name='priority_update_op') loss_weights = 1.0 / tf.sqrt(priorities + 1e-10) loss_weights = loss_weights / tf.reduce_max(loss_weights) loss = loss_weights * loss else: priorities_update_op = tf.no_op() loss = tf.reduce_mean(loss) with tf.name_scope('metrics') as scope: mean_q_val, mean_q_val_update_op = tf.metrics.mean( tf.reduce_mean(online_q_val)) mean_loss, mean_loss_update_op = tf.metrics.mean(loss) episode_reward = tf.placeholder(tf.float32, shape=[]) mean_episode_reward, mean_episode_reward_update_op = tf.metrics.mean( episode_reward) metrics_update_op = tf.group(mean_q_val_update_op, mean_loss_update_op) metrics_reset_op = tf.variables_initializer( tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope)) # collect metric summary alone, because it need to # summary after metrics update metric_summary = [ tf.summary.scalar('loss', mean_loss, collections=[]), tf.summary.scalar('mean_q_val', mean_q_val, collections=[]), tf.summary.scalar('mean_episode_reward', mean_episode_reward, collections=[]) ] optimizer = tf.train.AdamOptimizer(args.learning_rate, epsilon=0.01 / 32) with tf.device('/device:GPU:0'): grad_and_v = optimizer.compute_gradients( loss, var_list=online_net.trainable_variables) train_op = optimizer.apply_gradients(grad_and_v, global_step=global_step) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) train_op = tf.group(train_op, priorities_update_op, *update_ops) # build summary for g, v in grad_and_v: tf.summary.histogram(v.name + '/grad', g) for w in online_net.global_variables + target_net.global_variables: tf.summary.histogram(w.name, w) tf.summary.histogram('q_value', online_q) train_summary_str = tf.summary.merge_all() metric_summary_str = tf.summary.merge(metric_summary) # init op init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # prepare for the logdir if not tf.gfile.Exists(args.logdir): tf.gfile.MakeDirs(args.logdir) # summary writer train_writer = tf.summary.FileWriter(os.path.join(args.logdir, 'train'), tf.get_default_graph()) # session config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, intra_op_parallelism_threads=4, inter_op_parallelism_threads=4) config.gpu_options.allow_growth = True sess = tf.Session(config=config) # do initialization sess.run(init_op) # saver saver = tf.train.Saver(max_to_keep=20) if args.restore: saver.restore(sess, args.restore) reward_sum = 0 n_step = 0 state_recorder = StateRecorder() tm = TimeMeter() start = sess.run(global_step) * args.update_period pre_ob = env.reset() sess.run([sync_op]) for i in range(args.max_step): i_s = i + start if i_s < args.min_replay_history: eps = 1. else: eps = linearly_decaying_epsilon(args.epsilon_decay_period, i_s, args.min_replay_history, args.epsilon_train) state_recorder.add(pre_ob) a = sess.run(max_action, { state: state_recorder.state[None, ...], epsilon: eps }) # step env ob, r, t, _ = env.step(a) n_step += 1 reward_sum += r r = np.clip(r, -1, 1) # record pre observation, action, reward, False memory_pool.add(pre_ob, a, r, False) pre_ob = ob if args.render: env.render() if t: print('reward sum: {:.2f}, n step: {:d}'.format( reward_sum, n_step)) sess.run([mean_episode_reward_update_op], {episode_reward: reward_sum}) # reset env memory_pool.add(pre_ob, 0, 0, True) pre_ob = env.reset() state_recorder.reset() reward_sum, n_step = 0., 0 if i >= args.min_replay_history: if i_s % args.update_period == 0: tm.start() summary_fetch = ([train_summary_str] if abs( (i_s % args.print_every - args.print_every) % args.print_every) < args.update_period else []) ret = sess.run([train_op, metrics_update_op] + summary_fetch) tm.stop() if i_s % args.print_every == 0: ml, mq, mer = sess.run( [mean_loss, mean_q_val, mean_episode_reward]) print(('Step: {:d}, Mean Loss: {:.4f}, Mean Q: {:.4f}, ' 'Mean Episode Reward: {:.2f}, Epsilon: {:.2f}, ' 'Speed: {:.2f} i/s').format(i_s, ml, mq, mer, eps, args.batch_size / tm.get())) if len(ret) > 2: train_writer.add_summary(ret[-1], sess.run(global_step)) train_writer.add_summary(sess.run(metric_summary_str), sess.run(global_step)) tm.reset() if i_s % args.target_update_period == 0: sess.run([sync_op, metrics_reset_op]) print('........sync........') if i_s % 10000 == 0: saver.save(sess, '{}/{}'.format(args.logdir, args.model), global_step=sess.run(global_step), write_meta_graph=False)