예제 #1
0
파일: bbvi.py 프로젝트: daemon/vivi
    def forward(self, target_class, eps=1E-5, reinforce=False):
        input_dict = defaultdict(list)
        logits = F.linear(self.wordvec_params, self.embedding.weight)
        c_dist = cat.OneHotCategorical(logits=logits)
        rc_dist = rcat.RelaxedOneHotCategorical(2 / 3, logits=logits)
        probs = F.softmax(logits, 1)
        logp_btheta = []
        indices_lst = []
        z_lst = []
        z_tilde_lst = []
        for _ in range(self.batch_size):
            if reinforce:
                b = c_dist.sample()
                indices_lst.append(b.max(1)[1])
                logp_btheta.append(c_dist.log_prob(b).sum())
                continue
            z = rc_dist.rsample()
            indices = z.max(1)[1]
            b = torch.zeros_like(z).to(z.device)
            b[torch.arange(0, b.size(0)), indices] = 1
            logp_btheta.append(c_dist.log_prob(b).sum())
            u = torch.empty(*z.size()).uniform_().to(z.device).clamp_(
                eps, 1 - eps)
            u.requires_grad = True
            vb = u[b.byte()].unsqueeze(-1).expand_as(probs)
            z_tilde = u.log().mul(-1).log().mul(-1).mul(b) + \
                u.log().mul(-1).div(probs).sub(vb.log()).log().mul(-1).mul(1 - b)
            indices_lst.append(indices)
            z_lst.append(z)
            z_tilde_lst.append(z_tilde)

        if not reinforce:
            z = torch.stack(z_lst)
            z_tilde = torch.stack(z_tilde_lst)
            c1 = self.surrogate(z.view(z.size(0), -1)).squeeze(-1)
            c2 = self.surrogate(z_tilde.view(z_tilde.size(0), -1)).squeeze(-1)
        else:
            c1 = c2 = 0

        indices = torch.stack(indices_lst)
        lp_tm = F.log_softmax(self.target_model(indices), 1)
        lp_tm = lp_tm[:, target_class]
        lp_lm = self.language_model(indices).sum(1)
        entropy = (c_dist.entropy() * self.batch_size).mean()

        f_b = -lp_tm - lp_lm - entropy
        loss = (f_b - c2).detach() * torch.stack(logp_btheta) + c1 - c2
        loss = loss.mean()
        if not reinforce:
            torch.autograd.backward(loss, create_graph=True, retain_graph=True)
            loss_grad = torch.autograd.grad([loss], [self.logits],
                                            create_graph=True,
                                            retain_graph=True)[0]
            torch.autograd.backward((loss_grad**2).mean(),
                                    create_graph=True,
                                    retain_graph=True)
        else:
            loss.backward()
        return loss.item()
예제 #2
0
def get_predicted_letters(outputs, temp):
    """
    Sample from the distributions from the network.
    :param distributions: 2d tensor. Output from network.
    :return: List of tensors. The predicted letters in one hot encoding.
    """
    distributions = softmax_temp(outputs, temp)
    sampler = one_hot_categorical.OneHotCategorical(distributions)
    prediction = sampler.sample()

    return prediction
예제 #3
0
파일: bbvi.py 프로젝트: daemon/vivi
 def sample(self):
     logits = F.linear(self.wordvec_params, self.embedding.weight)
     c_dist = cat.OneHotCategorical(logits=logits)
     return c_dist.sample()
