예제 #1
0
 def entropy_function(probs):
     return Categorical(probs=probs).entropy()
예제 #2
0
 def __call__(self, log_action_probs):
     return Categorical(logits=log_action_probs)
예제 #3
0
파일: diayn.py 프로젝트: zizai/notebooks
 def act(self, s, z):
     return Categorical(torch.exp(self(s, z))).sample().item()
예제 #4
0
 def get_action(self, x, action=None):
     logits = self.actor(self.network(x))
     probs = Categorical(logits=logits)
     if action is None:
         action = probs.sample()
     return action, probs.log_prob(action), probs.entropy()
 def _action_distributions(self, encoder_out):
     action_probas = self.controller.actor(encoder_out)
     return action_probas, Categorical(action_probas)  # note: no logits
예제 #6
0
    def _test_mdnrnn_simulate_world(self, use_gpu=False):
        num_epochs = 300
        num_episodes = 400
        batch_size = 200
        action_dim = 2
        seq_len = 5
        state_dim = 2
        simulated_num_gaussians = 2
        mdrnn_num_gaussians = 2
        simulated_num_hidden_layers = 1
        simulated_num_hiddens = 3
        mdnrnn_num_hidden_layers = 1
        mdnrnn_num_hiddens = 10
        adam_lr = 0.01

        replay_buffer = MDNRNNMemoryPool(max_replay_memory_size=num_episodes)
        swm = SimulatedWorldModel(
            action_dim=action_dim,
            state_dim=state_dim,
            num_gaussians=simulated_num_gaussians,
            lstm_num_hidden_layers=simulated_num_hidden_layers,
            lstm_num_hiddens=simulated_num_hiddens,
        )

        possible_actions = torch.eye(action_dim)
        for _ in range(num_episodes):
            cur_state_mem = np.zeros((seq_len, state_dim))
            next_state_mem = np.zeros((seq_len, state_dim))
            action_mem = np.zeros((seq_len, action_dim))
            reward_mem = np.zeros(seq_len)
            not_terminal_mem = np.zeros(seq_len)
            next_mus_mem = np.zeros(
                (seq_len, simulated_num_gaussians, state_dim))

            swm.init_hidden(batch_size=1)
            next_state = torch.randn((1, 1, state_dim))
            for s in range(seq_len):
                cur_state = next_state
                action = possible_actions[np.random.randint(action_dim)].view(
                    1, 1, action_dim)
                next_mus, reward = swm(action, cur_state)

                not_terminal = 1
                if s == seq_len - 1:
                    not_terminal = 0

                # randomly draw for next state
                next_pi = torch.ones(
                    simulated_num_gaussians) / simulated_num_gaussians
                index = Categorical(next_pi).sample((1, )).long().item()
                next_state = next_mus[0, 0, index].view(1, 1, state_dim)

                cur_state_mem[s] = cur_state.detach().numpy()
                action_mem[s] = action.numpy()
                reward_mem[s] = reward.detach().numpy()
                not_terminal_mem[s] = not_terminal
                next_state_mem[s] = next_state.detach().numpy()
                next_mus_mem[s] = next_mus.detach().numpy()

            replay_buffer.insert_into_memory(cur_state_mem, action_mem,
                                             next_state_mem, reward_mem,
                                             not_terminal_mem)

        num_batch = num_episodes // batch_size
        mdnrnn_params = MDNRNNParameters(
            hidden_size=mdnrnn_num_hiddens,
            num_hidden_layers=mdnrnn_num_hidden_layers,
            minibatch_size=batch_size,
            learning_rate=adam_lr,
            num_gaussians=mdrnn_num_gaussians,
        )
        mdnrnn_net = MemoryNetwork(
            state_dim=state_dim,
            action_dim=action_dim,
            num_hiddens=mdnrnn_params.hidden_size,
            num_hidden_layers=mdnrnn_params.num_hidden_layers,
            num_gaussians=mdnrnn_params.num_gaussians,
        )
        if use_gpu and torch.cuda.is_available():
            mdnrnn_net = mdnrnn_net.cuda()
        trainer = MDNRNNTrainer(mdnrnn_network=mdnrnn_net,
                                params=mdnrnn_params,
                                cum_loss_hist=num_batch)

        for e in range(num_epochs):
            for i in range(num_batch):
                training_batch = replay_buffer.sample_memories(
                    batch_size, use_gpu=use_gpu, batch_first=use_gpu)
                losses = trainer.train(training_batch, batch_first=use_gpu)
                logger.info(
                    "{}-th epoch, {}-th minibatch: \n"
                    "loss={}, bce={}, gmm={}, mse={} \n"
                    "cum loss={}, cum bce={}, cum gmm={}, cum mse={}\n".format(
                        e,
                        i,
                        losses["loss"],
                        losses["bce"],
                        losses["gmm"],
                        losses["mse"],
                        np.mean(trainer.cum_loss),
                        np.mean(trainer.cum_bce),
                        np.mean(trainer.cum_gmm),
                        np.mean(trainer.cum_mse),
                    ))

                if (np.mean(trainer.cum_loss) < 0
                        and np.mean(trainer.cum_gmm) < -3.0
                        and np.mean(trainer.cum_bce) < 0.6
                        and np.mean(trainer.cum_mse) < 0.2):
                    return

        assert False, "losses not reduced significantly during training"
