Exemple #1
0
 def __init__(self,
              process_function,
              logger_name,
              process_count=(multiprocessing.cpu_count() * 2),
              moniter_childprocess_seconds=5,
              process_function_params_dict=None):
     self.process_count = process_count
     self.moniter_childprocess_seconds = moniter_childprocess_seconds
     self._works = []
     self.quit_event = multiprocessing.Event()
     self.logger = create_logger(logger_name)
     self.process_function = functools.partial(
         process_function, **process_function_params_dict
     ) if process_function_params_dict else process_function
     signal.signal(signal.SIGTERM, self._quit_worker_process)
Exemple #2
0
def run_experiment(args):
    torch.set_num_threads(1)

    from util.env import env_factory
    from util.log import create_logger

    from policies.critic import FF_V, LSTM_V
    from policies.actor import FF_Stochastic_Actor, LSTM_Stochastic_Actor

    import locale, os
    locale.setlocale(locale.LC_ALL, '')

    # wrapper function for creating parallelized envs
    env_fn = env_factory(args.env_name)
    obs_dim = env_fn().observation_space.shape[0]
    action_dim = env_fn().action_space.shape[0]

    # Set seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    std = torch.ones(action_dim) * args.std

    if args.recurrent:
        policy = LSTM_Stochastic_Actor(obs_dim,
                                       action_dim,
                                       env_name=args.env_name,
                                       fixed_std=std,
                                       bounded=False)
        critic = LSTM_V(obs_dim)
    else:
        policy = FF_Stochastic_Actor(obs_dim,
                                     action_dim,
                                     env_name=args.env_name,
                                     fixed_std=std,
                                     bounded=False)
        critic = FF_V(obs_dim)

    env = env_fn()
    eval_policy(policy,
                env,
                True,
                min_timesteps=args.prenormalize_steps,
                max_traj_len=args.traj_len,
                noise=1)

    policy.train(0)
    critic.train(0)

    algo = PPO(policy, critic, env_fn, args)

    # create a tensorboard logging object
    if not args.nolog:
        logger = create_logger(args)
    else:
        logger = None

    if args.save_actor is None and logger is not None:
        args.save_actor = os.path.join(logger.dir, 'actor.pt')

    if args.save_critic is None and logger is not None:
        args.save_critic = os.path.join(logger.dir, 'critic.pt')

    print()
    print("Proximal Policy Optimization:")
    print("\tseed:               {}".format(args.seed))
    print("\tenv:                {}".format(args.env_name))
    print("\ttimesteps:          {:n}".format(int(args.timesteps)))
    print("\titeration steps:    {:n}".format(int(args.num_steps)))
    print("\tprenormalize steps: {}".format(int(args.prenormalize_steps)))
    print("\ttraj_len:           {}".format(args.traj_len))
    print("\tdiscount:           {}".format(args.discount))
    print("\tactor_lr:           {}".format(args.a_lr))
    print("\tcritic_lr:          {}".format(args.c_lr))
    print("\tadam eps:           {}".format(args.eps))
    print("\tentropy coeff:      {}".format(args.entropy_coeff))
    print("\tgrad clip:          {}".format(args.grad_clip))
    print("\tbatch size:         {}".format(args.batch_size))
    print("\tepochs:             {}".format(args.epochs))
    print("\tworkers:            {}".format(args.workers))
    print()

    itr = 0
    timesteps = 0
    best_reward = None
    while timesteps < args.timesteps:
        kl, a_loss, c_loss, steps = algo.do_iteration(
            args.num_steps,
            args.traj_len,
            args.epochs,
            batch_size=args.batch_size,
            kl_thresh=args.kl)
        eval_reward = eval_policy(algo.actor,
                                  env,
                                  False,
                                  min_timesteps=args.traj_len * 5,
                                  max_traj_len=args.traj_len,
                                  verbose=False)

        timesteps += steps
        print("iter {:4d} | return: {:5.2f} | KL {:5.4f} | timesteps {:n}".
              format(itr, eval_reward, kl, timesteps))

        if best_reward is None or eval_reward > best_reward:
            print("\t(best policy so far! saving to {})".format(
                args.save_actor))
            best_reward = eval_reward
            if args.save_actor is not None:
                torch.save(algo.actor, args.save_actor)

            if args.save_critic is not None:
                torch.save(algo.critic, args.save_critic)

        if logger is not None:
            logger.add_scalar(args.env_name + '/kl', kl, timesteps)
            logger.add_scalar(args.env_name + '/return', eval_reward,
                              timesteps)
            logger.add_scalar(args.env_name + '/actor loss', a_loss, timesteps)
            logger.add_scalar(args.env_name + '/critic loss', c_loss,
                              timesteps)
        itr += 1
    print("Finished ({} of {}).".format(timesteps, args.timesteps))
Exemple #3
0
    def __init__(self, args):

        self.logger = create_logger(args)
