def _train_policy(self, dataset):
        """
        Train the model-based policy

        implementation details:
            (a) Train for self._training_epochs number of epochs
            (b) The dataset.random_iterator(...)  method will iterate through the dataset once in a random order
            (c) Use self._training_batch_size for iterating through the dataset
            (d) Keep track of the loss values by appending them to the losses array
        """
        timeit.start('train policy')

        losses = []
        ### PROBLEM 1
        ### YOUR CODE HERE
        for _ in range(self._training_epochs):
            for states, actions, next_states, _, _ in dataset.random_iterator(self._training_batch_size):
                # import sys; print(sys._getframe().f_code.co_name,sys._getframe().f_lineno)
                # from IPython import embed; embed()
                losses.append(self._policy.train_step(states, actions, next_states))

        logger.record_tabular('TrainingLossStart', losses[0])
        logger.record_tabular('TrainingLossFinal', losses[-1])

        timeit.stop('train policy')
예제 #2
0
    def _train_policy(self, dataset):
        """
        Train the model-based policy

        implementation details:
            (a) Train for self._training_epochs number of epochs
            (b) The dataset.random_iterator(...)  method will iterate through the dataset once in a random order
            (c) Use self._training_batch_size for iterating through the dataset
            (d) Keep track of the loss values by appending them to the losses array
        """
        timeit.start('train policy')

        losses = []
        ### PROBLEM 1
        ### YOUR CODE HERE
        # (a) Train for self._training_epochs number of epochs
        for _ in range(self._training_epochs):
            # (b) The dataset.random_iterator(...)  method will iterate through the dataset once in a random order
            # (c) Use self._training_batch_size for iterating through the dataset
            epoch_losses = []

            for _, (states, actions, next_states, _, _) in enumerate(
                    dataset.random_iterator(self._training_batch_size)):
                loss = self._policy.train_step(states, actions, next_states)

                epoch_losses.append(loss)
            # (d) Keep track of the loss values by appending them to the losses array
            losses.append(np.mean(epoch_losses))

        logger.record_tabular('TrainingLossStart', losses[0])
        logger.record_tabular('TrainingLossFinal', losses[-1])

        timeit.stop('train policy')
    def _train_policy(self, dataset):
        """
        Train the model-based policy

        implementation details:
            (a) Train for self._training_epochs number of epochs
            (b) The dataset.random_iterator(...)  method will iterate through the dataset once in a random order
            (c) Use self._training_batch_size for iterating through the dataset
            (d) Keep track of the loss values by appending them to the losses array
        """
        timeit.start('train policy')

        losses = []
        ### PROBLEM 1
        ### YOUR CODE HERE
        # raise NotImplementedError
        for _ in range(self._training_epochs):
            current_batches = dataset.random_iterator(self._training_batch_size)
            while True:
                state, action, next_state, _, _ = \
                    next(current_batches, [None] * 5)
                if state is None:
                    break
                loss = self._policy.train_step(state, action, next_state)
                losses.append(loss)

        logger.record_tabular('TrainingLossStart', losses[0])
        logger.record_tabular('TrainingLossFinal', losses[-1])

        timeit.stop('train policy')

        plt.figure()
        plt.plot(losses)
        plt.savefig(os.path.join(logger.dir, 'training.png'))
예제 #4
0
    def _train_policy(self, dataset):
        """
        Train the model-based policy

        implementation details:
            (a) Train for self._training_epochs number of epochs
            (b) The dataset.random_iterator(...)  method will iterate through the dataset once in a random order
            (c) Use self._training_batch_size for iterating through the dataset
            (d) Keep track of the loss values by appending them to the losses array
        """
        timeit.start('train policy')

        losses = []
        ### PROBLEM 1
        ### YOUR CODE HERE

        for epoch in range(self._training_epochs):
            for state, action, next_state, _, _ in dataset.random_iterator(
                    self._training_batch_size):
                loss = self._policy.train_step(states=state,
                                               actions=action,
                                               next_states=next_state)
                losses.append(loss)

        # raise NotImplementedError
        logger.record_tabular('TrainingLossStart', losses[0])
        logger.record_tabular('TrainingLossFinal', losses[-1])

        timeit.stop('train policy')