예제 #7
0
def policy_forward(observation):
    logits = policy_network(observation)
    if discrete:
        return Categorical(logits=logits)
    else:
        return Normal(logits, torch.exp(log_std))
예제 #8
0
 def _distribution(self, obs):
     logits = self.logits_net(obs)
     return Categorical(logits)
예제 #9
0
파일: ppo_fast.py 프로젝트: vwxyzjn/cleanrl
def act(args, agent, i):
    def make_env(gym_id, seed, idx):
        def thunk():
            env = gym.make(gym_id)
            env = wrap_atari(env)
            env = gym.wrappers.RecordEpisodeStatistics(env)
            if args.capture_video:
                if idx == 0:
                    env = ProbsVisualizationWrapper(env)
                    env = Monitor(env, f'videos/{experiment_name}')
            env = wrap_pytorch(
                wrap_deepmind(
                    env,
                    clip_rewards=True,
                    frame_stack=True,
                    scale=False,
                )
            )
            env.seed(seed)
            env.action_space.seed(seed)
            env.observation_space.seed(seed)
            return env
        return thunk
    envs = VecPyTorch(DummyVecEnv([make_env(args.gym_id, args.seed+i, i)]), device)
    assert isinstance(envs.action_space, Discrete), "only discrete action space is supported"
    # ALGO Logic: Storage for epoch data
    obs = torch.zeros((args.num_steps, args.num_envs) + envs.observation_space.shape).to(device)
    actions = torch.zeros((args.num_steps, args.num_envs) + envs.action_space.shape).to(device)
    logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
    rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
    dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
    values = torch.zeros((args.num_steps, args.num_envs)).to(device)
    
    # TRY NOT TO MODIFY: start the game
    global_step = 0
    # Note how `next_obs` and `next_done` are used; their usage is equivalent to
    # https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/84a7582477fb0d5c82ad6d850fe476829dddd2e1/a2c_ppo_acktr/storage.py#L60
    next_obs = envs.reset()
    next_done = torch.zeros(args.num_envs).to(device)
    num_updates = args.total_timesteps // args.batch_size
    print(num_updates)
    for update in range(1, num_updates+1):
        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            frac = 1.0 - (update - 1.0) / num_updates
            lrnow = lr(frac)
            optimizer.param_groups[0]['lr'] = lrnow
    
        # TRY NOT TO MODIFY: prepare the execution of the game.
        for step in range(0, args.num_steps):
            global_step += 1 * args.num_envs
            obs[step] = next_obs
            dones[step] = next_done
    
            # ALGO LOGIC: put action logic here
            with torch.no_grad():
                values[step] = agent.get_value(obs[step]).flatten()
                action, logproba, _ = agent.get_action(obs[step])
    
                # visualization
                if args.capture_video:
                    probs_list = np.array(Categorical(
                        logits=agent.actor(agent.forward(obs[step]))).probs[0:1].tolist())
                    envs.env_method("set_probs", probs_list, indices=0)
    
            actions[step] = action
            logprobs[step] = logproba
    
            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, rs, ds, infos = envs.step(action)
            rewards[step], next_done = rs.view(-1), torch.Tensor(ds).to(device)
    
            for info in infos:
                if 'episode' in info.keys():
                    print(f"global_step={global_step}, episode_reward={info['episode']['r']}")
                    # writer.add_scalar("charts/episode_reward", info['episode']['r'], global_step)
                    break
    def forward(self,
                key,
                values,
                lengths,
                text=None,
                isTrain=True,
                isAttended=True,
                random=False):
        '''
        :param key :(N, T, key_size) Output of the Encoder Key projection layer
        :param values: (N, T, value_size) Output of the Encoder Value projection layer
        :param text: (N, text_len) Batch input of text with text_length
        :param isTrain: Train or eval mode
        :return predictions: Returns the character perdiction probability 
        '''
        batch_size = key.shape[0]

        probs = torch.ones(batch_size, dtype=torch.float64).to(DEVICE)

        if (isTrain == True):
            max_len = text.shape[
                1] - 1  # note that y is <sos> ... <eos>, but maxlen should be len(... <eos>)
            embeddings = self.embedding(text)  # (N, text_len, embedding_size)
        else:
            max_len = 250

        # initialization
        predictions = []
        hidden_states = [[
            torch.zeros(batch_size, self.hidden_dim).to(DEVICE),
            torch.zeros(batch_size, self.hidden_dim).to(DEVICE)
        ],
                         [
                             torch.zeros(batch_size, self.key_size).to(DEVICE),
                             torch.zeros(batch_size,
                                         self.hidden_dim).to(DEVICE)
                         ]]

        prediction = torch.zeros(
            batch_size, 1).to(DEVICE) + self.sos_index  # first input is <sos>

        # reset lockedDropout mask
        self.lockeddropout_cell_1.reset_mask()
        self.lockeddropout_cell_2.reset_mask()

        if isAttended:
            attention_context = self.attention(hidden_states[1][0], key,
                                               values, lengths)[0]
        else:
            attention_context = torch.zeros(batch_size, self.value_size).to(
                DEVICE)  # random value when no attention

        for i in range(max_len):
            if isTrain:
                # teacher forcing
                if np.random.random_sample() <= 0.2:
                    # auto-regressive
                    char_embed = self.embedding(prediction.argmax(-1))
                else:
                    # use true label
                    char_embed = embeddings[:, i, :]
            else:
                # auto-regressive
                if random == False:
                    char_embed = self.embedding(prediction.argmax(-1))
                    probs = probs * prediction.softmax(dim=1).max(-1)[0]
                else:
                    pred_probs = prediction.softmax(dim=1)
                    m = Categorical(probs=pred_probs)
                    sample_pred = m.sample()
                    char_embed = self.embedding(sample_pred)
                    probs = probs * pred_probs.gather(
                        1, sample_pred.unsqueeze(1)).squeeze(1)

            inp = torch.cat([char_embed, attention_context], dim=1)
            hidden_states[0] = self.lstm1(inp, hidden_states[0])

            inp_2 = hidden_states[0][0]
            inp_2 = self.lockeddropout_cell_1(inp_2, dropout=self.dropout)
            hidden_states[1] = self.lstm2(inp_2, hidden_states[1])

            ### Compute attention from the output of the second LSTM Cell ###
            if isAttended:
                attention_context = self.attention(hidden_states[1][0], key,
                                                   values, lengths)[0]
            else:
                attention_context = torch.zeros(
                    batch_size, self.value_size).to(
                        DEVICE)  # random value when no attention

            output = hidden_states[1][0]
            output = self.lockeddropout_cell_2(output, dropout=self.dropout)

            # use output and attention_context to predict
            prediction = self.character_prob(
                torch.cat([output, attention_context],
                          dim=1))  # (N, vocab_size)
            predictions.append(prediction.unsqueeze(1))  # (N, 1, vocab_size)

        return torch.cat(predictions, dim=1), probs  # (N, T, vocab_size), (N,)
    def forward(self, obs, memory):

        # Model associated with first domain
        x1 = obs.image.transpose(1, 3).transpose(2, 3)
        x1 = self.image_conv1(x1)
        x1 = x1.reshape(x1.shape[0], -1)

        if self.use_memory:
            ValueError("Memory not supported yet")
            hidden = (memory[:, :self.semi_memory_size],
                      memory[:, self.semi_memory_size:])
            hidden = self.memory_rnn(x, hidden)
            embedding = hidden[0]
            memory = torch.cat(hidden, dim=1)
        else:
            embedding1 = x1

        if self.use_text:
            ValueError("Text not supported yet")
            embed_text = self._get_embed_text(obs.text)
            embedding = torch.cat((embedding, embed_text), dim=1)

        # Actor
        x1_actor = self.actor1(embedding1)

        # Critic
        x1 = self.critic1(embedding1)
        value1 = x1.squeeze(1)

        ##############################################################

        # Now model associated with second domain
        x2 = obs.image.transpose(1, 3).transpose(2, 3)
        x2 = self.image_conv2(x2)
        x2 = x2.reshape(x2.shape[0], -1)

        if self.use_memory:
            ValueError("Memory not supported yet")
            hidden = (memory[:, :self.semi_memory_size],
                      memory[:, self.semi_memory_size:])
            hidden = self.memory_rnn(x, hidden)
            embedding = hidden[0]
            memory = torch.cat(hidden, dim=1)
        else:
            embedding2 = x2

        if self.use_text:
            ValueError("Text not supported yet")
            embed_text = self._get_embed_text(obs.text)
            embedding = torch.cat((embedding, embed_text), dim=1)

        # Actor
        x2_actor = self.actor2(embedding2)

        # Critic
        x2 = self.critic2(embedding2)
        value2 = x2.squeeze(1)

        ##############################################################

        # Take average of two models

        # Actor average
        x = (x1_actor + x2_actor) / 2
        dist = Categorical(logits=F.log_softmax(x, dim=1))

        # Critic average
        value = (value1 + value2) / 2

        return dist, value, memory