예제 #4
0
파일: rollout.py 프로젝트: puyuan1996/MARL
    def generate_episode(self, episode_num=None, evaluate=False, epsilon=0):
        if self.args.replay_dir != '' and evaluate and episode_num == 0:  # prepare for save replay of evaluation
            self.env.close()
        o, u, r, s, avail_u, u_onehot, terminate, padded = [], [], [], [], [], [], [], []
        self.env.reset()
        terminated = False
        win_tag = False
        step = 0
        episode_reward = 0  # cumulative rewards
        last_action = np.zeros((self.args.n_agents, self.args.n_actions))
        if hasattr(self.agents, 'init_hidden'):  # qmix_ac
            self.agents.init_hidden(1)  # episode_num=1
        else:  #qmix
            self.agents.policy.init_hidden(1)  # episode_num=1

        # sample z for maven
        if self.args.alg == 'maven':
            state = self.env.get_state()
            state = torch.tensor(state, dtype=torch.float32)
            if self.args.cuda:
                state = state.cuda()
            z_prob = self.agents.policy.z_policy(state)
            maven_z = one_hot_categorical.OneHotCategorical(z_prob).sample()
            maven_z = list(maven_z.cpu())

        while not terminated and step < self.episode_limit:
            # time.sleep(0.2)
            obs = self.env.get_obs()
            state = self.env.get_state()
            actions, avail_actions, actions_onehot = [], [], []
            for agent_id in range(self.n_agents):
                avail_action = self.env.get_avail_agent_actions(agent_id)
                if self.args.alg == 'maven':
                    action = self.agents.choose_action(obs[agent_id],
                                                       last_action[agent_id],
                                                       agent_id, avail_action,
                                                       epsilon, maven_z,
                                                       evaluate)
                else:
                    action = self.agents.choose_action(obs[agent_id],
                                                       last_action[agent_id],
                                                       agent_id, avail_action,
                                                       epsilon, None, evaluate)
                # generate onehot vector of th action
                action_onehot = np.zeros(self.args.n_actions)
                action_onehot[action] = 1
                actions.append(action)
                actions_onehot.append(action_onehot)
                avail_actions.append(avail_action)
                last_action[agent_id] = action_onehot

            reward, terminated, info = self.env.step(actions)
            win_tag = True if terminated and 'battle_won' in info and info[
                'battle_won'] else False
            o.append(obs)  # list(array)
            s.append(state)  # array (48,)
            u.append(np.reshape(actions, [self.n_agents, 1]))  # array (3,1)
            u_onehot.append(actions_onehot)  # list(array)
            avail_u.append(avail_actions)  # list(list)
            r.append([reward])
            terminate.append([terminated])
            padded.append([0.])
            episode_reward += reward
            step += 1
            if self.args.epsilon_anneal_scale == 'step':
                epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon
        # last obs
        o.append(obs)
        s.append(state)
        o_next = o[1:]
        s_next = s[1:]
        o = o[:-1]
        s = s[:-1]
        # get avail_action for last obs,because target_q needs avail_action in training
        avail_actions = []
        for agent_id in range(self.n_agents):
            avail_action = self.env.get_avail_agent_actions(agent_id)
            avail_actions.append(avail_action)
        avail_u.append(avail_actions)
        avail_u_next = avail_u[1:]
        avail_u = avail_u[:-1]

        # if step < self.episode_limit,padding
        for i in range(step, self.episode_limit):
            o.append(np.zeros((self.n_agents, self.obs_shape)))
            u.append(np.zeros([self.n_agents, 1]))
            s.append(np.zeros(self.state_shape))
            r.append([0.])
            o_next.append(np.zeros((self.n_agents, self.obs_shape)))
            s_next.append(np.zeros(self.state_shape))
            u_onehot.append(np.zeros((self.n_agents, self.n_actions)))
            avail_u.append(np.zeros((self.n_agents, self.n_actions)))
            avail_u_next.append(np.zeros((self.n_agents, self.n_actions)))
            padded.append([1.])
            terminate.append([1.])

        episode = dict(o=o.copy(),
                       s=s.copy(),
                       u=u.copy(),
                       r=r.copy(),
                       avail_u=avail_u.copy(),
                       o_next=o_next.copy(),
                       s_next=s_next.copy(),
                       avail_u_next=avail_u_next.copy(),
                       u_onehot=u_onehot.copy(),
                       padded=padded.copy(),
                       terminated=terminate.copy())
        # add episode dim
        for key in episode.keys():
            episode[key] = np.array([episode[key]
                                     ])  # (1, 60, 3, 30) (1, 60, 48) ...
        if not evaluate:
            self.epsilon = epsilon
        if self.args.alg == 'maven':
            episode['z'] = np.array([maven_z.copy()])
        if evaluate and episode_num == self.args.evaluate_epoch - 1 and self.args.replay_dir != '':
            self.env.save_replay()
            self.env.close()
        return episode, episode_reward, win_tag