예제 #5
0
    def _train_policy(self, dataset):
        """
        Train the model-based policy

        implementation details:
            (a) Train for self._training_epochs number of epochs
            (b) The dataset.random_iterator(...)  method will iterate through the dataset once in a random order
            (c) Use self._training_batch_size for iterating through the dataset
            (d) Keep track of the loss values by appending them to the losses array
        """
        timeit.start('train policy')

        losses = []
        ### PROBLEM 1
        ### YOUR CODE HERE
        for epoch in range(self._training_epochs):
            for states, actions, next_states, _, _ in dataset.random_iterator(
                    self._training_batch_size):
                loss = self._policy.train_step(states, actions, next_states)
                losses.append(loss)
            # self._random_dataset_test = self._gather_rollouts(self._policy, 2)
            # for states, actions, next_states, _, _ in self._random_dataset_test.random_iterator(len(self._random_dataset_test)):
            #     eval_loss = self._policy.eval_loss(states, actions, next_states)
            #     print("Test loss: " + str(eval_loss))

        logger.record_tabular('TrainingLossStart', losses[0])
        logger.record_tabular('TrainingLossFinal', losses[-1])

        timeit.stop('train policy')
    def _train_policy(self, dataset):
        """
        Train the model-based policy

        implementation details:
            (a) Train for self._training_epochs number of epochs
            (b) The dataset.random_iterator(...)  method will iterate through the dataset once in a random order
            (c) Use self._training_batch_size for iterating through the dataset
            (d) Keep track of the loss values by appending them to the losses array
        """
        timeit.start('train policy')

        losses = []
        # Added: Training policy iteration
        for epoch_num in range(self._training_epochs):
            logger.info('Epoch %i' % (epoch_num + 1))
            for batch_num, (states, actions, next_states, _, _) in enumerate(
                    dataset.random_iterator(self._training_batch_size)):
                loss = self._policy.train_step(states, actions, next_states)
                losses.append(loss)
            logger.info('\tLoss: {:.3f}'.format(losses[-1]))

        logger.record_tabular('TrainingLossStart', losses[0])
        logger.record_tabular('TrainingLossFinal', losses[-1])

        timeit.stop('train policy')
    def _train_policy(self, dataset):
        """
        Train the model-based policy

        """
        timeit.start('train policy')

        losses = []
        ### PROBLEM 1
        ### YOUR CODE HERE
        # (a) Train for self._training_epochs number of epochs
        for ep in range(self._training_epochs):
            # (b) The dataset.random_iterator(...)  method will iterate through the dataset once in a random order,
            #       it will return a dataset with random sequence
            # (c) Use self._training_batch_size for iterating through the dataset
            _iter = dataset.random_iterator(self._training_batch_size)
            for states, actions, next_states, _, _ in _iter:
                loss = self._policy.train_step(states, actions, next_states)
                # dataset_size = dones[0]
                # batch_start_index = np.array(0,dataset_size, self._training_batch_size)
                # batch_end_index = batch_start_index + self._training_batch_size

                # for (start, end) in zip(batch_start_index,batch_end_index):
                #     loss = self._policy.train_step(states[start:end], actions[start:end], next_states[start:end]
                losses.append(loss)

        logger.record_tabular('TrainingLossStart', losses[0])
        logger.record_tabular('TrainingLossFinal', losses[-1])

        timeit.stop('train policy')
예제 #8
0
    def train_policy(self, dataset):
        """
        trains the model-based policy
        """
        timeit.start('train policy')

        losses = []
        for _ in range(self.training_epochs):
            loss_total = 0.0
            num_data = 0

            d = dataset.random_iterator(self.training_batch_size)
            for states, actions, next_states, _, _ in d:
                loss = self.policy.train_step(states, actions, next_states)
                loss_total += loss
                num_data += 1

            losses.append(loss / num_data)
        # plt.plot(losses)
        # plt.show()
        logger.record_tabular('TrainingLossStart', losses[0])
        logger.record_tabular('TrainingLossFinal', losses[-1])

        timeit.stop('train policy')
        return