예제 #12
0
    obs = np.empty((args.episode_length, ) + env.observation_space.shape)

    # TODO: put other storage logic here
    values = torch.zeros((args.episode_length))
    neglogprobs = torch.zeros((args.episode_length, ))
    entropys = torch.zeros((args.episode_length, ))

    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(args.episode_length):
        global_step += 1
        obs[step] = next_obs.copy()

        # TODO: put action logic here
        logits = pg.forward(obs[step])
        value = vf.forward(obs[step])
        probs = Categorical(logits=logits)
        action = probs.sample()
        neglogprobs[step] = -probs.log_prob(action)
        values[step] = value
        entropys[step] = probs.entropy()

        # TRY NOT TO MODIFY: execute the game and log data.
        actions[step] = action
        next_obs, rewards[step], dones[step], _ = env.step(
            int(actions[step].numpy()))
        if dones[step]:
            break

    # TODO: training.
    # calculate the discounted rewards, or namely, returns
    returns = np.zeros_like(rewards)
 def dist(self, obs) -> torch.distributions.Distribution:
     logits = self.net(obs)
     return Categorical(logits=logits)
예제 #14
0
def valor(env_fn, actor_critic=ActorCritic, ac_kwargs=dict(), disc=Discriminator, dc_kwargs=dict(), seed=0, episodes_per_epoch=40,
        epochs=50, gamma=0.99, pi_lr=3e-4, vf_lr=1e-3, dc_lr=5e-4, train_v_iters=80, train_dc_iters=10, train_dc_interv=10, 
        lam=0.97, max_ep_len=1000, logger_kwargs=dict(), con_dim=5, save_freq=10, k=1):

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    seed += 10000 * proc_id()
    torch.manual_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    ac_kwargs['action_space'] = env.action_space

    # Model
    actor_critic = actor_critic(input_dim=obs_dim[0]+con_dim, **ac_kwargs)
    disc = disc(input_dim=obs_dim[0], context_dim=con_dim, **dc_kwargs)

    # Buffer
    local_episodes_per_epoch = int(episodes_per_epoch / num_procs())
    buffer = Buffer(con_dim, obs_dim[0], act_dim[0], local_episodes_per_epoch, max_ep_len, train_dc_interv)

    # Count variables
    var_counts = tuple(count_vars(module) for module in
        [actor_critic.policy, actor_critic.value_f, disc.policy])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d, \t d: %d\n'%var_counts)    

    # Optimizers
    train_pi = torch.optim.Adam(actor_critic.policy.parameters(), lr=pi_lr)
    train_v = torch.optim.Adam(actor_critic.value_f.parameters(), lr=vf_lr)
    train_dc = torch.optim.Adam(disc.policy.parameters(), lr=dc_lr)

    # Parameters Sync
    sync_all_params(actor_critic.parameters())
    sync_all_params(disc.parameters())

    def update(e):
        obs, act, adv, pos, ret, logp_old = [torch.Tensor(x) for x in buffer.retrieve_all()]
        
        # Policy
        _, logp, _ = actor_critic.policy(obs, act)
        entropy = (-logp).mean()

        # Policy loss
        pi_loss = -(logp*(k*adv+pos)).mean()

        # Train policy
        train_pi.zero_grad()
        pi_loss.backward()
        average_gradients(train_pi.param_groups)
        train_pi.step()

        # Value function
        v = actor_critic.value_f(obs)
        v_l_old = F.mse_loss(v, ret)
        for _ in range(train_v_iters):
            v = actor_critic.value_f(obs)
            v_loss = F.mse_loss(v, ret)

            # Value function train
            train_v.zero_grad()
            v_loss.backward()
            average_gradients(train_v.param_groups)
            train_v.step()

        # Discriminator
        if (e+1) % train_dc_interv == 0:
            print('Discriminator Update!')
            con, s_diff = [torch.Tensor(x) for x in buffer.retrieve_dc_buff()]
            _, logp_dc, _ = disc(s_diff, con)
            d_l_old = -logp_dc.mean()

            # Discriminator train
            for _ in range(train_dc_iters):
                _, logp_dc, _ = disc(s_diff, con)
                d_loss = -logp_dc.mean()
                train_dc.zero_grad()
                d_loss.backward()
                average_gradients(train_dc.param_groups)
                train_dc.step()

            _, logp_dc, _ = disc(s_diff, con)
            dc_l_new = -logp_dc.mean()
        else:
            d_l_old = 0
            dc_l_new = 0

        # Log the changes
        _, logp, _, v = actor_critic(obs, act)
        pi_l_new = -(logp*(k*adv+pos)).mean()
        v_l_new = F.mse_loss(v, ret)
        kl = (logp_old - logp).mean()
        logger.store(LossPi=pi_loss, LossV=v_l_old, KL=kl, Entropy=entropy, DeltaLossPi=(pi_l_new-pi_loss),
            DeltaLossV=(v_l_new-v_l_old), LossDC=d_l_old, DeltaLossDC=(dc_l_new-d_l_old))
        # logger.store(Adv=adv.reshape(-1).numpy().tolist(), Pos=pos.reshape(-1).numpy().tolist())

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    context_dist = Categorical(logits=torch.Tensor(np.ones(con_dim)))
    total_t = 0

    for epoch in range(epochs):
        actor_critic.eval()
        disc.eval()
        for _ in range(local_episodes_per_epoch):
            c = context_dist.sample()
            c_onehot = F.one_hot(c, con_dim).squeeze().float()
            for _ in range(max_ep_len):
                concat_obs = torch.cat([torch.Tensor(o.reshape(1, -1)), c_onehot.reshape(1, -1)], 1)
                a, _, logp_t, v_t = actor_critic(concat_obs)

                buffer.store(c, concat_obs.squeeze().detach().numpy(), a.detach().numpy(), r, v_t.item(), logp_t.detach().numpy())
                logger.store(VVals=v_t)

                o, r, d, _ = env.step(a.detach().numpy()[0])
                ep_ret += r
                ep_len += 1
                total_t += 1

                terminal = d or (ep_len == max_ep_len)
                if terminal:
                    dc_diff = torch.Tensor(buffer.calc_diff()).unsqueeze(0)
                    con = torch.Tensor([float(c)]).unsqueeze(0)
                    _, _, log_p = disc(dc_diff, con)
                    buffer.end_episode(log_p.detach().numpy())
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, [actor_critic, disc], None)

        # Update
        actor_critic.train()
        disc.train()

        update(epoch)

        # Log
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', total_t)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('LossDC', average_only=True)
        logger.log_tabular('DeltaLossDC', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('Time', time.time()-start_time)
        logger.dump_tabular()
예제 #15
0
 def sample(
     self,
     reconstruct,
 ):
     m = Categorical(torch.exp(reconstruct))
     return m.sample()
예제 #16
0
 def get_pi(self, obs):
     return Categorical(logits=self.net(obs))
예제 #17
0
파일: Task4.py 프로젝트: ruslenghi/PAI2020
 def _distribution(self, obs):
     """Takes the observation and outputs a distribution over actions."""
     logits = self.logits_net(obs)
     return Categorical(logits=logits)
예제 #18
0
    def decode_where(self, input_coord_logits, input_attri_logits,
                     input_patch_vectors, sample_mode):
        """
        Inputs:
            where_states containing
            - **coord_logits**  (bsize, 1, grid_dim)
            - **attri_logits**  (bsize, 1, scale_ratio_dim, grid_dim)
            - **patch_vectors** (bsize, 1, patch_feature_dim, grid_dim)
            sample_mode
              0: top 1, 1: multinomial

        Outputs
            - **sample_inds**   (bsize, 3)
            - **sample_vecs**   (bsize, patch_feature_dim)
        """

        ##############################################################
        # Sampling locations
        ##############################################################
        coord_logits = input_coord_logits.squeeze(1)
        if sample_mode == 0:
            _, sample_coord_inds = torch.max(coord_logits + 1.0,
                                             dim=-1,
                                             keepdim=True)
        else:
            sample_coord_inds = Categorical(coord_logits).sample().unsqueeze(
                -1)

        ##############################################################
        # Sampling attributes and patch vectors
        ##############################################################

        patch_vectors = input_patch_vectors.squeeze(1)
        bsize, tsize, grid_dim = patch_vectors.size()
        aux_pos_inds = sample_coord_inds.expand(bsize, tsize).unsqueeze(-1)
        sample_patch_vectors = torch.gather(patch_vectors, -1,
                                            aux_pos_inds).squeeze(-1)

        attri_logits = input_attri_logits.squeeze(1)
        bsize, tsize, grid_dim = attri_logits.size()
        aux_pos_inds = sample_coord_inds.expand(bsize, tsize).unsqueeze(-1)
        local_logits = torch.gather(attri_logits, -1, aux_pos_inds).squeeze(-1)

        scale_logits = local_logits[:, :self.cfg.num_scales]
        ratio_logits = local_logits[:, self.cfg.num_scales:]

        if sample_mode == 0:
            _, sample_scale_inds = torch.max(scale_logits + 1.0,
                                             dim=-1,
                                             keepdim=True)
            _, sample_ratio_inds = torch.max(ratio_logits + 1.0,
                                             dim=-1,
                                             keepdim=True)
        else:
            sample_scale_inds = Categorical(scale_logits).sample().unsqueeze(
                -1)
            sample_ratio_inds = Categorical(ratio_logits).sample().unsqueeze(
                -1)

        sample_inds = torch.cat(
            [sample_coord_inds, sample_scale_inds, sample_ratio_inds], -1)

        return sample_inds, sample_patch_vectors
예제 #19
0
    def rewrite(self,
                tm,
                ac_logprobs,
                trace_rec,
                expr_rec,
                candidate_rewrite_pos,
                pending_actions,
                eval_flag,
                max_search_pos,
                reward_thres=None):
        if len(candidate_rewrite_pos) == 0:
            return [], [], [], [], [], []

        candidate_rewrite_pos.sort(reverse=True, key=operator.itemgetter(0))
        if not eval_flag:
            sample_exp_reward_tensor = []
            for idx, (cur_pred_reward, cur_pred_reward_tensor, cur_ac_prob,
                      rewrite_pos,
                      tensor_idx) in enumerate(candidate_rewrite_pos):
                sample_exp_reward_tensor.append(cur_pred_reward_tensor)
            sample_exp_reward_tensor = torch.cat(sample_exp_reward_tensor, 0)
            sample_exp_reward_tensor = torch.exp(sample_exp_reward_tensor * 10)
            sample_exp_reward = sample_exp_reward_tensor.data.cpu().numpy()

        expr = expr_rec[-1]
        extra_reward_rec = []
        extra_action_rec = []
        candidate_tree_managers = []
        candidate_update_tree_idxes = []
        candidate_rewrite_rec = []
        candidate_expr_rec = []
        candidate_pending_actions = []

        if len(pending_actions) > 0:
            for idx, (pred_reward, cur_pred_reward_tensor, cur_ac_prob,
                      rewrite_pos,
                      tensor_idx) in enumerate(candidate_rewrite_pos):
                if len(candidate_tree_managers) > 0 and idx >= max_search_pos:
                    break
                if reward_thres is not None and pred_reward < reward_thres:
                    if eval_flag:
                        break
                    elif np.random.random() > self.cont_prob:
                        continue
                init_expr = tm.to_string(rewrite_pos)
                op_idx = pending_actions[0]
                op_list = self.rewriter.get_rewrite_seq(op_idx)
                op = self.rewriter.get_rewrite_op(op_list[0])
                new_tm, cur_update_tree_idxes = op(tm, rewrite_pos)
                if len(cur_update_tree_idxes) == 0:
                    extra_action_rec.append((ac_logprobs[tensor_idx], op_idx))
                    continue
                cur_expr = str(new_tm)
                if cur_expr in candidate_expr_rec:
                    continue
                candidate_expr_rec.append(cur_expr)
                candidate_update_tree_idxes.append(cur_update_tree_idxes)
                candidate_tree_managers.append(new_tm)
                candidate_rewrite_rec.append(
                    (ac_logprobs[tensor_idx], pred_reward,
                     cur_pred_reward_tensor, rewrite_pos, init_expr,
                     int(op_idx)))
                candidate_pending_actions.append(pending_actions[1:])
                if len(candidate_tree_managers) >= max_search_pos:
                    break
            if len(candidate_tree_managers) > 0:
                return candidate_tree_managers, candidate_update_tree_idxes, candidate_rewrite_rec, candidate_pending_actions, extra_reward_rec, extra_action_rec

        if not eval_flag:
            sample_rewrite_pos_dist = Categorical(sample_exp_reward_tensor)
            sample_rewrite_pos = sample_rewrite_pos_dist.sample(
                sample_shape=[len(candidate_rewrite_pos)])
            #sample_rewrite_pos = torch.multinomial(sample_exp_reward_tensor, len(candidate_rewrite_pos))
            sample_rewrite_pos = sample_rewrite_pos.data.cpu().numpy()
            indexes = np.unique(sample_rewrite_pos, return_index=True)[1]
            sample_rewrite_pos = [
                sample_rewrite_pos[i] for i in sorted(indexes)
            ]
            sample_rewrite_pos = sample_rewrite_pos[:self.
                                                    num_sample_rewrite_pos]
            sample_exp_reward = [
                sample_exp_reward[i] for i in sample_rewrite_pos
            ]
            sample_rewrite_pos = [
                candidate_rewrite_pos[i] for i in sample_rewrite_pos
            ]
        else:
            sample_rewrite_pos = candidate_rewrite_pos.copy()

        for idx, (pred_reward, cur_pred_reward_tensor, cur_ac_prob,
                  rewrite_pos, tensor_idx) in enumerate(sample_rewrite_pos):
            if len(candidate_tree_managers) > 0 and idx >= max_search_pos:
                break
            if reward_thres is not None and pred_reward < reward_thres:
                if eval_flag:
                    break
                elif np.random.random() > self.cont_prob:
                    continue
            init_expr = tm.to_string(rewrite_pos)
            if eval_flag:
                _, candidate_acs = torch.sort(cur_ac_prob)
                candidate_acs = candidate_acs.data.cpu().numpy()
                candidate_acs = candidate_acs[::-1]
            else:
                candidate_acs_dist = Categorical(cur_ac_prob)
                candidate_acs = candidate_acs_dist.sample(
                    sample_shape=[self.num_actions])
                #candidate_acs = torch.multinomial(cur_ac_prob, self.num_actions)
                candidate_acs = candidate_acs.data.cpu().numpy()
                indexes = np.unique(candidate_acs, return_index=True)[1]
                candidate_acs = [candidate_acs[i] for i in sorted(indexes)]
            cur_active = False
            cur_ac_prob = cur_ac_prob.data.cpu().numpy()
            for i, op_idx in enumerate(candidate_acs):
                if (expr, init_expr, op_idx) in trace_rec:
                    continue
                op_list = self.rewriter.get_rewrite_seq(op_idx)
                op = self.rewriter.get_rewrite_op(op_list[0])
                new_tm, cur_update_tree_idxes = op(tm, rewrite_pos)
                if len(cur_update_tree_idxes) == 0:
                    extra_action_rec.append((ac_logprobs[tensor_idx], op_idx))
                    continue
                cur_expr = str(new_tm)
                if cur_expr in candidate_expr_rec:
                    continue
                candidate_expr_rec.append(cur_expr)
                candidate_update_tree_idxes.append(cur_update_tree_idxes)
                candidate_tree_managers.append(new_tm)
                candidate_rewrite_rec.append(
                    (ac_logprobs[tensor_idx], pred_reward,
                     cur_pred_reward_tensor, rewrite_pos, init_expr,
                     int(op_list[0])))
                candidate_pending_actions.append(op_list[1:])
                cur_active = True
                if len(candidate_tree_managers) >= max_search_pos:
                    break
            if not cur_active:
                extra_reward_rec.append(cur_pred_reward_tensor)
        return candidate_tree_managers, candidate_update_tree_idxes, candidate_rewrite_rec, candidate_pending_actions, extra_reward_rec, extra_action_rec
예제 #20
0
    def forward(self,
                enc_pad,
                enc_len,
                ys=None,
                tf_rate=1.0,
                max_dec_timesteps=500,
                sample=False,
                label_smoothing=True):
        batch_size = enc_pad.size(0)
        if ys is not None:
            # prepare input and output sequences
            bos = ys[0].data.new([self.bos])
            eos = ys[0].data.new([self.eos])
            ys_in = [torch.cat([bos, y], dim=0) for y in ys]
            ys_out = [torch.cat([y, eos], dim=0) for y in ys]
            pad_ys_in = pad_list(ys_in, pad_value=self.eos)
            pad_ys_out = pad_list(ys_out, pad_value=self.eos)
            # get length info
            batch_size, olength = pad_ys_out.size(0), pad_ys_out.size(1)
            # map idx to embedding
            eys = self.embedding(pad_ys_in)

        # initialization
        dec_c = self.zero_state(enc_pad)
        dec_z = self.zero_state(enc_pad)
        c = self.zero_state(enc_pad, dim=self.att_odim)

        w = None
        logits, prediction, ws = [], [], []
        # reset the attention module
        self.attention.reset()

        # loop for each timestep
        olength = max_dec_timesteps if ys is None else olength
        for t in range(olength):
            # supervised learning: using teacher forcing
            if ys is not None:
                # teacher forcing
                tf = True if np.random.random_sample() <= tf_rate else False
                emb = eys[:, t, :] if tf or t == 0 else self.embedding(
                    prediction[-1])
            # else, label the data with greedy/sampling
            else:
                if t == 0:
                    bos = cc(
                        torch.Tensor([self.bos for _ in range(batch_size)
                                      ]).type(torch.LongTensor))
                    emb = self.embedding(bos)
                else:
                    emb = self.embedding(prediction[-1])
            logit, dec_z, dec_c, c, w = \
                    self.forward_step(emb, dec_z, dec_c, c, w, enc_pad, enc_len)

            ws.append(w)
            logits.append(logit)
            if not sample:
                prediction.append(torch.argmax(logit, dim=-1))
            else:
                sampled_indices = Categorical(logits=logit).sample()
                prediction.append(sampled_indices)

        logits = torch.stack(logits, dim=1)
        log_probs = F.log_softmax(logits, dim=2)
        prediction = torch.stack(prediction, dim=1)
        ws = torch.stack(ws, dim=1)
        if ys is not None:
            ys_log_probs = torch.gather(
                log_probs, dim=2, index=pad_ys_out.unsqueeze(2)).squeeze(2)
        else:
            ys_log_probs = torch.gather(
                log_probs, dim=2, index=prediction.unsqueeze(2)).squeeze(2)

        # label smoothing
        if label_smoothing and self.ls_weight > 0 and self.training:
            loss_reg = torch.sum(log_probs * self.vlabeldist, dim=2)
            ys_log_probs = (
                1 - self.ls_weight) * ys_log_probs + self.ls_weight * loss_reg
        return logits, ys_log_probs, prediction, ws
예제 #21
0
         optimizer.param_groups[0]['lr'] = lrnow
 
     # TRY NOT TO MODIFY: prepare the execution of the game.
     for step in range(0, args.num_steps):
         global_step += 1 * args.num_envs
         obs[step] = next_obs
         dones[step] = next_done
 
         # ALGO LOGIC: put action logic here
         with torch.no_grad():
             values[step] = agent.get_value(obs[step]).flatten()
             action, logproba, _ = agent.get_action(obs[step])
 
             # visualization
             if args.capture_video:
                 probs_list = np.array(Categorical(
                     logits=agent.actor(agent.network(obs[step]))).probs[0:1].tolist())
                 envs.env_method("set_probs", probs_list, indices=0)
 
         actions[step] = action
         logprobs[step] = logproba
 
         # TRY NOT TO MODIFY: execute the game and log data.
         next_obs, rs, ds, infos = envs.step(action)
         rewards[step], next_done = rs.view(-1), torch.Tensor(ds).to(device)
 
         for info in infos:
             if 'episode' in info.keys():
                 print(f"global_step={global_step}, episode_reward={info['episode']['r']}")
                 writer.add_scalar("charts/episode_reward", info['episode']['r'], global_step)
                 break
 
예제 #22
0
# mylogprobgrad = torch.autograd.grad(outputs=mylogprob, inputs=(probs), retain_graph=True)[0]

print('rewards', rewards)
print('probs', probs)
print('logits', logits)

#REINFORCE
print('REINFORCE')

# def sample_reinforce_given_class(logits, samp):
#     return logprob

grads = []
for i in range(N):

    dist = Categorical(logits=logits)
    samp = dist.sample()
    logprob = dist.log_prob(samp)
    reward = f(samp)
    gradlogprob = torch.autograd.grad(outputs=logprob,
                                      inputs=(logits),
                                      retain_graph=True)[0]
    grads.append(reward * gradlogprob)

print()
grads = torch.stack(grads).view(N, C)
# print (grads.shape)
grad_mean_reinforce = torch.mean(grads, dim=0)
grad_std_reinforce = torch.std(grads, dim=0)

print('REINFORCE')
예제 #23
0
 def get_pi(self, x):
     logits = self.actor(self.network(x))
     return Categorical(logits=logits)
예제 #24
0
def sample_reinforce_given_class(logits, samp):
    dist = Categorical(logits=logits)
    logprob = dist.log_prob(samp)
    return logprob
 def _action_distributions(self, encoder_out):
     actor_logits = self.controller.actor(encoder_out)
     return actor_logits, Categorical(
         logits=actor_logits)  # float tensor of shape [1, num_actions]
예제 #26
0
 def get_action_and_log_prob(self, state_1):
     policy_action_probs = self.policy_net(state_1.to(self.device))
     policy_action_distribution = Categorical(logits=policy_action_probs)
     action = policy_action_distribution.sample()
     log_prob = policy_action_distribution.log_prob(action)
     return action.item(), log_prob
예제 #27
0
    def forward(self, state):
        dist = self.actor(state)
        dist = Categorical(dist)

        return dist
예제 #28
0
 def forward(self, x):
     # out = self.net((x.float() - self.scale) / self.normalize)
     out = self.net(x.float())
     return Categorical(out)
예제 #29
0
 def get_action(self, state):
     with torch.no_grad():
         logits = self.policy(state)
         dist = Categorical(logits=logits)
         a = dist.sample()  # sample action from softmax policy
     return a
예제 #30
0
 def get_policy(obs):
     #print('observations', obs)
     #time.sleep(1)
     logits = pi_net(obs)
     return Categorical(logits=logits)