Exemplo n.º 1
0
    def __init__(self, params):
        self.params = params
        self.kappa = self.params['kappa']

        self.dtype = self.params['dtype']
        self.device = self.params['device']

        self.models = []
        self.priors = []
        self.optims = []

        for i in range(self.params['ens_size']):
            model = MLP(self.params['model_params']).to(device=self.device)
            self.models.append(model)
            self.optims.append(
                torch.optim.Adam(model.parameters(),
                                 lr=self.params['lr'],
                                 weight_decay=self.params['reg']))

            prior = MLP(self.params['model_params']).to(device=self.device)
            prior.eval()
            self.priors.append(prior)
Exemplo n.º 2
0
class BCAgent(POLOAgent):
    """
    An agent extending upon POLO that uses behavior cloning on the planner
    predicted actions as a prior to MPC.
    """
    def __init__(self, params):
        super(BCAgent, self).__init__(params)

        # Initialize policy network
        pol_params = self.params['p-bc']['pol_params']
        pol_params['input_size'] = self.N
        pol_params['output_size'] = self.M
        if 'final_activation' not in pol_params:
            pol_params['final_activation'] = torch.tanh

        self.pol = MLP(pol_params)

        # Create policy optimizer
        ppar = self.params['p-bc']['pol_optim']
        self.pol_optim = torch.optim.Adam(self.pol.parameters(),
                                          lr=ppar['lr'],
                                          weight_decay=ppar['reg'])

        # Use a replay buffer that will save planner actions
        self.pol_buf = ReplayBuffer(self.N, self.M,
                                    self.params['p-bc']['buf_size'])

        # Logging (store cum_rew, cum_emp_rew)
        self.hist['pols'] = np.zeros((self.T, 2))

        self.has_pol = True

        self.pol_cache = ()

    def get_action(self):
        """
        BCAgent generates a planned trajectory using the behavior-cloned policy
        and then optimizes it via MPC.
        """
        self.pol.eval()

        # Run a rollout using the policy starting from the current state
        infos = self.get_traj_info()

        self.hist['pols'][self.time] = infos[3:5]
        self.pol_cache = (infos[0], infos[2])

        self.prior_actions = infos[1]

        # Generate trajectory via MPC with the prior actions as a prior
        action = super(BCAgent, self).get_action(prior=self.prior_actions)

        # Add final planning trajectory to BC buffer
        fin_states, fin_rews = self.cache[2], self.cache[3]
        fin_states = np.concatenate(([self.prev_obs], fin_states[1:]))
        pb_pct = self.params['p-bc']['pb_pct']
        pb_len = int(pb_pct * fin_states.shape[0])
        for t in range(pb_len):
            self.pol_buf.update(fin_states[t], fin_states[t + 1], fin_rews[t],
                                self.planned_actions[t], False)

        return action

    def do_updates(self):
        """
        Learn from the saved buffer of planned actions.
        """
        super(BCAgent, self).do_updates()

        if self.time % self.params['p-bc']['update_freq'] == 0:
            self.update_pol()

    def update_pol(self):
        """
        Update the policy via BC on the planner actions.
        """
        self.pol.train()

        params = self.params['p-bc']

        # Generate batches for training
        size = min(self.pol_buf.size, self.pol_buf.total_in)
        num_inds = params['batch_size'] * params['grad_steps']
        inds = np.random.randint(0, size, size=num_inds)

        states = self.pol_buf.buffer['s'][inds]
        acts = self.pol_buf.buffer['a'][inds]

        states = torch.tensor(states, dtype=self.dtype)
        actions = torch.tensor(acts, dtype=self.dtype)

        for i in range(params['grad_steps']):
            bi, ei = i * params['batch_size'], (i + 1) * params['batch_size']

            # Train based on L2 distance between actions and predictions
            preds = self.pol.forward(states[bi:ei])
            preds = torch.squeeze(preds, dim=-1)
            targets = torch.squeeze(actions[bi:ei], dim=-1)

            loss = torch.nn.functional.mse_loss(preds, targets)

            self.pol_optim.zero_grad()
            loss.backward()
            self.pol_optim.step()

    def get_traj_info(self):
        """
        Run the policy for a full trajectory and return details about the
        trajectory.
        """
        env_state = self.env.sim.get_state() if self.mujoco else None

        infos = traj.eval_traj(copy.deepcopy(self.env),
                               env_state,
                               self.prev_obs,
                               mujoco=self.mujoco,
                               perturb=self.perturb,
                               H=self.H,
                               gamma=self.gamma,
                               act_mode='deter',
                               pt=(self.pol, 0),
                               terminal=self.val_ens,
                               tvel=self.tvel)

        return infos

    def print_logs(self):
        """
        BC-specific logging information.
        """
        bi, ei = super(BCAgent, self).print_logs()

        self.print('BC metrics', mode='head')

        self.print('policy traj rew', self.hist['pols'][self.time - 1][0])
        self.print('policy traj emp rew', self.hist['pols'][self.time - 1][1])

        return bi, ei

    def test_policy(self):
        """
        Run the BC action selection mechanism.
        """
        env = copy.deepcopy(self.env)
        obs = env.reset()

        if self.tvel is not None:
            env.set_target_vel(self.tvel)
            obs = env._get_obs()

        env_state = env.sim.get_state() if self.mujoco else None
        infos = traj.eval_traj(env,
                               env_state,
                               obs,
                               mujoco=self.mujoco,
                               perturb=self.perturb,
                               H=self.eval_len,
                               gamma=1,
                               act_mode='deter',
                               pt=(self.pol, 0),
                               tvel=self.tvel)

        self.hist['pol_test'][self.time] = infos[3]