Exemple #4
0
def run_experiment(args):
    """
  The entry point for the QBN insertion algorithm. This function is called by r2l.py,
  and passed an args dictionary which contains hyperparameters for running the experiment.
  """
    locale.setlocale(locale.LC_ALL, '')

    from util.env import env_factory
    from util.log import create_logger

    if args.policy is None:
        print("You must provide a policy with --policy.")
        exit(1)

    policy = torch.load(args.policy)  # load policy to be discretized

    layertype = policy.layers[0].__class__.__name__
    if layertype != 'LSTMCell' and layertype != 'GRUCell':  # ensure that the policy loaded is actually recurrent
        print("Cannot do QBN insertion on a non-recurrent policy.")
        raise NotImplementedError

    if len(policy.layers
           ) > 1:  # ensure that the policy only has one hidden layer
        print(
            "Cannot do QBN insertion on a policy with more than one hidden layer."
        )
        raise NotImplementedError

    # retrieve dimensions of relevant quantities
    env_fn = env_factory(policy.env_name)
    obs_dim = env_fn().observation_space.shape[0]
    action_dim = env_fn().action_space.shape[0]
    hidden_dim = policy.layers[0].hidden_size

    # parse QBN layer sizes from command line arg
    layers = [int(x) for x in args.layers.split(',')]

    # create QBNs
    obs_qbn = QBN(obs_dim, layers=layers)
    hidden_qbn = QBN(hidden_dim, layers=layers)
    action_qbn = QBN(action_dim, layers=layers)
    if layertype == 'LSTMCell':
        cell_qbn = QBN(hidden_dim, layers=layers)
    else:
        cell_qbn = None

    # create optimizers for all QBNs
    obs_optim = optim.Adam(obs_qbn.parameters(), lr=args.lr, eps=1e-6)
    hidden_optim = optim.Adam(hidden_qbn.parameters(), lr=args.lr, eps=1e-6)
    action_optim = optim.Adam(action_qbn.parameters(), lr=args.lr, eps=1e-6)
    if layertype == 'LSTMCell':
        cell_optim = optim.Adam(cell_qbn.parameters(), lr=args.lr, eps=1e-6)

    best_reward = None

    if not args.nolog:
        logger = create_logger(args)
    else:
        logger = None

    actor_dir = os.path.split(args.policy)[0]

    ray.init()

    # evaluate policy without QBNs inserted to get baseline reward
    n_reward, _, _, _, _ = evaluate(policy, episodes=20)
    logger.add_scalar(policy.env_name + '_qbn/nominal_reward', n_reward, 0)

    # if generated data already exists at this directory, then just load that
    if os.path.exists(os.path.join(actor_dir, 'train_states.pt')):
        train_states = torch.load(os.path.join(actor_dir, 'train_states.pt'))
        train_actions = torch.load(os.path.join(actor_dir, 'train_actions.pt'))
        train_hiddens = torch.load(os.path.join(actor_dir, 'train_hiddens.pt'))

        test_states = torch.load(os.path.join(actor_dir, 'test_states.pt'))
        test_actions = torch.load(os.path.join(actor_dir, 'test_actions.pt'))
        test_hiddens = torch.load(os.path.join(actor_dir, 'test_hiddens.pt'))

        if layertype == 'LSTMCell':
            train_cells = torch.load(os.path.join(actor_dir, 'train_cells.pt'))
            test_cells = torch.load(os.path.join(actor_dir, 'test_cells.pt'))
    else:  # if no data exists and we need to generate some
        start = time.time()
        data = ray.get([
            collect_data.remote(policy, args.dataset / args.workers, 400,
                                np.random.randint(65535))
            for _ in range(args.workers)
        ])
        states = torch.from_numpy(np.vstack([r[0] for r in data]))
        actions = torch.from_numpy(np.vstack([r[1] for r in data]))
        hiddens = torch.from_numpy(np.vstack([r[2] for r in data]))
        if layertype == 'LSTMCell':
            cells = torch.from_numpy(np.vstack([r[3] for r in data]))

        split = int(0.8 * len(states))  # 80/20 train test split

        train_states, test_states = states[:split], states[split:]
        train_actions, test_actions = actions[:split], actions[split:]
        train_hiddens, test_hiddens = hiddens[:split], hiddens[split:]
        if layertype == 'LSTMCell':
            train_cells, test_cells = cells[:split], cells[split:]

        print(
            "{:3.2f} to collect {} timesteps.  Training set is {}, test set is {}"
            .format(time.time() - start, len(states), len(train_states),
                    len(test_states)))

        torch.save(train_states, os.path.join(actor_dir, 'train_states.pt'))
        torch.save(train_actions, os.path.join(actor_dir, 'train_actions.pt'))
        torch.save(train_hiddens, os.path.join(actor_dir, 'train_hiddens.pt'))
        if layertype == 'LSTMCell':
            torch.save(train_cells, os.path.join(actor_dir, 'train_cells.pt'))

        torch.save(test_states, os.path.join(actor_dir, 'test_states.pt'))
        torch.save(test_actions, os.path.join(actor_dir, 'test_actions.pt'))
        torch.save(test_hiddens, os.path.join(actor_dir, 'test_hiddens.pt'))
        if layertype == 'LSTMCell':
            torch.save(test_cells, os.path.join(actor_dir, 'test_cells.pt'))

    # run the nominal QBN training algorithm via unsupervised learning on the dataset
    for epoch in range(args.epochs):
        random_indices = SubsetRandomSampler(range(train_states.shape[0]))
        sampler = BatchSampler(random_indices,
                               args.batch_size,
                               drop_last=False)

        epoch_obs_losses = []
        epoch_hid_losses = []
        epoch_act_losses = []
        epoch_cel_losses = []
        for i, batch in enumerate(sampler):

            # get batch inputs from dataset
            batch_states = train_states[batch]
            batch_actions = train_actions[batch]
            batch_hiddens = train_hiddens[batch]
            if layertype == 'LSTMCell':
                batch_cells = train_cells[batch]

            # do forward pass to create derivative graph
            obs_loss = 0.5 * (batch_states -
                              obs_qbn(batch_states)).pow(2).mean()
            hid_loss = 0.5 * (batch_hiddens -
                              hidden_qbn(batch_hiddens)).pow(2).mean()
            act_loss = 0.5 * (batch_actions -
                              action_qbn(batch_actions)).pow(2).mean()
            if layertype == 'LSTMCell':
                cel_loss = 0.5 * (batch_cells -
                                  cell_qbn(batch_cells)).pow(2).mean()

            # gradient calculation and parameter updates
            obs_optim.zero_grad()
            obs_loss.backward()
            obs_optim.step()

            hidden_optim.zero_grad()
            hid_loss.backward()
            hidden_optim.step()

            action_optim.zero_grad()
            act_loss.backward()
            action_optim.step()

            if layertype == 'LSTMCell':
                cell_optim.zero_grad()
                cel_loss.backward()
                cell_optim.step()

            epoch_obs_losses.append(obs_loss.item())
            epoch_hid_losses.append(hid_loss.item())
            epoch_act_losses.append(act_loss.item())
            if layertype == 'LSTMCell':
                epoch_cel_losses.append(cel_loss.item())
            print("epoch {:3d} / {:3d}, batch {:3d} / {:3d}".format(
                epoch + 1, args.epochs, i + 1, len(sampler)),
                  end='\r')

        epoch_obs_losses = np.mean(epoch_obs_losses)
        epoch_hid_losses = np.mean(epoch_hid_losses)
        epoch_act_losses = np.mean(epoch_act_losses)
        if layertype == 'LSTMCell':
            epoch_cel_losses = np.mean(epoch_cel_losses)

        # collect some statistics about performance on the test set
        with torch.no_grad():
            state_loss = 0.5 * (test_states -
                                obs_qbn(test_states)).pow(2).mean()
            hidden_loss = 0.5 * (test_hiddens -
                                 hidden_qbn(test_hiddens)).pow(2).mean()
            act_loss = 0.5 * (test_actions -
                              action_qbn(test_actions)).pow(2).mean()
            if layertype == 'LSTMCell':
                cell_loss = 0.5 * (test_cells -
                                   cell_qbn(test_cells)).pow(2).mean()

        # evaluate QBN performance one-by-one
        print("\nEvaluating...")
        d_reward, s_states, h_states, c_states, a_states = evaluate(
            policy,
            obs_qbn=obs_qbn,
            hid_qbn=hidden_qbn,
            cel_qbn=cell_qbn,
            act_qbn=action_qbn)
        c_reward = 0.0
        if layertype == 'LSTMCell':
            c_reward, _, _, _, _ = evaluate(policy,
                                            obs_qbn=None,
                                            hid_qbn=None,
                                            cel_qbn=cell_qbn,
                                            act_qbn=None)
        h_reward, _, _, _, _ = evaluate(policy,
                                        obs_qbn=None,
                                        hid_qbn=hidden_qbn,
                                        cel_qbn=None,
                                        act_qbn=None)
        s_reward, _, _, _, _ = evaluate(policy,
                                        obs_qbn=obs_qbn,
                                        hid_qbn=None,
                                        cel_qbn=None,
                                        act_qbn=None)
        a_reward, _, _, _, _ = evaluate(policy,
                                        obs_qbn=None,
                                        hid_qbn=None,
                                        cel_qbn=None,
                                        act_qbn=action_qbn)

        if best_reward is None or d_reward > best_reward:
            torch.save(obs_qbn, os.path.join(logger.dir, 'obsqbn.pt'))
            torch.save(hidden_qbn, os.path.join(logger.dir, 'hidqbn.pt'))
            if layertype == 'LSTMCell':
                torch.save(cell_qbn, os.path.join(logger.dir, 'celqbn.pt'))

        if layertype == 'LSTMCell':
            print("Losses: {:7.5f} {:7.5f} {:7.5f}".format(
                state_loss, hidden_loss, cell_loss))
            print("States: {:5d} {:5d} {:5d}".format(s_states, h_states,
                                                     c_states))
            print(
                "QBN reward: {:5.1f} ({:5.1f}, {:5.1f}, {:5.1f}, {:5.1f}) | Nominal reward {:5.0f} "
                .format(d_reward, h_reward, s_reward, c_reward, a_reward,
                        n_reward))
        else:
            print("Losses: {:7.5f} {:7.5f}".format(state_loss, hidden_loss))
            print("States: {:5d} {:5d} ".format(s_states, h_states))
            print(
                "QBN reward: {:5.1f} ({:5.1f}, {:5.1f}, {:5.1f}) | Nominal reward {:5.0f} "
                .format(d_reward, h_reward, s_reward, a_reward, n_reward))

        if logger is not None:
            logger.add_scalar(policy.env_name + '_qbn/obs_loss', state_loss,
                              epoch)
            logger.add_scalar(policy.env_name + '_qbn/hidden_loss',
                              hidden_loss, epoch)
            logger.add_scalar(policy.env_name + '_qbn/qbn_reward', d_reward,
                              epoch)
            if layertype == 'LSTMCell':
                logger.add_scalar(policy.env_name + '_qbn/cell_loss',
                                  cell_loss, epoch)
                logger.add_scalar(policy.env_name + '_qbn/cellonly_reward',
                                  c_reward, epoch)
                logger.add_scalar(policy.env_name + '_qbn/cell_states',
                                  c_states, epoch)

            logger.add_scalar(policy.env_name + '_qbn/obsonly_reward',
                              s_reward, epoch)
            logger.add_scalar(policy.env_name + '_qbn/hiddenonly_reward',
                              h_reward, epoch)
            logger.add_scalar(policy.env_name + '_qbn/actiononly_reward',
                              a_reward, epoch)

            logger.add_scalar(policy.env_name + '_qbn/observation_states',
                              s_states, epoch)
            logger.add_scalar(policy.env_name + '_qbn/hidden_states', h_states,
                              epoch)
            logger.add_scalar(policy.env_name + '_qbn/action_states', a_states,
                              epoch)

    print("Training phase over. Beginning finetuning.")

    # initialize new optimizers, since the gradient magnitudes will likely change as we are calculating a different quantity.
    obs_optim = optim.Adam(obs_qbn.parameters(), lr=args.lr, eps=1e-6)
    hidden_optim = optim.Adam(hidden_qbn.parameters(), lr=args.lr, eps=1e-6)
    if layertype == 'LSTMCell':
        cell_optim = optim.Adam(cell_qbn.parameters(), lr=args.lr, eps=1e-6)
        optims = [obs_optim, hidden_optim, cell_optim, action_optim]
    else:
        optims = [obs_optim, hidden_optim, action_optim]

    optims = [action_optim]

    # run the finetuning portion of the QBN algorithm.
    for fine_iter in range(args.iterations):
        losses = []
        for ep in range(args.episodes):
            env = env_fn()
            state = torch.as_tensor(env.reset())

            done = False
            traj_len = 0
            if hasattr(policy, 'init_hidden_state'):
                policy.init_hidden_state()

            reward = 0
            while not done and traj_len < args.traj_len:
                with torch.no_grad():
                    state = torch.as_tensor(state).float()

                hidden = policy.hidden[0]
                #policy.hidden = [hidden_qbn(hidden)]

                if layertype == 'LSTMCell':
                    cell = policy.cells[0]
                    #policy.cells = [cell_qbn(cell)]

                # Compute qbn values
                qbn_action = action_qbn(policy(obs_qbn(state)))

                with torch.no_grad():
                    policy.hidden = [hidden]
                    if layertype == 'LSTMCell':
                        policy.cells = [cell]
                    action = policy(state)

                state, r, done, _ = env.step(action.numpy())
                reward += r
                traj_len += 1

                step_loss = 0.5 * (action - qbn_action).pow(
                    2
                )  # this creates the derivative graph for our backwards pass
                losses += [step_loss]

        # clear our parameter gradients
        for opt in optims:
            opt.zero_grad()

        # run the backwards pass
        losses = torch.stack(losses).mean()
        losses.backward()

        # update parameters
        for opt in optims:
            opt.step()

        # evaluate our QBN performance one-by-one
        print("\nEvaluating...")
        d_reward, s_states, h_states, c_states, a_states = evaluate(
            policy,
            obs_qbn=obs_qbn,
            hid_qbn=hidden_qbn,
            cel_qbn=cell_qbn,
            act_qbn=action_qbn)
        c_reward = 0.0
        if layertype == 'LSTMCell':
            c_reward, _, _, _, _ = evaluate(policy,
                                            obs_qbn=None,
                                            hid_qbn=None,
                                            cel_qbn=cell_qbn,
                                            act_qbn=None)
        h_reward, _, _, _, _ = evaluate(policy,
                                        obs_qbn=None,
                                        hid_qbn=hidden_qbn,
                                        cel_qbn=None,
                                        act_qbn=None)
        s_reward, _, _, _, _ = evaluate(policy,
                                        obs_qbn=obs_qbn,
                                        hid_qbn=None,
                                        cel_qbn=None,
                                        act_qbn=None)
        a_reward, _, _, _, _ = evaluate(policy,
                                        obs_qbn=None,
                                        hid_qbn=None,
                                        cel_qbn=None,
                                        act_qbn=action_qbn)

        if layertype == 'LSTMCell':
            print("Finetuning loss: {:7.5f}".format(losses))
            print("States: {:5d} {:5d} {:5d}".format(s_states, h_states,
                                                     c_states))
            print(
                "QBN reward: {:5.1f} ({:5.1f}, {:5.1f}, {:5.1f}) | Nominal reward {:5.0f} "
                .format(d_reward, h_reward, s_reward, c_reward, a_reward,
                        n_reward))
        else:
            print("Losses: {:7.5f} {:7.5f}".format(epoch_obs_losses,
                                                   epoch_hid_losses))
            print("States: {:5d} {:5d} ".format(s_states, h_states))
            print(
                "QBN reward: {:5.1f} ({:5.1f}, {:5.1f}) | Nominal reward {:5.0f} "
                .format(d_reward, h_reward, s_reward, a_reward, n_reward))

        if logger is not None:
            logger.add_scalar(policy.env_name + '_qbn/finetune_loss',
                              losses.item(), epoch + fine_iter)
            logger.add_scalar(policy.env_name + '_qbn/qbn_reward', d_reward,
                              epoch + fine_iter)
            if layertype == 'LSTMCell':
                logger.add_scalar(policy.env_name + '_qbn/cellonly_reward',
                                  c_reward, epoch + fine_iter)
                logger.add_scalar(policy.env_name + '_qbn/cell_states',
                                  c_states, epoch + fine_iter)

            logger.add_scalar(policy.env_name + '_qbn/obsonly_reward',
                              s_reward, epoch + fine_iter)
            logger.add_scalar(policy.env_name + '_qbn/hiddenonly_reward',
                              h_reward, epoch + fine_iter)
            logger.add_scalar(policy.env_name + '_qbn/actiononly_reward',
                              a_reward, epoch + fine_iter)

            logger.add_scalar(policy.env_name + '_qbn/observation_states',
                              s_states, epoch + fine_iter)
            logger.add_scalar(policy.env_name + '_qbn/hidden_states', h_states,
                              epoch + fine_iter)
            logger.add_scalar(policy.env_name + '_qbn/action_states', a_states,
                              epoch + fine_iter)

        if best_reward is None or d_reward > best_reward:
            torch.save(obs_qbn, os.path.join(logger.dir, 'obsqbn.pt'))
            torch.save(hidden_qbn, os.path.join(logger.dir, 'hidqbn.pt'))
            torch.save(cell_qbn, os.path.join(logger.dir, 'celqbn.pt'))
