def sample_gmm(batch_size, mixture_weights): cat = Categorical(probs=mixture_weights) cluster = cat.sample([batch_size]) # [B] mean = (cluster*10.).float().cuda() std = torch.ones([batch_size]).cuda() *5. norm = Normal(mean, std) samp = norm.sample() samp = samp.view(batch_size, 1) return samp
def sample_true2(): cat = Categorical(probs= torch.tensor(true_mixture_weights)) cluster = cat.sample() # print (cluster) # fsd norm = Normal(torch.tensor([cluster*10.]).float(), torch.tensor([5.0]).float()) samp = norm.sample() # print (samp) return samp,cluster
class OneHotCategorical(Distribution): r""" Creates a one-hot categorical distribution parameterized by `probs`. Samples are one-hot coded vectors of size probs.size(-1). See also: :func:`torch.distributions.Categorical` Example:: >>> m = OneHotCategorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 0 0 1 0 [torch.FloatTensor of size 4] Args: probs (Tensor or Variable): event probabilities """ params = {'probs': constraints.simplex} support = constraints.simplex has_enumerate_support = True def __init__(self, probs=None, logits=None): self._categorical = Categorical(probs, logits) batch_shape = self._categorical.probs.size()[:-1] event_shape = self._categorical.probs.size()[-1:] super(OneHotCategorical, self).__init__(batch_shape, event_shape) def sample(self, sample_shape=torch.Size()): sample_shape = torch.Size(sample_shape) probs = self._categorical.probs one_hot = probs.new(self._extended_shape(sample_shape)).zero_() indices = self._categorical.sample(sample_shape) if indices.dim() < one_hot.dim(): indices = indices.unsqueeze(-1) return one_hot.scatter_(-1, indices, 1) def log_prob(self, value): indices = value.max(-1)[1] return self._categorical.log_prob(indices) def entropy(self): return self._categorical.entropy() def enumerate_support(self): probs = self._categorical.probs n = self.event_shape[0] if isinstance(probs, Variable): values = Variable(torch.eye(n, out=probs.data.new(n, n))) else: values = torch.eye(n, out=probs.new(n, n)) values = values.view((n,) + (1,) * len(self.batch_shape) + (n,)) return values.expand((n,) + self.batch_shape + (n,))
def test_gmm_loss(self): """ Test case 1 """ n_samples = 10000 means = torch.Tensor([[0., 0.], [1., 1.], [-1., 1.]]) stds = torch.Tensor([[.03, .05], [.02, .1], [.1, .03]]) pi = torch.Tensor([.2, .3, .5]) cat_dist = Categorical(pi) indices = cat_dist.sample((n_samples,)).long() rands = torch.randn(n_samples, 2) samples = means[indices] + rands * stds[indices] class _model(nn.Module): def __init__(self, gaussians): super().__init__() self.means = nn.Parameter(torch.Tensor(1, gaussians, 2).normal_()) self.pre_stds = nn.Parameter(torch.Tensor(1, gaussians, 2).normal_()) self.pi = nn.Parameter(torch.Tensor(1, gaussians).normal_()) def forward(self, *inputs): return self.means, torch.exp(self.pre_stds), f.softmax(self.pi, dim=1) model = _model(3) optimizer = torch.optim.Adam(model.parameters()) iterations = 100000 log_step = iterations // 10 pbar = tqdm(total=iterations) cum_loss = 0 for i in range(iterations): batch = samples[torch.LongTensor(128).random_(0, n_samples)] m, s, p = model.forward() loss = gmm_loss(batch, m, s, p) cum_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_postfix_str("avg_loss={:10.6f}".format( cum_loss / (i + 1))) pbar.update(1) if i % log_step == log_step - 1: print(m) print(s) print(p)
def sample_true(batch_size): # print (true_mixture_weights.shape) cat = Categorical(probs=torch.tensor(true_mixture_weights)) cluster = cat.sample([batch_size]) # [B] mean = (cluster*10.).float() std = torch.ones([batch_size]) *5. # print (cluster.shape) # fsd # norm = Normal(torch.tensor([cluster*10.]).float(), torch.tensor([5.0]).float()) norm = Normal(mean, std) samp = norm.sample() # print (samp.shape) # fadsf samp = samp.view(batch_size, 1) return samp
def reinforce_baseline(surrogate, x, logits, mixtureweights, k=1, get_grad=False): B = logits.shape[0] probs = torch.softmax(logits, dim=1) outputs = {} cat = Categorical(probs=probs) grads =[] # net_loss = 0 for jj in range(k): cluster_H = cat.sample() outputs['logq'] = logq = cat.log_prob(cluster_H).view(B,1) outputs['logpx_given_z'] = logpx_given_z = logprob_undercomponent(x, component=cluster_H) outputs['logpz'] = logpz = torch.log(mixtureweights[cluster_H]).view(B,1) logpxz = logpx_given_z + logpz #[B,1] surr_pred = surrogate.net(x) outputs['f'] = f = logpxz - logq - 1. # outputs['net_loss'] = net_loss = net_loss - torch.mean((f.detach() ) * logq) outputs['net_loss'] = net_loss = - torch.mean((f.detach() - surr_pred.detach()) * logq) # net_loss += - torch.mean( -logq.detach()*logq) # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred)) grad_logq = torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0] surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq )**2) if get_grad: grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0] grads.append(grad) # net_loss = net_loss/ k if get_grad: grads = torch.stack(grads) # print (grads.shape) outputs['grad_avg'] = torch.mean(torch.mean(grads, dim=0),dim=0) outputs['grad_std'] = torch.std(grads, dim=0)[0] outputs['surr_loss'] = surr_loss # return net_loss, f, logpx_given_z, logpz, logq return outputs
def reinforce(x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = Categorical(probs=probs) net_loss = 0 for jj in range(k): cluster_H = cat.sample() logq = cat.log_prob(cluster_H).view(B,1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B,1) logpxz = logpx_given_z + logpz #[B,1] f = logpxz - logq net_loss += - torch.mean((f.detach() - 1.) * logq) # net_loss += - torch.mean( -logq.detach()*logq) net_loss = net_loss/ k return net_loss, f, logpx_given_z, logpz, logq
#REINFORCE print ('REINFORCE') # def sample_reinforce_given_class(logits, samp): # return logprob grads = [] for i in range (N): dist = Categorical(logits=logits) samp = dist.sample() logprob = dist.log_prob(samp) reward = f(samp) gradlogprob = torch.autograd.grad(outputs=logprob, inputs=(logits), retain_graph=True)[0] grads.append(reward*gradlogprob) print () grads = torch.stack(grads).view(N,C) # print (grads.shape) grad_mean_reinforce = torch.mean(grads,dim=0) grad_std_reinforce = torch.std(grads,dim=0) print ('REINFORCE') print ('mean:', grad_mean_reinforce) print ('std:', grad_std_reinforce) print ()
def G(rewards, start=0, end=None): return sum(rewards[start:end]) if __name__ == "__main__": for episode in range(NUM_EPISODES): s, done = env.reset(), False states, rewards, log_probs = [], [], [] while not done: s = torch.from_numpy(s).float() p = Categorical(actor(s)) a = p.sample() with torch.no_grad(): succ, r, done, _ = env.step(a.numpy()) states.append(s) rewards.append(r) log_probs.append(p.log_prob(a)) s = succ discounted_rewards = [DISCOUNT**t * r for t, r in enumerate(rewards)] cumulative_returns = [ G(discounted_rewards, t) for t, _ in enumerate(discounted_rewards) ] states = torch.stack(states)
def simplax(): def show_surr_preds(): batch_size = 1 rows = 3 cols = 1 fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150) for i in range(rows): x = sample_true(1).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) cluster_S = cat.rsample() cluster_H = H(cluster_S) logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1) check_nan(logprob_cluster) z = cluster_S n_evals = 40 x1 = np.linspace(-9,205, n_evals) x = torch.from_numpy(x1).view(n_evals,1).float().cuda() z = z.repeat(n_evals,1) cluster_H = cluster_H.repeat(n_evals,1) xz = torch.cat([z,x], dim=1) logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz #- logprob_cluster surr_pred = surrogate.net(xz) surr_pred = surr_pred.data.cpu().numpy() f = f.data.cpu().numpy() col =0 row = i # print (row) ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1) ax.plot(x1,surr_pred, label='Surr') ax.plot(x1,f, label='f') ax.set_title(str(cluster_H[0])) ax.legend() # save_dir = home+'/Documents/Grad_Estimators/GMM/' plt_path = exp_dir+'gmm_surr.png' plt.savefig(plt_path) print ('saved training plot', plt_path) plt.close() def plot_dist(): mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0) rows = 1 cols = 1 fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150) col =0 row = 0 ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1) xs = np.linspace(-9,205, 300) sum_ = np.zeros(len(xs)) # C = 20 for c in range(n_components): m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float()) ys = [] for x in xs: # component_i = (torch.exp(m.log_prob(x) )* ((c+5.) / 290.)).numpy() component_i = (torch.exp(m.log_prob(x) )* mixture_weights[c]).detach().cpu().numpy() ys.append(component_i) ys = np.reshape(np.array(ys), [-1]) sum_ += ys ax.plot(xs, ys, label='') ax.plot(xs, sum_, label='') # save_dir = home+'/Documents/Grad_Estimators/GMM/' plt_path = exp_dir+'gmm_plot_dist.png' plt.savefig(plt_path) print ('saved training plot', plt_path) plt.close() def get_loss(): x = sample_true(batch_size).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda()) cluster_S = cat.rsample() cluster_H = H(cluster_S) # cluster_onehot = torch.zeros(n_components) # cluster_onehot[cluster_H] = 1. logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1) check_nan(logprob_cluster) logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz - logprob_cluster surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] surr_pred = surrogate.net(surr_input) # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred) # loss = - torch.mean(f) surr_loss = torch.mean(torch.abs(logpxz.detach()-surr_pred)) return surr_loss def plot_posteriors(needsoftmax_mixtureweight, name=''): x = sample_true(1).cuda() trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components) logits = encoder.net(x) probs = torch.softmax(logits, dim=1).view(n_components) trueposterior = trueposterior.data.cpu().numpy() qz = probs.data.cpu().numpy() error = L2_mixtureweights(trueposterior,qz) kl = KL_mixutreweights(p=trueposterior, q=qz) rows = 1 cols = 1 fig = plt.figure(figsize=(8+cols,8+rows), facecolor='white') #, dpi=150) col =0 row = 0 ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1) width = .3 ax.bar(range(len(qz)), trueposterior, width=width, label='True') ax.bar(np.array(range(len(qz)))+width, qz, width=width, label='q') # ax.bar(np.array(range(len(q_b)))+width+width, q_b, width=width) ax.legend() ax.grid(True, alpha=.3) ax.set_title(str(error) + ' kl:' + str(kl)) ax.set_ylim(0.,1.) # save_dir = home+'/Documents/Grad_Estimators/GMM/' plt_path = exp_dir+'posteriors'+name+'.png' plt.savefig(plt_path) print ('saved training plot', plt_path) plt.close() def inference_error(needsoftmax_mixtureweight): error_sum = 0 kl_sum = 0 n=10 for i in range(n): # if x is None: x = sample_true(1).cuda() trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components) logits = encoder.net(x) probs = torch.softmax(logits, dim=1).view(n_components) error = L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy()) kl = KL_mixutreweights(trueposterior.data.cpu().numpy(), probs.data.cpu().numpy()) error_sum+=error kl_sum += kl return error_sum/n, kl_sum/n # fsdfa #SIMPLAX needsoftmax_mixtureweight = torch.randn(n_components, requires_grad=True, device="cuda")#.cuda() print ('current mixuture weights') print (torch.softmax(needsoftmax_mixtureweight, dim=0)) print() encoder = NN3(input_size=1, output_size=n_components).cuda() surrogate = NN3(input_size=1+n_components, output_size=1).cuda() # optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=.00004) # optim_net = torch.optim.Adam(encoder.parameters(), lr=.0004) # optim_surr = torch.optim.Adam(surrogate.parameters(), lr=.004) # optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=.0001) # optim_net = torch.optim.Adam(encoder.parameters(), lr=.0001) optim_net = torch.optim.SGD(encoder.parameters(), lr=.0001) # optim_surr = torch.optim.Adam(surrogate.parameters(), lr=.005) temp = 1. batch_size = 100 n_steps = 300000 surrugate_steps = 0 k = 1 L2_losses = [] inf_losses = [] inf_losses_kl = [] kl_losses_2 = [] surr_losses = [] steps_list =[] grad_reparam_list =[] grad_reinforce_list =[] f_list = [] logpxz_list = [] logprob_cluster_list = [] logpx_list = [] # logprob_cluster_list = [] for step in range(n_steps): for ii in range(surrugate_steps): surr_loss = get_loss() optim_surr.zero_grad() surr_loss.backward() optim_surr.step() x = sample_true(batch_size).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits, dim=1) # print (probs) # fsdafsa # cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda()) cat = Categorical(probs=probs) # cluster = cat.sample() # logprob_cluster = cat.log_prob(cluster.detach()) net_loss = 0 loss = 0 surr_loss = 0 for jj in range(k): # cluster_S = cat.rsample() # cluster_H = H(cluster_S) cluster_H = cat.sample() # print (cluster_H.shape) # print (cluster_H[0]) # fsad # cluster_onehot = torch.zeros(n_components) # cluster_onehot[cluster_H] = 1. # logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1) logprob_cluster = cat.log_prob(cluster_H.detach()).view(batch_size,1) # logprob_cluster = cat.log_prob(cluster_S).view(batch_size,1).detach() # cat = RelaxedOneHotCategorical(probs=probs.detach(), temperature=torch.tensor([temp]).cuda()) # logprob_cluster = cat.log_prob(cluster_S).view(batch_size,1) check_nan(logprob_cluster) logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz - logprob_cluster # print (f) # surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] surr_input = torch.cat([probs, x], dim=1) #[B,21] surr_pred = surrogate.net(surr_input) # print (f.shape) # print (surr_pred.shape) # print (logprob_cluster.shape) # fsadfsa # net_loss += - torch.mean((logpxz.detach()-surr_pred.detach()) * logprob_cluster + surr_pred - logprob_cluster) # net_loss += - torch.mean((logpxz.detach() - surr_pred.detach() - 1.) * logprob_cluster + surr_pred) net_loss += - torch.mean((logpxz.detach() - 1.) * logprob_cluster) # net_loss += - torch.mean((logpxz.detach()) * logprob_cluster - logprob_cluster) loss += - torch.mean(logpxz) surr_loss += torch.mean(torch.abs(logpxz.detach()-surr_pred)) net_loss = net_loss/ k loss = loss / k surr_loss = surr_loss/ k # if step %2==0: # optim.zero_grad() # loss.backward(retain_graph=True) # optim.step() optim_net.zero_grad() net_loss.backward(retain_graph=True) optim_net.step() # optim_surr.zero_grad() # surr_loss.backward(retain_graph=True) # optim_surr.step() # print (torch.mean(f).cpu().data.numpy()) # plot_posteriors(name=str(step)) # fsdf # kl_batch = compute_kl_batch(x,probs) if step%500==0: print (step, 'f:', torch.mean(f).cpu().data.numpy(), 'surr_loss:', surr_loss.cpu().data.detach().numpy(), 'theta dif:', L2_mixtureweights(true_mixture_weights,torch.softmax( needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy())) # if step %5000==0: # print (torch.softmax(needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()) # # test_samp, test_cluster = sample_true2() # # print (test_cluster.cpu().data.numpy(), test_samp.cpu().data.numpy(), torch.softmax(encoder.net(test_samp.cuda().view(1,1)), dim=1)) # print () if step > 0: L2_losses.append(L2_mixtureweights(true_mixture_weights,torch.softmax( needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy())) steps_list.append(step) surr_losses.append(surr_loss.cpu().data.detach().numpy()) inf_error, kl_error = inference_error(needsoftmax_mixtureweight) inf_losses.append(inf_error) inf_losses_kl.append(kl_error) kl_batch = compute_kl_batch(x,probs,needsoftmax_mixtureweight) kl_losses_2.append(kl_batch) logpx = copmute_logpx(x, needsoftmax_mixtureweight) logpx_list.append(logpx) f_list.append(torch.mean(f).cpu().data.detach().numpy()) logpxz_list.append(torch.mean(logpxz).cpu().data.detach().numpy()) logprob_cluster_list.append(torch.mean(logprob_cluster).cpu().data.detach().numpy()) # i_feel_like_it = 1 # if i_feel_like_it: if len(inf_losses) > 0: print ('probs', probs[0]) print('logpxz', logpxz[0]) print('pred', surr_pred[0]) print ('dif', logpxz.detach()[0]-surr_pred.detach()[0]) print ('logq', logprob_cluster[0]) print ('dif*logq', (logpxz.detach()[0]-surr_pred.detach()[0])*logprob_cluster[0]) output= torch.mean((logpxz.detach()-surr_pred.detach()) * logprob_cluster, dim=0)[0] output2 = torch.mean(surr_pred, dim=0)[0] output3 = torch.mean(logprob_cluster, dim=0)[0] # input_ = torch.mean(probs, dim=0) #[0] # print (probs.shape) # print (output.shape) # print (input_.shape) grad_reinforce = torch.autograd.grad(outputs=output, inputs=(probs), retain_graph=True)[0] grad_reparam = torch.autograd.grad(outputs=output2, inputs=(probs), retain_graph=True)[0] grad3 = torch.autograd.grad(outputs=output3, inputs=(probs), retain_graph=True)[0] # print (grad) # print (grad_reinforce.shape) # print (grad_reparam.shape) grad_reinforce = torch.mean(torch.abs(grad_reinforce)) grad_reparam = torch.mean(torch.abs(grad_reparam)) grad3 = torch.mean(torch.abs(grad3)) # print (grad_reinforce) # print (grad_reparam) # dfsfda grad_reparam_list.append(grad_reparam.cpu().data.detach().numpy()) grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy()) # grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy()) print ('reparam:', grad_reparam.cpu().data.detach().numpy()) print ('reinforce:', grad_reinforce.cpu().data.detach().numpy()) print ('logqz grad:', grad3.cpu().data.detach().numpy()) print ('current mixuture weights') print (torch.softmax(needsoftmax_mixtureweight, dim=0)) print() # print () else: grad_reparam_list.append(0.) grad_reinforce_list.append(0.) if len(surr_losses) > 3 and step %1000==0: plot_curve(steps=steps_list, thetaloss=L2_losses, infloss=inf_losses, surrloss=surr_losses, grad_reinforce_list=grad_reinforce_list, grad_reparam_list=grad_reparam_list, f_list=f_list, logpxz_list=logpxz_list, logprob_cluster_list=logprob_cluster_list, inf_losses_kl=inf_losses_kl, kl_losses_2=kl_losses_2, logpx_list=logpx_list) plot_posteriors(needsoftmax_mixtureweight) plot_dist() show_surr_preds() # print (f) # print (surr_pred) #Understand surr preds # if step %5000==0: # if step ==0: # fasdf data_dict = {} data_dict['steps'] = steps_list data_dict['losses'] = L2_losses # save_dir = home+'/Documents/Grad_Estimators/GMM/' with open( exp_dir+"data_simplax.p", "wb" ) as f: pickle.dump(data_dict, f) print ('saved data')
def valor(args): if not hasattr(args, "get"): args.get = args.__dict__.get env_fn = args.get('env_fn', lambda: gym.make('HalfCheetah-v2')) actor_critic = args.get('actor_critic', ActorCritic) ac_kwargs = args.get('ac_kwargs', {}) disc = args.get('disc', Discriminator) dc_kwargs = args.get('dc_kwargs', {}) seed = args.get('seed', 0) episodes_per_epoch = args.get('episodes_per_epoch', 40) epochs = args.get('epochs', 50) gamma = args.get('gamma', 0.99) pi_lr = args.get('pi_lr', 3e-4) vf_lr = args.get('vf_lr', 1e-3) dc_lr = args.get('dc_lr', 2e-3) train_v_iters = args.get('train_v_iters', 80) train_dc_iters = args.get('train_dc_iters', 50) train_dc_interv = args.get('train_dc_interv', 2) lam = args.get('lam', 0.97) max_ep_len = args.get('max_ep_len', 1000) logger_kwargs = args.get('logger_kwargs', {}) context_dim = args.get('context_dim', 4) max_context_dim = args.get('max_context_dim', 64) save_freq = args.get('save_freq', 10) k = args.get('k', 1) logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) # seed += 10000 * proc_id() torch.manual_seed(seed) np.random.seed(seed) env = env_fn() obs_dim = env.observation_space.shape act_dim = env.action_space.shape ac_kwargs['action_space'] = env.action_space # Model actor_critic = actor_critic(input_dim=obs_dim[0] + max_context_dim, **ac_kwargs) disc = disc(input_dim=obs_dim[0], context_dim=max_context_dim, **dc_kwargs) # Buffer local_episodes_per_epoch = episodes_per_epoch # int(episodes_per_epoch / num_procs()) buffer = Buffer(max_context_dim, obs_dim[0], act_dim[0], local_episodes_per_epoch, max_ep_len, train_dc_interv) # Count variables var_counts = tuple( count_vars(module) for module in [actor_critic.policy, actor_critic.value_f, disc.policy]) logger.log('\nNumber of parameters: \t pi: %d, \t v: %d, \t d: %d\n' % var_counts) # Optimizers #Optimizer for RL Policy train_pi = torch.optim.Adam(actor_critic.policy.parameters(), lr=pi_lr) #Optimizer for value function (for actor-critic) train_v = torch.optim.Adam(actor_critic.value_f.parameters(), lr=vf_lr) #Optimizer for decoder train_dc = torch.optim.Adam(disc.policy.parameters(), lr=dc_lr) #pdb.set_trace() # Parameters Sync #sync_all_params(actor_critic.parameters()) #sync_all_params(disc.parameters()) ''' Training function ''' def update(e): obs, act, adv, pos, ret, logp_old = [ torch.Tensor(x) for x in buffer.retrieve_all() ] # Policy #pdb.set_trace() _, logp, _ = actor_critic.policy(obs, act, batch=False) #pdb.set_trace() entropy = (-logp).mean() # Policy loss pi_loss = -(logp * (k * adv + pos)).mean() # Train policy (Go through policy update) train_pi.zero_grad() pi_loss.backward() # average_gradients(train_pi.param_groups) train_pi.step() # Value function v = actor_critic.value_f(obs) v_l_old = F.mse_loss(v, ret) for _ in range(train_v_iters): v = actor_critic.value_f(obs) v_loss = F.mse_loss(v, ret) # Value function train train_v.zero_grad() v_loss.backward() # average_gradients(train_v.param_groups) train_v.step() # Discriminator if (e + 1) % train_dc_interv == 0: print('Discriminator Update!') con, s_diff = [torch.Tensor(x) for x in buffer.retrieve_dc_buff()] _, logp_dc, _ = disc(s_diff, con) d_l_old = -logp_dc.mean() # Discriminator train for _ in range(train_dc_iters): _, logp_dc, _ = disc(s_diff, con) d_loss = -logp_dc.mean() train_dc.zero_grad() d_loss.backward() # average_gradients(train_dc.param_groups) train_dc.step() _, logp_dc, _ = disc(s_diff, con) dc_l_new = -logp_dc.mean() else: d_l_old = 0 dc_l_new = 0 # Log the changes _, logp, _, v = actor_critic(obs, act) pi_l_new = -(logp * (k * adv + pos)).mean() v_l_new = F.mse_loss(v, ret) kl = (logp_old - logp).mean() logger.store(LossPi=pi_loss, LossV=v_l_old, KL=kl, Entropy=entropy, DeltaLossPi=(pi_l_new - pi_loss), DeltaLossV=(v_l_new - v_l_old), LossDC=d_l_old, DeltaLossDC=(dc_l_new - d_l_old)) # logger.store(Adv=adv.reshape(-1).numpy().tolist(), Pos=pos.reshape(-1).numpy().tolist()) start_time = time.time() #Resets observations, rewards, done boolean o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0 #Creates context distribution where each logit is equal to one (This is first place to make change) context_dim_prob_dict = { i: 1 / context_dim if i < context_dim else 0 for i in range(max_context_dim) } last_phi_dict = {i: 0 for i in range(context_dim)} context_dist = Categorical( probs=torch.Tensor(list(context_dim_prob_dict.values()))) total_t = 0 for epoch in range(epochs): #Sets actor critic and decoder (discriminator) into eval mode actor_critic.eval() disc.eval() #Runs the policy local_episodes_per_epoch before updating the policy for index in range(local_episodes_per_epoch): # Sample from context distribution and one-hot encode it (Step 2) # Every time we run the policy we sample a new context c = context_dist.sample() c_onehot = F.one_hot(c, max_context_dim).squeeze().float() for _ in range(max_ep_len): concat_obs = torch.cat( [torch.Tensor(o.reshape(1, -1)), c_onehot.reshape(1, -1)], 1) ''' Feeds in observation and context into actor_critic which spits out a distribution Label is a sample from the observation pi is the action sampled logp is the log probability of some other action a logp_pi is the log probability of pi v_t is the value function ''' a, _, logp_t, v_t = actor_critic(concat_obs) #Stores context and all other info about the state in the buffer buffer.store(c, concat_obs.squeeze().detach().numpy(), a.detach().numpy(), r, v_t.item(), logp_t.detach().numpy()) logger.store(VVals=v_t) o, r, d, _ = env.step(a.detach().numpy()[0]) ep_ret += r ep_len += 1 total_t += 1 terminal = d or (ep_len == max_ep_len) if terminal: # Key stuff with discriminator dc_diff = torch.Tensor(buffer.calc_diff()).unsqueeze(0) #Context con = torch.Tensor([float(c)]).unsqueeze(0) #Feed in differences between each state in your trajectory and a specific context #Here, this is just the log probability of the label it thinks it is _, _, log_p = disc(dc_diff, con) buffer.end_episode(log_p.detach().numpy()) logger.store(EpRet=ep_ret, EpLen=ep_len) o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0 if (epoch % save_freq == 0) or (epoch == epochs - 1): logger.save_state({'env': env}, [actor_critic, disc], None) # Sets actor_critic and discriminator into training mode actor_critic.train() disc.train() update(epoch) #Need to implement curriculum learning here to update context distribution ''' #Psuedocode: Loop through each of d episodes taken in local_episodes_per_epoch and check log probability from discrimantor If >= 0.86, increase k in the following manner: k = min(int(1.5*k + 1), Kmax) Kmax = 64 ''' decoder_accs = [] stag_num = 10 stag_pct = 0.05 if (epoch + 1) % train_dc_interv == 0 and epoch > 0: #pdb.set_trace() con, s_diff = [torch.Tensor(x) for x in buffer.retrieve_dc_buff()] print("Context: ", con) print("num_contexts", len(con)) _, logp_dc, _ = disc(s_diff, con) log_p_context_sample = logp_dc.mean().detach().numpy() print("Log Probability context sample", log_p_context_sample) decoder_accuracy = np.exp(log_p_context_sample) print("Decoder Accuracy", decoder_accuracy) logger.store(LogProbabilityContext=log_p_context_sample, DecoderAccuracy=decoder_accuracy) ''' Create score (phi(i)) = -log_p_context_sample.mean() for each specific context Assign phis to each specific context Get p(i) in the following manner: (phi(i) + epsilon) Get Probabilities by doing p(i)/sum of all p(i)'s ''' logp_np = logp_dc.detach().numpy() con_np = con.detach().numpy() phi_dict = {i: 0 for i in range(context_dim)} count_dict = {i: 0 for i in range(context_dim)} for i in range(len(logp_np)): current_con = con_np[i] phi_dict[current_con] += logp_np[i] count_dict[current_con] += 1 print(phi_dict) phi_dict = { k: last_phi_dict[k] if count_dict[k] == 0 else (-1) * v / count_dict[k] for (k, v) in phi_dict.items() } sorted_dict = dict( sorted(phi_dict.items(), key=lambda item: item[1], reverse=True)) sorted_dict_keys = list(sorted_dict.keys()) rank_dict = { sorted_dict_keys[i]: 1 / (i + 1) for i in range(len(sorted_dict_keys)) } rank_dict_sum = sum(list(rank_dict.values())) context_dim_prob_dict = { k: rank_dict[k] / rank_dict_sum if k < context_dim else 0 for k in context_dim_prob_dict.keys() } print(context_dim_prob_dict) decoder_accs.append(decoder_accuracy) stagnated = (len(decoder_accs) > stag_num and (decoder_accs[-stag_num - 1] - decoder_accuracy) / stag_num < stag_pct) if stagnated: new_context_dim = max(int(0.75 * context_dim), 5) elif decoder_accuracy >= 0.86: new_context_dim = min(int(1.5 * context_dim + 1), max_context_dim) if stagnated or decoder_accuracy >= 0.86: print("new_context_dim: ", new_context_dim) new_context_prob_arr = np.array( new_context_dim * [1 / new_context_dim] + (max_context_dim - new_context_dim) * [0]) context_dist = Categorical( probs=ptu.from_numpy(new_context_prob_arr)) context_dim = new_context_dim for i in range(context_dim): if i in phi_dict: last_phi_dict[i] = phi_dict[i] elif i not in last_phi_dict: last_phi_dict[i] = max(phi_dict.values()) buffer.clear_dc_buff() else: logger.store(LogProbabilityContext=0, DecoderAccuracy=0) # Log logger.store(ContextDim=context_dim) logger.log_tabular('Epoch', epoch) logger.log_tabular('EpRet', with_min_and_max=True) logger.log_tabular('EpLen', average_only=True) logger.log_tabular('VVals', with_min_and_max=True) logger.log_tabular('TotalEnvInteracts', total_t) logger.log_tabular('LossPi', average_only=True) logger.log_tabular('DeltaLossPi', average_only=True) logger.log_tabular('LossV', average_only=True) logger.log_tabular('DeltaLossV', average_only=True) logger.log_tabular('LossDC', average_only=True) logger.log_tabular('DeltaLossDC', average_only=True) logger.log_tabular('Entropy', average_only=True) logger.log_tabular('KL', average_only=True) logger.log_tabular('Time', time.time() - start_time) logger.log_tabular('LogProbabilityContext', average_only=True) logger.log_tabular('DecoderAccuracy', average_only=True) logger.log_tabular('ContextDim', average_only=True) logger.dump_tabular()
def run(self, episodes, steps, train=False, render_once=1e10, saveonce=10): if train: assert self.recorder.log_message is not None, "log_message is necessary during training, Instantiate Runner with log message" reset_model = False if hasattr(self.model, "type") and self.model.type == "mem": print("Recurrent Model") reset_model = True assert not hasattr(self.model, "hidden_states"), "no hidden_states list attribute" self.env.display_neural_image = self.visual_activations for _ in range(episodes): self.episode_rewards = [] self.env.reset() self.env.enable_draw = True if not train or _ % render_once == render_once - 1 else False if reset_model: self.model.reset() state = self.env.get_state().reshape(-1) bar = tqdm(range(steps), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') trewards = 0 for step in bar: # if self.env.game_done: # break state = T.from_numpy(state).float() # print(state) actions = self.model(state).view(-1) # print(actions) c = Categorical(actions) action = c.sample() prob = c.probs[action] # print(actions,prob) u = np.zeros(self.nactions) u[action] = 1.0 newstate, reward = self.env.act(u) trewards += reward self.episode_rewards.append(trewards) if train: if self.model.type == "mem": self.trainer.store_records( (state.tolist(), action, reward, prob, self.model.hidden_states[-2], False)) else: self.trainer.store_records( (state.tolist(), action, reward, prob, [], False)) # self.trainer.store_records((state.tolist(),action,reward, c.log_prob(action),self.model.hidden_states[-2], False)) state = newstate.reshape(-1) if self.model.type == "mem" and self.visual_activations: u = T.cat(self.activations, dim=0).reshape(-1) self.neural_image_values = u.detach().numpy() self.activations = [] if _ % 10 == 0 and step / steps == 0: self.update_weights() self.neural_weights = self.weights self.weight_change = True if type(self.model.hidden_vectors) != type(None): self.hidden_state = self.model.hidden_vectors else: self.activations = [] bar.set_description(f"Episode: {_:4} Rewards : {trewards}") if train: self.env.step() else: self.env.step(speed=0) self.event_handler() self.window.fill((0, 0, 0)) if self.visual_activations and (not train or _ % render_once == render_once - 1): if self.model.type == "mem": self.draw_neural_image() self.window.blit(self.env.win, (0, 0)) if train: self.trainer.update() self.recorder.newdata(trewards) if _ % saveonce == saveonce - 1: self.recorder.save() self.recorder.plot() if _ % saveonce == saveonce - 1 and self.recorder.final_reward >= self.current_max_reward: self.recorder.save_model(self.model) self.current_max_reward = self.recorder.final_reward print("******* Run Complete *******")
def decode_one_batch_rl(self, greedy, batch, s_t_1, c_t_1, enc_outputs, enc_features, enc_padding_mask, extend_vocab_zeros, enc_batch_extended, coverage_t, device): # No teacher forcing for RL dec_batch, _, max_dec_len, dec_lens_var, target_batch = get_output_from_batch( self.params, batch, device) log_probs = [] decode_ids = [] # we create the dec_padding_mask at the runtime dec_padding_mask = [] y_t = dec_batch[:, 0] mask_t = torch.ones(len(enc_outputs), dtype=torch.long, device=device) # there is at least one token in the decoded seqs, which is STOP_DECODING for di in range(min(max_dec_len, self.params.max_dec_steps)): y_t_1 = y_t # first we have coverage_t_1, then we have a_t final_dist, s_t_1, c_t_1, attn_dist, coverage_t_plus = self.model.decoder( y_t_1, s_t_1, c_t_1, enc_outputs, enc_features, enc_padding_mask, extend_vocab_zeros, enc_batch_extended, coverage_t) if not greedy: # sampling multi_dist = Categorical(final_dist) y_t = multi_dist.sample() log_prob = multi_dist.log_prob(y_t) log_probs.append(log_prob) y_t = y_t.detach() dec_padding_mask.append(mask_t.detach().clone()) mask_t[(mask_t == 1) + (y_t == self.end_id) == 2] = 0 else: # baseline y_t = final_dist.max(1)[1] y_t = y_t.detach() decode_ids.append(y_t) # for next input is_oov = (y_t >= self.vocab.size()).long() y_t = (1 - is_oov) * y_t + is_oov * self.unk_id decode_ids = torch.stack(decode_ids, 1) if not greedy: dec_padding_mask = torch.stack(dec_padding_mask, 1).float() log_probs = torch.stack(log_probs, 1) * dec_padding_mask dec_lens = dec_padding_mask.sum(1) log_probs = log_probs.sum(1) / dec_lens if (dec_lens == 0).any(): print("Decode lengths encounter zero!") print(dec_lens) decoded_seqs = [] for i in range(len(enc_outputs)): dec_ids = decode_ids[i].cpu().numpy() article_oovs = batch.art_oovs[i] dec_words = data.outputids2decwords(dec_ids, self.vocab, article_oovs, self.params.pointer_gen) if len(dec_words) < 2: dec_seq = "xxx" else: dec_seq = " ".join(dec_words) decoded_seqs.append(dec_seq) return decoded_seqs, log_probs
def reinforce(n_components, needsoftmax_mixtureweight=None): if needsoftmax_mixtureweight is None: needsoftmax_mixtureweight = torch.randn(n_components, requires_grad=True, device="cuda") else: needsoftmax_mixtureweight = torch.tensor(needsoftmax_mixtureweight, requires_grad=True, device="cuda") mixtureweights = torch.softmax(needsoftmax_mixtureweight, dim=0).float() #[C] encoder = NN3(input_size=1, output_size=n_components, n_residual_blocks=2).cuda() # optim_net = torch.optim.SGD(encoder.parameters(), lr=1e-4, weight_decay=1e-7) optim_net = torch.optim.Adam(encoder.parameters(), lr=1e-5, weight_decay=1e-7) batch_size = 10 n_steps = 300000 k = 1 data_dict = {} data_dict['steps'] = [] data_dict['losses'] = [] # needsoftmax_qprobs = torch.randn((1,n_components), requires_grad=True, device="cuda") # optim_net = torch.optim.SGD([needsoftmax_qprobs], lr=1e-3, weight_decay=1e-7) # probs = torch.softmax(needsoftmax_qprobs, dim=1) # print ('probs:', to_print2(probs)) # x = sample_gmm(batch_size, mixture_weights=mixtureweights) # count = np.zeros(3) for step in range(n_steps): x = sample_gmm(batch_size, mixture_weights=mixtureweights) logits = encoder.net(x) # logits = needsoftmax_qprobs # print (logits.shape) # fdsfd probs = torch.softmax(logits, dim=1) # print (probs) # print (torch.log(probs)) # print (torch.softmax(torch.log(probs), dim=1)) # print (probs.shape) # print (probs) # probs = probs.repeat(batch_size, 1) cat = Categorical(probs=probs) net_loss = 0 for jj in range(k): cluster_H = cat.sample() # c_ = cluster_H.data.cpu().numpy()[0] # count[c_]+=1 # print (cluster_H.shape) # print (cluster_H) # print(logits) # print (cluster_H) # print (cluster_H.shape) # cluster_H = torch.tensor([0,1,2]).cuda() # print (cluster_H.shape) # fsfsad # print(logits.shape) # print (cluster_H.shape) # print () # tt = torch.tensor([0]).cuda() #.view(1,1) # print (tt.shape) # print (logits[tt]) # print (logits[tt].shape) # aa = torch.index_select(logits, 1, tt) # print (aa.shape) # print (aa) # print () # aa = torch.index_select(logits, 1, cluster_H) # print (aa.shape) # print (aa) # sfad # print (logits[0]) # fsdfas # print (logits[cluster_H]) # print (logits[cluster_H].shape) # dsfasd logq = cat.log_prob(cluster_H).view(batch_size,1) # print (logq1.shape) # print (logq) # # fsd # # print (torch.log(probs)) # # fasfd # # logq2 = torch.index_select(logits, 1, cluster_H) #.view(batch_size,1) # # logq3 = torch.log(torch.index_select(probs, 1, cluster_H) )#.view(batch_size,1) # grad0 = torch.autograd.grad(outputs=logq[0], inputs=(probs), retain_graph=True)[0] # grad1 = torch.autograd.grad(outputs=logq[1], inputs=(probs), retain_graph=True)[0] # grad2 = torch.autograd.grad(outputs=logq[2], inputs=(probs), retain_graph=True)[0] # print (grad0) # print (grad1) # print (grad2) # print () # print (grad0*probs[0][0]) # print (grad1*probs[0][1]) # print (grad2*probs[0][2]) # print () # grad0 = torch.autograd.grad(outputs=logq[0], inputs=(needsoftmax_qprobs), retain_graph=True)[0] # grad1 = torch.autograd.grad(outputs=logq[1], inputs=(needsoftmax_qprobs), retain_graph=True)[0] # grad2 = torch.autograd.grad(outputs=logq[2], inputs=(needsoftmax_qprobs), retain_graph=True)[0] # print (grad0) # print (grad1) # print (grad2) # print () # print (grad0*probs[0][0]) # print (grad1*probs[0][1]) # print (grad2*probs[0][2]) # print () # print (grad0*probs[0][0] + grad1*probs[0][1] + grad2*probs[0][2]) # print () # print () # grad0 = torch.autograd.grad(outputs=logq[0].detach()*logq[0], inputs=(needsoftmax_qprobs), retain_graph=True)[0] # grad1 = torch.autograd.grad(outputs=logq[1].detach()*logq[1], inputs=(needsoftmax_qprobs), retain_graph=True)[0] # grad2 = torch.autograd.grad(outputs=logq[2].detach()*logq[2], inputs=(needsoftmax_qprobs), retain_graph=True)[0] # print (grad0) # print (grad1) # print (grad2) # print () # print (grad0*probs[0][0]) # print (grad1*probs[0][1]) # print (grad2*probs[0][2]) # print () # print (grad0*probs[0][0] + grad1*probs[0][1] + grad2*probs[0][2]) # print () # fsfad # print(logq1, logq2) # print(logq3) # fsdaf # print (logq.shape) # print (logq) # fads logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(batch_size,1) logpxz = logpx_given_z + logpz #[B,1] # net_loss += - torch.mean((logpxz.detach() - 100.) * logq) net_loss += - torch.mean( -logq.detach()*logq) net_loss = net_loss/ k optim_net.zero_grad() net_loss.backward(retain_graph=True) optim_net.step() # if step%10==0: # print (count/np.sum(count), probs.data.cpu().numpy()) if step%100==0: # print (step, to_print(net_loss), to_print(logpxz - logq), to_print(logpx_given_z), to_print(logpz), to_print(logq)) print () print( 'S:{:5d}'.format(step), # 'T:{:.2f}'.format(time.time() - start_time), 'Loss:{:.4f}'.format(to_print1(net_loss)), 'ELBO:{:.4f}'.format(to_print1(logpxz - logq)), 'lpx|z:{:.4f}'.format(to_print1(logpx_given_z)), 'lpz:{:.4f}'.format(to_print1(logpz)), 'lqz:{:.4f}'.format(to_print1(logq)), ) pz_give_x = true_posterior(x, mixture_weights=mixtureweights) # print (pz_give_x.shape) # print (to_print2(x[0]), to_print2(cluster_H[0])) print (to_print2(probs[0])) # print (to_print2(torch.exp(logq[0]))) # before = to_print2(torch.exp(logq[0])) # firstH = cluster_H[0] # logits = encoder.net(x) # # logits = needsoftmax_qprobs # probs = torch.softmax(logits, dim=1) # cat = Categorical(probs=probs) # net_loss = 0 # for jj in range(k): # cluster_H = cat.sample() # logq = cat.log_prob(cluster_H).view(batch_size,1) # # logq = logits[cluster_H].view(batch_size,1) # logpx_given_z = logprob_undercomponent(x, component=cluster_H) # logpz = torch.log(mixtureweights[cluster_H]).view(batch_size,1) # logpxz = logpx_given_z + logpz #[B,1] # # net_loss += - torch.mean((logpxz.detach() - 100.) * logq) # net_loss += - torch.mean( -logq.detach() * logq) # net_loss = net_loss/ k # # print () # print( # 'S:{:5d}'.format(step), # # 'T:{:.2f}'.format(time.time() - start_time), # 'Loss:{:.4f}'.format(to_print1(net_loss)), # 'ELBO:{:.4f}'.format(to_print1(logpxz - logq)), # 'lpx|z:{:.4f}'.format(to_print1(logpx_given_z)), # 'lpz:{:.4f}'.format(to_print1(logpz)), # 'lqz:{:.4f}'.format(to_print1(logq)), # ) # pz_give_x = true_posterior(x, mixture_weights=mixtureweights) # # print (pz_give_x.shape) # print (to_print2(x[0]), to_print2(cluster_H[0])) # print (to_print2(probs[0])) # print (to_print2(torch.exp(cat.log_prob(firstH)[0]))) # after = to_print2(torch.exp(cat.log_prob(firstH)[0])) # # logq = logits[cluster_H].view(batch_size,1) # dif = before - after # print ('Dif:', dif, 'positive is good') # if dif < 0: # print ('howww') # fafsd # print (to_print2(torch.exp(logq[cluster_H[0]]))) # print (to_print2(torch.exp(cat.log_prob(torch.tensor([0]).cuda())[0]))) #, to_print2(torch.exp(logq[1])), to_print2(torch.exp(logq[2]))) # print (to_print2(torch.exp(cat.log_prob(torch.tensor([1]).cuda())[0]))) #, to_print2(torch.exp(logq[1])), to_print2(torch.exp(logq[2]))) # print (to_print2(torch.exp(cat.log_prob(torch.tensor([2]).cuda())[0]))) #, to_print2(torch.exp(logq[1])), to_print2(torch.exp(logq[2]))) # print (to_print2(torch.exp(logq[0])), to_print2(torch.exp(logq[1])), to_print2(torch.exp(logq[2]))) # print (to_print2(pz_give_x[0])) # print (step, 'f:', torch.mean(f).cpu().data.numpy(), 'surr_loss:', surr_loss.cpu().data.detach().numpy(), # 'theta dif:', L2_mixtureweights(true_mixture_weights,torch.softmax( # needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy())) # # if step %5000==0: # # print (torch.softmax(needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()) # # # test_samp, test_cluster = sample_true2() # # # print (test_cluster.cpu().data.numpy(), test_samp.cpu().data.numpy(), torch.softmax(encoder.net(test_samp.cuda().view(1,1)), dim=1)) # # print () # if step > 0: # L2_losses.append(L2_mixtureweights(true_mixture_weights,torch.softmax( # needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy())) # data_dict['steps'].append(step) # surr_losses.append(surr_loss.cpu().data.detach().numpy()) # inf_error, kl_error = inference_error(needsoftmax_mixtureweight) # inf_losses.append(inf_error) # inf_losses_kl.append(kl_error) # kl_batch = compute_kl_batch(x,probs,needsoftmax_mixtureweight) # kl_losses_2.append(kl_batch) # logpx = copmute_logpx(x, needsoftmax_mixtureweight) # logpx_list.append(logpx) # f_list.append(torch.mean(f).cpu().data.detach().numpy()) # logpxz_list.append(torch.mean(logpxz).cpu().data.detach().numpy()) # logprob_cluster_list.append(torch.mean(logprob_cluster).cpu().data.detach().numpy()) # # i_feel_like_it = 1 # # if i_feel_like_it: # if len(inf_losses) > 0: # print ('probs', probs[0]) # print('logpxz', logpxz[0]) # print('pred', surr_pred[0]) # print ('dif', logpxz.detach()[0]-surr_pred.detach()[0]) # print ('logq', logprob_cluster[0]) # print ('dif*logq', (logpxz.detach()[0]-surr_pred.detach()[0])*logprob_cluster[0]) # output= torch.mean((logpxz.detach()-surr_pred.detach()) * logprob_cluster, dim=0)[0] # output2 = torch.mean(surr_pred, dim=0)[0] # output3 = torch.mean(logprob_cluster, dim=0)[0] # # input_ = torch.mean(probs, dim=0) #[0] # # print (probs.shape) # # print (output.shape) # # print (input_.shape) # grad_reinforce = torch.autograd.grad(outputs=output, inputs=(probs), retain_graph=True)[0] # grad_reparam = torch.autograd.grad(outputs=output2, inputs=(probs), retain_graph=True)[0] # grad3 = torch.autograd.grad(outputs=output3, inputs=(probs), retain_graph=True)[0] # # print (grad) # # print (grad_reinforce.shape) # # print (grad_reparam.shape) # grad_reinforce = torch.mean(torch.abs(grad_reinforce)) # grad_reparam = torch.mean(torch.abs(grad_reparam)) # grad3 = torch.mean(torch.abs(grad3)) # # print (grad_reinforce) # # print (grad_reparam) # # dfsfda # grad_reparam_list.append(grad_reparam.cpu().data.detach().numpy()) # grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy()) # # grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy()) # print ('reparam:', grad_reparam.cpu().data.detach().numpy()) # print ('reinforce:', grad_reinforce.cpu().data.detach().numpy()) # print ('logqz grad:', grad3.cpu().data.detach().numpy()) # print ('current mixuture weights') # print (torch.softmax(needsoftmax_mixtureweight, dim=0)) # print() # # print () # else: # grad_reparam_list.append(0.) # grad_reinforce_list.append(0.) # if len(surr_losses) > 3 and step %1000==0: # plot_curve(steps=steps_list, thetaloss=L2_losses, # infloss=inf_losses, surrloss=surr_losses, # grad_reinforce_list=grad_reinforce_list, # grad_reparam_list=grad_reparam_list, # f_list=f_list, logpxz_list=logpxz_list, # logprob_cluster_list=logprob_cluster_list, # inf_losses_kl=inf_losses_kl, # kl_losses_2=kl_losses_2, # logpx_list=logpx_list) # plot_posteriors(needsoftmax_mixtureweight) # plot_dist() # show_surr_preds() # print (f) # print (surr_pred) #Understand surr preds # if step %5000==0: # if step ==0: # fasdf # save_dir = home+'/Documents/Grad_Estimators/GMM/' with open( exp_dir+"data_simplax.p", "wb" ) as f: pickle.dump(data_dict, f) print ('saved data')
model = model.to(device) # parameter temperature = args.temperature softmax = nn.Softmax(0) char2int = HPchar2int_v2() int2char = {v: k for k, v in char2int.items()} # inference input_sent = '哈利走進霍格華茲' input_ids = list(map(char2int.get, list(input_sent))) model.eval() with torch.no_grad(): while len(input_ids) < args.max_len: input_tensor = torch.LongTensor(input_ids).unsqueeze(0).to(device) masks_tensor = FutureMask(input_tensor).to(device) outputs, _ = model(input_ids=input_tensor, input_mask=masks_tensor) outputs = outputs[0, -1, :] / temperature outputs = softmax(outputs) sampler = Categorical(outputs) input_ids.append(sampler.sample().cpu().item()) input_ids = list(map(int2char.get, input_ids)) target_chars = ''.join(input_ids) target_chars = re.sub('\[NL\]', '\n', target_chars) print(target_chars)
class OneHotCategorical(Distribution): r""" Creates a one-hot categorical distribution parameterized by :attr:`probs` or :attr:`logits`. Samples are one-hot coded vectors of size ``probs.size(-1)``. .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension. :attr:`probs` will return this normalized value. The `logits` argument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension. :attr:`logits` will return this normalized value. See also: :func:`torch.distributions.Categorical` for specifications of :attr:`probs` and :attr:`logits`. Example:: >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 tensor([ 0., 0., 0., 1.]) Args: probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """ arg_constraints = { 'probs': constraints.simplex, 'logits': constraints.real_vector } support = constraints.one_hot has_enumerate_support = True def __init__(self, probs=None, logits=None, validate_args=None): self._categorical = Categorical(probs, logits) batch_shape = self._categorical.batch_shape event_shape = self._categorical.param_shape[-1:] super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(OneHotCategorical, _instance) batch_shape = torch.Size(batch_shape) new._categorical = self._categorical.expand(batch_shape) super(OneHotCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new def _new(self, *args, **kwargs): return self._categorical._new(*args, **kwargs) @property def _param(self): return self._categorical._param @property def probs(self): return self._categorical.probs @property def logits(self): return self._categorical.logits @property def mean(self): return self._categorical.probs @property def variance(self): return self._categorical.probs * (1 - self._categorical.probs) @property def param_shape(self): return self._categorical.param_shape def sample(self, sample_shape=torch.Size()): sample_shape = torch.Size(sample_shape) probs = self._categorical.probs num_events = self._categorical._num_events indices = self._categorical.sample(sample_shape) return torch.nn.functional.one_hot(indices, num_events).to(probs) def log_prob(self, value): if self._validate_args: self._validate_sample(value) indices = value.max(-1)[1] return self._categorical.log_prob(indices) def entropy(self): return self._categorical.entropy() def enumerate_support(self, expand=True): n = self.event_shape[0] values = torch.eye(n, dtype=self._param.dtype, device=self._param.device) values = values.view((n, ) + (1, ) * len(self.batch_shape) + (n, )) if expand: values = values.expand((n, ) + self.batch_shape + (n, )) return values
def get_action_and_value(self, x, action=None): logits = self.actor(x) probs = Categorical(logits=logits) if action is None: action = probs.sample() return action, probs.log_prob(action), probs.entropy(), self.critic(x)
# dist = LogitRelaxedBernoulli(torch.Tensor([1.]), bern_param) # dist_bernoulli = Bernoulli(bern_param) C= 2 n_components = C B=1 probs = torch.ones(B,C) bern_param = bern_param.view(B,1) aa = 1 - bern_param probs = torch.cat([aa, bern_param], dim=1) cat = Categorical(probs= probs) grads = [] for i in range(n): b = cat.sample() logprob = cat.log_prob(b.detach()) # b_ = torch.argmax(z, dim=1) logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0] grad = f(b) * logprobgrad grads.append(grad[0][0].data.numpy()) print ('Grad Estimator: Reinfoce categorical') print ('Grad mean', np.mean(grads)) print ('Grad std', np.std(grads)) print () reinforce_cat_grad_means.append(np.mean(grads)) reinforce_cat_grad_stds.append(np.std(grads))
class DIPTransform: def __init__(self, image_shape, num_iters, stop_iters, input_depth=32, input_noise_std=0, optimizer='adam', lr=1e-2, plot_every=0, device='cpu'): """ :param image_shape: image shape (CxHxW) :param num_iters: number of iterations to overfit :param stop_iters: dict int->float specifying categorical distribution over percentages in (0, 100] representing the probability that the transform runs for KEY/100 * NUM_ITERS iters. :param input_depth: depth of random input. Default value taken from paper. :param input_noise_std: stdev of noise added to base random input at each iter. They say in the paper that this helps sometimes (seemingly using 0.03), but default is 0 (no noise). :param optimizer: supported optimizers are 'adam' and 'LBFGS' :param lr: learning rate. Per paper and code, 1e-2 works best. :param plot_every: how often to save images. Doesn't save images if set to 0 (default). :device: 'cuda' to use GPU if available. For example, with NUM_ITERS = 100 and STOP_ITERS = {10: 0.5, 50: 0.3, 100: 0.2}, there is a 50% chance that the transform runs for 10 iterations, a 30% chance it runs for 50 iterations, and a 20% chance it runs for the full 100 iterations. """ self.iters, probs = zip(*stop_iters.items()) self.probs = Categorical(torch.tensor(probs)) self.num_iters = num_iters self.image_shape = image_shape self.input_depth = input_depth self.input_noise_std = input_noise_std self.plot_every = plot_every self.device = device self.optimizer = optimizer self.lr = lr self.loss = torch.nn.MSELoss() # const strings from provided DIP code self.opt_over = 'net' self.const_input = 'noise' # initialize network self.net = skip(input_depth, 3, num_channels_down=[8, 16, 32], num_channels_up=[8, 16, 32], num_channels_skip=[0, 0, 4], upsample_mode='bilinear', need_sigmoid=True, need_bias=True, pad='zeros', act_fun='LeakyReLU').to(self.device) def sample_iters(self): return self.iters[self.probs.sample()] * self.num_iters // 100 def run(self, image, num_iters): assert image.shape[ 1:] == self.image_shape, 'Wrong shape. Expected {}, got {}.'.format( self.image_shape, image.shape[1:]) # run net for num_iters iterations net_input = get_noise(self.input_depth, self.const_input, self.image_shape[1:]).to(self.device) net_input_saved = net_input.detach().clone() noise = net_input.detach().clone() p = get_params(self.opt_over, self.net, net_input) #def local_closure(iter_num): # self.closure(net_input_saved, image, noise, iter_num) lambda_closure = lambda iter_num: self.closure(net_input_saved, image, noise, iter_num) optimize(self.optimizer, p, lambda_closure, self.lr, num_iters, pass_iter=True) transformed = self.net(net_input) return transformed def closure(self, net_input_saved, image, noise, iter_num): net_input = net_input_saved if self.input_noise_std > 0: net_input = net_input_saved + (noise.normal_() * self.input_noise_std) out = self.net(net_input) total_loss = self.loss(out, image) total_loss.backward() #print(total_loss) if self.plot_every > 0 and iter_num % self.plot_every == 0 and total_loss < 0.01: out_np = torch_to_np(out) plot_image_grid([np.clip(out_np, 0, 1)], factor=4, nrow=1, show=False, save_path=f'results_dip/imgs/{iter_num}.png') # maybe log loss here? def __call__(self, sample): """ Takes in, transforms, and returns PIL image given by SAMPLE. Transformation is a random number of iterations of DIP. Distribution of number of iterations is specified when the transform is initialized. """ torch_img = np_to_torch(tmp := pil_to_np(sample)).to(self.device) plot_image_grid([np.clip(tmp, 0, 1)], factor=4, nrow=1, show=False, save_path='results_dip/imgs/true.png') # TODO remove num_iters = self.sample_iters() transformed = self.run(torch_img, num_iters) return np_to_pil(torch_to_np(transformed))
def sample_action(self, state): actions_logits = self(state) distribution = Categorical(logits=actions_logits) action = distribution.sample() return action
def select_actions(pi, deterministic=False): cate_dist = Categorical(pi) if deterministic: return torch.argmax(pi, dim=1).item() else: return cate_dist.sample().unsqueeze(-1)
L2_losses = [] steps_list = [] for step in range(n_steps): optim.zero_grad() loss = 0 net_loss = 0 for i in range(batch_size): x = sample_true() logits = encoder.net(x) # print (logits.shape) # print (torch.softmax(logits, dim=0)) # fsfd cat = Categorical(probs= torch.softmax(logits, dim=0)) cluster = cat.sample() logprob_cluster = cat.log_prob(cluster.detach()) # print (logprob_cluster) pxz = logprob_undercomponent(x, component=cluster, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=False) f = pxz - logprob_cluster # print (f) # logprob = logprob_givenmixtureeweights(x, needsoftmax_mixtureweight) net_loss += -f.detach() * logprob_cluster loss += -f loss = loss / batch_size net_loss = net_loss / batch_size # print (loss, net_loss) loss.backward(retain_graph=True) optim.step()
def gen_IDENTIFY(model, task, state, turn, vocab, topk=1, target=True, var='A', \ exclude=set()): # print(task) sim_flag = turn[1] turn = turn[0] # print('##I am here and this is the form:', state.form) clr_terms = set() | exclude cur_state = state while cur_state: form = cur_state.form match = re.search('[0-9]+', form) # print('Match before the loop:', match) while match != None: clr_terms.add(match.group(0)) form = form[match.span()[1]:] match = re.search('[0-9]+', form) cur_state = cur_state.prev # print('Match after the loop:', match) prob_mistake = 0.0 if turn == 'S' and state.intent == 2: prob_mistake = float(os.getenv('MIS_ID1')) elif turn == 'S' and state.intent == 6: prob_mistake = float(os.getenv('MIS_ID2')) elif turn == 'S' and state.intent == 7: prob_mistake = float(os.getenv('MIS_ID3')) sampled_val = random.random() if turn == 'S' and sampled_val <= prob_mistake and state.intent in [ 2, 6, 7 ]: # print('^^^^^^^^^^^^^^^^^^^^^^^^^') # print('Mistake in generation') # print('^^^^^^^^^^^^^^^^^^^^^^^^^') print('Sampled Value: %f, Mistake Chance: %f, Previous Move: %d' % (sampled_val, prob_mistake, state.intent)) task = gen_task(task, random.choice([1, 2])) if turn == 'S' and (state.prev == None or target == False): # print('FIRST TURN') term_probs = get_term_probs(model, task, simulation=sim_flag) distribution = Categorical(term_probs) clr_term = str(int(distribution.sample())) if target: return 'IDENTIFY(T,' + clr_term + ')' else: return 'IDENTIFY(A,' + clr_term + ')' elif turn == 'S' and target == True and state.intent == 2: if random.random() <= 0.8: term_probs = get_term_probs(model, task, simulation=sim_flag) distribution = Categorical(term_probs) clr_term = str(int(distribution.sample())) return 'IDENTIFY(T,' + clr_term + ')' else: term_probs1 = get_term_probs(model, [task[1], task[0], task[2]], simulation=sim_flag) term_probs2 = get_term_probs(model, [task[2], task[0], task[1]], simulation=sim_flag) distribution1 = Categorical(term_probs1) distribution2 = Categorical(term_probs2) clr_term1 = str(int(distribution1.sample())) clr_term2 = str(int(distribution2.sample())) return 'DISTINGUISH(T,A) AND COMPARE_REF(A,[and],B,C) AND IDENTIFY(B,' + clr_term1 + \ ') AND IDENTIFY(C,' + clr_term2 + ')' elif turn == 'S' and state.intent == 7: return gen_IDENTIFY1(model, task, state, turn, vocab, topk=1) elif turn == 'S' and state.intent == 6: sorted_ind = np.argsort(state.distribution)[::-1] if 0 in sorted_ind[:2]: return gen_IDENTIFY1(model, task, state, turn, vocab, topk=1) else: # print('$$$$$$$$ | TERMS IN CLARIFICATION DID NOT INDENTIFY TARGET | $$$$$$$$') return gen_IDENTIFY2(model, task, state, (turn, sim_flag), vocab, topk=1) if target == True and len(task) > 1: ix = sample_patch(state.distribution, state) new_task = [task[ix]] for i in range(len(task)): if i != ix: new_task.append(task[i]) task = new_task start_var = var if target == False else 'T' term_probs = get_term_probs(model, task, simulation=False, turn=turn) term_probs = term_probs.topk(k=20, dim=0) #getting the non conflicting color patch here # print('term probs:', term_probs) clr_term = str(int(term_probs.indices[0])) for i in range(0, 20): clr_term = str(int(term_probs.indices[i])) if clr_term not in clr_terms: break return 'IDENTIFY(' + start_var + ',' + clr_term + ')'
def compute_loss_policy(pred, args, L, Lambda): L = L.type(dtype) pred_prob = softmax(pred).type(dtype) # pred of size bs x N x 2 d = int(args.num_nodes * args.edge_density) if args.batch_size == 1: m = Categorical(pred_prob[0, :, :]) y_sampled = m.sample((args.num_ysampling, )).type(dtype) # y of size: args.num_ysampling x N pred_prob_sampled_log = m.log_prob(y_sampled).type(dtype) # of size: args.num_ysampling x N pred_prob_sampled_sum_log = pred_prob_sampled_log.sum(dim=-1) # of size args.num_ysampling y_sampled_label = y_sampled * 2 - 1 # y of size: args.num_ysampling x N L = L.squeeze(0).type(dtype) # L of size: N x N c = torch.mm(y_sampled_label, torch.mm(L, torch.t(y_sampled_label))) c = 1 / 4 * torch.diagonal(c, offset=0) # c of size args.num_ysampling if args.problem == 'max': c_plus_penalty = -c + Lambda * y_sampled_label.sum(dim=1).pow(2) else: c_plus_penalty = c + Lambda * y_sampled_label.sum(dim=1).pow(2) loss = pred_prob_sampled_sum_log.dot(c_plus_penalty) w = torch.exp(pred_prob_sampled_sum_log) / torch.exp( pred_prob_sampled_sum_log).sum(dim=-1) acc = w.dot(c) z = (acc / args.num_nodes - d / 4) / np.sqrt(d / 4) inb = torch.dot(torch.abs(y_sampled_label.sum(dim=1)), w) else: m = Categorical(pred_prob) y_sampled = m.sample((args.num_ysampling, )).type(dtype) # y_sampled of size: args.num_ysampling x bs x N pred_prob_sampled_log = m.log_prob(y_sampled) # of size: args.num_ysampling x bs x N y_sampled = y_sampled.permute(1, 2, 0) # y_sampled of size: bs x N x args.num_ysampling pred_prob_sampled_sum_log = pred_prob_sampled_log.sum(dim=-1).permute( 1, 0) # of size args.num_ysampling x bs -> bs x args.num_ysampling y_sampled_label = y_sampled * 2 - 1 c = torch.bmm(y_sampled_label.permute(0, 2, 1), torch.bmm(L, y_sampled_label)) # c of size bs x args.num_ysampling x args.num_ysampling c = 1 / 4 * torch.diagonal(c, offset=0, dim1=-2, dim2=-1) c_plus_penalty = c + Lambda * y_sampled_label.sum(dim=1).pow(2) # c_plus_penalty of size bs x args.num_ysampling loss = torch.bmm( c_plus_penalty.view([args.batch_size, 1, args.num_ysampling]), pred_prob_sampled_sum_log.view( [args.batch_size, args.num_ysampling, 1])) # loss of size bs loss = torch.mean(loss) w = torch.exp(pred_prob_sampled_sum_log) / torch.exp( pred_prob_sampled_sum_log).sum(dim=-1).view([args.batch_size, 1]) acc = torch.dot(c.view([args.batch_size, 1, args.num_ysampling]), w.view([args.batch_size, args.num_ysampling, 1])) inb = torch.dot( torch.abs(y_sampled_label.sum(dim=1)).view( [args.batch_size, 1, args.num_ysampling]), w.view([args.batch_size, args.num_ysampling, 1])) acc = torch.mean(acc) z = (acc / args.num_nodes - d / 4) / np.sqrt(d / 4) inb = torch.mean(inb) inb = torch.round(inb) return loss, acc, z, inb
class OneHotCategorical(Distribution): r""" Creates a one-hot categorical distribution parameterized by :attr:`probs` or :attr:`logits`. Samples are one-hot coded vectors of size ``probs.size(-1)``. .. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1. See also: :func:`torch.distributions.Categorical` for specifications of :attr:`probs` and :attr:`logits`. Example:: >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 tensor([ 0., 0., 0., 1.]) Args: probs (Tensor): event probabilities logits (Tensor): event log probabilities """ arg_constraints = {'probs': constraints.simplex} support = constraints.simplex has_enumerate_support = True def __init__(self, probs=None, logits=None, validate_args=None): self._categorical = Categorical(probs, logits) batch_shape = self._categorical.batch_shape event_shape = self._categorical.param_shape[-1:] super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args) def _new(self, *args, **kwargs): return self._categorical._new(*args, **kwargs) @property def probs(self): return self._categorical.probs @property def logits(self): return self._categorical.logits @property def mean(self): return self._categorical.probs @property def variance(self): return self._categorical.probs * (1 - self._categorical.probs) @property def param_shape(self): return self._categorical.param_shape def sample(self, sample_shape=torch.Size()): sample_shape = torch.Size(sample_shape) probs = self._categorical.probs one_hot = probs.new(self._extended_shape(sample_shape)).zero_() indices = self._categorical.sample(sample_shape) if indices.dim() < one_hot.dim(): indices = indices.unsqueeze(-1) return one_hot.scatter_(-1, indices, 1) def log_prob(self, value): if self._validate_args: self._validate_sample(value) indices = value.max(-1)[1] return self._categorical.log_prob(indices) def entropy(self): return self._categorical.entropy() def enumerate_support(self): n = self.event_shape[0] values = self._new((n, n)) torch.eye(n, out=values) values = values.view((n, ) + (1, ) * len(self.batch_shape) + (n, )) return values.expand((n, ) + self.batch_shape + (n, ))
def forward(self, hidden_state=None): """ Performs a forward pass. If training, use Gumbel Softmax (hard) for sampling, else use discrete sampling. Hidden state here represents the encoded image/metadata - initializes the RNN from it. """ hidden_state = self.input_module(hidden_state) state, batch_size = self._init_state(hidden_state, type(self.rnn)) # Init output if not self.vqvae and self.discrete_communication and self.rl: output = [ torch.zeros( (batch_size, self.vocab_size), dtype=torch.float32, device=self.device, ) ] output[0][:, self.sos_id] = 1.0 else: # In vqvae case, there is no sos symbol, since all words come from the unordered embedding table. # It is not possible to index code words by sos or eos symbols, since the number of codewords # is not necessarily the vocab size! output = [ torch.zeros( (batch_size, self.vocab_size), dtype=torch.float32, device=self.device, ) ] # Keep track of sequence lengths initial_length = self.output_len + 1 # add the sos token seq_lengths = ( torch.ones([batch_size], dtype=torch.int64, device=self.device) * initial_length ) # [initial_length, initial_length, ..., initial_length]. This gets reduced whenever it ends somewhere. embeds = [] # keep track of the embedded sequence sentence_probability = torch.zeros( (batch_size, self.vocab_size), device=self.device ) losses_2_3 = torch.empty(self.output_len, device=self.device) entropy = torch.empty((batch_size, self.output_len), device=self.device) message_logits = torch.empty((batch_size, self.output_len), device=self.device) if self.vqvae: distance_computer = EmbeddingtableDistances(self.e) for i in range(self.output_len): emb = torch.matmul(output[-1], self.embedding) embeds.append(emb) state = self.rnn(emb, state) if type(self.rnn) is nn.LSTMCell: h, _ = state else: h = state indices = [None] * batch_size if not self.rl: if not self.vqvae: # That's the original baseline setting p = F.softmax(self.linear_out(h), dim=1) token, sentence_probability = self.calculate_token_gumbel_softmax( p, self.tau, sentence_probability, batch_size ) else: pre_quant = self.linear_out(h) if not self.discrete_communication: token = self.vq.apply(pre_quant, self.e, indices) else: distances = distance_computer(pre_quant) softmin = F.softmax(-distances, dim=1) if not self.gumbel_softmax: token = self.hard_max.apply( softmin, indices, self.discrete_latent_number ) # This also updates the indices else: _, indices[:] = torch.max(softmin, dim=1) token, _ = self.calculate_token_gumbel_softmax( softmin, self.tau, 0, batch_size ) else: if not self.vqvae: all_logits = F.log_softmax(self.linear_out(h) / self.tau, dim=1) else: pre_quant = self.linear_out(h) distances = distance_computer(pre_quant) all_logits = F.log_softmax(-distances / self.tau, dim=1) _, indices[:] = torch.max(all_logits, dim=1) distr = Categorical(logits=all_logits) entropy[:, i] = distr.entropy() if self.training: token_index = distr.sample() token = to_one_hot(token_index, n_dims=self.vocab_size) else: token_index = all_logits.argmax(dim=1) token = to_one_hot(token_index, n_dims=self.vocab_size) message_logits[:, i] = distr.log_prob(token_index) if not (self.vqvae and not self.discrete_communication and not self.rl): # Whenever we have a meaningful eos symbol, we prune the messages in the end self._calculate_seq_len( seq_lengths, token, initial_length, seq_pos=i + 1 ) if self.vqvae: loss_2 = torch.mean( torch.norm(pre_quant.detach() - self.e[indices], dim=1) ** 2 ) loss_3 = torch.mean( torch.norm(pre_quant - self.e[indices].detach(), dim=1) ** 2 ) loss_2_3 = ( loss_2 + self.beta * loss_3 ) # This corresponds to the second and third loss term in VQ-VAE losses_2_3[i] = loss_2_3 token = token.to(self.device) output.append(token) messages = torch.stack(output, dim=1) loss_2_3_out = torch.mean(losses_2_3) return ( messages, seq_lengths, entropy, torch.stack(embeds, dim=1), sentence_probability, loss_2_3_out, message_logits, )
neglogprobs = torch.zeros((args.episode_length,), device=device) entropys = torch.zeros((args.episode_length,), device=device) # TRY NOT TO MODIFY: prepare the execution of the game. for step in range(args.episode_length): global_step += 1 obs[step] = next_obs.copy() # ALGO LOGIC: put action logic here logits, std = pg.forward([obs[step]]) values[step] = vf.forward([obs[step]]) # ALGO LOGIC: `env.action_space` specific logic if isinstance(env.action_space, Discrete): probs = Categorical(logits=logits) action = probs.sample() actions[step], neglogprobs[step], entropys[step] = action.tolist()[0], -probs.log_prob(action), probs.entropy() elif isinstance(env.action_space, Box): probs = Normal(logits, std) action = probs.sample() clipped_action = torch.clamp(action, torch.min(torch.Tensor(env.action_space.low)), torch.min(torch.Tensor(env.action_space.high))) actions[step], neglogprobs[step], entropys[step] = clipped_action.tolist()[0], -probs.log_prob(action).sum(), probs.entropy().sum() elif isinstance(env.action_space, MultiDiscrete): logits_categories = torch.split(logits, env.action_space.nvec.tolist(), dim=1) action = [] probs_categories = [] probs_entropies = torch.zeros((logits.shape[0])) neglogprob = torch.zeros((logits.shape[0])) for i in range(len(logits_categories)):
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.distributions.normal import Normal from torch.distributions.binomial import Binomial from torch.distributions.categorical import Categorical S = Categorical(torch.tensor([0.5, 0.5])) for i in range(100000): s = S.sample() x = [] if s==1: u else
#REINFORCE print ('REINFORCE') # def sample_reinforce_given_class(logits, samp): # return logprob grads = [] for i in range (N): dist = Categorical(logits=logits) samp = dist.sample() logprob = dist.log_prob(samp) reward = f(samp) gradlogprob = torch.autograd.grad(outputs=logprob, inputs=(logits), retain_graph=True)[0] grads.append(reward*gradlogprob) print () grads = torch.stack(grads).view(N,C) # print (grads.shape) grad_mean_reinforce = torch.mean(grads,dim=0) grad_std_reinforce = torch.std(grads,dim=0) print ('REINFORCE') print ('mean:', grad_mean_reinforce) print ('std:', grad_std_reinforce) print ()
L2_losses = [] steps_list = [] for step in range(n_steps): optim.zero_grad() loss = 0 net_loss = 0 for i in range(batch_size): x = sample_true() logits = encoder.net(x) # print (logits.shape) # print (torch.softmax(logits, dim=0)) # fsfd cat = Categorical(probs=torch.softmax(logits, dim=0)) cluster = cat.sample() logprob_cluster = cat.log_prob(cluster.detach()) # print (logprob_cluster) pxz = logprob_undercomponent( x, component=cluster, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=False) f = pxz - logprob_cluster # print (f) # logprob = logprob_givenmixtureeweights(x, needsoftmax_mixtureweight) net_loss += -f.detach() * logprob_cluster loss += -f loss = loss / batch_size net_loss = net_loss / batch_size
print('True') print('[-.5478, .1122, .4422]') print('dif:', np.abs(grad_mean_simplax.numpy() - true)) print() #REINFORCE print('REINFORCE') # def sample_reinforce_given_class(logits, samp): # return logprob grads = [] for i in range(N): dist = Categorical(logits=logits) samp = dist.sample() logprob = dist.log_prob(samp) reward = f(samp) gradlogprob = torch.autograd.grad(outputs=logprob, inputs=(logits), retain_graph=True)[0] grads.append(reward * gradlogprob) print() grads = torch.stack(grads).view(N, C) # print (grads.shape) grad_mean_reinforce = torch.mean(grads, dim=0) grad_std_reinforce = torch.std(grads, dim=0) print('REINFORCE') print('mean:', grad_mean_reinforce)
def throw(self): policy = Categorical(self.dist) return policy.sample()
tgt_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=src_key_padding_mask, memory=encoder_output) ''' Now make prediction on next word we're only interested in the values of the embeddings of the current step we're on which we can get by output[decoder_step,:,:] which has dim (batch_size,vocab_size) ''' word_probs = F.softmax(output[decoding_step, :, :], dim=1) #IMPLEMENTATION FOR REINFORCE VERY EASY https://pytorch.org/docs/stable/distributions.html #MAKE SURE THAT THIS BACKPROPS THROUGH EVERY ACTION WE CHOOSE m = Categorical(word_probs) chosen_word = m.sample( ) #this generates along dim 1 of word_probs which is what we want since dim 1 contains distributions dec_input = torch.cat( [dec_input, chosen_word.view(1, -1)], dim=0 ) #append chosen_word as row to end of decoder input for next iteration policy_net.train() reward = get_bleu_scores(trg_tensor, dec_input, dataset_dict['TGT'], BLEU1=False).to(main_params.device) #now that we have reward, go back through the path and get gradients eos_reached = torch.zeros(main_params.batch_size, dtype=torch.uint8).type(torch.BoolTensor).to( main_params.device) encoder_output_detached = encoder_output.clone().detach()
def get_action(self, x, action=None): logits = self.actor(self.forward(x)) probs = Categorical(logits=logits) if action is None: action = probs.sample() return action, probs.log_prob(action), probs.entropy()
def get_action(model, device, state): state = torch.Tensor(state).to(device) action_probs, value_ext, value_int = model(state) action_dist = Categorical(action_probs) action = action_dist.sample() return action.data.cpu().numpy().squeeze()
def run_episode(self): ''' Collect experiences from an episode of self-plays. ''' observations = [[] for i in range(4)] actions = [[None] for i in range(4)] rewards = [[] for i in range(4)] entropy = [[] for i in range(4)] log_pbs = [[] for i in range(4)] values = [[] for i in range(4)] env = make('hungry_geese', debug=False) frame = env.reset(num_agents=4) while any(entry['status'] == 'ACTIVE' for entry in frame): step = frame[0]['observation']['step'] food = frame[0]['observation']['food'] geese = frame[0]['observation']['geese'] for i in range(4): agent = geese[i] if not agent: continue obs = {'index': i, 'geese': geese[:], 'food': food[:]} observations[i].append(obs) logits, value = self.predict(observations[i]) # Mask invalid actions to boost learning speed. action_mask = get_action_mask(i, geese, actions[i][-1]) logits = torch.where(action_mask, logits, to_tensor(-1e5)) policy = Categorical(logits=logits) move = policy.sample() # Naive reward mechanism to encourage eating. # Feel free to change it to improve model performance. if is_eating(agent, move, food): rewards[i].append(1) else: rewards[i].append(0) actions[i].append(action_of(move)) entropy[i].append(policy.entropy()) log_pbs[i].append(policy.log_prob(move)) values[i].append(value) # Next frame frame = env.step(tails_of(actions)) # Episode is over. # Assign final reward for each agent. for i in range(4): score = frame[i]['reward'] turns, length = divmod(score, 100) # Encourage surviving. rewards[i][-1] += turns / 200 # Ensure shapes are consistent. assert len(rewards[i]) == len(log_pbs[i]) == len(values[i]) # Calculate actual returns. Q = [] for i in range(4): length = len(rewards[i]) returns = torch.zeros(length).detach() val = 0 for t in reversed(range(length)): val = rewards[i][t] + self.GAMMA * val returns[t] = val Q.append(returns) # Flatten training data. Q = torch.hstack(Q).detach() V = torch.hstack(flatten(values)) E = torch.hstack(flatten(entropy)) log_probs = torch.hstack(flatten(log_pbs)) # Again ensure shapes of training data are consistent. assert Q.shape == V.shape == E.shape == log_probs.shape self.memory.add(Q, V, E, log_probs)
def make_action(self, state, test=False): if test: state = torch.tensor(state, device='cuda' if use_cuda else 'cpu').permute( 2, 0, 1).unsqueeze(0) if self.args.exploration_method.startswith('greedy'): with torch.no_grad(): actions = torch.softmax(self.online_net(state), 1).max(1)[1].view(-1, 1) if test: return actions.item() return actions elif self.args.exploration_method.startswith('epsilon'): # TODO: # At first, you decide whether you want to explore the environemnt sample = random.random() if self.args.exploration_method.startswith('epsilon_exp'): global episodes_done_num EPS_START = 0.9 EPS_END = 0.1 EPS_DECAY = 200 eps_threshold = EPS_END + (EPS_START - EPS_END) * np.exp( -1. * episodes_done_num / EPS_DECAY) else: eps_threshold = .1 # TODO: # if explore, you randomly samples one action # else, use your model to predict action if sample > eps_threshold or test: with torch.no_grad(): # t.max(1) will return largest column value of each row. # second column on max result is index of where max element was # found, so we pick action with the larger expected reward. if test: return self.online_net(state).max(1)[1].view(1, 1).item() return self.online_net(state).max(1)[1].view(1, 1) else: return torch.tensor([[random.randrange(self.num_actions)]], device='cuda' if use_cuda else 'cpu', dtype=torch.long) elif self.args.exploration_method.startswith('boltzmann'): with torch.no_grad(): probs = torch.softmax( self.online_net(state) / self.args.boltzmann_temperature, 1) m = Categorical(probs) action = m.sample().view(-1, 1) if test: return action.item() return action elif self.args.exploration_method.startswith('thompson'): with torch.no_grad(): if test: probs = torch.softmax( self.online_net.forward(state, dropout_rate=0, thompson=False), 1) else: probs = torch.softmax( self.online_net.forward(state, dropout_rate=0.3, thompson=True), 1) actions = probs.max(1)[1].view(-1, 1) if test: return actions.item() return actions else: raise ValueError("Unknown exploration method")
def run(self, episodes, steps, train=False, render_once=1e10, saveonce=10): if train: assert self.recorder.log_message is not None, "log_message is necessary during training, Instantiate Runner with log message" reset_model = False if hasattr(self.model, "type") and self.model.type == "mem": print("Recurrent Model") reset_model = True self.env.display_neural_image = self.visual_activations for _ in range(episodes): self.env.reset() self.env.enable_draw = True if not train or _ % render_once == render_once - 1 else False if reset_model: self.model.reset() state = self.env.get_state().reshape(-1) bar = tqdm(range(steps), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') trewards = 0 for step in bar: state = T.from_numpy(state).float() actions = self.model(state) c = Categorical(actions) action = c.sample() log_prob = c.log_prob(action) u = np.zeros(self.nactions) u[action] = 1.0 newstate, reward = self.env.act(u) state = newstate.reshape(-1) trewards += reward if train: self.trainer.store_records(reward, log_prob) if self.visual_activations: u = T.cat(self.activations, dim=0).reshape(-1) self.env.neural_image_values = u.detach().numpy() self.activations = [] if _ % 10 == 0 and step / steps == 0: self.update_weights() self.env.neural_weights = self.weights self.env.weight_change = True if type(self.model.hidden_vectors) != type(None): self.env.hidden_state = self.model.hidden_vectors bar.set_description(f"Episode: {_:4} Rewards : {trewards}") if train: self.env.step() else: self.env.step(speed=0) if train: self.trainer.update() self.trainer.clear_memory() self.recorder.newdata(trewards) if _ % saveonce == saveonce - 1: self.recorder.save() self.recorder.plot() if _ % saveonce == saveonce - 1 and self.recorder.final_reward >= self.current_max_reward: self.recorder.save_model(self.model) self.current_max_reward = self.recorder.final_reward print("******* Run Complete *******")
def forward(self): ''' https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py#L126 ''' h0 = None # setting h0 to None will initialize LSTM state with 0s anchors = [] anchors_w_1 = [] arc_seq = {} entropys = [] log_probs = [] skip_count = [] skip_penaltys = [] inputs = self.g_emb.weight skip_targets = torch.tensor([1.0 - self.skip_target, self.skip_target]).cuda() for layer_id in range(self.num_layers): if self.search_whole_channels: inputs = inputs.unsqueeze(0) output, hn = self.w_lstm(inputs, h0) output = output.squeeze(0) h0 = hn logit = self.w_soft(output) if self.temperature is not None: logit /= self.temperature if self.tanh_constant is not None: logit = self.tanh_constant * torch.tanh(logit) branch_id_dist = Categorical(logits=logit) branch_id = branch_id_dist.sample() arc_seq[str(layer_id)] = [branch_id] log_prob = branch_id_dist.log_prob(branch_id) log_probs.append(log_prob.view(-1)) entropy = branch_id_dist.entropy() entropys.append(entropy.view(-1)) inputs = self.w_emb(branch_id) inputs = inputs.unsqueeze(0) else: # https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py#L171 assert False, "Not implemented error: search_whole_channels = False" output, hn = self.w_lstm(inputs, h0) output = output.squeeze(0) if layer_id > 0: query = torch.cat(anchors_w_1, dim=0) query = torch.tanh(query + self.w_attn_2(output)) query = self.v_attn(query) logit = torch.cat([-query, query], dim=1) if self.temperature is not None: logit /= self.temperature if self.tanh_constant is not None: logit = self.tanh_constant * torch.tanh(logit) skip_dist = Categorical(logits=logit) skip = skip_dist.sample() skip = skip.view(layer_id) arc_seq[str(layer_id)].append(skip) skip_prob = torch.sigmoid(logit) kl = skip_prob * torch.log(skip_prob / skip_targets) kl = torch.sum(kl) skip_penaltys.append(kl) log_prob = skip_dist.log_prob(skip) log_prob = torch.sum(log_prob) log_probs.append(log_prob.view(-1)) entropy = skip_dist.entropy() entropy = torch.sum(entropy) entropys.append(entropy.view(-1)) # Calculate average hidden state of all nodes that got skips # and use it as input for next step skip = skip.type(torch.float) skip = skip.view(1, layer_id) skip_count.append(torch.sum(skip)) inputs = torch.matmul(skip, torch.cat(anchors, dim=0)) inputs /= (1.0 + torch.sum(skip)) else: inputs = self.g_emb.weight anchors.append(output) anchors_w_1.append(self.w_attn_1(output)) self.sample_arc = arc_seq entropys = torch.cat(entropys) self.sample_entropy = torch.sum(entropys) log_probs = torch.cat(log_probs) self.sample_log_prob = torch.sum(log_probs) skip_count = torch.stack(skip_count) self.skip_count = torch.sum(skip_count) skip_penaltys = torch.stack(skip_penaltys) self.skip_penaltys = torch.mean(skip_penaltys)
def train(model, iterator, optimizer, criterion, tag_pad_idx, tag_unk_idx, inside_word_idx, UD_TAGS=None, noise=0): epoch_loss = 0 epoch_correct = 0 epoch_n_label = 0 model.train() if noise > 0: counts = [ UD_TAGS.vocab.freqs[UD_TAGS.vocab.itos[k]] if UD_TAGS.vocab.itos[k] in UD_TAGS.vocab.freqs else 0 for k in range(len(UD_TAGS.vocab)) ] c = Categorical( torch.tensor(counts).cuda() / float(sum(UD_TAGS.vocab.freqs.values()))) b = Bernoulli(probs=torch.tensor([noise]).cuda()) for batch in tqdm(iterator): text = batch.text tags = batch.udtags optimizer.zero_grad() # text = [sent len, batch size] predictions = model(text) # predictions = [sent len, batch size, output dim] # tags = [sent len, batch size] predictions = predictions.view(-1, predictions.shape[-1]) tags = tags.view(-1) if noise > 0: assert noise >= 0 and noise <= 1 non_pad_elements = tags != tag_pad_idx prob_mask = (b.sample(tags.shape) == 1).squeeze(1).cuda() & non_pad_elements noisy_preds = c.sample((torch.sum(prob_mask).item(), )) # noisy_preds = random.choices([UD_TAGS.vocab.stoi[elt] for elt in UD_TAGS.vocab.freqs.keys()], k=torch.sum(prob_mask).item()) tags[prob_mask] = torch.tensor(noisy_preds) # predictions = [sent len * batch size, output dim] # tags = [sent len * batch size] loss = criterion(predictions, tags) correct, n_labels = categorical_accuracy(predictions, tags, tag_pad_idx, tag_unk_idx, inside_word_idx) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_correct += correct.item() epoch_n_label += n_labels return epoch_loss / len(iterator), epoch_correct / epoch_n_label
def get_action(self, state): """interface for Agent""" s = torch.FloatTensor(state).to(self.device) logits = self.model(s).detach() m = Categorical(logits = logits) return m.sample().cpu().data.numpy().tolist()[0]
class OneHotCategorical(Distribution): r""" Creates a one-hot categorical distribution parameterized by :attr:`probs` or :attr:`logits`. Samples are one-hot coded vectors of size ``probs.size(-1)``. .. note:: :attr:`probs` will be normalized to be summing to 1. See also: :func:`torch.distributions.Categorical` for specifications of :attr:`probs` and :attr:`logits`. Example:: >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 tensor([ 0., 0., 0., 1.]) Args: probs (Tensor): event probabilities logits (Tensor): event log probabilities """ arg_constraints = {'probs': constraints.simplex} support = constraints.simplex has_enumerate_support = True def __init__(self, probs=None, logits=None, validate_args=None): self._categorical = Categorical(probs, logits) batch_shape = self._categorical.batch_shape event_shape = self._categorical.param_shape[-1:] super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args) def _new(self, *args, **kwargs): return self._categorical._new(*args, **kwargs) @property def probs(self): return self._categorical.probs @property def logits(self): return self._categorical.logits @property def mean(self): return self._categorical.probs @property def variance(self): return self._categorical.probs * (1 - self._categorical.probs) @property def param_shape(self): return self._categorical.param_shape def sample(self, sample_shape=torch.Size()): sample_shape = torch.Size(sample_shape) probs = self._categorical.probs one_hot = probs.new(self._extended_shape(sample_shape)).zero_() indices = self._categorical.sample(sample_shape) if indices.dim() < one_hot.dim(): indices = indices.unsqueeze(-1) return one_hot.scatter_(-1, indices, 1) def log_prob(self, value): if self._validate_args: self._validate_sample(value) indices = value.max(-1)[1] return self._categorical.log_prob(indices) def entropy(self): return self._categorical.entropy() def enumerate_support(self): n = self.event_shape[0] values = self._new((n, n)) torch.eye(n, out=values) values = values.view((n,) + (1,) * len(self.batch_shape) + (n,)) return values.expand((n,) + self.batch_shape + (n,))
def decode_teacher_forcing(self, y, vl, vg): # Masks not_masked = y.new_ones(1, dtype=torch.bool)[0] mask = ((y > 0).sum(dim=(0, 1)) > 0) # Initialize word states w = vg.new_full((vg.shape[0], 1), PretrainedEmbeddings.INDEX_START, dtype=torch.long) vg = self.dropout(vg) h = self.init_h(vg) c = self.init_c(vg) states = (vl, h, c, None) # Process words words, hs, alphas = [], [], [] for j in range(self.max_word): if torch.equal(mask[j], not_masked): p, states = self.proc_word(w, states) words.append(p) _, h, _, alpha = states hs.append(h) if alpha is not None: alphas.append(alpha) if self.teacher_forcing is None or self.teacher_forcing.get_tfr( ) >= random.random(): w = y[:, 0, j] else: p = softmax(p, dim=-1) cat = Categorical(probs=p) w = cat.sample() else: p = vg.new_ones(vg.shape[0], self.embed_num) / self.embed_num words.append(p) if self.teacher_forcing is None or self.teacher_forcing.get_tfr( ) >= random.random(): w = y[:, 0, j] else: w = vg.new_zeros(vg.shape[0]) words = torch.stack(words, dim=1) # Attention Encoded Text Embedding # Concat hidden states NxTxd hs = torch.stack(hs, dim=1) # Projection to NxTxr rows g = softmax(self.aete2(tanh(self.aete1(hs))), dim=2) # Weighted sum over T to Nxrxd m = g.permute(0, 2, 1).matmul(hs) # AETE embedding Nxd x1 = m.max(dim=1)[0] # Saliency Weighted Global Average Pooling alphas = torch.stack(alphas, dim=1) if self.multi_image > 1: # Spatial attention maps NxMxT aws = (alphas * g.max(dim=2)[0].unsqueeze(dim=-1).unsqueeze(dim=-1)).sum( dim=1) # SWGAP Nx1024 x2 = (aws.unsqueeze(dim=-1) * vl).sum(dim=2) x2 = x2.max(dim=1)[0] else: # Spatial attention maps NxT aws = (alphas * g.max(dim=2)[0].unsqueeze(dim=-1)).sum(dim=1) # SWGAP Nx1024 x2 = (aws.unsqueeze(dim=-1) * vl).sum(dim=1) # Joint Nx14 dis = self.joint(torch.cat([x1, x2], dim=1)) dis = dis.view((dis.shape[0], self.DISEASE_NUM, 2)) return words, dis