def train(active_mv): senv = ShapeNetEnv(FLAGS) replay_mem = ReplayMemory(FLAGS) #### for debug #a = np.array([[1,0,1],[0,0,0]]) #b = np.array([[1,0,1],[0,1,0]]) #print('IoU: {}'.format(replay_mem.calu_IoU(a, b))) #sys.exit() #### for debug log_string('====== Starting burning in memories ======') burn_in(senv, replay_mem) log_string('====== Done. {} trajectories burnt in ======'.format( FLAGS.burn_in_length)) #epsilon = FLAGS.init_eps K_single = np.asarray([[420.0, 0.0, 112.0], [0.0, 420.0, 112.0], [0.0, 0.0, 1]]) K_list = np.tile(K_single[None, None, ...], (1, FLAGS.max_episode_length, 1, 1)) rollout_obj = Rollout(active_mv, senv, replay_mem, FLAGS) ### burn in(pretrain) for MVnet if FLAGS.burn_in_iter > 0: for i in xrange(FLAGS.burnin_start_iter, FLAGS.burnin_start_iter + FLAGS.burn_in_iter): rollout_obj.go(i, verbose=True, add_to_mem=True, mode='random', is_train=True) if not FLAGS.random_pretrain: replay_mem.enable_gbl() mvnet_input = replay_mem.get_batch_list(FLAGS.batch_size) else: mvnet_input = replay_mem.get_batch_list_random( senv, FLAGS.batch_size) tic = time.time() out_stuff = active_mv.run_step(mvnet_input, mode='burnin', is_training=True) burnin_summ = burnin_log(i, out_stuff, time.time() - tic) active_mv.train_writer.add_summary(burnin_summ, i) if (i + 1) % 5000 == 0 and i > FLAGS.burnin_start_iter: save_pretrain(active_mv, i + 1) if (i + 1) % 1000 == 0 and i > FLAGS.burnin_start_iter: evaluate_burnin(active_mv, FLAGS.test_episode_num, replay_mem, i + 1, rollout_obj, mode='random') for i_idx in xrange(FLAGS.max_iter): t0 = time.time() rollout_obj.go(i_idx, verbose=True, add_to_mem=True, mode='random') t1 = time.time() replay_mem.enable_gbl() mvnet_input = replay_mem.get_batch_list(FLAGS.batch_size) t2 = time.time() out_stuff = active_mv.run_step(mvnet_input, mode='train_mv', is_training=True) replay_mem.disable_gbl() t3 = time.time() train_log(i_idx, out_stuff, (t0, t1, t2, t3)) active_mv.train_writer.add_summary(out_stuff.merged_train, i_idx) if i_idx % FLAGS.save_every_step == 0 and i_idx > 0: save(active_mv, i_idx, i_idx, i_idx) if i_idx % FLAGS.test_every_step == 0 and i_idx > 0: #print('Evaluating active policy') #evaluate(active_mv, FLAGS.test_episode_num, replay_mem, i_idx, rollout_obj, mode='active') print('Evaluating random policy') evaluate(active_mv, FLAGS.test_episode_num, replay_mem, i_idx, rollout_obj, mode='random')
def train(active_mv): senv = ShapeNetEnv(FLAGS) replay_mem = ReplayMemory(FLAGS) #### for debug #a = np.array([[1,0,1],[0,0,0]]) #b = np.array([[1,0,1],[0,1,0]]) #print('IoU: {}'.format(replay_mem.calu_IoU(a, b))) #sys.exit() #### for debug log_string('====== Starting burning in memories ======') burn_in(senv, replay_mem) log_string('====== Done. {} trajectories burnt in ======'.format( FLAGS.burn_in_length)) #epsilon = FLAGS.init_eps K_single = np.asarray([[420.0, 0.0, 112.0], [0.0, 420.0, 112.0], [0.0, 0.0, 1]]) K_list = np.tile(K_single[None, None, ...], (1, FLAGS.max_episode_length, 1, 1)) rollout_obj = Rollout(active_mv, senv, replay_mem, FLAGS) ### burn in(pretrain) for MVnet if FLAGS.burn_in_iter > 0: for i in xrange(FLAGS.burnin_start_iter, FLAGS.burnin_start_iter + FLAGS.burn_in_iter): if (not FLAGS.reproj_mode) or (i == FLAGS.burnin_start_iter): rollout_obj.go(i, verbose=True, add_to_mem=True, mode=FLAGS.burnin_mode, is_train=True) if not FLAGS.random_pretrain: replay_mem.enable_gbl() mvnet_input = replay_mem.get_batch_list(FLAGS.batch_size) else: mvnet_input = replay_mem.get_batch_list_random( senv, FLAGS.batch_size) tic = time.time() out_stuff = active_mv.run_step(mvnet_input, mode='burnin', is_training=True) #import ipdb #ipdb.set_trace() summs_burnin = burnin_log(i, out_stuff, time.time() - tic) for summ in summs_burnin: active_mv.train_writer.add_summary(summ, i) if (i + 1 ) % FLAGS.save_every_step == 0 and i > FLAGS.burnin_start_iter: save_pretrain(active_mv, i + 1) if (((i + 1) % FLAGS.test_every_step == 0 and i > FLAGS.burnin_start_iter) or (FLAGS.eval0 and i == FLAGS.burnin_start_iter)): evaluate_burnin( active_mv, FLAGS.test_episode_num, replay_mem, i + 1, rollout_obj, mode=FLAGS.burnin_mode, override_mvnet_input=(batch_to_single_mvinput(mvnet_input) if FLAGS.reproj_mode else None)) for i_idx in xrange(FLAGS.max_iter): t0 = time.time() if np.random.uniform() < FLAGS.epsilon: rollout_obj.go(i_idx, verbose=True, add_to_mem=True, mode=FLAGS.explore_mode, is_train=True) else: rollout_obj.go(i_idx, verbose=True, add_to_mem=True, is_train=True) t1 = time.time() replay_mem.enable_gbl() mvnet_input = replay_mem.get_batch_list(FLAGS.batch_size) t2 = time.time() if FLAGS.finetune_dqn: out_stuff = active_mv.run_step(mvnet_input, mode='train_dqn', is_training=True) elif FLAGS.finetune_dqn_only: out_stuff = active_mv.run_step(mvnet_input, mode='train_dqn_only', is_training=True) else: out_stuff = active_mv.run_step(mvnet_input, mode='train', is_training=True) replay_mem.disable_gbl() t3 = time.time() train_log(i_idx, out_stuff, (t0, t1, t2, t3)) active_mv.train_writer.add_summary(out_stuff.merged_train, i_idx) if (i_idx + 1) % FLAGS.save_every_step == 0 and i_idx > 0: save(active_mv, i_idx + 1, i_idx + 1, i_idx + 1) if (i_idx + 1) % FLAGS.test_every_step == 0 and i_idx > 0: print('Evaluating active policy') evaluate(active_mv, FLAGS.test_episode_num, replay_mem, i_idx + 1, rollout_obj, mode='active') print('Evaluating random policy') evaluate(active_mv, FLAGS.test_episode_num, replay_mem, i_idx + 1, rollout_obj, mode='oneway')