Exemple #5
0
def run_experiment(args):

  from util.env import env_factory, train_normalizer
  from util.log import create_logger

  from policies.critic import FF_V, LSTM_V, GRU_V
  from policies.actor import FF_Stochastic_Actor, LSTM_Stochastic_Actor, GRU_Stochastic_Actor, QBN_GRU_Stochastic_Actor

  import locale, os
  locale.setlocale(locale.LC_ALL, '')

  # wrapper function for creating parallelized envs
  env_fn = env_factory(args.env)
  obs_dim = env_fn().observation_space.shape[0]
  action_dim = env_fn().action_space.shape[0]

  # Set seeds
  torch.manual_seed(args.seed)
  np.random.seed(args.seed)

  std = torch.ones(action_dim)*args.std

  layers = [int(x) for x in args.layers.split(',')]

  if args.arch.lower() == 'lstm':
    policy = LSTM_Stochastic_Actor(obs_dim, action_dim, env_name=args.env, fixed_std=std, bounded=False, layers=layers)
    critic = LSTM_V(obs_dim, layers=layers)
  elif args.arch.lower() == 'gru':
    policy = GRU_Stochastic_Actor(obs_dim, action_dim, env_name=args.env, fixed_std=std, bounded=False, layers=layers)
    critic = GRU_V(obs_dim, layers=layers)
  elif args.arch.lower() == 'qbngru':
    policy = QBN_GRU_Stochastic_Actor(obs_dim, action_dim, env_name=args.env, fixed_std=std, bounded=False, layers=layers)
    critic = GRU_V(obs_dim, layers=layers)
  elif args.arch.lower() == 'ff':
    policy = FF_Stochastic_Actor(obs_dim, action_dim, env_name=args.env, fixed_std=std, bounded=False, layers=layers)
    critic = FF_V(obs_dim, layers=layers)
  else:
    raise RuntimeError
  policy.legacy = False
  env = env_fn()

  print("Collecting normalization statistics with {} states...".format(args.prenormalize_steps))
  train_normalizer(policy, args.prenormalize_steps, max_traj_len=args.traj_len, noise=1)

  critic.copy_normalizer_stats(policy)

  policy.train(0)
  critic.train(0)

  algo = PPO(policy, critic, env_fn, args)

  # create a tensorboard logging object
  if not args.nolog:
    logger = create_logger(args)
  else:
    logger = None

  if args.save_actor is None and logger is not None:
    args.save_actor = os.path.join(logger.dir, 'actor.pt')

  if args.save_critic is None and logger is not None:
    args.save_critic = os.path.join(logger.dir, 'critic.pt')

  print()
  print("Proximal Policy Optimization:")
  print("\tseed:               {}".format(args.seed))
  print("\tenv:                {}".format(args.env))
  print("\ttimesteps:          {:n}".format(int(args.timesteps)))
  print("\titeration steps:    {:n}".format(int(args.num_steps)))
  print("\tprenormalize steps: {}".format(int(args.prenormalize_steps)))
  print("\ttraj_len:           {}".format(args.traj_len))
  print("\tdiscount:           {}".format(args.discount))
  print("\tactor_lr:           {}".format(args.a_lr))
  print("\tcritic_lr:          {}".format(args.c_lr))
  print("\tadam eps:           {}".format(args.eps))
  print("\tentropy coeff:      {}".format(args.entropy_coeff))
  print("\tgrad clip:          {}".format(args.grad_clip))
  print("\tbatch size:         {}".format(args.batch_size))
  print("\tepochs:             {}".format(args.epochs))
  print("\tworkers:            {}".format(args.workers))
  print()

  itr = 0
  timesteps = 0
  best_reward = None
  while timesteps < args.timesteps:
    eval_reward, kl, a_loss, c_loss, m_loss, s_loss, steps, (times) = algo.do_iteration(args.num_steps, args.traj_len, args.epochs, batch_size=args.batch_size, kl_thresh=args.kl, mirror=args.mirror)

    timesteps += steps
    print("iter {:4d} | return: {:5.2f} | KL {:5.4f} | ".format(itr, eval_reward, kl, timesteps), end='')
    if m_loss != 0:
      print("mirror {:6.5f} | ".format(m_loss), end='')

    if s_loss != 0:
      print("sparsity {:6.5f} | ".format(s_loss), end='')

    print("timesteps {:n}".format(timesteps))

    if best_reward is None or eval_reward > best_reward:
      print("\t(best policy so far! saving to {})".format(args.save_actor))
      best_reward = eval_reward
      if args.save_actor is not None:
        torch.save(algo.actor, args.save_actor)
      
      if args.save_critic is not None:
        torch.save(algo.critic, args.save_critic)

    if logger is not None:
      logger.add_scalar(args.env + '/kl', kl, timesteps)
      logger.add_scalar(args.env + '/return', eval_reward, timesteps)
      logger.add_scalar(args.env + '/actor loss', a_loss, timesteps)
      logger.add_scalar(args.env + '/critic loss', c_loss, timesteps)
      logger.add_scalar(args.env + '/mirror loss', m_loss, timesteps)
      logger.add_scalar(args.env + '/sparsity loss', s_loss, timesteps)
      logger.add_scalar(args.env + '/sample rate', times[0], timesteps)
      logger.add_scalar(args.env + '/update time', times[1], timesteps)
    itr += 1
  print("Finished ({} of {}).".format(timesteps, args.timesteps))
