Пример #1
0
    def process(self, sess):
        """
        Process grabs a rollout that's been produced by the thread runner,
        and updates the parameters.  The update is then sent to the parameter
        server.
        """
        sess.run(self.sync)  # copy weights from shared to local
        rollout = self.pull_batch_from_queue()
        batch = process_rollout(rollout,
                                gamma=constants['GAMMA'],
                                lambda_=constants['LAMBDA'],
                                clip=self.envWrap)

        should_compute_summary = self.task == 0 and self.local_steps % 11 == 0

        if should_compute_summary:
            fetches = [self.summary_op, self.train_op, self.global_step]
        else:
            fetches = [self.train_op, self.global_step]

        if self.unsup and self.local_steps % 10001 == 0:
            import os
            if not os.path.exists("checkpoints"):
                os.mkdir("checkpoints")
            saveToFlat(self.local_ap_network.get_variables(),
                       "checkpoints/model_%i.p" % (self.local_steps // 10001))

        feed_dict = {
            self.local_network.x: batch.si,
            self.ac: batch.a,
            self.adv: batch.adv,
            self.r: batch.r,
            self.local_network.state_in[0]: batch.features[0],
            self.local_network.state_in[1]: batch.features[1],
        }
        if self.unsup:
            feed_dict[self.local_network.x] = batch.si[:-1]
            feed_dict[self.local_ap_network.s1] = batch.si[:-1]
            feed_dict[self.local_ap_network.s2] = batch.si[1:]
            feed_dict[self.local_ap_network.asample] = batch.a

        fetched = sess.run(fetches, feed_dict=feed_dict)
        if batch.terminal:
            print("Global Step Counter: %d" % fetched[-1])

        if should_compute_summary:
            self.summary_writer.add_summary(tf.Summary.FromString(fetched[0]),
                                            fetched[-1])
            self.summary_writer.flush()
        self.local_steps += 1
Пример #2
0
def learn(sess, n_tasks, z_size, data_dir, num_steps, max_seq_len,
          batch_size_per_task=16, rnn_size=256,
          grad_clip=1.0, v_lr=0.0001, vr_lr=0.0001,
          min_v_lr=0.00001, v_decay=0.999, kl_tolerance=0.5,
          lr=0.001, min_lr=0.00001, decay=0.999,
          view="transposed",
          model_dir="tf_rnn", layer_norm=False,
          rnn_mmd=False, no_cor=False,
          w_mmd=1.0,
          alpha=1.0, beta=0.1,
          recurrent_dp=1.0,
          input_dp=1.0,
          output_dp=1.0):
  batch_size = batch_size_per_task * n_tasks

  wrapper = WrapperFactory.get_wrapper(view)
  if wrapper is None:
    raise Exception("Such view is not available")

  print("Batch size for each taks is", batch_size_per_task)
  print("The total batch size is", batch_size)

  check_dir(model_dir)
  lf = open(model_dir + '/log_%s' % datetime.now().isoformat(), "w")
  # define env
  na = make_env(config.env_name).action_space.n
  input_size = z_size + na
  output_size = z_size
  print("the environment", config.env_name, "has %i actions" % na)

  seq_len = max_seq_len

  fns = os.listdir(data_dir)
  fns = [fn for fn in fns if '.npz' in fn]
  random.shuffle(fns)
  dm = get_dm(wrapper, seq_len, na, data_dir, fns, not no_cor)
  tf_vrct_lr = tf.placeholder(tf.float32,
                              shape=[])  # learn from reconstruction.
  vaes, vcomps = build_vaes(n_tasks, na, z_size, seq_len, tf_vrct_lr,
                            kl_tolerance)
  vae_losses = [vcomp.loss for vcomp in vcomps]
  transform_loss = get_transform_loss(vcomps[0], vaes[1], wrapper)

  old_vae0 = ConvVAE(name="old_vae0", z_size=z_size)
  old_vcomp0 = build_vae("old_vae0", old_vae0, na, z_size, seq_len,
                         tf_vrct_lr, kl_tolerance)
  assign_old_eq_new = tf.group([tf.assign(oldv, newv)
                                for (oldv, newv) in
                                zip(old_vcomp0.var_list, vcomps[0].var_list)])

  vmmd_losses = get_vmmd_losses(n_tasks, old_vcomp0, vcomps, alpha, beta)
  vrec_ops = get_vae_rec_ops(n_tasks, vcomps, vmmd_losses, w_mmd)
  vrec_all_op = tf.group(vrec_ops)

  # Meta RNN.
  rnn = VRNN("rnn", max_seq_len, input_size, output_size, batch_size_per_task,
             rnn_size, layer_norm, recurrent_dp, input_dp, output_dp)

  global_step = tf.Variable(0, name='global_step', trainable=False)
  tf_rpred_lr = tf.placeholder(tf.float32, shape=[])
  rcomp0 = build_rnn("rnn", rnn, na, z_size, batch_size_per_task, seq_len)

  print("The basic rnn has been built")

  rcomps = build_rnns(n_tasks, rnn, vaes, vcomps, kl_tolerance)
  rnn_losses = [rcomp.loss for rcomp in rcomps]

  if rnn_mmd:
    rmmd_losses = get_rmmd_losses(n_tasks, old_vcomp0, vcomps, alpha, beta)
    for i in range(n_tasks):
      rnn_losses[i] += 0.1 * rmmd_losses[i]

  ptransform_loss = get_predicted_transform_loss(vcomps[0], rcomps[0],
                                                 vaes[1],
                                                 wrapper, batch_size_per_task,
                                                 seq_len)
  print("RNN has been connected to each VAE")

  rnn_total_loss = tf.reduce_mean(rnn_losses)
  rpred_opt = tf.train.AdamOptimizer(tf_rpred_lr, name="rpred_opt")
  gvs = rpred_opt.compute_gradients(rnn_total_loss, rcomp0.var_list)
  clip_gvs = [(tf.clip_by_value(grad, -grad_clip, grad_clip), var) for
              grad, var in gvs if grad is not None]
  rpred_op = rpred_opt.apply_gradients(clip_gvs, global_step=global_step,
                                       name='rpred_op')

  # VAE in prediction phase
  vpred_ops, tf_vpred_lrs = get_vae_pred_ops(n_tasks, vcomps, rnn_losses)
  vpred_all_op = tf.group(vpred_ops)

  rpred_lr = lr
  vrct_lr = v_lr
  vpred_lr = vr_lr
  sess.run(tf.global_variables_initializer())

  for i in range(num_steps):

    step = sess.run(global_step)
    rpred_lr = (rpred_lr - min_lr) * decay + min_lr
    vrct_lr = (vrct_lr - min_v_lr) * v_decay + min_v_lr
    vpred_lr = (vpred_lr - min_v_lr) * v_decay + min_v_lr

    ratio = 1.0

    data_buffer = []

    for it in range(config.psteps_per_it):
      raw_obs_list, raw_a_list = dm.random_batch(batch_size_per_task)
      data_buffer.append((raw_obs_list, raw_a_list))

      feed = {tf_rpred_lr: rpred_lr, tf_vrct_lr: vrct_lr,
              tf_vpred_lrs[0]: vpred_lr,
              tf_vpred_lrs[1]: vpred_lr * ratio}
      feed[old_vcomp0.x] = raw_obs_list[0]
      for j in range(n_tasks):
        vcomp = vcomps[j]
        feed[vcomp.x] = raw_obs_list[j]
        feed[vcomp.a] = raw_a_list[j][:, :-1, :]

      (rnn_cost, rnn_cost2, vae_cost, vae_cost2,
       transform_cost, ptransform_cost, _, _) = sess.run(
        [rnn_losses[0], rnn_losses[1],
         vae_losses[0], vae_losses[1],
         transform_loss, ptransform_loss,
         rpred_op, vpred_all_op], feed)
      ratio = rnn_cost2 / rnn_cost

    if i % config.log_interval == 0:
      output_log = get_output_log(step, rpred_lr, [vae_cost], [rnn_cost], [transform_cost], [ptransform_cost])
      lf.write(output_log)

    data_order = np.arange(len(data_buffer))
    nd = len(data_order)
    np.random.shuffle(data_order)

    for it in range(config.rsteps_per_it):
      if (it + 1) % nd == 0:
        np.random.shuffle(data_order)
      rid = data_order[it % nd]

      raw_obs_list, raw_a_list = data_buffer[rid]
      # raw_obs_list, raw_a_list = dm.random_batch(batch_size_per_task)

      feed = {tf_rpred_lr: rpred_lr, tf_vrct_lr: vrct_lr}
      feed[old_vcomp0.x] = raw_obs_list[0]
      for j in range(n_tasks):
        vcomp = vcomps[j]
        feed[vcomp.x] = raw_obs_list[j]
        feed[vcomp.a] = raw_a_list[j][:, :-1, :]

      (rnn_cost, rnn_cost2, vae_cost, vae_cost2, transform_cost,
       ptransform_cost, _) = sess.run([
        rnn_losses[0], rnn_losses[1],
        vae_losses[0], vae_losses[1],
        transform_loss, ptransform_loss,
        vrec_all_op], feed)

    if i % config.log_interval == 0:
      output_log = get_output_log(step, rpred_lr, [vae_cost], [rnn_cost], [transform_cost], [ptransform_cost])
      lf.write(output_log)

    lf.flush()

    if (i + 1) % config.target_update_interval == 0:
      sess.run(assign_old_eq_new)

    if i % config.model_save_interval == 0:
      tmp_dir = model_dir + '/it_%i' % i
      check_dir(tmp_dir)
      saveToFlat(rcomp0.var_list, tmp_dir + '/rnn.p')
      for j in range(n_tasks):
        vcomp = vcomps[j]
        saveToFlat(vcomp.var_list, tmp_dir + '/vae%i.p' % j)

  saveToFlat(rcomp0.var_list, model_dir + '/final_rnn.p')
  for i in range(n_tasks):
    vcomp = vcomps[i]
    saveToFlat(vcomp.var_list, model_dir + '/final_vae%i.p' % i)
Пример #3
0
def train_root(config):
    if config['dataset'] == 'nyu':
        dataset = NYUDataset(subset='training', root_dir='/home/data/nyu/')
    elif config['dataset'] == 'icvl':
        dataset = ICVLDataset(subset='training', root_dir='/hand_pose_data/icvl/')
    elif config['dataset'] == 'mrsa15':
        dataset = MRSADataset(subset='training', test_fold=config['mrsa_test_fold'],
                              root_dir='/hand_pose_data/mrsa15/')
    else:
        raise ValueError('Dataset name %s error...' % config['dataset'])

    actor_root = Actor(scope='actor_root',
                       tau=config['tau'],
                       lr=config['learning_rate'],
                       obs_dims=config['root_obs_dims'],
                       cnn_layer=config['root_actor_cnn_layers'],
                       fc_layer=config['root_actor_fc_layers'])
    critic_root = Critic(scope='critic_root',
                         tau=config['tau'],
                         lr=config['learning_rate'],
                         obs_dims=config['root_obs_dims'],
                         cnn_layer=config['root_critic_cnn_layers'],
                         fc_layer=config['root_critic_fc_layers'])
    env = HandEnv(dataset=config['dataset'],
                  subset='training',
                  iter_per_joint=config['iter_per_joint'],
                  reward_beta=config['beta'])
    root_buffer = ReplayBuffer(buffer_size=config['buffer_size'])
    sampler = Sampler(actor_root, critic_root, None, None, env, dataset)

    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    with tf.Session(config=tf_config) as sess:
        sess.run(tf.global_variables_initializer())
        # actor model
        root_dir = config['saved_model_path'] + '/' + config['dataset'] + '/'
        saved_actor_dir = root_dir + config['actor_model_name'] + '_root.pkl'
        if os.path.exists(saved_actor_dir):
            utils.loadFromFlat(actor_root.get_trainable_variables(), saved_actor_dir)
            print("Actor parameter loaded from %s" % saved_actor_dir)
        else:
            print("[Warning]: initialize actor root model")
        actor_root.load_sess(sess)
        sess.run(actor_root.update_target_ops)

        # critic model
        saved_critic_dir = root_dir + config['critic_model_name'] + '_root.pkl'
        if os.path.exists(saved_critic_dir):
            utils.loadFromFlat(critic_root.get_trainable_variables(), saved_critic_dir)
            print("Critic parameter loaded from %s" % saved_critic_dir)
        else:
            print("[Warning]: initialize critic root model")
        critic_root.load_sess(sess)
        sess.run(critic_root.update_target_ops)

        i = 0
        while i < config['n_rounds']:
            i += 1
            print('--------------------------------Round % i---------------------------------' % i)
            # sampling
            samples = sampler.collect_multiple_samples_root(config['files_per_time'])
            root_buffer.add(samples)

            for _ in range(config['n_iters']):
                actor_loss_list, q_loss_list = [], []
                for _ in range(config['update_iters']):
                    # get a mini-batch of data
                    state, action, reward, new_state, gamma = root_buffer.get_batch(config['batch_size'])
                    # update actor
                    q_gradient = critic_root.get_q_gradient(obs=state, ac=action)
                    _, actor_loss = actor_root.train(q_gradient=q_gradient[0], obs=state)
                    # update critic
                    next_ac = actor_root.get_target_action(obs=new_state)
                    _, q_loss = critic_root.train(obs=state, ac=action, next_obs=new_state,
                                                  next_ac=next_ac, r=reward, gamma=gamma)
                    actor_loss_list.append(actor_loss)
                    q_loss_list.append(q_loss)
                # update target network
                sess.run(actor_root.update_target_ops)
                sess.run(critic_root.update_target_ops)
                print('Actor average loss: {:.4f}, Critic: {:.4f}'
                      .format(np.mean(actor_loss_list), np.mean(q_loss_list)))

            utils.saveToFlat(actor_root.get_trainable_variables(), saved_actor_dir)
            utils.saveToFlat(critic_root.get_trainable_variables(), saved_critic_dir)
Пример #4
0
def train(config):
    if config['dataset'] == 'nyu':
        dataset = NYUDataset(subset='training', root_dir='/hand_pose_data/nyu/', predefined_bbx=(63, 63, 31))
        # pre-trained  model
        pre_ac_dim = 3 * dataset.jnt_num
        pre_cnn_layer = (8, 16, 32, 64, 128)  # 512
        pre_fc_layer = (512, 512, 256)
        # actor-critic
        ac_dim = 4 * (dataset.jnt_num - 1)
        actor_cnn_layer = (8, 16, 32, 64, 128)  # 512
        actor_fc_layer = (512, 512, 256)
        critic_cnn_layer = (8, 16, 32, 64, 128)  # 512
        critic_fc_layer = (ac_dim, 512, 512, 128)

    elif config['dataset'] == 'icvl':
        dataset = ICVLDataset(subset='training', root_dir='/hand_pose_data/icvl/', predefined_bbx=(63, 63, 31))
        # pre-trained  model
        pre_ac_dim = 3 * dataset.jnt_num
        pre_cnn_layer = (8, 16, 32, 64, 128)  # 512
        pre_fc_layer = (512, 512, 256)
        # actor-critic
        ac_dim = 4 * (dataset.jnt_num - 1)
        actor_cnn_layer = (8, 16, 32, 64, 128)  # 512
        actor_fc_layer = (512, 512, 256)
        critic_cnn_layer = (8, 16, 32, 64, 128)  # 512
        critic_fc_layer = (ac_dim, 512, 512, 128)

    elif config['dataset'] == 'mrsa15':
        # (180, 120, 70), 6 * 21 = 126
        dataset = MRSADataset(subset='training', test_fold=config['mrsa_test_fold'],
                              root_dir='/hand_pose_data/mrsa15/', predefined_bbx=(63, 63, 31))
        # pre-trained  model
        pre_ac_dim = 3 * dataset.jnt_num
        pre_cnn_layer = (8, 16, 32, 64, 128)  # 512
        pre_fc_layer = (512, 512, 256)
        # actor-critic
        ac_dim = 4 * (dataset.jnt_num - 1)
        actor_cnn_layer = (8, 16, 32, 64, 128)  # 512
        actor_fc_layer = (512, 512, 256)
        critic_cnn_layer = (8, 16, 32, 64, 128)  # 512
        critic_fc_layer = (ac_dim, 512, 512, 128)

    else:
        raise ValueError('Dataset name %s error...' % config['dataset'])
    obs_dims = (dataset.predefined_bbx[2] + 1, dataset.predefined_bbx[1] + 1, dataset.predefined_bbx[0] + 1)

    # build pretrain model
    pretrain_model = Pretrain(scope='pretrain',
                              obs_dims=obs_dims+(1,),
                              cnn_layer=pre_cnn_layer,
                              fc_layer=pre_fc_layer,
                              ac_dim=pre_ac_dim)
    # build actor and critic model
    actor = Actor(scope='actor',
                  obs_dims=obs_dims+(2,),
                  ac_dim=ac_dim,
                  cnn_layer=actor_cnn_layer,
                  fc_layer=actor_fc_layer,
                  tau=config['tau'],
                  beta=config['beta'],
                  lr=config['actor_lr'])
    critic = Critic(scope='critic',
                    obs_dims=obs_dims+(2,),
                    ac_dim=ac_dim,
                    cnn_layer=critic_cnn_layer,
                    fc_layer=critic_fc_layer,
                    tau=config['tau'],
                    lr=config['critic_lr'])

    # initialize environment
    env = HandEnv(dataset_name=config['dataset'], subset='training', max_iters=config['max_iters'],
                  predefined_bbx=dataset.predefined_bbx, pretrained_model=pretrain_model,
                  reward_range=config['reward_range'], num_cpus=config['num_cpus'])

    # initialize sampler
    sampler = Sampler(actor, env, dataset, step_size=config['step_size'], gamma=config['gamma'])
    buffer = ReplayBuffer(buffer_size=config['buffer_size'])

    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    with tf.Session(config=tf_config) as sess:
        sess.run(tf.global_variables_initializer())
        root_dir = config['saved_model_path'] + config['dataset'] + '/'
        writer = SummaryWriter(log_dir=root_dir)

        # load pretrained model
        if config['dataset'] == 'mrsa15':
            model_save_dir = root_dir + config['dataset'] + '_' + config['mrsa_test_fold'] + '_pretrain.pkl'
        else:
            model_save_dir = root_dir + config['dataset'] + '_pretrain.pkl'
        if os.path.exists(model_save_dir):
            utils.loadFromFlat(pretrain_model.get_trainable_variables(), model_save_dir)
            print("Pre-train parameter loaded from %s" % model_save_dir)
        else:
            raise ValueError('Model not found from %s' % model_save_dir)

        # load actor model
        save_actor_dir = root_dir + config['dataset'] + '_actor.pkl'
        if os.path.exists(save_actor_dir):
            # utils.loadFromFlat(actor.get_trainable_variables(), save_actor_dir)
            print("Actor parameter loaded from %s" % save_actor_dir)
        else:
            print("[Warning]: initialize the actor model")
        sess.run(actor.update_target_ops)
        actor.load_sess(sess)

        # critic model
        save_critic_dir = root_dir + config['dataset'] + '_critic.pkl'
        if os.path.exists(save_critic_dir):
            # utils.loadFromFlat(critic.get_trainable_variables(), save_critic_dir)
            print("Critic parameter loaded from %s" % save_critic_dir)
        else:
            print("[Warning]: initialize critic root model")
        sess.run(critic.update_target_ops)
        critic.load_sess(sess)

        best_max_error = 20
        test_examples = sampler.aggregate_test_samples()
        for i in range(config['n_rounds']):
            print('--------------------------------Round % i---------------------------------' % i)
            if i % config['test_gap'] == 0:
                # test
                start_time = time.time()
                print('>>>number of examples for testing: %i(%i)'
                      % (min(2*config['num_batch_samples'], len(test_examples)), len(test_examples)))
                examples = random.sample(test_examples, min(2*config['num_batch_samples'], len(test_examples)))
                max_error, rs = sampler.test_batch_samples(examples, 8*config['batch_size'], sess)
                writer.add_histogram('RL_' + config['dataset'] + '_final_rewards', rs, i)
                writer.add_histogram('RL_' + config['dataset'] + '_max_error', max_error, i)
                writer.add_scalar('RL_' + config['dataset'] + '_mean_max_error', np.mean(max_error), i)
                if best_max_error > np.mean(max_error):
                    # save model
                    utils.saveToFlat(actor.get_trainable_variables(), save_actor_dir)
                    utils.saveToFlat(critic.get_trainable_variables(), save_critic_dir)
                    best_max_error = np.mean(max_error)
                    print('>>>Model save as %s' % save_actor_dir)
                end_time = time.time()
                print('>>>Testing: Average max error {:.2f}, average reward {:.2f}, time used {:.2f}s'
                      .format(np.mean(max_error), np.mean(rs), end_time-start_time))

            start_time = time.time()
            # sampling
            experiences, rs = sampler.collect_experiences(num_files=config['files_per_time'],
                                                          num_batch_samples=config['num_batch_samples'],
                                                          batch_size=8*config['batch_size'],
                                                          sess=sess,
                                                          num_cpus=config['num_cpus'])
            buffer.add(experiences)
            end_time = time.time()
            print('Sampling: time used %.2fs, buffer size %i' % (end_time-start_time, buffer.count()))

            # training
            start_time = time.time()
            actor_loss_list, q_loss_list = [], []
            for _ in range(config['train_iters']):
                # get a mini-batch of data
                action, reward, gamma, state, new_state = buffer.get_batch(config['batch_size'])
                # update actor
                q_gradient = critic.get_q_gradient(obs=state, ac=action, dropout_prob=1.0)
                _, actor_loss, global_step, actor_acs = \
                    actor.train(q_gradient=q_gradient[0], obs=state, dropout_prob=0.5, step_size=config['step_size'])
                # update critic
                next_ac = actor.get_target_action(obs=new_state, dropout_prob=1.0, step_size=config['step_size'])
                _, critic_loss = critic.train(obs=state, ac=action, next_obs=new_state, next_ac=next_ac,
                                              r=reward, gamma=gamma, dropout_prob=0.5)
                # record result
                actor_loss_list.append(np.mean(actor_loss))
                q_loss_list.append(critic_loss)
                writer.add_scalar('RL_' + config['dataset'] + '_actor_loss', np.mean(actor_loss), global_step)
                writer.add_scalar('RL_' + config['dataset'] + '_critic_loss', critic_loss, global_step)

                if global_step % config['update_iters'] == 0:
                    # update target network
                    sess.run([actor.update_target_ops, critic.update_target_ops])
                    print('Average loss: actor {:.4f}, critic: {:.4f}, training steps: {}, '
                          'average acs {:.4f}, average q-gradients {:.4f}'
                          .format(np.mean(actor_loss_list), np.mean(q_loss_list), global_step,
                                  np.mean(actor_acs), np.mean(q_gradient)))

            end_time = time.time()
            print('Training time used: {:.2f}s, training steps: {}'.format(end_time - start_time, global_step))
        writer.close()
Пример #5
0
def pre_train(config):
    if config['dataset'] == 'nyu':
        dataset = NYUDataset(subset='training',
                             root_dir='/hand_pose_data/nyu/',
                             predefined_bbx=(63, 63, 31))
        ac_dim = 3 * dataset.jnt_num
        weights = np.ones([1, dataset.jnt_num])
        weights[0, 13] = config['root_weight']  # weight root joint error
        cnn_layer = (8, 16, 32, 64, 128)  # 512
        fc_layer = (512, 512, 256)
    elif config['dataset'] == 'icvl':
        dataset = ICVLDataset(subset='training',
                              root_dir='/hand_pose_data/icvl/',
                              predefined_bbx=(63, 63, 31))
        ac_dim = 3 * dataset.jnt_num
        weights = np.ones([1, dataset.jnt_num])
        weights[0, 0] = config['root_weight']  # weight root joint error
        cnn_layer = (8, 16, 32, 64, 128)  # 512
        fc_layer = (512, 512, 256)
    elif config['dataset'] == 'mrsa15':
        # (180, 120, 70), 6 * 21 = 126
        dataset = MRSADataset(subset='training',
                              test_fold=config['mrsa_test_fold'],
                              root_dir='/hand_pose_data/mrsa15/',
                              predefined_bbx=(63, 63, 31))
        ac_dim = 3 * dataset.jnt_num
        weights = np.ones([1, dataset.jnt_num])
        weights[0, 0] = config['root_weight']  # weight root joint error
        cnn_layer = (8, 16, 32, 64, 128)  # 512
        fc_layer = (512, 512, 256)
    else:
        raise ValueError('Dataset name %s error...' % config['dataset'])
    print('Loss Weights:', weights)
    obs_dims = (dataset.predefined_bbx[2] + 1, dataset.predefined_bbx[1] + 1,
                dataset.predefined_bbx[0] + 1, 1)
    env = HandEnv(dataset_name=config['dataset'],
                  subset='training',
                  max_iters=5,
                  predefined_bbx=dataset.predefined_bbx,
                  pretrained_model=None)
    scope = 'pre_train'
    batch_size = config['batch_size']

    # define model and loss
    model = Pretrain(scope, obs_dims, cnn_layer, fc_layer,
                     ac_dim)  # model.obs, model.ac, model.dropout_prob
    tf_label = tf.placeholder(shape=(None, ac_dim),
                              dtype=tf.float32,
                              name='action')
    tf_weights = tf.placeholder(shape=(1, dataset.jnt_num),
                                dtype=tf.float32,
                                name='action')
    # average joint mse error
    tf_mse = tf.reduce_mean(tf_weights * tf.reduce_sum(tf.reshape(
        tf.square(model.ac - tf_label), [-1, int(ac_dim / 3), 3]),
                                                       axis=2),
                            axis=1)
    tf_loss = tf.reduce_mean(tf_mse)  # average over mini-batch
    tf_max_error = tf.sqrt(
        tf.reduce_max(tf.reduce_sum(tf.reshape(tf.square(model.ac - tf_label),
                                               [-1, int(ac_dim / 3), 3]),
                                    axis=2),
                      axis=1))

    global_step = tf.Variable(0, trainable=False, name='step')
    lr = tf.train.exponential_decay(config['lr_start'], global_step,
                                    config['lr_decay_iters'],
                                    config['lr_decay_rate'])
    optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(
        loss=tf_loss, global_step=global_step)
    # optimizer = tf.train.RMSPropOptimizer(learning_rate=lr).minimize(loss=tf_loss, global_step=global_step)

    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    with tf.Session(config=tf_config) as sess:
        sess.run(tf.global_variables_initializer())
        root_dir = config['saved_model_path'] + config['dataset'] + '/'
        writer = SummaryWriter(log_dir=root_dir)

        if config['dataset'] == 'mrsa15':
            model_save_dir = root_dir + config['dataset'] + '_' + config[
                'mrsa_test_fold'] + '_pretrain.pkl'
        else:
            model_save_dir = root_dir + config['dataset'] + '_pretrain.pkl'
        if os.path.exists(model_save_dir) and not config['new_training']:
            utils.loadFromFlat(model.get_trainable_variables(), model_save_dir)
            print("Pre-train parameter loaded from %s" % model_save_dir)
        else:
            print("[Warning]: initialize the pre-train model")

        x_test, y_test, _ = collect_test_samples(env, dataset,
                                                 config['num_cpus'], ac_dim)

        n_test = x_test.shape[0]
        print('test samples %i' % n_test)
        best_loss = 1000
        i = 0
        while i < config['n_rounds']:
            i += 1
            print(
                '--------------------------------Round % i---------------------------------'
                % i)
            # test
            loss_list, max_error_list = [], []
            if i % config['test_gap'] == 1:
                start_time = time.time()
                for j in range(n_test // batch_size + 1):
                    idx1 = j * batch_size
                    idx2 = min((j + 1) * batch_size, n_test)
                    batch_loss, batch_max_error = sess.run(
                        [tf_mse, tf_max_error],
                        feed_dict={
                            model.obs: x_test[idx1:idx2, ...],
                            tf_label: y_test[idx1:idx2],
                            model.dropout_prob: 1.0,
                            tf_weights: weights
                        })
                    loss_list.append(batch_loss)
                    max_error_list.append(batch_max_error)
                test_loss = np.mean(np.hstack(loss_list))
                max_error = np.hstack(max_error_list)
                writer.add_scalar(config['dataset'] + '_test_loss', test_loss,
                                  i)
                writer.add_histogram(config['dataset'] + '_max_error',
                                     max_error, i)
                end_time = time.time()
                print(
                    '>>> Testing loss: {:.4f}, best loss {:.4f}, mean_max_error {:.4f}, time used: {:.2f}s'
                    .format(test_loss, best_loss, np.mean(max_error),
                            end_time - start_time))
                if best_loss > test_loss:
                    utils.saveToFlat(model.get_trainable_variables(),
                                     model_save_dir)
                    best_loss = test_loss.copy()
                    print('>>> Model saved... best loss {:.4f}'.format(
                        best_loss))

            # train
            start_time = time.time()
            x_train, y_train = collect_train_samples(
                env, dataset, config['files_per_time'],
                config['samples_per_time'], config['num_cpus'], ac_dim)
            print('Collected samples {}'.format(x_train.shape[0]))
            loss_list = []
            for _ in range(config['train_iters']):
                batch_idx = np.random.randint(0, x_train.shape[0], batch_size)
                _, batch_loss, step = sess.run(
                    [optimizer, tf_loss, global_step],
                    feed_dict={
                        model.obs: x_train[batch_idx, ...],
                        tf_label: y_train[batch_idx],
                        model.dropout_prob: 0.5,
                        tf_weights: weights
                    })
                loss_list.append(batch_loss)
            end_time = time.time()
            writer.add_scalar(config['dataset'] + '_train_loss',
                              np.mean(loss_list), i)
            print(
                'Training loss: {:.4f}, time used: {:.2f}s, step: {:d}'.format(
                    np.mean(loss_list), end_time - start_time, step))
        writer.close()