예제 #9
0
    def _train_policy(self, dataset):
        """
        Train the model-based policy

        implementation details:
            (a) Train for self._training_epochs number of epochs
            (b) The dataset.random_iterator(...)  method will iterate through the dataset once in a random order
            (c) Use self._training_batch_size for iterating through the dataset
            (d) Keep track of the loss values by appending them to the losses array
        """
        timeit.start('train policy')

        losses = []
        ### PROBLEM 1
        ### YOUR CODE HERE
        # raise NotImplementedError
        for ep in range(self._training_epochs):
            data_generator = dataset.random_iterator(self._training_batch_size)
            for i, batch_x in enumerate(data_generator):
                states, actions, next_states = batch_x[:3]
                loss = self._policy.train_step(states, actions, next_states)
                losses.append(loss)
                # logger.debug('%s/Epoch, it %s: Loss: %s' % (ep, i, loss))

        logger.info('loss start to end: %s, %s' % (losses[0], losses[-1]))
        logger.record_tabular('TrainingLossStart', losses[0])
        logger.record_tabular('TrainingLossFinal', losses[-1])

        timeit.stop('train policy')
예제 #10
0
    def run_q3(self):
        """
        start with random dataset, train policy on dataset, gather rollouts with policy, add to 
        dataset, repeat
        """
        dataset = self.random_dataset

        itr = -1
        logger.info('Iteration {0}'.format(itr))
        logger.record_tabular('Itr', itr)
        self.log(dataset)

        for itr in range(self.num_onpolicy_iters + 1):
            logger.info('Iteration {0}'.format(itr))
            logger.record_tabular('Itr', itr)

            logger.info('Training policy...')
            self.train_policy(self.random_dataset)

            logger.info('Gathering rollouts...')
            new_dataset = self.gather_rollouts(self.policy,
                                               self.num_onpolicy_rollouts)

            logger.info('Appending dataset...')
            dataset.append(new_dataset)

            self.log(new_dataset)
예제 #11
0
    def run_q3(self):
        """
        Starting with the random dataset, train the policy on the dataset, gather rollouts with the policy,
        append the new rollouts to the existing dataset, and repeat
        """
        dataset = self._random_dataset

        itr = -1
        logger.info('Iteration {0}'.format(itr))
        logger.record_tabular('Itr', itr)
        self._log(dataset)

        for itr in range(self._num_onpolicy_iters + 1):
            logger.info('Iteration {0}'.format(itr))
            logger.record_tabular('Itr', itr)

            ### PROBLEM 3
            ### YOUR CODE HERE
            logger.info('Training policy...')
            self._train_policy(dataset)

            ### PROBLEM 3
            ### YOUR CODE HERE
            logger.info('Gathering rollouts...')
            new_dataset = self._gather_rollouts(self._policy,
                                                self._num_onpolicy_rollouts)

            ### PROBLEM 3
            ### YOUR CODE HERE
            logger.info('Appending dataset...')
            dataset.append(new_dataset)

            self._log(new_dataset)
예제 #12
0
    def _train_policy(self, dataset):

        # timing for policy training
        timeit.start('train policy')

        losses = []

        # loop for self._training_epochs
        for _ in range(self._training_epochs):

            # iterate over dataset
            for states, actions, next_states, _, _ in \
                    dataset.random_iterator(self._training_batch_size):

                # compute loss
                loss = self._policy.train_step(states, actions, next_states)
                losses.append(loss)

        # perform logging
        logger.record_tabular('TrainingLossStart', losses[0])
        logger.record_tabular('TrainingLossFinal', losses[-1])
        timeit.stop('train policy')
    def _train_policy(self, dataset):
        """
        Train the model-based policy

        implementation details:
            (a) Train for self._training_epochs number of epochs
            (b) The dataset.random_iterator(...)  method will iterate through the dataset once in a random order
            (c) Use self._training_batch_size for iterating through the dataset
            (d) Keep track of the loss values by appending them to the losses array
        """
        timeit.start('train policy')

        ### PROBLEM 1
        ### YOUR CODE HERE
        # raise NotImplementedError
        print('batch size is', self._training_batch_size)
        print('epoch size is', self._training_epochs)
        # Iterate dataset once in an epoch
        losses = []
        for epoch in range(self._training_epochs):
            t_loss = 0

            # Alternative way:
            # random_data = dataset.random_iterator(self._training_batch_size)
            # for states, actions, next_states, _, _ in random_data:
            # Enumerate is to add index!
            for r_num, (states, actions, next_states, _, _) in enumerate(
                    dataset.random_iterator(self._training_batch_size)):
                loss = self._policy.train_step(states, actions, next_states)
                t_loss += loss
            t_loss = t_loss / (r_num + 1)
            losses.append(t_loss)

        # TO-DO: why not print out in q1?
        logger.record_tabular('TrainingLossStart', losses[0])
        logger.record_tabular('TrainingLossFinal', losses[-1])

        timeit.stop('train policy')