Exemple #6
0
def run_experiment(args):
    from util.env import env_factory
    from util.log import create_logger

    # wrapper function for creating parallelized envs
    env_fn = env_factory(args.env_name,
                         simrate=args.simrate,
                         command_profile=args.command_profile,
                         input_profile=args.input_profile,
                         learn_gains=args.learn_gains,
                         dynamics_randomization=args.dyn_random,
                         reward=args.reward,
                         history=args.history,
                         mirror=args.mirror,
                         ik_baseline=args.ik_baseline,
                         no_delta=args.no_delta,
                         traj=args.traj)
    obs_dim = env_fn().observation_space.shape[0]
    action_dim = env_fn().action_space.shape[0]

    # Set up Parallelism
    os.environ['OMP_NUM_THREADS'] = '1'
    if not ray.is_initialized():
        if args.redis_address is not None:
            ray.init(num_cpus=args.num_procs, redis_address=args.redis_address)
        else:
            ray.init(num_cpus=args.num_procs)

    # Set seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.previous is not None:
        policy = torch.load(os.path.join(args.previous, "actor.pt"))
        critic = torch.load(os.path.join(args.previous, "critic.pt"))
        # TODO: add ability to load previous hyperparameters, if this is something that we event want
        # with open(args.previous + "experiment.pkl", 'rb') as file:
        #     args = pickle.loads(file.read())
        print("loaded model from {}".format(args.previous))
    else:
        if args.recurrent:
            policy = Gaussian_LSTM_Actor(obs_dim,
                                         action_dim,
                                         fixed_std=np.exp(-2),
                                         env_name=args.env_name)
            critic = LSTM_V(obs_dim)
        else:
            if args.learn_stddev:
                policy = Gaussian_FF_Actor(obs_dim,
                                           action_dim,
                                           fixed_std=None,
                                           env_name=args.env_name,
                                           bounded=args.bounded)
            else:
                policy = Gaussian_FF_Actor(obs_dim,
                                           action_dim,
                                           fixed_std=np.exp(args.std_dev),
                                           env_name=args.env_name,
                                           bounded=args.bounded)
            critic = FF_V(obs_dim)

        with torch.no_grad():
            policy.obs_mean, policy.obs_std = map(
                torch.Tensor,
                get_normalization_params(iter=args.input_norm_steps,
                                         noise_std=1,
                                         policy=policy,
                                         env_fn=env_fn,
                                         procs=args.num_procs))
        critic.obs_mean = policy.obs_mean
        critic.obs_std = policy.obs_std

    policy.train()
    critic.train()

    print("obs_dim: {}, action_dim: {}".format(obs_dim, action_dim))

    # create a tensorboard logging object
    logger = create_logger(args)

    algo = PPO(args=vars(args), save_path=logger.dir)

    print()
    print("Synchronous Distributed Proximal Policy Optimization:")
    print(" ├ recurrent:      {}".format(args.recurrent))
    print(" ├ run name:       {}".format(args.run_name))
    print(" ├ max traj len:   {}".format(args.max_traj_len))
    print(" ├ seed:           {}".format(args.seed))
    print(" ├ num procs:      {}".format(args.num_procs))
    print(" ├ lr:             {}".format(args.lr))
    print(" ├ eps:            {}".format(args.eps))
    print(" ├ lam:            {}".format(args.lam))
    print(" ├ gamma:          {}".format(args.gamma))
    print(" ├ learn stddev:  {}".format(args.learn_stddev))
    print(" ├ std_dev:        {}".format(args.std_dev))
    print(" ├ entropy coeff:  {}".format(args.entropy_coeff))
    print(" ├ clip:           {}".format(args.clip))
    print(" ├ minibatch size: {}".format(args.minibatch_size))
    print(" ├ epochs:         {}".format(args.epochs))
    print(" ├ num steps:      {}".format(args.num_steps))
    print(" ├ use gae:        {}".format(args.use_gae))
    print(" ├ max grad norm:  {}".format(args.max_grad_norm))
    print(" └ max traj len:   {}".format(args.max_traj_len))
    print()

    algo.train(env_fn,
               policy,
               critic,
               args.n_itr,
               logger=logger,
               anneal_rate=args.anneal)
