Exemplo n.º 1
0
def train(active_mv):

    senv = ShapeNetEnv(FLAGS)
    replay_mem = ReplayMemory(FLAGS)

    #### for debug
    #a = np.array([[1,0,1],[0,0,0]])
    #b = np.array([[1,0,1],[0,1,0]])
    #print('IoU: {}'.format(replay_mem.calu_IoU(a, b)))
    #sys.exit()
    #### for debug

    log_string('====== Starting burning in memories ======')
    burn_in(senv, replay_mem)
    log_string('====== Done. {} trajectories burnt in ======'.format(
        FLAGS.burn_in_length))

    #epsilon = FLAGS.init_eps
    K_single = np.asarray([[420.0, 0.0, 112.0], [0.0, 420.0, 112.0],
                           [0.0, 0.0, 1]])
    K_list = np.tile(K_single[None, None, ...],
                     (1, FLAGS.max_episode_length, 1, 1))

    rollout_obj = Rollout(active_mv, senv, replay_mem, FLAGS)
    ### burn in(pretrain) for MVnet
    if FLAGS.burn_in_iter > 0:
        for i in xrange(FLAGS.burnin_start_iter,
                        FLAGS.burnin_start_iter + FLAGS.burn_in_iter):
            rollout_obj.go(i,
                           verbose=True,
                           add_to_mem=True,
                           mode='random',
                           is_train=True)
            if not FLAGS.random_pretrain:
                replay_mem.enable_gbl()
                mvnet_input = replay_mem.get_batch_list(FLAGS.batch_size)
            else:
                mvnet_input = replay_mem.get_batch_list_random(
                    senv, FLAGS.batch_size)
            tic = time.time()
            out_stuff = active_mv.run_step(mvnet_input,
                                           mode='burnin',
                                           is_training=True)
            burnin_summ = burnin_log(i, out_stuff, time.time() - tic)
            active_mv.train_writer.add_summary(burnin_summ, i)

            if (i + 1) % 5000 == 0 and i > FLAGS.burnin_start_iter:
                save_pretrain(active_mv, i + 1)

            if (i + 1) % 1000 == 0 and i > FLAGS.burnin_start_iter:
                evaluate_burnin(active_mv,
                                FLAGS.test_episode_num,
                                replay_mem,
                                i + 1,
                                rollout_obj,
                                mode='random')

    for i_idx in xrange(FLAGS.max_iter):

        t0 = time.time()

        rollout_obj.go(i_idx, verbose=True, add_to_mem=True, mode='random')
        t1 = time.time()

        replay_mem.enable_gbl()
        mvnet_input = replay_mem.get_batch_list(FLAGS.batch_size)
        t2 = time.time()

        out_stuff = active_mv.run_step(mvnet_input,
                                       mode='train_mv',
                                       is_training=True)
        replay_mem.disable_gbl()
        t3 = time.time()

        train_log(i_idx, out_stuff, (t0, t1, t2, t3))

        active_mv.train_writer.add_summary(out_stuff.merged_train, i_idx)

        if i_idx % FLAGS.save_every_step == 0 and i_idx > 0:
            save(active_mv, i_idx, i_idx, i_idx)

        if i_idx % FLAGS.test_every_step == 0 and i_idx > 0:
            #print('Evaluating active policy')
            #evaluate(active_mv, FLAGS.test_episode_num, replay_mem, i_idx, rollout_obj, mode='active')
            print('Evaluating random policy')
            evaluate(active_mv,
                     FLAGS.test_episode_num,
                     replay_mem,
                     i_idx,
                     rollout_obj,
                     mode='random')