예제 #14
0
    def log(self):
        end_idxs = np.nonzero(self._dones)[0] + 1

        returns = []

        start_idx = 0
        for end_idx in end_idxs:
            rewards = self._rewards[start_idx:end_idx]
            returns.append(np.sum(rewards))

            start_idx = end_idx

        logger.record_tabular('ReturnAvg', np.mean(returns))
        logger.record_tabular('ReturnStd', np.std(returns))
        logger.record_tabular('ReturnMin', np.min(returns))
        logger.record_tabular('ReturnMax', np.max(returns))
예제 #15
0
    def train(self):
        dataset = self._random_dataset

        itr = -1
        logger.info('Iteration {0}'.format(itr))
        logger.record_tabular('Itr', itr)
        self._log(dataset)

        for itr in range(self._num_onpolicy_iters + 1):
            logger.info('Iteration {0}'.format(itr))
            logger.record_tabular('Itr', itr)

            logger.info('Training policy...')
            self._train_policy(dataset)

            logger.info('Gathering rollouts...')
            new_dataset = self._gather_rollouts(self._policy,
                                                self._num_onpolicy_rollouts)

            logger.info('Appending dataset...')
            dataset.append(new_dataset)

            self._log(new_dataset)
예제 #16
0
    def log(self, write_table_header=False):
        logger.log("Logging data in directory: %s" % logger.get_snapshot_dir())

        logger.record_tabular("Episode", self.num_episodes)

        logger.record_tabular("Accumulated Training Steps",
                              self.num_train_interactions)

        logger.record_tabular("Policy Error", self.logging_policies_error)
        logger.record_tabular("Q-Value Error", self.logging_qvalues_error)
        logger.record_tabular("V-Value Error", self.logging_vvalues_error)

        logger.record_tabular("Alpha", np_ify(self.log_alpha.exp()).item())
        logger.record_tabular("Entropy",
                              np_ify(self.logging_entropy.mean(dim=(0, ))))

        act_mean = np_ify(self.logging_mean.mean(dim=(0, )))
        act_std = np_ify(self.logging_std.mean(dim=(0, )))
        for aa in range(self.action_dim):
            logger.record_tabular("Mean Action %02d" % aa, act_mean[aa])
            logger.record_tabular("Std Action %02d" % aa, act_std[aa])

        # Evaluation Stats to plot
        logger.record_tabular("Test Rewards Mean",
                              np_ify(self.logging_eval_rewards.mean()))
        logger.record_tabular("Test Rewards Std",
                              np_ify(self.logging_eval_rewards.std()))
        logger.record_tabular("Test Returns Mean",
                              np_ify(self.logging_eval_returns.mean()))
        logger.record_tabular("Test Returns Std",
                              np_ify(self.logging_eval_returns.std()))

        # Add the previous times to the logger
        times_itrs = gt.get_times().stamps.itrs
        train_time = times_itrs.get('train', [0])[-1]
        sample_time = times_itrs.get('sample', [0])[-1]
        eval_time = times_itrs.get('eval', [0])[-1]
        epoch_time = train_time + sample_time + eval_time
        total_time = gt.get_times().total
        logger.record_tabular('Train Time (s)', train_time)
        logger.record_tabular('(Previous) Eval Time (s)', eval_time)
        logger.record_tabular('Sample Time (s)', sample_time)
        logger.record_tabular('Epoch Time (s)', epoch_time)
        logger.record_tabular('Total Train Time (s)', total_time)

        # Dump the logger data
        logger.dump_tabular(with_prefix=False,
                            with_timestamp=False,
                            write_header=write_table_header)
        # Save pytorch models
        self.save_training_state()
        logger.log("----")