Exemple #7
0
def run_experiment(args):

    # wrapper function for creating parallelized envs
    env_thunk = env_factory(args.env_name)
    with env_thunk() as env:
        obs_space = env.observation_space.shape[0]
        act_space = env.action_space.shape[0]

    # wrapper function for creating parallelized policies
    def policy_thunk():
        from rl.policies.actor import FF_Actor, LSTM_Actor, Linear_Actor
        if args.load_model is not None:
            return torch.load(args.load_model)
        else:
            if not args.recurrent:
                policy = Linear_Actor(obs_space,
                                      act_space,
                                      hidden_size=args.hidden_size).float()
            else:
                policy = LSTM_Actor(obs_space,
                                    act_space,
                                    hidden_size=args.hidden_size).float()

            # policy parameters should be zero initialized according to ARS paper
            for p in policy.parameters():
                p.data = torch.zeros(p.shape)
            return policy

    # the 'black box' function that will get passed into ARS
    def eval_fn(policy,
                env,
                reward_shift,
                traj_len,
                visualize=False,
                normalize=False):
        if hasattr(policy, 'init_hidden_state'):
            policy.init_hidden_state()

        state = torch.tensor(env.reset()).float()
        rollout_reward = 0
        done = False

        timesteps = 0
        while not done and timesteps < traj_len:
            if normalize:
                state = policy.normalize_state(state)
            action = policy.forward(state).detach().numpy()
            state, reward, done, _ = env.step(action)
            state = torch.tensor(state).float()
            rollout_reward += reward - reward_shift
            timesteps += 1
        return rollout_reward, timesteps

    import locale
    locale.setlocale(locale.LC_ALL, '')

    print("Augmented Random Search:")
    print("\tenv:          {}".format(args.env_name))
    print("\tseed:         {}".format(args.seed))
    print("\ttimesteps:    {:n}".format(args.timesteps))
    print("\tstd:          {}".format(args.std))
    print("\tdeltas:       {}".format(args.deltas))
    print("\tstep size:    {}".format(args.lr))
    print("\treward shift: {}".format(args.reward_shift))
    print()
    algo = ARS(policy_thunk,
               env_thunk,
               deltas=args.deltas,
               step_size=args.lr,
               std=args.std,
               workers=args.workers,
               redis_addr=args.redis)

    if args.algo not in ['v1', 'v2']:
        print("Valid arguments for --algo are 'v1' and 'v2'")
        exit(1)
    elif args.algo == 'v2':
        normalize_states = True
    else:
        normalize_states = False

    def black_box(p, env):
        return eval_fn(p,
                       env,
                       args.reward_shift,
                       args.traj_len,
                       normalize=normalize_states)

    avg_reward = 0
    timesteps = 0
    i = 0

    logger = create_logger(args)

    #   if args.save_model is None:
    #     args.save_model = os.path.join(logger.dir, 'actor.pt')

    args.save_model = os.path.join(logger.dir, 'actor.pt')

    env = env_thunk()
    while timesteps < args.timesteps:
        if not i % args.average_every:
            avg_reward = 0
            print()

        start = time.time()
        samples = algo.step(black_box)
        elapsed = time.time() - start
        iter_reward = 0
        for eval_rollout in range(10):
            reward, _ = eval_fn(algo.policy,
                                env,
                                0,
                                args.traj_len,
                                normalize=normalize_states)
            iter_reward += reward / 10

        timesteps += samples
        avg_reward += iter_reward
        secs_per_sample = 1000 * elapsed / samples
        print(("iter {:4d} | "
               "ret {:6.2f} | "
               "last {:3d} iters: {:6.2f} | "
               "{:0.4f}s per 1k steps | "
               "timesteps {:10n}").format(i+1,  \
                iter_reward, (i%args.average_every)+1,      \
                avg_reward/((i%args.average_every)+1), \
                secs_per_sample, timesteps),    \
                end="\r")
        i += 1

        logger.add_scalar('eval', iter_reward, timesteps)
        torch.save(algo.policy, args.save_model)