Exemplo n.º 3
0
class VPGAgent(Agent):
    """
    An agent running online policy gradient. Calling VPGAgent itself uses
    REINFORCE, but can be subclassed for other policy gradient class algorithms.
    """
    def __init__(self, params):
        super(VPGAgent, self).__init__(params)
        self.H = self.params['pg']['H']
        self.lam = self.params['pg']['lam']

        # Initialize policy network
        pol_params = self.params['pg']['pol_params']
        pol_params['input_size'] = self.N
        pol_params['output_size'] = self.M
        if 'final_activation' not in pol_params:
            pol_params['final_activation'] = torch.tanh

        self.pol = MLP(pol_params)

        # Std's are not dependent on state
        init_log_std = -0.8 * torch.ones(self.M)  # ~0.45
        self.log_std = torch.nn.Parameter(init_log_std, requires_grad=True)

        # Create policy optimizer
        ppar = self.params['pg']['pol_optim']
        self.pol_params = list(self.pol.parameters()) + [self.log_std]
        self.pol_optim = torch.optim.Adam(self.pol_params,
                                          lr=ppar['lr'],
                                          weight_decay=ppar['reg'])

        # Create value function and optimizer
        val_params = self.params['pg']['val_params']
        val_params['input_size'] = self.N
        val_params['output_size'] = 1

        self.val = MLP(val_params)

        vpar = self.params['pg']['val_optim']
        self.val_optim = torch.optim.Adam(self.val.parameters(),
                                          lr=vpar['lr'],
                                          weight_decay=vpar['reg'])

        # Logging
        self.hist['ent'] = np.zeros(self.T)

    def get_dist(self, s):
        """
        Create a pytorch normal distribution from
        the policy network for state s.
        """
        s = torch.tensor(s, dtype=self.dtype)
        mu = self.pol.forward(s)
        std = self.log_std.exp()

        return torch.distributions.Normal(mu, std)

    def get_ent(self):
        """
        Return the current entropy (multivariate Gaussian).
        """
        std = self.log_std.exp()
        tpe = 2 * np.pi * np.e
        return .5 * torch.log(tpe * torch.prod(std))

    def get_action(self):
        """
        Gets action by running policy.
        """
        self.pol.eval()

        if self.params['pg']['run_deterministic']:
            x = torch.tensor(self.prev_obs, dtype=self.dtype)
            act = self.pol.forward(x).detach().cpu().numpy()
        else:
            act = sample_pol(self.pol, self.log_std, self.prev_obs)

        act = np.clip(act, self.params['env']['min_act'],
                      self.params['env']['max_act'])

        self.hist['ent'][self.time] = self.get_ent().detach().cpu().numpy()

        return act

    def do_updates(self):
        """
        Performs actor and critic updates.
        """
        if self.time % self.params['pg']['update_every'] == 0 or self.time == 1:
            plan_time = 0
            H, num_rollouts = self.H, self.params['pg']['num_rollouts']
            for i in range(self.params['pg']['num_iter']):
                # Sample rollouts using ground truth model
                check = time.time()
                rollouts = self.sample_rollouts(H, num_rollouts)
                plan_time += time.time() - check

                # Performs value updates alongside advantage calculation
                rews = self.update_pol(rollouts)

            # Time spent generating rollouts should be considered planning time
            self.hist['plan_time'][self.time - 1] += plan_time
            self.hist['update_time'][self.time - 1] -= plan_time

    def sample_rollouts(self, H, num_rollouts):
        """
        Use traj module to sample rollouts using the policy.
        """
        env_state = self.env.sim.get_state() if self.mujoco else None

        self.pol.eval()
        rollouts = traj.generate_trajectories(
            num_rollouts,
            self.env,
            env_state,
            self.prev_obs,
            mujoco=self.mujoco,
            perturb=self.perturb,
            H=self.H,
            gamma=self.gamma,
            act_mode='gauss',
            pt=(sample_pol, self.pol, self.log_std),
            terminal=None,
            tvel=self.tvel,
            num_cpu=self.params['pg']['num_cpu'])

        return rollouts

    def update_val(self, obs, targets):
        """
        Update value function with MSE loss.
        """
        preds = self.val.forward(obs)
        preds = torch.squeeze(preds, dim=-1)

        loss = torch.nn.functional.mse_loss(targets, preds)

        self.val_optim.zero_grad()
        loss.backward(retain_graph=True)
        self.val_optim.step()

        return loss.item()

    def calc_advs(self, obs, rews, update_vals=True):
        """
        Calculate advantages for use of updating the policy (and updating value
        function). Can either use rewards-to-go or GAE.
        """
        num_rollouts, H = obs.shape[:2]

        self.val.eval()

        if not self.params['pg']['use_gae']:
            # Calculate terminal values
            fin_obs = obs[:, -1]
            fin_vals = self.val.forward(fin_obs)
            fin_vals = torch.squeeze(fin_vals, dim=-1)

            # Calculate rewards-to-go
            rtg = torch.zeros((num_rollouts, H))
            for k in reversed(range(H)):
                if k < H - 1:
                    rtg[:, k] += self.gamma * rtg[:, k + 1]
                else:
                    rtg[:, k] += self.gamma * fin_vals
                rtg[:, k] += rews[:, k]

            if update_vals:
                self.val.train()
                self.update_val(obs, rtg)

            # Normalize advantages for policy gradient
            for k in range(H):
                rtg[:, k] -= torch.mean(rtg[:, k])

            return rtg

        # Generalized Advantage Estimation (GAE)
        prev_obs = torch.tensor(self.prev_obs, dtype=self.dtype)
        orig_val = self.val.forward(prev_obs)
        vals = torch.squeeze(self.val.forward(obs), dim=-1)

        deltas = torch.zeros(rews.shape)
        advs = torch.zeros((num_rollouts, H))

        lg = self.lam * self.gamma
        for k in reversed(range(H)):
            prev_vals = vals[:, k - 1] if k > 0 else orig_val
            deltas[:, k] = self.gamma * vals[:, k] + rews[:, k] - prev_vals

            if k == H - 1:
                advs[:, k] = deltas[:, k]
            else:
                advs[:, k] = lg * advs[:, k + 1] + deltas[:, k]

        advs = advs.detach()

        # Optionally, also update the value functions
        if update_vals:
            self.val.train()

            # It is reasonable to train on advs or deltas
            dvals = advs

            # Have to perform trick to match deltas with prev vals
            fvals = torch.stack([orig_val for _ in range(vals.shape[0])],
                                dim=0)
            rets = torch.cat(
                [fvals + dvals[:, :1], vals[:, :-1] + dvals[:, 1:]], dim=-1)
            fobs = torch.unsqueeze(prev_obs, dim=0)
            fobs = torch.stack([fobs for _ in range(vals.shape[0])], dim=0)
            obs = torch.cat([fobs, obs[:, :-1]], dim=1)

            self.update_val(obs, rets)

        # Normalize advantages for policy gradient
        advs -= torch.mean(advs)
        advs /= 1e-3 + torch.std(advs)

        return advs

    def get_pol_loss(self, logprob, advs, orig_logprob=None):
        """
        For REINFORCE, the policy loss is thelogprobs times the advatanges. It
        is important that the logprobs carry the gradient so that we can
        backpropagate through them in the policy update.
        """
        return torch.mean(logprob * advs)

    def get_logprob(self, pol, log_std, obs, acts):
        """
        Get log probabilities for the actions, keeping the gradients.
        """
        num_rollouts, H = obs.shape[0:2]

        pol.train()

        dist = self.get_dist(obs)
        logprob = dist.log_prob(acts).sum(-1)

        return logprob

    def update_pol(self, rollouts, orig_logprob=None):
        """
        Update the policy on the on-policy rollouts.
        """
        H = rollouts[0][0].shape[0]

        self.pol.train()

        obs = np.zeros((len(rollouts), self.H, self.N))
        acts = np.zeros((len(rollouts), self.H, self.M))
        rews = torch.zeros((len(rollouts), self.H))
        for i in range(len(rollouts)):
            for k in range(self.H):
                obs[i, k] = rollouts[i][0][k]
                acts[i, k] = rollouts[i][1][k]
                rews[i, k] = rollouts[i][2][k]

        obs = torch.tensor(obs, dtype=self.dtype)
        acts = torch.tensor(acts, dtype=self.dtype)

        # Perform updates for multiple steps on the value function
        if self.params['pg']['use_gae']:
            for _ in range(self.params['pg']['val_steps']):
                advs = self.calc_advs(obs, rews, update_vals=True)
        else:
            advs = self.calc_advs(obs, rews, update_vals=False)

        # Perform updates for multiple epochs on the policy
        bsize = self.params['pg']['batch_size']
        for _ in range(self.params['pg']['pol_steps']):
            inds = np.random.permutation(len(rollouts))

            binds = inds[:bsize]
            bobs, bacts = obs[binds], acts[binds]
            brews, badvs = rews[binds], advs[binds]

            if orig_logprob is not None:
                bprobs = orig_logprob[binds]
            else:
                bprobs = None

            # Get a logprob that has gradients
            logprob = self.get_logprob(self.pol, self.log_std, bobs, bacts)
            if not self.continue_updates(logprob, bprobs):
                break

            # Compute policy loss (i.e. gradient ascent)
            J = -self.get_pol_loss(logprob, badvs, orig_logprob=bprobs)

            # Apply entropy bonus
            ent_coef = self.params['pg']['pol_optim']['ent_temp']
            if ent_coef != 0:
                J -= ent_coef * self.get_ent()

            self.pol_optim.zero_grad()
            torch.nn.utils.clip_grad_norm_(self.pol.parameters(),
                                           self.params['pg']['grad_clip'])
            J.backward()
            self.pol_optim.step()

            # Clamp stds to be within set bounds
            log_min = np.log(self.params['pg']['min_std'])
            log_min = torch.tensor(log_min, dtype=self.dtype)
            log_max = np.log(self.params['pg']['max_std'])
            log_max = torch.tensor(log_max, dtype=self.dtype)
            self.log_std.data = torch.clamp(self.log_std.data, log_min,
                                            log_max)

        return rews

    def continue_updates(self, logprob, orig_logprob=None):
        """
        Method for whether or not to continue updates.
        """
        return True

    def print_logs(self):
        """
        Policy gradient-specific logging information.
        """
        bi, ei = super(VPGAgent, self).print_logs()

        self.print('policy gradient metrics', mode='head')

        self.print('entropy avg', np.mean(self.hist['ent'][bi:ei]))
        self.print('sigma avg',
                   np.mean(torch.exp(self.log_std).detach().cpu().numpy()))

        return bi, ei