예제 #5
0
    def generate_episode(self, episode_num=None, evaluate=False):
        o, u, r, s, avail_u, u_onehot, terminate, padded = [], [], [], [], [], [], [], []
        self.env.reset()
        terminated = False
        step = 0
        episode_reward = 0  # 累积奖励
        last_action = np.zeros((self.args.n_agents, self.args.n_actions))
        self.agents.policy.init_hidden(1)  # 初始化hidden_state

        # epsilon
        epsilon = 0 if evaluate else self.epsilon
        if self.args.epsilon_anneal_scale == 'episode':
            epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon
        if self.args.epsilon_anneal_scale == 'epoch':
            if episode_num == 0:
                epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon

        # sample z for maven
        if self.args.alg == 'maven':
            state = self.env.get_state()
            state = torch.tensor(state, dtype=torch.float32)
            if self.args.cuda:
                state = state.cuda()
            z_prob = self.agents.policy.z_policy(state)
            maven_z = one_hot_categorical.OneHotCategorical(z_prob).sample()
            maven_z = list(maven_z.cpu())
        while not terminated:
            # time.sleep(0.2)
            obs = self.env.get_obs()
            state = self.env.get_state()
            actions, avail_actions, actions_onehot = [], [], []
            for agent_id in range(self.n_agents):
                avail_action = self.env.get_avail_agent_actions(agent_id)
                if self.args.alg == 'maven':
                    action = self.agents.choose_action(obs[agent_id], last_action[agent_id], agent_id,
                                                       avail_action, epsilon, maven_z, evaluate)
                else:
                    action = self.agents.choose_action(obs[agent_id], last_action[agent_id], agent_id,
                                                       avail_action, epsilon, evaluate)
                # 生成对应动作的0 1向量
                action_onehot = np.zeros(self.args.n_actions)
                action_onehot[action] = 1
                actions.append(action)
                actions_onehot.append(action_onehot)
                avail_actions.append(avail_action)
                last_action[agent_id] = action_onehot

            reward, terminated, _ = self.env.step(actions)
            if step == self.episode_limit - 1:
                terminated = 1

            o.append(obs)
            s.append(state)
            # 和环境交互的actions需要是一个list,里面就装着代表每个agent动作的整数
            # buffer里存的action,每个agent的动作都需要是一个1维向量
            u.append(np.reshape(actions, [self.n_agents, 1]))
            u_onehot.append(actions_onehot)
            avail_u.append(avail_actions)
            r.append([reward])
            terminate.append([terminated])
            padded.append([0.])
            episode_reward += reward
            step += 1
            # if terminated:
            #     time.sleep(1)
            if self.args.epsilon_anneal_scale == 'step':
                epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon
        # 处理最后一个obs
        o.append(obs)
        s.append(state)
        o_next = o[1:]
        s_next = s[1:]
        o = o[:-1]
        s = s[:-1]
        # 最后一个obs需要单独计算一下avail_action,到时候需要计算target_q
        avail_actions = []
        for agent_id in range(self.n_agents):
            avail_action = self.env.get_avail_agent_actions(agent_id)
            avail_actions.append(avail_action)
        avail_u.append(avail_actions)
        avail_u_next = avail_u[1:]
        avail_u = avail_u[:-1]

        # 返回的episode必须长度都是self.episode_limit,所以不够的话进行填充
        for i in range(step, self.episode_limit):  # 没有的字段用0填充,并且padded为1
            o.append(np.zeros((self.n_agents, self.obs_shape)))
            u.append(np.zeros([self.n_agents, 1]))
            s.append(np.zeros(self.state_shape))
            r.append([0.])
            o_next.append(np.zeros((self.n_agents, self.obs_shape)))
            s_next.append(np.zeros(self.state_shape))
            u_onehot.append(np.zeros((self.n_agents, self.n_actions)))
            avail_u.append(np.zeros((self.n_agents, self.n_actions)))
            avail_u_next.append(np.zeros((self.n_agents, self.n_actions)))
            padded.append([1.])
            terminate.append([1.])

        '''
        (o[n], u[n], r[n], o_next[n], avail_u[n], u_onehot[n])组成第n条经验,各项维度都为(episode数,transition数,n_agents, 自己的具体维度)
         因为avail_u表示当前经验的obs可执行的动作,但是计算target_q的时候,需要obs_net及其可执行动作,
        '''
        episode = dict(o=o.copy(),
                       s=s.copy(),
                       u=u.copy(),
                       r=r.copy(),
                       avail_u=avail_u.copy(),
                       o_next=o_next.copy(),
                       s_next=s_next.copy(),
                       avail_u_next=avail_u_next.copy(),
                       u_onehot=u_onehot.copy(),
                       padded=padded.copy(),
                       terminated=terminate.copy()
                       )
        # 因为buffer里存的是四维的,这里得到的episode只有三维,即transition、agent、shape三个维度,
        # 还差一个episode维度,所以给它加一维
        for key in episode.keys():
            episode[key] = np.array([episode[key]])
        if not evaluate:
            self.epsilon = epsilon
        if self.args.alg == 'maven':
            episode['z'] = np.array([maven_z.copy()])
        return episode, episode_reward
