Пример #1
0
def main(FLAGS):

    # set seed
    np.random.seed(FLAGS.seed)
    tf.set_random_seed(FLAGS.seed)

    with tf.device('/cpu:0'), tf.name_scope('input'):

        # load data
        data, meta = load_data(FLAGS.dataset_root,
                               FLAGS.dataset,
                               is_training=True)
        train_data, val_data = split_data(data, FLAGS.validate_rate)
        batch_size = FLAGS.n_class_per_iter * FLAGS.n_img_per_class
        img_shape = train_data[0].shape[1:]

        # build DataSampler
        train_data_sampler = DataSampler(train_data, meta['n_class'],
                                         FLAGS.n_class_per_iter,
                                         FLAGS.n_img_per_class)

        val_data_sampler = DataSampler(val_data, meta['n_class'],
                                       FLAGS.n_class_per_iter,
                                       FLAGS.n_img_per_class)

        # build tf_dataset for training
        train_dataset = (tf.data.Dataset.from_generator(
            lambda: train_data_sampler, (tf.float32, tf.int32),
            ([batch_size, *img_shape
              ], [batch_size])).take(FLAGS.n_iter_per_epoch).flat_map(
                  lambda x, y: tf.data.Dataset.from_tensor_slices((x, y))).map(
                      preprocess_for_train, 8).batch(batch_size).prefetch(1))

        # build tf_dataset for val
        val_dataset = (tf.data.Dataset.from_generator(
            lambda: val_data_sampler, (tf.float32, tf.int32),
            ([batch_size, *img_shape], [batch_size])).take(100).flat_map(
                lambda x, y: tf.data.Dataset.from_tensor_slices((x, y))).map(
                    preprocess_for_eval, 8).batch(batch_size).prefetch(1))

        # clean up
        del data, train_data, val_data

        # construct data iterator
        data_iterator = tf.data.Iterator.from_structure(
            train_dataset.output_types, train_dataset.output_shapes)

        # construct iterator initializer for training and validation
        train_data_init = data_iterator.make_initializer(train_dataset)
        val_data_init = data_iterator.make_initializer(val_dataset)

        # get data from data iterator
        images, labels = data_iterator.get_next()
        tf.summary.image('images', images)

    # define useful scalars
    learning_rate = tf.placeholder(tf.float32, shape=(), name='learning_rate')
    tf.summary.scalar('lr', learning_rate)
    is_training = tf.placeholder(tf.bool, [], name='is_training')
    global_step = tf.train.create_global_step()

    # define optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # optimizer = tf.train.GradientDescentOptimizer(learning_rate)

    # build the net
    model = importlib.import_module('models.{}'.format(FLAGS.model))
    net = model.Net(n_feats=FLAGS.n_feats, weight_decay=FLAGS.weight_decay)

    if net.data_format == 'channels_first' or net.data_format == 'NCHW':
        images = tf.transpose(images, [0, 3, 1, 2])

    # get features
    features = net(images, is_training)
    tf.summary.histogram('features', features)

    # summary variable defined in net
    for w in net.global_variables:
        tf.summary.histogram(w.name, w)

    with tf.name_scope('losses'):
        # compute loss, if features is l2 normed, then 2 * cosine_distance will
        # equal squared l2 distance.
        distance = 2 * custom_ops.cosine_distance(features)
        # hard mining
        arch_idx, pos_idx, neg_idx = custom_ops.semi_hard_mining(
            distance, FLAGS.n_class_per_iter, FLAGS.n_img_per_class,
            FLAGS.threshold)

        # triplet loss
        N_pair_lefted = tf.shape(arch_idx)[0]

        def true_fn():
            pos_distance = tf.gather_nd(distance,
                                        tf.stack([arch_idx, pos_idx], 1))
            neg_distance = tf.gather_nd(distance,
                                        tf.stack([arch_idx, neg_idx], 1))
            return custom_ops.triplet_distance(pos_distance, neg_distance,
                                               FLAGS.threshold)

        loss = tf.cond(N_pair_lefted > 0, true_fn, lambda: 0.)
        pair_rate = N_pair_lefted / (FLAGS.n_class_per_iter *
                                     FLAGS.n_img_per_class**2)

        # compute l2 regularization
        l2_reg = tf.losses.get_regularization_loss()

    with tf.name_scope('metrics') as scope:

        mean_loss, mean_loss_update_op = tf.metrics.mean(loss,
                                                         name='mean_loss')

        mean_pair_rate, mean_pair_rate_update_op = tf.metrics.mean(
            pair_rate, name='mean_pair_rate')

        reset_metrics = tf.variables_initializer(
            tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope))
        metrics_update_op = tf.group(mean_loss_update_op,
                                     mean_pair_rate_update_op)

        # collect metric summary alone, because it need to
        # summary after metrics update
        metric_summary = [
            tf.summary.scalar('loss', mean_loss, collections=[]),
            tf.summary.scalar('pair_rate', mean_pair_rate, collections=[])
        ]

    # compute grad
    grads_and_vars = optimizer.compute_gradients(loss + l2_reg)

    # summary grads
    for g, v in grads_and_vars:
        tf.summary.histogram(v.name + '/grad', g)

    # run train_op and update_op together
    train_op = optimizer.apply_gradients(grads_and_vars,
                                         global_step=global_step)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    train_op = tf.group(train_op, *update_ops)

    # build summary
    jpg_img_str = tf.placeholder(tf.string, shape=[], name='jpg_img_str')
    emb_summary_str = tf.summary.image(
        'emb',
        tf.expand_dims(tf.image.decode_image(jpg_img_str, 3), 0),
        collections=[])
    train_summary_str = tf.summary.merge_all()
    metric_summary_str = tf.summary.merge(metric_summary)

    # init op
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # prepare for the logdir
    if not tf.gfile.Exists(FLAGS.logdir):
        tf.gfile.MakeDirs(FLAGS.logdir)

    # saver
    saver = tf.train.Saver(max_to_keep=FLAGS.n_epoch)

    # summary writer
    train_writer = tf.summary.FileWriter(os.path.join(FLAGS.logdir, 'train'),
                                         tf.get_default_graph())
    val_writer = tf.summary.FileWriter(os.path.join(FLAGS.logdir, 'val'),
                                       tf.get_default_graph())

    # session
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False,
                            intra_op_parallelism_threads=8,
                            inter_op_parallelism_threads=0)
    config.gpu_options.allow_growth = True

    sess = tf.Session(config=config)

    # do initialization
    sess.run(init_op)

    # restore
    if FLAGS.restore:
        saver.restore(sess, FLAGS.restore)

    lr_boundaries = list(map(int, FLAGS.boundaries.split(',')))
    lr_values = list(map(float, FLAGS.values.split(',')))
    lr_manager = LRManager(lr_boundaries, lr_values)
    time_meter = TimeMeter()

    # start to train
    for e in range(FLAGS.n_epoch):
        print('-' * 40)
        print('Epoch: {:d}'.format(e))

        # training loop
        try:
            i = 0
            sess.run([train_data_init, reset_metrics])
            while True:

                lr = lr_manager.get(e)
                fetch = [train_summary_str] if i % FLAGS.log_every == 0 else []

                time_meter.start()
                result = sess.run([train_op, metrics_update_op] + fetch, {
                    learning_rate: lr,
                    is_training: True
                })
                time_meter.stop()

                if i % FLAGS.log_every == 0:
                    # fetch summary str
                    t_summary = result[-1]
                    t_metric_summary = sess.run(metric_summary_str)

                    t_loss, t_pr = sess.run([mean_loss, mean_pair_rate])
                    sess.run(reset_metrics)

                    spd = batch_size / time_meter.get_and_reset()

                    print(
                        'Iter: {:d}, LR: {:g}, Loss: {:.4f}, PR: {:.2f}, Spd: {:.2f} i/s'
                        .format(i, lr, t_loss, t_pr, spd))

                    train_writer.add_summary(t_summary,
                                             global_step=sess.run(global_step))
                    train_writer.add_summary(t_metric_summary,
                                             global_step=sess.run(global_step))

                i += 1
        except tf.errors.OutOfRangeError:
            pass

        # save checkpoint
        saver.save(sess,
                   '{}/{}'.format(FLAGS.logdir, FLAGS.model),
                   global_step=sess.run(global_step),
                   write_meta_graph=False)

        # val loop
        try:
            sess.run([val_data_init, reset_metrics])
            v_flist, v_llist = [], []
            v_iter = 0
            while True:
                v_feats, v_labels, _ = sess.run(
                    [features, labels, metrics_update_op],
                    {is_training: False})
                if v_iter < FLAGS.n_iter_for_emb:
                    v_flist.append(v_feats)
                    v_llist.append(v_labels)
                v_iter += 1
        except tf.errors.OutOfRangeError:
            pass

        v_loss, v_pr = sess.run([mean_loss, mean_pair_rate])
        print('[VAL]Loss: {:.4f}, PR: {:.2f}'.format(v_loss, v_pr))

        v_jpg_str = feat2emb(
            np.concatenate(v_flist, axis=0), np.concatenate(v_llist, axis=0),
            TSNE_transform if int(FLAGS.n_feats) > 2 else None)

        val_writer.add_summary(sess.run(metric_summary_str),
                               global_step=sess.run(global_step))
        val_writer.add_summary(sess.run(emb_summary_str,
                                        {jpg_img_str: v_jpg_str}),
                               global_step=sess.run(global_step))

    print('-' * 40)