예제 #17
0
    # Load buffer
    replay_buffer = utils.ReplayBuffer()
    if args.env_name == 'Multigoal-v0':
        replay_buffer.load_point_mass(buffer_name,
                                      bootstrap_dim=4,
                                      dist_cost_coeff=0.01)
    else:
        replay_buffer.load(buffer_name, bootstrap_dim=4)

    evaluations = []

    episode_num = 0
    done = True

    training_iters = 0
    while training_iters < args.max_timesteps:
        pol_vals = policy.train(replay_buffer, iterations=int(args.eval_freq))

        ret_eval, var_ret, median_ret = evaluate_policy(policy)
        evaluations.append(ret_eval)
        np.save("./results/" + file_name, evaluations)

        training_iters += args.eval_freq
        print("Training iterations: " + str(training_iters))
        logger.record_tabular('Training Epochs',
                              int(training_iters // int(args.eval_freq)))
        logger.record_tabular('AverageReturn', ret_eval)
        logger.record_tabular('VarianceReturn', var_ret)
        logger.record_tabular('MedianReturn', median_ret)
        logger.dump_tabular()
예제 #18
0
            mmd_sigma=args.mmd_sigma,
            lagrange_thresh=args.lagrange_thresh,
            use_kl=(True if args.distance_type == "KL" else False),
            use_ensemble=(False
                          if args.use_ensemble_variance == "False" else True),
            kernel_type=args.kernel_type)

    evaluations = []

    episode_num = 0
    done = True

    training_iters = 0
    while training_iters < args.max_timesteps:
        pol_vals = policy.train(replay_buffer, iterations=int(args.eval_freq))

        # NOTE : @dhruvramani - commeneted this for colabs
        # ret_eval, var_ret, median_ret = evaluate_policy(policy)
        # evaluations.append(ret_eval)
        # np.save("./results/" + file_name, evaluations)

        training_iters += args.eval_freq
        print("Training iterations: " + str(training_iters))
        logger.record_tabular('Training Epochs',
                              int(training_iters // int(args.eval_freq)))
        # logger.record_tabular('AverageReturn', ret_eval)
        # logger.record_tabular('VarianceReturn', var_ret)
        # logger.record_tabular('MedianReturn', median_ret)
        logger.dump_tabular()
        print("Iter done")
예제 #19
0
    def train(self, replay_buffer, iterations, batch_size=100):

        for it in range(iterations):
            # Sample replay buffer / batch
            state, action, next_state, reward, not_done = replay_buffer.sample(
                batch_size)

            # Critic Training
            with torch.no_grad():
                _, _, next_action = self.actor_target(next_state,
                                                      self.vae.decode)
                target_Q1, target_Q2 = self.critic_target(
                    next_state, next_action)
                target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (
                    1 - self.lmbda) * torch.max(target_Q1, target_Q2)
                target_Q = reward + not_done * self.discount * target_Q

            current_Q1, current_Q2 = self.critic(state, action)
            critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
                current_Q2, target_Q)

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            # Actor Training
            latent_actions, mid_actions, actions = self.actor(
                state, self.vae.decode)
            actor_loss = -self.critic.q1(state, actions).mean()
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Update Target Networks
            for param, target_param in zip(self.critic.parameters(),
                                           self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(),
                                           self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

        # Logging
        logger.record_dict(
            create_stats_ordered_dict(
                'Q_target',
                target_Q.cpu().data.numpy(),
            ))
        logger.record_tabular('Actor Loss', actor_loss.cpu().data.numpy())
        logger.record_tabular('Critic Loss', critic_loss.cpu().data.numpy())
        logger.record_dict(
            create_stats_ordered_dict('Actions',
                                      actions.cpu().data.numpy()))
        logger.record_dict(
            create_stats_ordered_dict('Mid Actions',
                                      mid_actions.cpu().data.numpy()))
        logger.record_dict(
            create_stats_ordered_dict('Latent Actions',
                                      latent_actions.cpu().data.numpy()))
        logger.record_dict(
            create_stats_ordered_dict(
                'Latent Actions Norm',
                torch.norm(latent_actions, dim=1).cpu().data.numpy()))
        logger.record_dict(
            create_stats_ordered_dict(
                'Perturbation Norm',
                torch.norm(actions - mid_actions, dim=1).cpu().data.numpy()))
        logger.record_dict(
            create_stats_ordered_dict('Current_Q',
                                      current_Q1.cpu().data.numpy()))
        assert (np.abs(np.mean(target_Q.cpu().data.numpy())) < 1e6)
예제 #20
0
    def train(self,
              replay_buffer,
              iterations,
              batch_size=100,
              discount=0.99,
              tau=0.005,
              policy_noise=0.2,
              noise_clip=0.5,
              policy_freq=2):

        for it in range(iterations):
            # Sample replay buffer
            x, y, u, r, d, mask = replay_buffer.sample(batch_size)
            state = torch.FloatTensor(x).to(device)
            action = torch.FloatTensor(u).to(device)
            next_state = torch.FloatTensor(y).to(device)
            done = torch.FloatTensor(1 - d).to(device)
            reward = torch.FloatTensor(r).to(device)

            recon, mean, std = self.vae(state, action)
            recon_loss = F.mse_loss(recon, action)
            KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) -
                              std.pow(2)).mean()
            vae_loss = recon_loss + 0.5 * KL_loss

            self.vae_optimizer.zero_grad()
            vae_loss.backward()
            self.vae_optimizer.step()

            # Select action according to policy and add clipped noise
            noise = torch.FloatTensor(u).data.normal_(0,
                                                      policy_noise).to(device)
            noise = noise.clamp(-noise_clip, noise_clip)
            next_action = (self.actor_target(next_state) + noise).clamp(
                -self.max_action, self.max_action)

            # Compute the target Q value
            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + (done * discount * target_Q).detach()

            # Get current Q estimates
            current_Q1, current_Q2 = self.critic(state, action)
            std_loss = torch.std(
                torch.cat([current_Q1.unsqueeze(0),
                           current_Q2.unsqueeze(0)], 0),
                dim=0,
                unbiased=False,
            )

            # Compute critic loss
            critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
                current_Q2, target_Q)

            # Optimize the critic
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            # Delayed policy updates
            if it % policy_freq == 0:

                # Compute actor loss
                actor_loss = -self.critic.Q1(state, self.actor(state)).mean()

                sampled_actions = self.vae.decode(state)
                actor_actions = self.actor(state)
                action_divergence = ((sampled_actions -
                                      actor_actions)**2).sum(-1)

                # Optimize the actor
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()

                # Update the frozen target models
                for param, target_param in zip(
                        self.critic.parameters(),
                        self.critic_target.parameters()):
                    target_param.data.copy_(tau * param.data +
                                            (1 - tau) * target_param.data)

                for param, target_param in zip(self.actor.parameters(),
                                               self.actor_target.parameters()):
                    target_param.data.copy_(tau * param.data +
                                            (1 - tau) * target_param.data)

        logger.record_dict(
            create_stats_ordered_dict(
                'Q_target',
                target_Q.cpu().data.numpy(),
            ))
        logger.record_tabular('Actor Loss', actor_loss.cpu().data.numpy())
        logger.record_tabular('Critic Loss', critic_loss.cpu().data.numpy())
        logger.record_tabular('Std Loss', std_loss.cpu().data.numpy().mean())
        logger.record_dict(
            create_stats_ordered_dict('Action_Divergence',
                                      action_divergence.cpu().data.numpy()))
예제 #21
0
파일: main.py 프로젝트: MahdiehNejati/PLAS
        policy = algos.LatentPerturbation(vae_trainer.vae, state_dim,
                                          action_dim, latent_dim, max_action,
                                          **vars(args))

    evaluations = []
    episode_num = 0
    done = True
    training_iters = 0
    while training_iters < args.max_timesteps:
        # Train
        pol_vals = policy.train(replay_buffer,
                                iterations=int(args.eval_freq),
                                batch_size=args.batch_size)
        training_iters += args.eval_freq
        print("Training iterations: " + str(training_iters))
        logger.record_tabular('Training Epochs',
                              int(training_iters // int(args.eval_freq)))

        # Save Model
        if training_iters % args.save_freq == 0 and args.save_model:
            policy.save('model_' + str(training_iters), folder_name)

        # Eval
        info = eval_policy(policy, env)
        evaluations.append(info['AverageReturn'])
        np.save(os.path.join(folder_name, 'eval'), evaluations)
        eval_dict = eval_critic(policy.select_action, policy.critic.q1, env)
        for k, v in eval_dict.items():
            logger.record_tabular('Eval_critic/' + k, v)

        for k, v in info.items():
            logger.record_tabular(k, v)