예제 #6
0
    def generate_episode(self, episode_num=None, evaluate=False):
        o, u, r, s, avail_u, u_onehot, terminate, padded = [], [], [], [], [], [], [], []
        self.env.reset()
        terminated = False
        step = 0
        episode_reward = 0  # cumulative rewards
        last_action = np.zeros((self.args.n_agents, self.args.n_actions))
        self.agents.policy.init_hidden(1)

        # epsilon
        epsilon = 0 if evaluate else self.epsilon
        if self.args.epsilon_anneal_scale == 'episode':
            epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon
        if self.args.epsilon_anneal_scale == 'epoch':
            if episode_num == 0:
                epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon

        # sample z for maven
        if self.args.alg == 'maven':
            state = self.env.get_state()
            state = torch.tensor(state, dtype=torch.float32)
            if self.args.cuda:
                state = state.cuda()
            z_prob = self.agents.policy.z_policy(state)
            maven_z = one_hot_categorical.OneHotCategorical(z_prob).sample()
            maven_z = list(maven_z.cpu())
        while not terminated:
            # time.sleep(0.2)

            obs = self.env.get_obs()
            state = self.env.get_state()
            actions, avail_actions, actions_onehot = [], [], []
            for agent_id in range(self.n_agents):
                avail_action = self.env.get_avail_agent_actions(agent_id)
                if self.args.alg == 'maven':
                    action = self.agents.choose_action(obs[agent_id],
                                                       last_action[agent_id],
                                                       agent_id, avail_action,
                                                       epsilon, maven_z,
                                                       evaluate)
                else:
                    action = self.agents.choose_action(obs[agent_id],
                                                       last_action[agent_id],
                                                       agent_id, avail_action,
                                                       epsilon, evaluate)
                # generate onehot vector of th action
                action_onehot = np.zeros(self.args.n_actions)
                action_onehot[action] = 1
                actions.append(action)
                actions_onehot.append(action_onehot)
                avail_actions.append(avail_action)
                last_action[agent_id] = action_onehot

            unuse_o, reward, terminated, _ = self.env.step(actions)

            if step == self.episode_limit - 1:
                terminated = 1

            o.append(obs)
            s.append(state)
            u.append(np.reshape(actions, [self.n_agents, 1]))
            u_onehot.append(actions_onehot)
            avail_u.append(avail_actions)
            r.append([reward])
            terminate.append([terminated])
            padded.append([0.])
            episode_reward += reward
            step += 1  ## 不保存s^'???
            # if terminated:
            #     time.sleep(1)
            if self.args.epsilon_anneal_scale == 'step':
                epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon
        # last obs
        o.append(obs)
        s.append(state)
        o_next = o[1:]
        s_next = s[1:]
        o = o[:-1]
        s = s[:-1]
        # get avail_action for last obs,because target_q needs avail_action in training
        avail_actions = []
        for agent_id in range(self.n_agents):
            avail_action = self.env.get_avail_agent_actions(agent_id)
            avail_actions.append(avail_action)
        avail_u.append(avail_actions)
        avail_u_next = avail_u[1:]
        avail_u = avail_u[:-1]

        # if step < self.episode_limit,padding
        for i in range(step, self.episode_limit):
            o.append(np.zeros((self.n_agents, self.obs_shape)))
            u.append(np.zeros([self.n_agents, 1]))
            s.append(np.zeros(self.state_shape))
            r.append([0.])
            o_next.append(np.zeros((self.n_agents, self.obs_shape)))
            s_next.append(np.zeros(self.state_shape))
            u_onehot.append(np.zeros((self.n_agents, self.n_actions)))
            avail_u.append(np.zeros((self.n_agents, self.n_actions)))
            avail_u_next.append(np.zeros((self.n_agents, self.n_actions)))
            padded.append([1.])
            terminate.append([1.])

        episode = dict(o=o.copy(),
                       s=s.copy(),
                       u=u.copy(),
                       r=r.copy(),
                       avail_u=avail_u.copy(),
                       o_next=o_next.copy(),
                       s_next=s_next.copy(),
                       avail_u_next=avail_u_next.copy(),
                       u_onehot=u_onehot.copy(),
                       padded=padded.copy(),
                       terminated=terminate.copy())
        # add episode dim
        for key in episode.keys():
            episode[key] = np.array([episode[key]])
        if not evaluate:
            self.epsilon = epsilon
        if self.args.alg == 'maven':
            episode['z'] = np.array([maven_z.copy()])
        return episode, episode_reward