Exemple #8
0
def run_experiment(args):
    from policies.critic import FF_Q, LSTM_Q
    from policies.actor import FF_Stochastic_Actor, LSTM_Stochastic_Actor, FF_Actor, LSTM_Actor

    locale.setlocale(locale.LC_ALL, '')

    # wrapper function for creating parallelized envs
    env = env_factory(args.env_name)()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if hasattr(env, 'seed'):
        env.seed(args.seed)

    obs_space = env.observation_space.shape[0]
    act_space = env.action_space.shape[0]

    replay_buff = ReplayBuffer(obs_space, act_space, args.timesteps)

    if args.recurrent:
        print('Recurrent ', end='')
        q1 = LSTM_Q(obs_space, act_space, env_name=args.env_name)
        q2 = LSTM_Q(obs_space, act_space, env_name=args.env_name)

        if args.algo == 'sac':
            actor = LSTM_Stochastic_Actor(obs_space,
                                          act_space,
                                          env_name=args.env_name,
                                          bounded=True)
        else:
            actor = LSTM_Actor(obs_space, act_space, env_name=args.env_name)
    else:
        q1 = FF_Q(obs_space, act_space, env_name=args.env_name)
        q2 = FF_Q(obs_space, act_space, env_name=args.env_name)

        if args.algo == 'sac':
            actor = FF_Stochastic_Actor(obs_space,
                                        act_space,
                                        env_name=args.env_name,
                                        bounded=True)
        else:
            actor = FF_Actor(obs_space, act_space, env_name=args.env_name)

    if args.algo == 'sac':
        print('Soft Actor-Critic')
        algo = SAC(actor, q1, q2, torch.prod(torch.Tensor(env.reset().shape)),
                   args)
    elif args.algo == 'td3':
        print('Twin-Delayed Deep Deterministic Policy Gradient')
        algo = TD3(actor, q1, q2, args)
    elif args.algo == 'ddpg':
        print('Deep Deterministic Policy Gradient')
        algo = DDPG(actor, q1, args)

    print("\tenv:            {}".format(args.env_name))
    print("\tseed:           {}".format(args.seed))
    print("\ttimesteps:      {:n}".format(args.timesteps))
    print("\tactor_lr:       {}".format(args.a_lr))
    print("\tcritic_lr:      {}".format(args.c_lr))
    print("\tdiscount:       {}".format(args.discount))
    print("\ttau:            {}".format(args.tau))
    print("\tbatch_size:     {}".format(args.batch_size))
    print("\twarmup period:  {:n}".format(args.start_timesteps))
    print()

    iter = 0
    episode_reward = 0
    episode_timesteps = 0

    # create a tensorboard logging object
    logger = create_logger(args)

    if args.save_actor is None:
        args.save_actor = os.path.join(logger.dir, 'actor.pt')

    # Keep track of some statistics for each episode
    training_start = time.time()
    episode_start = time.time()
    episode_loss = 0
    update_steps = 0
    best_reward = None

    #eval_policy(algo.actor, min_timesteps=args.prenormalize_steps, max_traj_len=args.max_traj_len, visualize=False
    train_normalizer(algo.actor,
                     args.prenormalize_steps,
                     noise=algo.expl_noise)

    # Fill replay buffer, update policy until n timesteps have passed
    timesteps = 0
    state = env.reset().astype(np.float32)
    while timesteps < args.timesteps:
        buffer_ready = (algo.recurrent
                        and replay_buff.trajectories > args.batch_size) or (
                            not algo.recurrent
                            and replay_buff.size > args.batch_size)
        warmup = timesteps < args.start_timesteps

        state, r, done = collect_experience(algo.actor,
                                            env,
                                            replay_buff,
                                            state,
                                            episode_timesteps,
                                            max_len=args.traj_len,
                                            noise=algo.expl_noise)
        episode_reward += r
        episode_timesteps += 1
        timesteps += 1

        if not buffer_ready or warmup:
            iter = 0

        # Update the policy once our replay buffer is big enough
        if buffer_ready and done and not warmup:
            update_steps = 0
            if not algo.recurrent:
                num_updates = episode_timesteps
            else:
                num_updates = 1

            losses = []
            for _ in range(num_updates):
                losses.append(
                    algo.update_policy(replay_buff,
                                       args.batch_size,
                                       traj_len=args.traj_len))

            episode_elapsed = (time.time() - episode_start)
            episode_secs_per_sample = episode_elapsed / episode_timesteps

            actor_loss = np.mean([loss[0] for loss in losses])
            critic_loss = np.mean([loss[1] for loss in losses])
            update_steps = sum([loss[-1] for loss in losses])

            logger.add_scalar(args.env_name + '/actor loss', actor_loss,
                              timesteps - args.start_timesteps)
            logger.add_scalar(args.env_name + '/critic loss', critic_loss,
                              timesteps - args.start_timesteps)
            logger.add_scalar(args.env_name + '/update steps', update_steps,
                              timesteps - args.start_timesteps)

            if args.algo == 'sac':
                alpha_loss = np.mean([loss[2] for loss in losses])
                logger.add_scalar(args.env_name + '/alpha loss', alpha_loss,
                                  timesteps - args.start_timesteps)

            completion = 1 - float(timesteps) / args.timesteps
            avg_sample_r = (time.time() - training_start) / timesteps
            secs_remaining = avg_sample_r * args.timesteps * completion
            hrs_remaining = int(secs_remaining // (60 * 60))
            min_remaining = int(secs_remaining - hrs_remaining * 60 * 60) // 60

            if iter % args.eval_every == 0 and iter != 0:
                eval_reward = eval_policy(algo.actor,
                                          min_timesteps=1000,
                                          verbose=False,
                                          visualize=False,
                                          max_traj_len=args.traj_len)
                logger.add_scalar(args.env_name + '/return', eval_reward,
                                  timesteps - args.start_timesteps)

                print(
                    "evaluation after {:4d} episodes | return: {:7.3f} | timesteps {:9n}{:100s}"
                    .format(iter, eval_reward,
                            timesteps - args.start_timesteps, ''))
                if best_reward is None or eval_reward > best_reward:
                    torch.save(algo.actor, args.save_actor)
                    best_reward = eval_reward
                    print("\t(best policy so far! saving to {})".format(
                        args.save_actor))

        try:
            print(
                "episode {:5d} | episode timestep {:5d}/{:5d} | return {:5.1f} | update timesteps: {:7n} | {:3.1f}s/1k samples | approx. {:3d}h {:02d}m remain\t\t\t\t"
                .format(iter, episode_timesteps, args.traj_len, episode_reward,
                        update_steps, 1000 * episode_secs_per_sample,
                        hrs_remaining, min_remaining),
                end='\r')

        except NameError:
            pass

        if done:
            if hasattr(algo.actor, 'init_hidden_state'):
                algo.actor.init_hidden_state()

            episode_start, episode_reward, episode_timesteps, episode_loss = time.time(
            ), 0, 0, 0
            iter += 1
Exemple #9
0
def run_experiment(args):
    """
  The entry point for the dynamics extraction algorithm.
  """
    from util.log import create_logger

    locale.setlocale(locale.LC_ALL, '')

    policy = torch.load(args.policy)

    legacy = 'legacy' if not (hasattr(policy, 'legacy')
                              and policy.legacy == False) else ''
    env_fn = env_factory(policy.env_name + legacy)

    layers = [int(x) for x in args.layers.split(',')]

    env = env_fn()
    policy.init_hidden_state()
    policy(torch.tensor(env.reset()).float())
    latent_dim = get_hiddens(policy).shape[0]

    models = []
    opts = []
    for fn in [env.get_friction, env.get_damping, env.get_mass, env.get_quat]:
        output_dim = fn().shape[0]
        model = Model(latent_dim, output_dim, layers=layers)
        models += [model]
        opts += [optim.Adam(model.parameters(), lr=args.lr, eps=1e-5)]

    logger = create_logger(args)

    best_loss = None
    actor_dir = os.path.split(args.policy)[0]
    create_new = True
    #if os.path.exists(os.path.join(actor_dir, 'test_latents.pt')):
    if False:
        x = torch.load(os.path.join(logger.dir, 'train_latents.pt'))
        test_x = torch.load(os.path.join(logger.dir, 'test_latents.pt'))

        train_frics = torch.load(os.path.join(logger.dir, 'train_frics.pt'))
        test_frics = torch.load(os.path.join(logger.dir, 'test_frics.pt'))

        train_damps = torch.load(os.path.join(logger.dir, 'train_damps.pt'))
        test_damps = torch.load(os.path.join(logger.dir, 'test_damps.pt'))

        train_masses = torch.load(os.path.join(logger.dir, 'train_masses.pt'))
        test_masses = torch.load(os.path.join(logger.dir, 'test_masses.pt'))

        train_quats = torch.load(os.path.join(logger.dir, 'train_quats.pt'))
        test_quats = torch.load(os.path.join(logger.dir, 'test_quats.pt'))

        if args.points > len(x) + len(y):
            create_new = True
        else:
            create_new = False

    if create_new:
        if not ray.is_initialized():
            if args.redis is not None:
                ray.init(redis_address=args.redis)
            else:
                ray.init(num_cpus=args.workers)

        print("Collecting {:4d} timesteps of data.".format(args.points))
        points_per_worker = max(args.points // args.workers, 1)
        start = time.time()

        frics, damps, masses, quats, x = concat(
            ray.get([
                collect_data.remote(policy, points=points_per_worker)
                for _ in range(args.workers)
            ]))

        split = int(0.8 * len(x))

        test_x = x[split:]
        x = x[:split]

        test_frics = frics[split:]
        frics = frics[:split]

        test_damps = damps[split:]
        damps = damps[:split]

        test_masses = masses[split:]
        masses = masses[:split]

        test_quats = quats[split:]
        quats = quats[:split]

        print(
            "{:3.2f} to collect {} timesteps.  Training set is {}, test set is {}"
            .format(time.time() - start,
                    len(x) + len(test_x), len(x), len(test_x)))
        torch.save(x, os.path.join(logger.dir, 'train_latents.pt'))
        torch.save(test_x, os.path.join(logger.dir, 'test_latents.pt'))

        torch.save(frics, os.path.join(logger.dir, 'train_frics.pt'))
        torch.save(test_frics, os.path.join(logger.dir, 'test_frics.pt'))

        torch.save(damps, os.path.join(logger.dir, 'train_damps.pt'))
        torch.save(test_damps, os.path.join(logger.dir, 'test_damps.pt'))

        torch.save(masses, os.path.join(logger.dir, 'train_masses.pt'))
        torch.save(test_masses, os.path.join(logger.dir, 'test_masses.pt'))

        torch.save(quats, os.path.join(logger.dir, 'train_quats.pt'))
        torch.save(test_quats, os.path.join(logger.dir, 'test_quats.pt'))

    for epoch in range(args.epochs):

        random_indices = SubsetRandomSampler(range(len(x) - 1))
        sampler = BatchSampler(random_indices,
                               args.batch_size,
                               drop_last=False)

        for j, batch_idx in enumerate(sampler):
            batch_x = x[batch_idx]  #.float()
            #batch_fric = frics[batch_idx]
            #batch_damp = damps[batch_idx]
            #batch_mass = masses[batch_idx]
            #batch_quat = quats[batch_idx]
            batch = [
                frics[batch_idx], damps[batch_idx], masses[batch_idx],
                quats[batch_idx]
            ]

            losses = []
            for model, batch_y, opt in zip(models, batch, opts):
                loss = 0.5 * (batch_y - model(batch_x)).pow(2).mean()

                opt.zero_grad()
                loss.backward()
                opt.step()

                losses.append(loss.item())

            print("Epoch {:3d} batch {:4d}/{:4d}      ".format(
                epoch, j,
                len(sampler) - 1),
                  end='\r')

        train_y = [frics, damps, masses, quats]
        test_y = [test_frics, test_damps, test_masses, test_quats]
        order = ['friction', 'damping', 'mass', 'slope']

        with torch.no_grad():
            print("\nEpoch {:3d} losses:".format(epoch))
            for model, y_tr, y_te, name in zip(models, train_y, test_y, order):
                loss_total = 0.5 * (y_tr - model(x)).pow(2).mean().item()

                preds = model(test_x)
                test_loss = 0.5 * (y_te - preds).pow(2).mean().item()
                pce = torch.mean(torch.abs((y_te - preds) / y_te))
                err = torch.mean(torch.abs((y_te - preds)))

                logger.add_scalar(logger.arg_hash + '/' + name + '_loss',
                                  test_loss, epoch)
                logger.add_scalar(logger.arg_hash + '/' + name + '_percenterr',
                                  pce, epoch)
                logger.add_scalar(logger.arg_hash + '/' + name + '_abserr',
                                  err, epoch)
                torch.save(model,
                           os.path.join(logger.dir, name + '_extractor.pt'))
                print("\t{:16s}: train {:7.6f} test {:7.6f}".format(
                    name, loss_total, test_loss))
Exemple #10
0
def run_experiment(args):
    from policies.critic import FF_Q, LSTM_Q, GRU_Q
    from policies.actor import FF_Stochastic_Actor, LSTM_Stochastic_Actor, GRU_Stochastic_Actor, FF_Actor, LSTM_Actor, GRU_Actor

    locale.setlocale(locale.LC_ALL, '')

    # wrapper function for creating parallelized envs
    env = env_factory(args.env)()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if hasattr(env, 'seed'):
        env.seed(args.seed)

    obs_space = env.observation_space.shape[0]
    act_space = env.action_space.shape[0]

    replay_buff = ReplayBuffer(args.buffer)

    layers = [int(x) for x in args.layers.split(',')]

    if args.arch == 'lstm':
        q1 = LSTM_Q(obs_space, act_space, env_name=args.env, layers=layers)
        q2 = LSTM_Q(obs_space, act_space, env_name=args.env, layers=layers)

        if args.algo == 'sac':
            actor = LSTM_Stochastic_Actor(obs_space,
                                          act_space,
                                          env_name=args.env,
                                          bounded=True,
                                          layers=layers)
        else:
            actor = LSTM_Actor(obs_space,
                               act_space,
                               env_name=args.env,
                               layers=layers)
    elif args.arch == 'gru':
        q1 = GRU_Q(obs_space, act_space, env_name=args.env, layers=layers)
        q2 = GRU_Q(obs_space, act_space, env_name=args.env, layers=layers)

        if args.algo == 'sac':
            actor = GRU_Stochastic_Actor(obs_space,
                                         act_space,
                                         env_name=args.env,
                                         bounded=True,
                                         layers=layers)
        else:
            actor = GRU_Actor(obs_space,
                              act_space,
                              env_name=args.env,
                              layers=layers)
    elif args.arch == 'ff':
        q1 = FF_Q(obs_space, act_space, env_name=args.env, layers=layers)
        q2 = FF_Q(obs_space, act_space, env_name=args.env, layers=layers)

        if args.algo == 'sac':
            actor = FF_Stochastic_Actor(obs_space,
                                        act_space,
                                        env_name=args.env,
                                        bounded=True,
                                        layers=layers)
        else:
            actor = FF_Actor(obs_space,
                             act_space,
                             env_name=args.env,
                             layers=layers)

    if args.algo == 'sac':
        print('Soft Actor-Critic')
        algo = SAC(actor, q1, q2, torch.prod(torch.Tensor(env.reset().shape)),
                   args)
    elif args.algo == 'td3':
        print('Twin-Delayed Deep Deterministic Policy Gradient')
        algo = TD3(actor, q1, q2, args)
    elif args.algo == 'ddpg':
        print('Deep Deterministic Policy Gradient')
        algo = DDPG(actor, q1, args)

    print("\tenv:            {}".format(args.env))
    print("\tseed:           {}".format(args.seed))
    print("\ttimesteps:      {:n}".format(args.timesteps))
    print("\tactor_lr:       {}".format(args.a_lr))
    print("\tcritic_lr:      {}".format(args.c_lr))
    print("\tdiscount:       {}".format(args.discount))
    print("\ttau:            {}".format(args.tau))
    print("\tbatch_size:     {}".format(args.batch_size))
    print("\twarmup period:  {:n}".format(args.start_timesteps))
    print("\tworkers:        {}".format(args.workers))
    print("\tlayers:         {}".format(args.layers))
    print()

    # create a tensorboard logging object
    logger = create_logger(args)

    if args.save_actor is None:
        args.save_actor = os.path.join(logger.dir, 'actor.pt')

    # Keep track of some statistics for each episode
    training_start = time.time()
    episode_start = time.time()
    best_reward = None

    if not ray.is_initialized():
        #if args.redis is not None:
        #  ray.init(redis_address=args.redis)
        #else:
        #  ray.init(num_cpus=args.workers)
        ray.init(num_cpus=args.workers)

    workers = [
        Off_Policy_Worker.remote(actor, env_factory(args.env))
        for _ in range(args.workers)
    ]

    train_normalizer(algo.actor,
                     args.prenormalize_steps,
                     noise=algo.expl_noise)

    # Fill replay buffer, update policy until n timesteps have passed
    timesteps, i = 0, 0
    state = env.reset().astype(np.float32)
    while i < args.iterations:
        if timesteps < args.timesteps:
            actor_param_id = ray.put(list(algo.actor.parameters()))
            norm_id = ray.put([
                algo.actor.welford_state_mean,
                algo.actor.welford_state_mean_diff, algo.actor.welford_state_n
            ])

            for w in workers:
                w.sync_policy.remote(actor_param_id, input_norm=norm_id)

            buffers = ray.get([
                w.collect_episode.remote(args.expl_noise, args.traj_len)
                for w in workers
            ])

            replay_buff.merge_with(buffers)

            timesteps += sum(len(b.states) for b in buffers)

            #for i in range(len(replay_buff.traj_idx)-1):
            #  for j in range(replay_buff.traj_idx[i], replay_buff.traj_idx[i+1]):
            #    print("traj {:2d} timestep {:3d}, not done {}, reward {},".format(i, j, replay_buff.not_dones[j], replay_buff.rewards[j]))

        if (algo.recurrent and len(replay_buff.traj_idx) > args.batch_size
            ) or (not algo.recurrent and replay_buff.size > args.batch_size):
            i += 1
            loss = []
            for _ in range(args.updates):
                loss.append(
                    algo.update_policy(replay_buff,
                                       batch_size=args.batch_size))
            loss = np.mean(loss, axis=0)

            print('algo {:4s} | explored: {:5n} of {:5n}'.format(
                args.algo, timesteps, args.timesteps),
                  end=' | ')
            if args.algo == 'ddpg':
                print(
                    'iteration {:6n} | actor loss {:6.4f} | critic loss {:6.4f} | buffer size {:6n} / {:6n} ({:4n} trajectories) | {:60s}'
                    .format(i + 1, loss[0], loss[1],
                            replay_buff.size, replay_buff.max_size,
                            len(replay_buff.traj_idx), ''))
                logger.add_scalar(args.env + '/actor loss', loss[0], i)
                logger.add_scalar(args.env + '/critic loss', loss[1], i)
            if args.algo == 'td3':
                print(
                    'iteration {:6n} | actor loss {:6.4f} | critic loss {:6.4f} | buffer size {:6n} / {:6n} ({:4n} trajectories) | {:60s}'
                    .format(i + 1, loss[0], loss[1],
                            replay_buff.size, replay_buff.max_size,
                            len(replay_buff.traj_idx), ''))
                logger.add_scalar(args.env + '/actor loss', loss[0], i)
                logger.add_scalar(args.env + '/critic loss', loss[1], i)

            if i % args.eval_every == 0 and iter != 0:
                eval_reward = eval_policy(algo.actor, env, 5, args.traj_len)
                logger.add_scalar(args.env + '/return', eval_reward, i)

                print("evaluation after {:4d} iterations | return: {:7.3f}".
                      format(i, eval_reward, ''))
                if best_reward is None or eval_reward > best_reward:
                    torch.save(algo.actor, args.save_actor)
                    best_reward = eval_reward
                    print("\t(best policy so far! saving to {})".format(
                        args.save_actor))

        else:
            if algo.recurrent:
                print(
                    "Collected {:5d} of {:5d} warmup trajectories \t\t".format(
                        len(replay_buff.traj_idx), args.batch_size),
                    end='\r')
            else:
                print(
                    "Collected {:5d} of {:5d} warmup trajectories \t\t".format(
                        replay_buff.size, args.batch_size),
                    end='\r')
        """