Пример #2
0
def main(args):

    # create env
    env = Env(args.game)

    with tf.name_scope('data'), tf.device('/cpu:0'):
        global_step = tf.train.create_global_step()
        epsilon = tf.placeholder(dtype=tf.float32, name='epsilon')

        if not args.use_priority:
            memory_pool = MemoryPool(args.pool_size)
            output_dtypes = (tf.uint8, tf.int64, tf.float32, tf.bool, tf.uint8)
            output_shapes = ([None, 4, 84, 84], [
                None,
            ], [
                None,
            ], [
                None,
            ], [None, 4, 84, 84])
        else:
            memory_pool = PriorityMemoryPool(args.pool_size)
            output_dtypes = (tf.uint8, tf.int64, tf.float32, tf.bool, tf.uint8,
                             tf.int64, tf.float32)
            output_shapes = ([None, 4, 84, 84], [
                None,
            ], [
                None,
            ], [
                None,
            ], [None, 4, 84, 84], [
                None,
            ], [
                None,
            ])

        train_dataset = (tf.data.Dataset.from_generator(
            lambda: memory_iter(memory_pool, args.batch_size),
            output_dtypes,
            output_shapes=output_shapes).prefetch(1))

        iterator = train_dataset.make_one_shot_iterator()

        if not args.use_priority:
            state, action, reward, done, next_state = iterator.get_next()
        else:
            state, action, reward, done, next_state, indices, priorities = iterator.get_next(
            )

    with tf.name_scope('net'), tf.device('/device:GPU:0'):
        Net = importlib.import_module('models.{}'.format(args.model)).Net

        online_net = Net(env.n_action, name='online_net')
        target_net = Net(env.n_action, name='target_net')

        if args.model != 'c51_cnn':
            online_q = online_net(state_preprocess(state), True)
            target_q = target_net(state_preprocess(next_state), False)
        else:
            support = np.linspace(-args.vmax,
                                  args.vmax,
                                  args.n_atom,
                                  dtype='float32')

            online_logits = online_net(state_preprocess(state), True)
            online_q_distribution = softmax(online_logits, axis=2)
            online_q = tf.reduce_sum(online_q_distribution * support, axis=2)

            target_logits = target_net(state_preprocess(next_state), False)
            target_q_distribution = softmax(target_logits, axis=2)
            target_q = tf.reduce_sum(target_q_distribution * support, axis=2)

        # choice action
        max_action = tf.cond(
            tf.random_uniform(shape=[], dtype=tf.float32) < epsilon,
            lambda: sample_action(env.n_action),
            lambda: tf.argmax(online_q, axis=1)[0])

        sync_op = []
        for w_target, w_online in zip(target_net.global_variables,
                                      online_net.global_variables):
            sync_op.append(tf.assign(w_target, w_online, use_locking=True))
        sync_op = tf.group(*sync_op)

    with tf.name_scope('losses'), tf.device('/device:GPU:0'):
        if args.model != 'c51_cnn':
            # compute online net q value
            online_q_val = tf.reduce_sum(
                tf.one_hot(action, env.n_action, 1., 0.) * online_q, axis=1)

            # compute target value
            if args.double:
                online_action = tf.argmax(online_net(
                    state_preprocess(next_state), False),
                                          axis=1)
                Y = tf.reduce_sum(
                    tf.one_hot(online_action, env.n_action, 1., 0.) * target_q,
                    axis=1)
            else:
                Y = tf.reduce_max(target_q, axis=1)

            target_q_max = reward + args.gamma * \
                Y * (1. - tf.cast(done, 'float32'))
            target_q_max = tf.stop_gradient(target_q_max)

            loss = tf.losses.huber_loss(labels=target_q_max,
                                        predictions=online_q_val,
                                        reduction=tf.losses.Reduction.NONE)
        else:
            batch_indices = tf.range(args.batch_size, dtype=tf.int64)
            indices = tf.stack([batch_indices, action], axis=1)

            online_q_val = tf.gather_nd(online_q, indices)
            online_a_logits = tf.gather_nd(online_logits, indices)

            projected_distribution = project_distribution(
                target_q_distribution, support, reward, done, args.gamma,
                args.batch_size)
            projected_distribution = tf.stop_gradient(projected_distribution)

            loss = tf.nn.softmax_cross_entropy_with_logits(
                logits=online_a_logits, labels=projected_distribution)

        if args.use_priority:
            print('using priority memory...')
            new_priority = tf.sqrt(loss + 1e-10)

            priorities_update_op = tf.py_func(memory_pool.set_priority,
                                              [indices, new_priority], [],
                                              name='priority_update_op')

            loss_weights = 1.0 / tf.sqrt(priorities + 1e-10)
            loss_weights = loss_weights / tf.reduce_max(loss_weights)
            loss = loss_weights * loss
        else:
            priorities_update_op = tf.no_op()

        loss = tf.reduce_mean(loss)

    with tf.name_scope('metrics') as scope:
        mean_q_val, mean_q_val_update_op = tf.metrics.mean(
            tf.reduce_mean(online_q_val))
        mean_loss, mean_loss_update_op = tf.metrics.mean(loss)

        episode_reward = tf.placeholder(tf.float32, shape=[])
        mean_episode_reward, mean_episode_reward_update_op = tf.metrics.mean(
            episode_reward)

        metrics_update_op = tf.group(mean_q_val_update_op, mean_loss_update_op)
        metrics_reset_op = tf.variables_initializer(
            tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope))

        # collect metric summary alone, because it need to
        # summary after metrics update
        metric_summary = [
            tf.summary.scalar('loss', mean_loss, collections=[]),
            tf.summary.scalar('mean_q_val', mean_q_val, collections=[]),
            tf.summary.scalar('mean_episode_reward',
                              mean_episode_reward,
                              collections=[])
        ]

    optimizer = tf.train.AdamOptimizer(args.learning_rate, epsilon=0.01 / 32)

    with tf.device('/device:GPU:0'):
        grad_and_v = optimizer.compute_gradients(
            loss, var_list=online_net.trainable_variables)

        train_op = optimizer.apply_gradients(grad_and_v,
                                             global_step=global_step)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        train_op = tf.group(train_op, priorities_update_op, *update_ops)

    # build summary
    for g, v in grad_and_v:
        tf.summary.histogram(v.name + '/grad', g)
    for w in online_net.global_variables + target_net.global_variables:
        tf.summary.histogram(w.name, w)
    tf.summary.histogram('q_value', online_q)

    train_summary_str = tf.summary.merge_all()
    metric_summary_str = tf.summary.merge(metric_summary)

    # init op
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # prepare for the logdir
    if not tf.gfile.Exists(args.logdir):
        tf.gfile.MakeDirs(args.logdir)

    # summary writer
    train_writer = tf.summary.FileWriter(os.path.join(args.logdir, 'train'),
                                         tf.get_default_graph())

    # session
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False,
                            intra_op_parallelism_threads=4,
                            inter_op_parallelism_threads=4)
    config.gpu_options.allow_growth = True

    sess = tf.Session(config=config)

    # do initialization
    sess.run(init_op)

    # saver
    saver = tf.train.Saver(max_to_keep=20)
    if args.restore:
        saver.restore(sess, args.restore)

    reward_sum = 0
    n_step = 0

    state_recorder = StateRecorder()
    tm = TimeMeter()

    start = sess.run(global_step) * args.update_period

    pre_ob = env.reset()

    sess.run([sync_op])
    for i in range(args.max_step):
        i_s = i + start

        if i_s < args.min_replay_history:
            eps = 1.
        else:
            eps = linearly_decaying_epsilon(args.epsilon_decay_period, i_s,
                                            args.min_replay_history,
                                            args.epsilon_train)

        state_recorder.add(pre_ob)

        a = sess.run(max_action, {
            state: state_recorder.state[None, ...],
            epsilon: eps
        })

        # step env
        ob, r, t, _ = env.step(a)
        n_step += 1
        reward_sum += r
        r = np.clip(r, -1, 1)

        # record pre observation, action, reward, False
        memory_pool.add(pre_ob, a, r, False)
        pre_ob = ob

        if args.render:
            env.render()

        if t:
            print('reward sum: {:.2f}, n step: {:d}'.format(
                reward_sum, n_step))
            sess.run([mean_episode_reward_update_op],
                     {episode_reward: reward_sum})

            # reset env
            memory_pool.add(pre_ob, 0, 0, True)

            pre_ob = env.reset()
            state_recorder.reset()
            reward_sum, n_step = 0., 0

        if i >= args.min_replay_history:
            if i_s % args.update_period == 0:
                tm.start()
                summary_fetch = ([train_summary_str] if abs(
                    (i_s % args.print_every - args.print_every) %
                    args.print_every) < args.update_period else [])
                ret = sess.run([train_op, metrics_update_op] + summary_fetch)
                tm.stop()

            if i_s % args.print_every == 0:

                ml, mq, mer = sess.run(
                    [mean_loss, mean_q_val, mean_episode_reward])

                print(('Step: {:d}, Mean Loss: {:.4f}, Mean Q: {:.4f}, '
                       'Mean Episode Reward: {:.2f}, Epsilon: {:.2f}, '
                       'Speed: {:.2f} i/s').format(i_s, ml, mq, mer, eps,
                                                   args.batch_size / tm.get()))

                if len(ret) > 2:
                    train_writer.add_summary(ret[-1], sess.run(global_step))

                train_writer.add_summary(sess.run(metric_summary_str),
                                         sess.run(global_step))

                tm.reset()

            if i_s % args.target_update_period == 0:
                sess.run([sync_op, metrics_reset_op])
                print('........sync........')

            if i_s % 10000 == 0:
                saver.save(sess,
                           '{}/{}'.format(args.logdir, args.model),
                           global_step=sess.run(global_step),
                           write_meta_graph=False)