예제 #7
0
    def sample(self, wav_onehot, lc_sparse, speaker_inds, jitter_index, n_rep):
        """
        Generate n_rep samples, using lc_sparse and speaker_inds for local and global
        conditioning.  

        wav_onehot: full length wav vector
        lc_sparse: full length local conditioning vector derived from full
        wav_onehot
        """
        # initialize model geometry
        mfcc_vc = self.vc['beg'].parent
        up_vc = self.vc['pre_upsample'].child
        beg_grcc_vc = self.vc['beg_grcc']
        end_vc = self.vc['end_grcc']

        # calculate full output range
        wav_gr = vconv.GridRange((0, 1e12), (0, wav_onehot.size()[2]), 1)
        full_out_gr = vconv.output_range(mfcc_vc, end_vc, wav_gr)
        n_ts = full_out_gr.sub_length()

        # calculate starting input range for single timestep
        one_gr = vconv.GridRange((0, 1e12), (0, 1), 1)
        vconv.compute_inputs(end_vc, one_gr)

        # calculate starting position of wav
        wav_beg = int(beg_grcc_vc.input_gr.sub[0] - mfcc_vc.input_gr.sub[0])
        # wav_end = int(beg_grcc_vc.input_gr.sub[1] - mfcc_vc.input_gr.sub[0])
        wav_onehot = wav_onehot[:,:,wav_beg:]

        # !!! hack - I'm not sure why the int() cast is necessary
        n_init_ts = int(beg_grcc_vc.in_len())

        lc_sparse = lc_sparse.repeat(n_rep, 1, 1)
        jitter_index = jitter_index.repeat(n_rep, 1)
        speaker_inds = speaker_inds.repeat(n_rep)

        # precalculate conditioning vector for all timesteps
        D1 = lc_sparse.size()[1]
        lc_jitter = torch.take(lc_sparse,
                jitter_index.unsqueeze(1).expand(-1, D1, -1))
        lc_conv = self.lc_conv(lc_jitter) 
        lc_dense = self.lc_upsample(lc_conv)
        cond = self.cond(lc_dense, speaker_inds)
        n_ts = cond.size()[2]

        
        # cond_loff, cond_roff = vconv.output_offsets(mfcc_vc, up_end_vc)

        # zero out  
        start_pos = 26000
        n_samples = 20000
        end_pos = start_pos + n_samples

        # wav_onehot[...,n_init_ts:] = 0
        wav_onehot = wav_onehot.repeat(n_rep, 1, 1)
        # wav_onehot[...,start_pos:end_pos] = 0

        # assert cond.size()[2] == wav_onehot.size()[2]

        # loop through timesteps
        # inrange = torch.tensor((0, n_init_ts), dtype=torch.int32)
        inrange = torch.tensor((start_pos - n_init_ts, start_pos), dtype=torch.int32)
        # end_ind = torch.tensor([n_ts], dtype=torch.int32)
        end_ind = torch.tensor([end_pos], dtype=torch.int32)

        # inefficient - this recalculates intermediate activations for the
        # entire receptive fields, rather than just the advancing front
        while not torch.equal(inrange[1], end_ind[0]):
        # while inrange[1] != end_ind[0]:
            sig = self.base_layer(wav_onehot[:,:,inrange[0]:inrange[1]]) 
            sig, skp_sum = self.conv_layers[0](sig, cond[:,:,inrange[0]:inrange[1]])
            for layer in self.conv_layers[1:]:
                sig, skp = layer(sig, cond[:,:,inrange[0]:inrange[1]])
                skp_sum += skp

            post1 = self.post1(self.relu(skp_sum))
            quant = self.post2(self.relu(post1))
            cat = dcat.OneHotCategorical(logits=quant.squeeze(2))
            wav_onehot[1:,:,inrange[1]] = cat.sample()[1:,...]
            inrange += 1
            if inrange[0] % 100 == 0:
                print(inrange, end_ind[0])

        
        # convert to value format
        quant_range = wav_onehot.new(list(range(self.n_quant)))
        wav = torch.matmul(wav_onehot.permute(0,2,1), quant_range)
        torch.set_printoptions(threshold=100000)
        pad = 5
        print('padding = {}'.format(pad))
        print('original')
        print(wav[0,start_pos-pad:end_pos+pad])
        print('synth')
        print(wav[1,start_pos-pad:end_pos+pad])

        # print(wav[:,end_pos:end_pos + 10000])
        print('synth range std: {}, baseline std: {}'.format(
            wav[:,start_pos:end_pos].std(), wav[:,end_pos:].std()
            ))

        return wav