Exemplo n.º 2
0
def train(active_mv):

    senv = ShapeNetEnv(FLAGS)
    replay_mem = ReplayMemory(FLAGS)

    #### for debug
    #a = np.array([[1,0,1],[0,0,0]])
    #b = np.array([[1,0,1],[0,1,0]])
    #print('IoU: {}'.format(replay_mem.calu_IoU(a, b)))
    #sys.exit()
    #### for debug

    log_string('====== Starting burning in memories ======')
    burn_in(senv, replay_mem)
    log_string('====== Done. {} trajectories burnt in ======'.format(
        FLAGS.burn_in_length))

    #epsilon = FLAGS.init_eps
    K_single = np.asarray([[420.0, 0.0, 112.0], [0.0, 420.0, 112.0],
                           [0.0, 0.0, 1]])
    K_list = np.tile(K_single[None, None, ...],
                     (1, FLAGS.max_episode_length, 1, 1))

    rollout_obj = Rollout(active_mv, senv, replay_mem, FLAGS)
    ### burn in(pretrain) for MVnet
    if FLAGS.burn_in_iter > 0:
        for i in xrange(FLAGS.burnin_start_iter,
                        FLAGS.burnin_start_iter + FLAGS.burn_in_iter):

            if (not FLAGS.reproj_mode) or (i == FLAGS.burnin_start_iter):
                rollout_obj.go(i,
                               verbose=True,
                               add_to_mem=True,
                               mode=FLAGS.burnin_mode,
                               is_train=True)
                if not FLAGS.random_pretrain:
                    replay_mem.enable_gbl()
                    mvnet_input = replay_mem.get_batch_list(FLAGS.batch_size)
                else:
                    mvnet_input = replay_mem.get_batch_list_random(
                        senv, FLAGS.batch_size)

            tic = time.time()
            out_stuff = active_mv.run_step(mvnet_input,
                                           mode='burnin',
                                           is_training=True)

            #import ipdb
            #ipdb.set_trace()

            summs_burnin = burnin_log(i, out_stuff, time.time() - tic)
            for summ in summs_burnin:
                active_mv.train_writer.add_summary(summ, i)

            if (i + 1
                ) % FLAGS.save_every_step == 0 and i > FLAGS.burnin_start_iter:
                save_pretrain(active_mv, i + 1)

            if (((i + 1) % FLAGS.test_every_step == 0
                 and i > FLAGS.burnin_start_iter)
                    or (FLAGS.eval0 and i == FLAGS.burnin_start_iter)):

                evaluate_burnin(
                    active_mv,
                    FLAGS.test_episode_num,
                    replay_mem,
                    i + 1,
                    rollout_obj,
                    mode=FLAGS.burnin_mode,
                    override_mvnet_input=(batch_to_single_mvinput(mvnet_input)
                                          if FLAGS.reproj_mode else None))

    for i_idx in xrange(FLAGS.max_iter):

        t0 = time.time()

        if np.random.uniform() < FLAGS.epsilon:
            rollout_obj.go(i_idx,
                           verbose=True,
                           add_to_mem=True,
                           mode=FLAGS.explore_mode,
                           is_train=True)
        else:
            rollout_obj.go(i_idx, verbose=True, add_to_mem=True, is_train=True)
        t1 = time.time()

        replay_mem.enable_gbl()
        mvnet_input = replay_mem.get_batch_list(FLAGS.batch_size)
        t2 = time.time()

        if FLAGS.finetune_dqn:
            out_stuff = active_mv.run_step(mvnet_input,
                                           mode='train_dqn',
                                           is_training=True)
        elif FLAGS.finetune_dqn_only:
            out_stuff = active_mv.run_step(mvnet_input,
                                           mode='train_dqn_only',
                                           is_training=True)
        else:
            out_stuff = active_mv.run_step(mvnet_input,
                                           mode='train',
                                           is_training=True)
        replay_mem.disable_gbl()
        t3 = time.time()

        train_log(i_idx, out_stuff, (t0, t1, t2, t3))

        active_mv.train_writer.add_summary(out_stuff.merged_train, i_idx)

        if (i_idx + 1) % FLAGS.save_every_step == 0 and i_idx > 0:
            save(active_mv, i_idx + 1, i_idx + 1, i_idx + 1)

        if (i_idx + 1) % FLAGS.test_every_step == 0 and i_idx > 0:
            print('Evaluating active policy')
            evaluate(active_mv,
                     FLAGS.test_episode_num,
                     replay_mem,
                     i_idx + 1,
                     rollout_obj,
                     mode='active')
            print('Evaluating random policy')
            evaluate(active_mv,
                     FLAGS.test_episode_num,
                     replay_mem,
                     i_idx + 1,
                     rollout_obj,
                     mode='oneway')