class MFEC:
    def __init__(self, env, args, device='cpu'):
        """
        Instantiate an MFEC Agent
        ----------
        env: gym.Env
            gym environment to train on
        args: args class from argparser
            args are from from train.py: see train.py for help with each arg
        device: string
            'cpu' or 'cuda:0' depending on use_cuda flag from train.py
        """
        self.environment_type = args.environment_type
        self.env = env
        self.actions = range(self.env.action_space.n)
        self.frames_to_stack = args.frames_to_stack
        self.Q_train_algo = args.Q_train_algo
        self.use_Q_max = args.use_Q_max
        self.force_knn = args.force_knn
        self.weight_neighbors = args.weight_neighbors
        self.delta = args.delta
        self.device = device
        self.rs = np.random.RandomState(args.seed)

        # Hyperparameters
        self.epsilon = args.initial_epsilon
        self.final_epsilon = args.final_epsilon
        self.epsilon_decay = args.epsilon_decay
        self.gamma = args.gamma
        self.lr = args.lr
        self.q_lr = args.q_lr

        # Autoencoder for state embedding network
        self.vae_batch_size = args.vae_batch_size  # batch size for training VAE
        self.vae_epochs = args.vae_epochs  # number of epochs to run VAE
        self.embedding_type = args.embedding_type
        self.SR_embedding_type = args.SR_embedding_type
        self.embedding_size = args.embedding_size
        self.in_height = args.in_height
        self.in_width = args.in_width

        if self.embedding_type == 'VAE':
            self.vae_train_frames = args.vae_train_frames
            self.vae_loss = VAELoss()
            self.vae_print_every = args.vae_print_every
            self.load_vae_from = args.load_vae_from
            self.vae_weights_file = args.vae_weights_file
            self.vae = VAE(self.frames_to_stack, self.embedding_size,
                           self.in_height, self.in_width)
            self.vae = self.vae.to(self.device)
            self.optimizer = get_optimizer(args.optimizer,
                                           self.vae.parameters(), self.lr)
        elif self.embedding_type == 'random':
            self.projection = self.rs.randn(
                self.embedding_size, self.in_height * self.in_width *
                self.frames_to_stack).astype(np.float32)
        elif self.embedding_type == 'SR':
            self.SR_train_algo = args.SR_train_algo
            self.SR_gamma = args.SR_gamma
            self.SR_epochs = args.SR_epochs
            self.SR_batch_size = args.SR_batch_size
            self.n_hidden = args.n_hidden
            self.SR_train_frames = args.SR_train_frames
            self.SR_filename = args.SR_filename
            if self.SR_embedding_type == 'random':
                self.projection = np.random.randn(
                    self.embedding_size,
                    self.in_height * self.in_width).astype(np.float32)
                if self.SR_train_algo == 'TD':
                    self.mlp = MLP(self.embedding_size, self.n_hidden)
                    self.mlp = self.mlp.to(self.device)
                    self.loss_fn = nn.MSELoss(reduction='mean')
                    params = self.mlp.parameters()
                    self.optimizer = get_optimizer(args.optimizer, params,
                                                   self.lr)

        # QEC
        self.max_memory = args.max_memory
        self.num_neighbors = args.num_neighbors
        self.qec = QEC(self.actions, self.max_memory, self.num_neighbors,
                       self.use_Q_max, self.force_knn, self.weight_neighbors,
                       self.delta, self.q_lr)

        #self.state = np.empty(self.embedding_size, self.projection.dtype)
        #self.action = int
        self.memory = []
        self.print_every = args.print_every
        self.episodes = 0

    def choose_action(self, values):
        """
        Choose epsilon-greedy policy according to Q-estimates
        """
        # Exploration
        if self.rs.random_sample() < self.epsilon:
            self.action = self.rs.choice(self.actions)

        # Exploitation
        else:
            best_actions = np.argwhere(values == np.max(values)).flatten()
            self.action = self.rs.choice(best_actions)

        return self.action

    def TD_update(self, prev_embedding, prev_action, reward, values, time):
        # On-policy value estimate of current state (epsiloln-greedy)
        # Expected Sarsa
        v_t = (1 -
               self.epsilon) * np.max(values) + self.epsilon * np.mean(values)
        value = reward + self.gamma * v_t
        self.qec.update(prev_embedding, prev_action, value, time - 1)

    def MC_update(self):
        value = 0.0
        for _ in range(len(self.memory)):
            experience = self.memory.pop()
            value = value * self.gamma + experience["reward"]
            self.qec.update(
                experience["state"],
                experience["action"],
                value,
                experience["time"],
            )

    def add_to_memory(self, state_embedding, action, reward, time):
        self.memory.append({
            "state": state_embedding,
            "action": action,
            "reward": reward,
            "time": time,
        })

    def run_episode(self):
        """
        Train an MFEC agent for a single episode:
            Interact with environment
            Perform update
        """
        self.episodes += 1
        RENDER_SPEED = 0.04
        RENDER = False

        episode_frames = 0
        total_reward = 0
        total_steps = 0

        # Update epsilon
        if self.epsilon > self.final_epsilon:
            self.epsilon = self.epsilon * self.epsilon_decay

        #self.env.seed(random.randint(0, 1000000))
        state = self.env.reset()
        if self.environment_type == 'fourrooms':
            fewest_steps = self.env.shortest_path_length(self.env.state)
        done = False
        time = 0
        while not done:
            time += 1
            if self.embedding_type == 'random':
                state = np.array(state).flatten()
                state_embedding = np.dot(self.projection, state)
            elif self.embedding_type == 'VAE':
                state = torch.tensor(state).permute(2, 0, 1)  #(H,W,C)->(C,H,W)
                state = state.unsqueeze(0).to(self.device)
                with torch.no_grad():
                    mu, logvar = self.vae.encoder(state)
                    state_embedding = torch.cat([mu, logvar], 1)
                    state_embedding = state_embedding.squeeze()
                    state_embedding = state_embedding.cpu().numpy()
            elif self.embedding_type == 'SR':
                if self.SR_train_algo == 'TD':
                    state = np.array(state).flatten()
                    state_embedding = np.dot(self.projection, state)
                    with torch.no_grad():
                        state_embedding = self.mlp(
                            torch.tensor(state_embedding)).cpu().numpy()
                elif self.SR_train_algo == 'DP':
                    s = self.env.state
                    state_embedding = self.true_SR_dict[s]
            state_embedding = state_embedding / np.linalg.norm(state_embedding)
            if RENDER:
                self.env.render()
                time.sleep(RENDER_SPEED)

            # Get estimated value of each action
            values = [
                self.qec.estimate(state_embedding, action)
                for action in self.actions
            ]

            action = self.choose_action(values)
            state, reward, done, _ = self.env.step(action)
            if self.Q_train_algo == 'MC':
                self.add_to_memory(state_embedding, action, reward, time)
            elif self.Q_train_algo == 'TD':
                if time > 1:
                    self.TD_update(prev_embedding, prev_action, prev_reward,
                                   values, time)
            prev_reward = reward
            prev_embedding = state_embedding
            prev_action = action
            total_reward += reward
            total_steps += 1
            episode_frames += self.env.skip

        if self.Q_train_algo == 'MC':
            self.MC_update()
        if self.episodes % self.print_every == 0:
            print("KNN usage:", np.mean(self.qec.knn_usage))
            self.qec.knn_usage = []
            print("Proportion of replace:", np.mean(self.qec.replace_usage))
            self.qec.replace_usage = []
        if self.environment_type == 'fourrooms':
            n_extra_steps = total_steps - fewest_steps
            return n_extra_steps, episode_frames, total_reward
        else:
            return episode_frames, total_reward

    def warmup(self):
        """
        Collect 1 million frames from random policy and train VAE
        """
        if self.embedding_type == 'VAE':
            if self.load_vae_from is not None:
                self.vae.load_state_dict(torch.load(self.load_vae_from))
                self.vae = self.vae.to(self.device)
            else:
                # Collect 1 million frames from random policy
                print("Generating dataset to train VAE from random policy")
                vae_data = []
                state = self.env.reset()
                total_frames = 0
                while total_frames < self.vae_train_frames:
                    action = random.randint(0, self.env.action_space.n - 1)
                    state, reward, done, _ = self.env.step(action)
                    vae_data.append(state)
                    total_frames += self.env.skip
                    if done:
                        state = self.env.reset()
                # Dataset, Dataloader for 1 million frames
                vae_data = torch.tensor(
                    vae_data
                )  # (N x H x W x C) - (1mill/skip X 84 X 84 X frames_to_stack)
                vae_data = vae_data.permute(0, 3, 1, 2)  # (N x C x H x W)
                vae_dataset = TensorDataset(vae_data)
                vae_dataloader = DataLoader(vae_dataset,
                                            batch_size=self.vae_batch_size,
                                            shuffle=True)
                # Training loop
                print("Training VAE")
                self.vae.train()
                for epoch in range(self.vae_epochs):
                    train_loss = 0
                    for batch_idx, batch in enumerate(vae_dataloader):
                        batch = batch[0].to(self.device)
                        self.optimizer.zero_grad()
                        recon_batch, mu, logvar = self.vae(batch)
                        loss = self.vae_loss(recon_batch, batch, mu, logvar)
                        train_loss += loss.item()
                        loss.backward()
                        self.optimizer.step()
                        if batch_idx % self.vae_print_every == 0:
                            msg = 'VAE Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                                epoch, batch_idx * len(batch),
                                len(vae_dataloader.dataset),
                                loss.item() / len(batch))
                            print(msg)
                    print('====> Epoch {} Average loss: {:.4f}'.format(
                        epoch, train_loss / len(vae_dataloader.dataset)))
                    if self.vae_weights_file is not None:
                        torch.save(self.vae.state_dict(),
                                   self.vae_weights_file)
            self.vae.eval()
        elif self.embedding_type == 'SR':
            if self.SR_embedding_type == 'random':
                if self.SR_train_algo == 'TD':
                    total_frames = 0
                    transitions = []
                    while total_frames < self.SR_train_frames:
                        observation = self.env.reset()
                        s_t = self.env.state  # will not work on Atari
                        done = False
                        while not done:
                            action = np.random.randint(self.env.action_space.n)
                            observation, reward, done, _ = self.env.step(
                                action)
                            s_tp1 = self.env.state  # will not work on Atari
                            transitions.append((s_t, s_tp1))
                            total_frames += self.env.skip
                            s_t = s_tp1
                    # Dataset, Dataloader
                    dataset = SRDataset(self.env, self.projection, transitions)
                    dataloader = DataLoader(dataset,
                                            batch_size=self.SR_batch_size,
                                            shuffle=True)
                    train_losses = []
                    #Training loop
                    for epoch in range(self.SR_epochs):
                        for batch_idx, batch in enumerate(dataloader):
                            self.optimizer.zero_grad()
                            e_t, e_tp1 = batch
                            e_t = e_t.to(self.device)
                            e_tp1 = e_tp1.to(self.device)
                            mhat_t = self.mlp(e_t)
                            mhat_tp1 = self.mlp(e_tp1)
                            target = e_t + self.gamma * mhat_tp1.detach()
                            loss = self.loss_fn(mhat_t, target)
                            loss.backward()
                            self.optimizer.step()
                            train_losses.append(loss.item())
                        print("Epoch:", epoch, "Average loss",
                              np.mean(train_losses))

                    emb_reps = np.zeros(
                        [self.env.n_states, self.embedding_size])
                    SR_reps = np.zeros(
                        [self.env.n_states, self.embedding_size])
                    labels = []
                    room_size = self.env.room_size
                    for i, (state,
                            obs) in enumerate(self.env.state_dict.items()):
                        emb = np.dot(self.projection, obs.flatten())
                        emb_reps[i, :] = emb
                        with torch.no_grad():
                            emb = torch.tensor(emb).to(self.device)
                            SR = self.mlp(emb).cpu().numpy()
                        SR_reps[i, :] = SR
                        if state[0] < room_size + 1 and state[
                                1] < room_size + 1:
                            label = 0
                        elif state[0] > room_size + 1 and state[
                                1] < room_size + 1:
                            label = 1
                        elif state[0] < room_size + 1 and state[
                                1] > room_size + 1:
                            label = 2
                        elif state[0] > room_size + 1 and state[
                                1] > room_size + 1:
                            label = 3
                        else:
                            label = 4
                        labels.append(label)
                    np.save('%s_SR_reps.npy' % (self.SR_filename), SR_reps)
                    np.save('%s_emb_reps.npy' % (self.SR_filename), emb_reps)
                    np.save('%s_labels.npy' % (self.SR_filename), labels)
                elif self.SR_train_algo == 'MC':
                    pass
                elif self.SR_train_algo == 'DP':
                    # Use this to ensure same order every time
                    idx_to_state = {
                        i: state
                        for i, state in enumerate(self.env.state_dict.keys())
                    }
                    state_to_idx = {v: k for k, v in idx_to_state.items()}
                    T = np.zeros([self.env.n_states, self.env.n_states])
                    for i, s in idx_to_state.items():
                        for a in range(4):
                            self.env.state = s
                            _, _, _, _ = self.env.step(a)
                            s_tp1 = self.env.state
                            T[state_to_idx[s], state_to_idx[s_tp1]] += 0.25
                    true_SR = np.eye(self.env.n_states)
                    done = False
                    t = 0
                    while not done:
                        t += 1
                        new_SR = true_SR + (self.SR_gamma**t) * (np.matmul(
                            true_SR, T))
                        done = np.max(np.abs(true_SR - new_SR)) < 1e-10
                        true_SR = new_SR
                    self.true_SR_dict = {}
                    for s, obs in self.env.state_dict.items():
                        idx = state_to_idx[s]
                        self.true_SR_dict[s] = true_SR[idx, :]
        else:
            pass  # random projection doesn't require warmup
Exemplo n.º 5
0
            model.cuda(device=args.gpu)
    else:
        if args.model == 'MLP':
            model = MLP(embedding_indexer=seq_indexer,
                        gpu=args.gpu,
                        feat_num=label_indexer.__len__(),
                        dropout=args.dropout_rate)
        elif args.model == 'CNN':
            model = TextCNN(embedding_indexer=seq_indexer,
                            gpu=args.gpu,
                            feat_num=label_indexer.__len__(),
                            dropout=args.dropout_rate,
                            kernel_size=[2, 3, 5])

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 betas=(0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=0,
                                 amsgrad=False)
    eval = sst2F1Eval()
    best_score = 0.0
    count = 0
    for epoch in range(args.num_epoch):
        train_loss = 0.0
        k = 0
        for x, y in tqdm(train_loader):
            padded_text, lens, mask = seq_indexer.add_padding_tensor(
                x, gpu=args.gpu)
            label = label_indexer.instance2tensor(y, gpu=args.gpu)