def relax_grad2(x, logits, b, surrogate, mixtureweights): B = logits.shape[0] C = logits.shape[1] cat = Categorical(logits=logits) # u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda() u = myclamp(torch.rand(B,C).cuda()) gumbels = -torch.log(-torch.log(u)) z = logits + gumbels # b = torch.argmax(z, dim=1) #.view(B,1) logq = cat.log_prob(b).view(B,1) surr_input = torch.cat([z, x, logits.detach()], dim=1) cz = surrogate.net(surr_input) z_tilde = sample_relax_given_b(logits, b) surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1) cz_tilde = surrogate.net(surr_input) logpx_given_z = logprob_undercomponent(x, component=b) logpz = torch.log(mixtureweights[b]).view(B,1) logpxz = logpx_given_z + logpz #[B,1] f = logpxz - logq net_loss = - torch.mean( (f.detach() - cz_tilde.detach()) * logq - logq + cz - cz_tilde ) grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0] #[B,C] pb = torch.exp(logq) return grad, pb
def sample_relax(logits, surrogate): cat = Categorical(logits=logits) u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda() gumbels = -torch.log(-torch.log(u)) z = logits + gumbels b = torch.argmax(z, dim=1) #.view(B,1) logprob = cat.log_prob(b).view(B,1) # czs = [] # for j in range(1): # z = sample_relax_z(logits) # surr_input = torch.cat([z, x, logits.detach()], dim=1) # cz = surrogate.net(surr_input) # czs.append(cz) # czs = torch.stack(czs) # cz = torch.mean(czs, dim=0)#.view(1,1) surr_input = torch.cat([z, x, logits.detach()], dim=1) cz = surrogate.net(surr_input) cz_tildes = [] for j in range(1): z_tilde = sample_relax_given_b(logits, b) surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1) cz_tilde = surrogate.net(surr_input) cz_tildes.append(cz_tilde) cz_tildes = torch.stack(cz_tildes) cz_tilde = torch.mean(cz_tildes, dim=0) #.view(B,1) return b, logprob, cz, cz_tilde
def sample_relax(logits): #, k=1): # u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) #.cuda() u = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda() gumbels = -torch.log(-torch.log(u)) z = logits + gumbels b = torch.argmax(z, dim=1) cat = Categorical(logits=logits) logprob = cat.log_prob(b).view(B,1) v_k = torch.rand(B,1).clamp(1e-12, 1.-1e-12) z_tilde_b = -torch.log(-torch.log(v_k)) #this way seems biased even tho it shoudlnt be # v_k = torch.gather(input=u, dim=1, index=b.view(B,1)) # z_tilde_b = torch.gather(input=z, dim=1, index=b.view(B,1)) v = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda() probs = torch.softmax(logits,dim=1).repeat(B,1) # print (probs.shape, torch.log(v_k).shape, torch.log(v).shape) # fasdfa # print (v.shape) # print (v.shape) z_tilde = -torch.log((- torch.log(v) / probs) - torch.log(v_k)) # print (z_tilde) # print (z_tilde_b) z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b) # print (z_tilde) # fasdfs return z, b, logprob, z_tilde
def sample_relax_given_class(logits, samp): cat = Categorical(logits=logits) u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) gumbels = -torch.log(-torch.log(u)) z = logits + gumbels b = samp #torch.argmax(z, dim=1) logprob = cat.log_prob(b).view(B,1) u_b = torch.gather(input=u, dim=1, index=b.view(B,1)) z_tilde_b = -torch.log(-torch.log(u_b)) z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b)) z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b) z = z_tilde u_b = torch.gather(input=u, dim=1, index=b.view(B,1)) z_tilde_b = -torch.log(-torch.log(u_b)) u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b)) z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b) return z, z_tilde, logprob
def sample_true2(): cat = Categorical(probs= torch.tensor(true_mixture_weights)) cluster = cat.sample() # print (cluster) # fsd norm = Normal(torch.tensor([cluster*10.]).float(), torch.tensor([5.0]).float()) samp = norm.sample() # print (samp) return samp,cluster
def sample_gmm(batch_size, mixture_weights): cat = Categorical(probs=mixture_weights) cluster = cat.sample([batch_size]) # [B] mean = (cluster*10.).float().cuda() std = torch.ones([batch_size]).cuda() *5. norm = Normal(mean, std) samp = norm.sample() samp = samp.view(batch_size, 1) return samp
class OneHotCategorical(Distribution): r""" Creates a one-hot categorical distribution parameterized by `probs`. Samples are one-hot coded vectors of size probs.size(-1). See also: :func:`torch.distributions.Categorical` Example:: >>> m = OneHotCategorical(torch.Tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 0 0 1 0 [torch.FloatTensor of size 4] Args: probs (Tensor or Variable): event probabilities """ params = {'probs': constraints.simplex} support = constraints.simplex has_enumerate_support = True def __init__(self, probs=None, logits=None): self._categorical = Categorical(probs, logits) batch_shape = self._categorical.probs.size()[:-1] event_shape = self._categorical.probs.size()[-1:] super(OneHotCategorical, self).__init__(batch_shape, event_shape) def sample(self, sample_shape=torch.Size()): sample_shape = torch.Size(sample_shape) probs = self._categorical.probs one_hot = probs.new(self._extended_shape(sample_shape)).zero_() indices = self._categorical.sample(sample_shape) if indices.dim() < one_hot.dim(): indices = indices.unsqueeze(-1) return one_hot.scatter_(-1, indices, 1) def log_prob(self, value): indices = value.max(-1)[1] return self._categorical.log_prob(indices) def entropy(self): return self._categorical.entropy() def enumerate_support(self): probs = self._categorical.probs n = self.event_shape[0] if isinstance(probs, Variable): values = Variable(torch.eye(n, out=probs.data.new(n, n))) else: values = torch.eye(n, out=probs.new(n, n)) values = values.view((n,) + (1,) * len(self.batch_shape) + (n,)) return values.expand((n,) + self.batch_shape + (n,))
def test_gmm_loss(self): """ Test case 1 """ n_samples = 10000 means = torch.Tensor([[0., 0.], [1., 1.], [-1., 1.]]) stds = torch.Tensor([[.03, .05], [.02, .1], [.1, .03]]) pi = torch.Tensor([.2, .3, .5]) cat_dist = Categorical(pi) indices = cat_dist.sample((n_samples,)).long() rands = torch.randn(n_samples, 2) samples = means[indices] + rands * stds[indices] class _model(nn.Module): def __init__(self, gaussians): super().__init__() self.means = nn.Parameter(torch.Tensor(1, gaussians, 2).normal_()) self.pre_stds = nn.Parameter(torch.Tensor(1, gaussians, 2).normal_()) self.pi = nn.Parameter(torch.Tensor(1, gaussians).normal_()) def forward(self, *inputs): return self.means, torch.exp(self.pre_stds), f.softmax(self.pi, dim=1) model = _model(3) optimizer = torch.optim.Adam(model.parameters()) iterations = 100000 log_step = iterations // 10 pbar = tqdm(total=iterations) cum_loss = 0 for i in range(iterations): batch = samples[torch.LongTensor(128).random_(0, n_samples)] m, s, p = model.forward() loss = gmm_loss(batch, m, s, p) cum_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_postfix_str("avg_loss={:10.6f}".format( cum_loss / (i + 1))) pbar.update(1) if i % log_step == log_step - 1: print(m) print(s) print(p)
def sample_true(batch_size): # print (true_mixture_weights.shape) cat = Categorical(probs=torch.tensor(true_mixture_weights)) cluster = cat.sample([batch_size]) # [B] mean = (cluster*10.).float() std = torch.ones([batch_size]) *5. # print (cluster.shape) # fsd # norm = Normal(torch.tensor([cluster*10.]).float(), torch.tensor([5.0]).float()) norm = Normal(mean, std) samp = norm.sample() # print (samp.shape) # fadsf samp = samp.view(batch_size, 1) return samp
def reinforce_baseline(surrogate, x, logits, mixtureweights, k=1, get_grad=False): B = logits.shape[0] probs = torch.softmax(logits, dim=1) outputs = {} cat = Categorical(probs=probs) grads =[] # net_loss = 0 for jj in range(k): cluster_H = cat.sample() outputs['logq'] = logq = cat.log_prob(cluster_H).view(B,1) outputs['logpx_given_z'] = logpx_given_z = logprob_undercomponent(x, component=cluster_H) outputs['logpz'] = logpz = torch.log(mixtureweights[cluster_H]).view(B,1) logpxz = logpx_given_z + logpz #[B,1] surr_pred = surrogate.net(x) outputs['f'] = f = logpxz - logq - 1. # outputs['net_loss'] = net_loss = net_loss - torch.mean((f.detach() ) * logq) outputs['net_loss'] = net_loss = - torch.mean((f.detach() - surr_pred.detach()) * logq) # net_loss += - torch.mean( -logq.detach()*logq) # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred)) grad_logq = torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0] surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq )**2) if get_grad: grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0] grads.append(grad) # net_loss = net_loss/ k if get_grad: grads = torch.stack(grads) # print (grads.shape) outputs['grad_avg'] = torch.mean(torch.mean(grads, dim=0),dim=0) outputs['grad_std'] = torch.std(grads, dim=0)[0] outputs['surr_loss'] = surr_loss # return net_loss, f, logpx_given_z, logpz, logq return outputs
def sample_relax(probs): cat = Categorical(probs=probs) #Sample z u = torch.rand(B,C).cuda() u = u.clamp(1e-8, 1.-1e-8) gumbels = -torch.log(-torch.log(u)) z = torch.log(probs) + gumbels b = torch.argmax(z, dim=1) logprob = cat.log_prob(b).view(B,1) #Sample z_tilde u_b = torch.rand(B,1).cuda() u_b = u_b.clamp(1e-8, 1.-1e-8) z_tilde_b = -torch.log(-torch.log(u_b)) u = torch.rand(B,C).cuda() u = u.clamp(1e-8, 1.-1e-8) z_tilde = -torch.log((- torch.log(u) / probs) - torch.log(u_b)) z_tilde[:,b] = z_tilde_b return z, b, logprob, z_tilde, gumbels
def sample_relax_given_class_k(logits, samp, k): cat = Categorical(logits=logits) b = samp #torch.argmax(z, dim=1) logprob = cat.log_prob(b).view(B,1) zs = [] z_tildes = [] for i in range(k): u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) gumbels = -torch.log(-torch.log(u)) z = logits + gumbels u_b = torch.gather(input=u, dim=1, index=b.view(B,1)) z_tilde_b = -torch.log(-torch.log(u_b)) z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b)) z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b) z = z_tilde u_b = torch.gather(input=u, dim=1, index=b.view(B,1)) z_tilde_b = -torch.log(-torch.log(u_b)) u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b)) z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b) zs.append(z) z_tildes.append(z_tilde) zs= torch.stack(zs) z_tildes= torch.stack(z_tildes) z = torch.mean(zs, dim=0) z_tilde = torch.mean(z_tildes, dim=0) return z, z_tilde, logprob
def reinforce(x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = Categorical(probs=probs) net_loss = 0 for jj in range(k): cluster_H = cat.sample() logq = cat.log_prob(cluster_H).view(B,1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B,1) logpxz = logpx_given_z + logpz #[B,1] f = logpxz - logq net_loss += - torch.mean((f.detach() - 1.) * logq) # net_loss += - torch.mean( -logq.detach()*logq) net_loss = net_loss/ k return net_loss, f, logpx_given_z, logpz, logq
def relax_grad(x, logits, b, surrogate, mixtureweights): B = logits.shape[0] C = logits.shape[1] cat = Categorical(logits=logits) # u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda() u = myclamp(torch.rand(B,C).cuda()) gumbels = -torch.log(-torch.log(u)) z = logits + gumbels # b = torch.argmax(z, dim=1) #.view(B,1) logq = cat.log_prob(b).view(B,1) surr_input = torch.cat([z, x, logits.detach()], dim=1) cz = surrogate.net(surr_input) z_tilde = sample_relax_given_b(logits, b) surr_input = torch.cat([z_tilde, x, logits.detach()], dim=1) cz_tilde = surrogate.net(surr_input) logpx_given_z = logprob_undercomponent(x, component=b) logpz = torch.log(mixtureweights[b]).view(B,1) logpxz = logpx_given_z + logpz #[B,1] f = logpxz - logq grad_logq = torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0] grad_surr_z = torch.autograd.grad([torch.mean(cz)], [logits], create_graph=True, retain_graph=True)[0] grad_surr_z_tilde = torch.autograd.grad([torch.mean(cz_tilde)], [logits], create_graph=True, retain_graph=True)[0] # surr_loss = torch.mean(((f.detach() - cz_tilde) * grad_logq - grad_logq + grad_surr_z - grad_surr_z_tilde)**2, dim=1, keepdim=True) surr_loss = ((f.detach() - cz_tilde) * grad_logq - grad_logq + grad_surr_z - grad_surr_z_tilde)**2 # print (surr_loss.shape) # print (logq.shape) # fasda # print (surr_loss, torch.exp(logq)) return surr_loss, torch.exp(logq)
def _distribution(self, obs): """Takes the observation and outputs a distribution over actions.""" logits = self.logits_net(obs) return Categorical(logits=logits)
def compute_run(self, obs): embedding = self.encode(obs) x_dist = self.actor(embedding) dist = Categorical(logits=F.log_softmax(x_dist, dim=1)) value = self.critic(embedding).squeeze(1) return dist, value
def forward(self, audio_feature, decode_step, tf_rate=0.0, teacher=None, state_len=None): bs = audio_feature.shape[0] # Encode encode_feature, encode_len = self.encoder(audio_feature, state_len) ctc_output = None att_output = None att_maps = None # CTC based decoding if self.joint_ctc: ctc_output = self.ctc_layer(encode_feature) # Attention based decoding if self.joint_att: if teacher is not None: teacher = self.embed(teacher) # Init (init char = <SOS>, reset all rnn state and cell) self.decoder.init_rnn(encode_feature) self.attention.reset_enc_mem() last_char = self.embed( torch.zeros((bs), dtype=torch.long).to( next(self.decoder.parameters()).device)) output_char_seq = [] output_att_seq = [[]] * self.attention.num_head # Decode for t in range(decode_step): # Attend (inputs current state of first layer, encoded features) attention_score, context = self.attention( self.decoder.state_list[0], encode_feature, encode_len) # Spell (inputs context + embedded last character) decoder_input = torch.cat([last_char, context], dim=-1) dec_out = self.decoder(decoder_input) # To char cur_char = self.char_trans(dec_out) # Teacher forcing if (teacher is not None): if random.random() <= tf_rate: last_char = teacher[:, t + 1, :] else: sampled_char = Categorical(F.softmax(cur_char, dim=-1)).sample() last_char = self.embed(sampled_char) else: last_char = self.embed(torch.argmax(cur_char, dim=-1)) output_char_seq.append(cur_char) for head, a in enumerate(attention_score): output_att_seq[head].append(a.cpu()) att_output = torch.stack(output_char_seq, dim=1) att_maps = [torch.stack(att, dim=1) for att in output_att_seq] return ctc_output, encode_len, att_output, att_maps
class OneHotCategorical(Distribution): r""" Creates a one-hot categorical distribution parameterized by :attr:`probs` or :attr:`logits`. Samples are one-hot coded vectors of size ``probs.size(-1)``. .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension. attr:`probs` will return this normalized value. The `logits` argument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension. attr:`logits` will return this normalized value. See also: :func:`torch.distributions.Categorical` for specifications of :attr:`probs` and :attr:`logits`. Example:: >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 tensor([ 0., 0., 0., 1.]) Args: probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """ arg_constraints = {'probs': constraints.simplex, 'logits': constraints.real_vector} support = constraints.one_hot has_enumerate_support = True def __init__(self, probs=None, logits=None, validate_args=None): self._categorical = Categorical(probs, logits) batch_shape = self._categorical.batch_shape event_shape = self._categorical.param_shape[-1:] super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(OneHotCategorical, _instance) batch_shape = torch.Size(batch_shape) new._categorical = self._categorical.expand(batch_shape) super(OneHotCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new def _new(self, *args, **kwargs): return self._categorical._new(*args, **kwargs) @property def _param(self): return self._categorical._param @property def probs(self): return self._categorical.probs @property def logits(self): return self._categorical.logits @property def mean(self): return self._categorical.probs @property def variance(self): return self._categorical.probs * (1 - self._categorical.probs) @property def param_shape(self): return self._categorical.param_shape def sample(self, sample_shape=torch.Size()): sample_shape = torch.Size(sample_shape) probs = self._categorical.probs num_events = self._categorical._num_events indices = self._categorical.sample(sample_shape) return torch.nn.functional.one_hot(indices, num_events).to(probs) def log_prob(self, value): if self._validate_args: self._validate_sample(value) indices = value.max(-1)[1] return self._categorical.log_prob(indices) def entropy(self): return self._categorical.entropy() def enumerate_support(self, expand=True): n = self.event_shape[0] values = torch.eye(n, dtype=self._param.dtype, device=self._param.device) values = values.view((n,) + (1,) * len(self.batch_shape) + (n,)) if expand: values = values.expand((n,) + self.batch_shape + (n,)) return values
n_steps = 100000 L2_losses = [] steps_list = [] for step in range(n_steps): optim.zero_grad() loss = 0 net_loss = 0 for i in range(batch_size): x = sample_true() logits = encoder.net(x) # print (logits.shape) # print (torch.softmax(logits, dim=0)) # fsfd cat = Categorical(probs=torch.softmax(logits, dim=0)) cluster = cat.sample() logprob_cluster = cat.log_prob(cluster.detach()) # print (logprob_cluster) pxz = logprob_undercomponent( x, component=cluster, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=False) f = pxz - logprob_cluster # print (f) # logprob = logprob_givenmixtureeweights(x, needsoftmax_mixtureweight) net_loss += -f.detach() * logprob_cluster loss += -f loss = loss / batch_size net_loss = net_loss / batch_size
#REINFORCE print ('REINFORCE') # def sample_reinforce_given_class(logits, samp): # return logprob grads = [] for i in range (N): dist = Categorical(logits=logits) samp = dist.sample() logprob = dist.log_prob(samp) reward = f(samp) gradlogprob = torch.autograd.grad(outputs=logprob, inputs=(logits), retain_graph=True)[0] grads.append(reward*gradlogprob) print () grads = torch.stack(grads).view(N,C) # print (grads.shape) grad_mean_reinforce = torch.mean(grads,dim=0) grad_std_reinforce = torch.std(grads,dim=0) print ('REINFORCE') print ('mean:', grad_mean_reinforce) print ('std:', grad_std_reinforce)
n_steps = 100000 L2_losses = [] steps_list = [] for step in range(n_steps): optim.zero_grad() loss = 0 net_loss = 0 for i in range(batch_size): x = sample_true() logits = encoder.net(x) # print (logits.shape) # print (torch.softmax(logits, dim=0)) # fsfd cat = Categorical(probs= torch.softmax(logits, dim=0)) cluster = cat.sample() logprob_cluster = cat.log_prob(cluster.detach()) # print (logprob_cluster) pxz = logprob_undercomponent(x, component=cluster, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=False) f = pxz - logprob_cluster # print (f) # logprob = logprob_givenmixtureeweights(x, needsoftmax_mixtureweight) net_loss += -f.detach() * logprob_cluster loss += -f loss = loss / batch_size net_loss = net_loss / batch_size # print (loss, net_loss) loss.backward(retain_graph=True)
def get_policy(self, obs): logits = self.policy(obs) if self.discrete_action: return Categorical(logits=logits) else: return Normal(logits, self.log_std.exp())
def valor(args): if not hasattr(args, "get"): args.get = args.__dict__.get env_fn = args.get('env_fn', lambda: gym.make('HalfCheetah-v2')) actor_critic = args.get('actor_critic', ActorCritic) ac_kwargs = args.get('ac_kwargs', {}) disc = args.get('disc', Discriminator) dc_kwargs = args.get('dc_kwargs', {}) seed = args.get('seed', 0) episodes_per_epoch = args.get('episodes_per_epoch', 40) epochs = args.get('epochs', 50) gamma = args.get('gamma', 0.99) pi_lr = args.get('pi_lr', 3e-4) vf_lr = args.get('vf_lr', 1e-3) dc_lr = args.get('dc_lr', 2e-3) train_v_iters = args.get('train_v_iters', 80) train_dc_iters = args.get('train_dc_iters', 50) train_dc_interv = args.get('train_dc_interv', 2) lam = args.get('lam', 0.97) max_ep_len = args.get('max_ep_len', 1000) logger_kwargs = args.get('logger_kwargs', {}) context_dim = args.get('context_dim', 4) max_context_dim = args.get('max_context_dim', 64) save_freq = args.get('save_freq', 10) k = args.get('k', 1) logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) # seed += 10000 * proc_id() torch.manual_seed(seed) np.random.seed(seed) env = env_fn() obs_dim = env.observation_space.shape act_dim = env.action_space.shape ac_kwargs['action_space'] = env.action_space # Model actor_critic = actor_critic(input_dim=obs_dim[0] + max_context_dim, **ac_kwargs) disc = disc(input_dim=obs_dim[0], context_dim=max_context_dim, **dc_kwargs) # Buffer local_episodes_per_epoch = episodes_per_epoch # int(episodes_per_epoch / num_procs()) buffer = Buffer(max_context_dim, obs_dim[0], act_dim[0], local_episodes_per_epoch, max_ep_len, train_dc_interv) # Count variables var_counts = tuple( count_vars(module) for module in [actor_critic.policy, actor_critic.value_f, disc.policy]) logger.log('\nNumber of parameters: \t pi: %d, \t v: %d, \t d: %d\n' % var_counts) # Optimizers #Optimizer for RL Policy train_pi = torch.optim.Adam(actor_critic.policy.parameters(), lr=pi_lr) #Optimizer for value function (for actor-critic) train_v = torch.optim.Adam(actor_critic.value_f.parameters(), lr=vf_lr) #Optimizer for decoder train_dc = torch.optim.Adam(disc.policy.parameters(), lr=dc_lr) #pdb.set_trace() # Parameters Sync #sync_all_params(actor_critic.parameters()) #sync_all_params(disc.parameters()) ''' Training function ''' def update(e): obs, act, adv, pos, ret, logp_old = [ torch.Tensor(x) for x in buffer.retrieve_all() ] # Policy #pdb.set_trace() _, logp, _ = actor_critic.policy(obs, act, batch=False) #pdb.set_trace() entropy = (-logp).mean() # Policy loss pi_loss = -(logp * (k * adv + pos)).mean() # Train policy (Go through policy update) train_pi.zero_grad() pi_loss.backward() # average_gradients(train_pi.param_groups) train_pi.step() # Value function v = actor_critic.value_f(obs) v_l_old = F.mse_loss(v, ret) for _ in range(train_v_iters): v = actor_critic.value_f(obs) v_loss = F.mse_loss(v, ret) # Value function train train_v.zero_grad() v_loss.backward() # average_gradients(train_v.param_groups) train_v.step() # Discriminator if (e + 1) % train_dc_interv == 0: print('Discriminator Update!') con, s_diff = [torch.Tensor(x) for x in buffer.retrieve_dc_buff()] _, logp_dc, _ = disc(s_diff, con) d_l_old = -logp_dc.mean() # Discriminator train for _ in range(train_dc_iters): _, logp_dc, _ = disc(s_diff, con) d_loss = -logp_dc.mean() train_dc.zero_grad() d_loss.backward() # average_gradients(train_dc.param_groups) train_dc.step() _, logp_dc, _ = disc(s_diff, con) dc_l_new = -logp_dc.mean() else: d_l_old = 0 dc_l_new = 0 # Log the changes _, logp, _, v = actor_critic(obs, act) pi_l_new = -(logp * (k * adv + pos)).mean() v_l_new = F.mse_loss(v, ret) kl = (logp_old - logp).mean() logger.store(LossPi=pi_loss, LossV=v_l_old, KL=kl, Entropy=entropy, DeltaLossPi=(pi_l_new - pi_loss), DeltaLossV=(v_l_new - v_l_old), LossDC=d_l_old, DeltaLossDC=(dc_l_new - d_l_old)) # logger.store(Adv=adv.reshape(-1).numpy().tolist(), Pos=pos.reshape(-1).numpy().tolist()) start_time = time.time() #Resets observations, rewards, done boolean o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0 #Creates context distribution where each logit is equal to one (This is first place to make change) context_dim_prob_dict = { i: 1 / context_dim if i < context_dim else 0 for i in range(max_context_dim) } last_phi_dict = {i: 0 for i in range(context_dim)} context_dist = Categorical( probs=torch.Tensor(list(context_dim_prob_dict.values()))) total_t = 0 for epoch in range(epochs): #Sets actor critic and decoder (discriminator) into eval mode actor_critic.eval() disc.eval() #Runs the policy local_episodes_per_epoch before updating the policy for index in range(local_episodes_per_epoch): # Sample from context distribution and one-hot encode it (Step 2) # Every time we run the policy we sample a new context c = context_dist.sample() c_onehot = F.one_hot(c, max_context_dim).squeeze().float() for _ in range(max_ep_len): concat_obs = torch.cat( [torch.Tensor(o.reshape(1, -1)), c_onehot.reshape(1, -1)], 1) ''' Feeds in observation and context into actor_critic which spits out a distribution Label is a sample from the observation pi is the action sampled logp is the log probability of some other action a logp_pi is the log probability of pi v_t is the value function ''' a, _, logp_t, v_t = actor_critic(concat_obs) #Stores context and all other info about the state in the buffer buffer.store(c, concat_obs.squeeze().detach().numpy(), a.detach().numpy(), r, v_t.item(), logp_t.detach().numpy()) logger.store(VVals=v_t) o, r, d, _ = env.step(a.detach().numpy()[0]) ep_ret += r ep_len += 1 total_t += 1 terminal = d or (ep_len == max_ep_len) if terminal: # Key stuff with discriminator dc_diff = torch.Tensor(buffer.calc_diff()).unsqueeze(0) #Context con = torch.Tensor([float(c)]).unsqueeze(0) #Feed in differences between each state in your trajectory and a specific context #Here, this is just the log probability of the label it thinks it is _, _, log_p = disc(dc_diff, con) buffer.end_episode(log_p.detach().numpy()) logger.store(EpRet=ep_ret, EpLen=ep_len) o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0 if (epoch % save_freq == 0) or (epoch == epochs - 1): logger.save_state({'env': env}, [actor_critic, disc], None) # Sets actor_critic and discriminator into training mode actor_critic.train() disc.train() update(epoch) #Need to implement curriculum learning here to update context distribution ''' #Psuedocode: Loop through each of d episodes taken in local_episodes_per_epoch and check log probability from discrimantor If >= 0.86, increase k in the following manner: k = min(int(1.5*k + 1), Kmax) Kmax = 64 ''' decoder_accs = [] stag_num = 10 stag_pct = 0.05 if (epoch + 1) % train_dc_interv == 0 and epoch > 0: #pdb.set_trace() con, s_diff = [torch.Tensor(x) for x in buffer.retrieve_dc_buff()] print("Context: ", con) print("num_contexts", len(con)) _, logp_dc, _ = disc(s_diff, con) log_p_context_sample = logp_dc.mean().detach().numpy() print("Log Probability context sample", log_p_context_sample) decoder_accuracy = np.exp(log_p_context_sample) print("Decoder Accuracy", decoder_accuracy) logger.store(LogProbabilityContext=log_p_context_sample, DecoderAccuracy=decoder_accuracy) ''' Create score (phi(i)) = -log_p_context_sample.mean() for each specific context Assign phis to each specific context Get p(i) in the following manner: (phi(i) + epsilon) Get Probabilities by doing p(i)/sum of all p(i)'s ''' logp_np = logp_dc.detach().numpy() con_np = con.detach().numpy() phi_dict = {i: 0 for i in range(context_dim)} count_dict = {i: 0 for i in range(context_dim)} for i in range(len(logp_np)): current_con = con_np[i] phi_dict[current_con] += logp_np[i] count_dict[current_con] += 1 print(phi_dict) phi_dict = { k: last_phi_dict[k] if count_dict[k] == 0 else (-1) * v / count_dict[k] for (k, v) in phi_dict.items() } sorted_dict = dict( sorted(phi_dict.items(), key=lambda item: item[1], reverse=True)) sorted_dict_keys = list(sorted_dict.keys()) rank_dict = { sorted_dict_keys[i]: 1 / (i + 1) for i in range(len(sorted_dict_keys)) } rank_dict_sum = sum(list(rank_dict.values())) context_dim_prob_dict = { k: rank_dict[k] / rank_dict_sum if k < context_dim else 0 for k in context_dim_prob_dict.keys() } print(context_dim_prob_dict) decoder_accs.append(decoder_accuracy) stagnated = (len(decoder_accs) > stag_num and (decoder_accs[-stag_num - 1] - decoder_accuracy) / stag_num < stag_pct) if stagnated: new_context_dim = max(int(0.75 * context_dim), 5) elif decoder_accuracy >= 0.86: new_context_dim = min(int(1.5 * context_dim + 1), max_context_dim) if stagnated or decoder_accuracy >= 0.86: print("new_context_dim: ", new_context_dim) new_context_prob_arr = np.array( new_context_dim * [1 / new_context_dim] + (max_context_dim - new_context_dim) * [0]) context_dist = Categorical( probs=ptu.from_numpy(new_context_prob_arr)) context_dim = new_context_dim for i in range(context_dim): if i in phi_dict: last_phi_dict[i] = phi_dict[i] elif i not in last_phi_dict: last_phi_dict[i] = max(phi_dict.values()) buffer.clear_dc_buff() else: logger.store(LogProbabilityContext=0, DecoderAccuracy=0) # Log logger.store(ContextDim=context_dim) logger.log_tabular('Epoch', epoch) logger.log_tabular('EpRet', with_min_and_max=True) logger.log_tabular('EpLen', average_only=True) logger.log_tabular('VVals', with_min_and_max=True) logger.log_tabular('TotalEnvInteracts', total_t) logger.log_tabular('LossPi', average_only=True) logger.log_tabular('DeltaLossPi', average_only=True) logger.log_tabular('LossV', average_only=True) logger.log_tabular('DeltaLossV', average_only=True) logger.log_tabular('LossDC', average_only=True) logger.log_tabular('DeltaLossDC', average_only=True) logger.log_tabular('Entropy', average_only=True) logger.log_tabular('KL', average_only=True) logger.log_tabular('Time', time.time() - start_time) logger.log_tabular('LogProbabilityContext', average_only=True) logger.log_tabular('DecoderAccuracy', average_only=True) logger.log_tabular('ContextDim', average_only=True) logger.dump_tabular()
def sample_reinforce_given_class(logits, samp): dist = Categorical(logits=logits) logprob = dist.log_prob(samp) return logprob
# mylogprobgrad = torch.autograd.grad(outputs=mylogprob, inputs=(probs), retain_graph=True)[0] print('rewards', rewards) print('probs', probs) print('logits', logits) #REINFORCE print('REINFORCE') # def sample_reinforce_given_class(logits, samp): # return logprob grads = [] for i in range(N): dist = Categorical(logits=logits) samp = dist.sample() logprob = dist.log_prob(samp) reward = f(samp) gradlogprob = torch.autograd.grad(outputs=logprob, inputs=(logits), retain_graph=True)[0] grads.append(reward * gradlogprob) print() grads = torch.stack(grads).view(N, C) # print (grads.shape) grad_mean_reinforce = torch.mean(grads, dim=0) grad_std_reinforce = torch.std(grads, dim=0) print('REINFORCE')
#REINFORCE # dist = LogitRelaxedBernoulli(torch.Tensor([1.]), bern_param) # dist_bernoulli = Bernoulli(bern_param) C= 2 n_components = C B=1 probs = torch.ones(B,C) bern_param = bern_param.view(B,1) aa = 1 - bern_param probs = torch.cat([aa, bern_param], dim=1) cat = Categorical(probs= probs) grads = [] for i in range(n): b = cat.sample() logprob = cat.log_prob(b.detach()) # b_ = torch.argmax(z, dim=1) logprobgrad = torch.autograd.grad(outputs=logprob, inputs=(bern_param), retain_graph=True)[0] grad = f(b) * logprobgrad grads.append(grad[0][0].data.numpy()) print ('Grad Estimator: Reinfoce categorical') print ('Grad mean', np.mean(grads)) print ('Grad std', np.std(grads))
def decode_where(self, input_coord_logits, input_attri_logits, input_patch_vectors, sample_mode): """ Inputs: where_states containing - **coord_logits** (bsize, 1, grid_dim) - **attri_logits** (bsize, 1, scale_ratio_dim, grid_dim) - **patch_vectors** (bsize, 1, patch_feature_dim, grid_dim) sample_mode 0: top 1, 1: multinomial Outputs - **sample_inds** (bsize, 3) - **sample_vecs** (bsize, patch_feature_dim) """ ############################################################## # Sampling locations ############################################################## coord_logits = input_coord_logits.squeeze(1) if sample_mode == 0: _, sample_coord_inds = torch.max(coord_logits + 1.0, dim=-1, keepdim=True) else: sample_coord_inds = Categorical(coord_logits).sample().unsqueeze( -1) ############################################################## # Sampling attributes and patch vectors ############################################################## patch_vectors = input_patch_vectors.squeeze(1) bsize, tsize, grid_dim = patch_vectors.size() aux_pos_inds = sample_coord_inds.expand(bsize, tsize).unsqueeze(-1) sample_patch_vectors = torch.gather(patch_vectors, -1, aux_pos_inds).squeeze(-1) attri_logits = input_attri_logits.squeeze(1) bsize, tsize, grid_dim = attri_logits.size() aux_pos_inds = sample_coord_inds.expand(bsize, tsize).unsqueeze(-1) local_logits = torch.gather(attri_logits, -1, aux_pos_inds).squeeze(-1) scale_logits = local_logits[:, :self.cfg.num_scales] ratio_logits = local_logits[:, self.cfg.num_scales:] if sample_mode == 0: _, sample_scale_inds = torch.max(scale_logits + 1.0, dim=-1, keepdim=True) _, sample_ratio_inds = torch.max(ratio_logits + 1.0, dim=-1, keepdim=True) else: sample_scale_inds = Categorical(scale_logits).sample().unsqueeze( -1) sample_ratio_inds = Categorical(ratio_logits).sample().unsqueeze( -1) sample_inds = torch.cat( [sample_coord_inds, sample_scale_inds, sample_ratio_inds], -1) return sample_inds, sample_patch_vectors
def forward(self, hidden_state=None): """ Performs a forward pass. If training, use Gumbel Softmax (hard) for sampling, else use discrete sampling. Hidden state here represents the encoded image/metadata - initializes the RNN from it. """ # hidden_state = self.input_module(hidden_state) state, batch_size = self._init_state(hidden_state, type(self.rnn)) # Init output if not (self.vqvae and not self.discrete_communication and not self.rl): output = [ torch.zeros( (batch_size, self.vocab_size), dtype=torch.float32, device=self.device, ) ] output[0][:, self.sos_id] = 1.0 else: # In vqvae case, there is no sos symbol, since all words come from the unordered embedding table. # It is not possible to index code words by sos or eos symbols, since the number of codewords # is not necessarily the vocab size! output = [ torch.zeros( (batch_size, self.vocab_size), dtype=torch.float32, device=self.device, ) ] # Keep track of sequence lengths initial_length = self.output_len + 1 # add the sos token seq_lengths = ( torch.ones([batch_size], dtype=torch.int64, device=self.device) * initial_length ) # [initial_length, initial_length, ..., initial_length]. This gets reduced whenever it ends somewhere. embeds = [] # keep track of the embedded sequence sentence_probability = torch.zeros((batch_size, self.vocab_size), device=self.device) losses_2_3 = torch.empty(self.output_len, device=self.device) entropy = torch.empty((batch_size, self.output_len), device=self.device) message_logits = torch.empty((batch_size, self.output_len), device=self.device) if self.vqvae: distance_computer = EmbeddingtableDistances(self.e) for i in range(self.output_len): emb = torch.matmul(output[-1], self.embedding) embeds.append(emb) state = self.rnn.forward(emb, state) if type(self.rnn) is nn.LSTMCell: h, _ = state else: h = state indices = [None] * batch_size if not self.rl: if not self.vqvae: # That's the original baseline setting p = F.softmax(self.linear_out(h), dim=1) token, sentence_probability = self.calculate_token_gumbel_softmax( p, self.tau, sentence_probability, batch_size) else: pre_quant = self.linear_out(h) if not self.discrete_communication: token = self.vq.apply(pre_quant, self.e, indices) else: distances = distance_computer(pre_quant) softmin = F.softmax(-distances, dim=1) if not self.gumbel_softmax: token = self.hard_max.apply( softmin, indices, self.discrete_latent_number ) # This also updates the indices else: _, indices[:] = torch.max(softmin, dim=1) token, _ = self.calculate_token_gumbel_softmax( softmin, self.tau, 0, batch_size) else: if not self.vqvae: all_logits = F.log_softmax(self.linear_out(h) / self.tau, dim=1) else: pre_quant = self.linear_out(h) distances = distance_computer(pre_quant) all_logits = F.log_softmax(-distances / self.tau, dim=1) _, indices[:] = torch.max(all_logits, dim=1) distr = Categorical(logits=all_logits) entropy[:, i] = distr.entropy() if self.training: token_index = distr.sample() token = to_one_hot(token_index, n_dims=self.vocab_size) else: token_index = all_logits.argmax(dim=1) token = to_one_hot(token_index, n_dims=self.vocab_size) message_logits[:, i] = distr.log_prob(token_index) if not (self.vqvae and not self.discrete_communication and not self.rl): # Whenever we have a meaningful eos symbol, we prune the messages in the end self._calculate_seq_len(seq_lengths, token, initial_length, seq_pos=i + 1) if self.vqvae: loss_2 = torch.mean( torch.norm(pre_quant.detach() - self.e[indices], dim=1)**2) loss_3 = torch.mean( torch.norm(pre_quant - self.e[indices].detach(), dim=1)**2) loss_2_3 = ( loss_2 + self.beta * loss_3 ) # This corresponds to the second and third loss term in VQ-VAE losses_2_3[i] = loss_2_3 token = token.to(self.device) output.append(token) messages = torch.stack(output, dim=1) loss_2_3_out = torch.mean(losses_2_3) return ( messages, seq_lengths, entropy, torch.stack(embeds, dim=1), sentence_probability, loss_2_3_out, message_logits, )
def evaluate_actions_sil(pi, actions): cate_dist = Categorical(pi) return cate_dist.log_prob( actions.squeeze(-1)).unsqueeze(-1), cate_dist.entropy().unsqueeze(-1)
def select_actions(pi, deterministic=False): cate_dist = Categorical(pi) if deterministic: return torch.argmax(pi, dim=1).item() else: return cate_dist.sample().unsqueeze(-1)
def forward(self, obs, memory, instr_embedding=None): if self.use_desc and instr_embedding is None: if self.enable_instr and self.arch == "fusion": instr_embedding, instr_embedding2 = self._get_instr_embedding( obs.instr) else: instr_embedding = self._get_instr_embedding(obs.instr) if self.use_desc and self.lang_model == "attgru": # outputs: B x L x D # memory: B x M mask = (obs.instr != 0).float() instr_embedding = instr_embedding[:, :mask.shape[1]] keys = self.memory2key(memory) pre_softmax = (keys[:, None, :] * instr_embedding).sum(2) + 1000 * mask attention = F.softmax(pre_softmax, dim=1) instr_embedding = (instr_embedding * attention[:, :, None]).sum(1) x = torch.transpose(torch.transpose(obs.image, 1, 3), 2, 3) if self.arch.startswith("expert_filmcnn"): x = self.image_conv(x) for controler in self.controllers: x = controler(x, instr_embedding) x = F.relu(self.film_pool(x)) elif self.arch == "fusion": # old fusion model ''' x = self.image_conv(x) w = self.w_conv(x) N,_,W,H = w.shape w = w.view([N, self.instr_sents, -1]) w = F.softmax(w,dim=1) y = torch.matmul(instr_embedding, w).view([N, 128, W, H]) ''' # new fusion model: separate cnns for image extractor and attention module input x_feat = self.image_conv(x) w = self.w_conv(x) N, _, W, H = w.shape w = w.view([N, self.instr_sents + 1, -1]) w = F.softmax(w, dim=1) y = torch.matmul(instr_embedding, w[:, :-1]).view([N, 128, W, H]) x = torch.cat([x_feat, y], axis=1) x = self.combined_conv(x) x = x.view(x.shape[0], x.shape[1], 1, 1) if self.enable_instr: for controler in self.controllers: x = controler(x, instr_embedding2) x = F.relu(x) else: x = self.image_conv(x) x = x.reshape(x.shape[0], -1) if self.use_memory: hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:]) hidden = self.memory_rnn(x, hidden) embedding = hidden[0] memory = torch.cat(hidden, dim=1) else: embedding = x if self.use_desc and not "filmcnn" in self.arch and not "fusion" in self.arch: embedding = torch.cat((embedding, instr_embedding), dim=1) if hasattr(self, 'aux_info') and self.aux_info: extra_predictions = { info: self.extra_heads[info](embedding) for info in self.extra_heads } else: extra_predictions = dict() x = self.actor(embedding) dist = Categorical(logits=F.log_softmax(x, dim=1)) x = self.critic(embedding) value = x.squeeze(1) return { 'dist': dist, 'value': value, 'memory': memory, 'extra_predictions': extra_predictions }
def update_parameters(self, exps): # Collect experiences for _ in range(self.epochs): # Initialize log values log_entropies = [] log_values = [] log_policy_losses = [] log_value_losses = [] log_grad_norms = [] for inds in self._get_batches_starting_indexes(): # Initialize batch values batch_entropy = 0 batch_value = 0 batch_policy_loss = 0 batch_value_loss = 0 batch_loss = 0 # Initialize memory if self.acmodel.recurrent: memory = exps.memory[inds] for i in range(self.recurrence): # Create a sub-batch of experience sb = exps[inds + i] # Compute loss if self.acmodel.recurrent: if self.variable_view: dist, gaze, value, memory = self.acmodel( sb.obs, memory * sb.mask) # Combine action and gaze distributions into single 1D distribution # gaze_dist = gaze[0].probs.ger(gaze[1].probs) # dist = Categorical(dist.probs.ger(gaze_dist.view(-1))) full_dist = [] for j in range(inds.shape[0]): gaze_dist = gaze[0].probs[j].ger( gaze[1].probs[j]).view([-1]) full_action_space_dist = Categorical( (dist.probs[j].ger(gaze_dist)).view(-1)) full_dist.append(full_action_space_dist.probs) dist = Categorical(torch.stack(full_dist)) sb.action = sb.action.long( ) * 22 + sb.gaze[:, 0] * 11 + sb.gaze[:, 1] else: dist, value, memory = self.acmodel( sb.obs, memory * sb.mask) else: dist, value = self.acmodel(sb.obs) entropy = dist.entropy().mean() ratio = torch.exp(dist.log_prob(sb.action) - sb.log_prob) surr1 = ratio * sb.advantage surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * sb.advantage policy_loss = -torch.min(surr1, surr2).mean() value_clipped = sb.value + torch.clamp( value - sb.value, -self.clip_eps, self.clip_eps) surr1 = (value - sb.returnn).pow(2) surr2 = (value_clipped - sb.returnn).pow(2) value_loss = torch.max(surr1, surr2).mean() loss = policy_loss - self.entropy_coef * entropy + self.value_loss_coef * value_loss # Update batch values batch_entropy += entropy.item() batch_value += value.mean().item() batch_policy_loss += policy_loss.item() batch_value_loss += value_loss.item() batch_loss += loss # Update memories for next epoch if self.acmodel.recurrent and i < self.recurrence - 1: exps.memory[inds + i + 1] = memory.detach() # Update batch values batch_entropy /= self.recurrence batch_value /= self.recurrence batch_policy_loss /= self.recurrence batch_value_loss /= self.recurrence batch_loss /= self.recurrence # Update actor-critic self.optimizer.zero_grad() batch_loss.backward() grad_norm = sum( p.grad.data.norm(2).item()**2 for p in self.acmodel.parameters())**0.5 torch.nn.utils.clip_grad_norm_(self.acmodel.parameters(), self.max_grad_norm) self.optimizer.step() # Update log values log_entropies.append(batch_entropy) log_values.append(batch_value) log_policy_losses.append(batch_policy_loss) log_value_losses.append(batch_value_loss) log_grad_norms.append(grad_norm) # Log some values logs = { "entropy": numpy.mean(log_entropies), "value": numpy.mean(log_values), "policy_loss": numpy.mean(log_policy_losses), "value_loss": numpy.mean(log_value_losses), "grad_norm": numpy.mean(log_grad_norms) } return logs
def forward(self, comm, obs, memory, instr_embedding=None): message_embedding = torch.matmul(comm, self.comm_embed.weight) _, hidden = self.comm_encoder_rnn(message_embedding) message_encoded = hidden[-1] if self.student_obs_type == "vision": # Calculating instruction embedding if self.use_instr and instr_embedding is None: if self.lang_model == 'gru': _, hidden = self.instr_rnn(self.word_embedding(obs.instr)) instr_embedding = hidden[-1] # Calculating the image imedding x = torch.transpose(torch.transpose(obs.image, 1, 3), 2, 3) if self.arch.startswith("expert_filmcnn"): image_embedding = self.image_conv(x) # Calculating FiLM_embedding from image and instruction embedding for controler in self.controllers: x = controler(image_embedding, instr_embedding) FiLM_embedding = F.relu(self.film_pool(x)) else: FiLM_embedding = self.image_conv(x) FiLM_embedding = FiLM_embedding.reshape(FiLM_embedding.shape[0], -1) # Going through the memory layer if self.use_memory: hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:]) hidden = self.memory_rnn(FiLM_embedding, hidden) embedding = hidden[0] memory = torch.cat(hidden, dim=1) else: embedding = x if self.use_instr and not "filmcnn" in self.arch: embedding = torch.cat((embedding, instr_embedding), dim=1) if hasattr(self, 'aux_info') and self.aux_info: extra_predictions = { info: self.extra_heads[info](embedding) for info in self.extra_heads } else: extra_predictions = dict() elif self.student_obs_type == "blind": embedding = torch.zeros(message_embedding.shape[0], self.semi_memory_size) extra_predictions = {} if torch.cuda.is_available(): embedding = embedding.cuda() else: raise ValueError( "Student observation type must be either vision or blind") #Policy part policy_input = torch.cat((embedding, message_encoded), dim=1) policy_input = self.dropout(policy_input) x = self.actor(policy_input) dist = Categorical(logits=F.log_softmax(x, dim=1)) x = self.critic(policy_input) value = x.squeeze(1) return { 'dist': dist, 'value': value, 'memory_student': memory, 'extra_predictions': extra_predictions }
def __init__(self, probs=None, logits=None, validate_args=None): self._categorical = Categorical(probs, logits) batch_shape = self._categorical.batch_shape event_shape = self._categorical.param_shape[-1:] super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def sample_action(self, state): actions_logits = self(state) distribution = Categorical(logits=actions_logits) action = distribution.sample() return action
def get_action_and_value(self, x, action=None): logits = self.actor(x) probs = Categorical(logits=logits) if action is None: action = probs.sample() return action, probs.log_prob(action), probs.entropy(), self.critic(x)
def __init__(self, temperature, probs=None, logits=None, validate_args=None): self._categorical = Categorical(probs, logits) self.temperature = temperature batch_shape = self._categorical.batch_shape event_shape = self._categorical.param_shape[-1:] super(ExpRelaxedCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def trainAttention(model=NaiveAttention,dimbed=50,dimout=2,maxiter=epochs,epsilon=0.01,reg=None,verbose=True): # Le paramètre "reg" permet de choisir de régulariser ou non avec un critère entropique. writer = SummaryWriter("runs") if model == NaiveAttention: print("\n//////////////////// Attention-based LinNet : naive baseline /////////////////\n") name = "attentionbase.pch" etiq = "/Base_SGD" elif model == SimpleAttention: print("\n///////////////// Attention-based LinNet : basic implementation //////////////\n") name = "attentionclassic.pch" etiq = "/Classic_SGD" elif model == FurtherAttention: print("\n///////////////// Attention-based LinNet : further improvements //////////////\n") name = "attentionfurther.pch" etiq = "/Further_SGD_regul" elif model == LSTMAttention: print("\n//////////////////// Attention-based LinNet : adding an LSTM /////////////////\n") name = "attentionlstm.pch" etiq = "/LSTM_SGD" elif model == BILSTMAttention: print("\n//////////////////// Attention-based LinNet : adding an BiLSTM /////////////////\n") name = "attentionbilstm.pch" etiq = "/BiLSTM_SGD" # Creating a checkpointed model savepath = Path(name) if savepath.is_file(): print("Restarting from previous state.") with savepath.open("rb") as fp : state = torch.load(fp) else: lin = model(dimbed,dimout).to(device) optim = torch.optim.SGD(params=lin.parameters(),lr=epsilon) state = State(lin,optim) loss = nn.CrossEntropyLoss() # Training the model for epoch in tqdm(range(state.epoch,maxiter)): state.model = state.model.train() losstrain = 0 accytrain = 0 divtrain = 0 for x, y in train_loader: state.optim.zero_grad() y = y.to(device) if model == NaiveAttention: preds = state.model(x) else: preds, attns = state.model(x) if model != NaiveAttention: entropytrain = Categorical(probs = attns.squeeze(2).t()).entropy() penalty = reg * torch.sum(entropytrain) if reg else 0 ltrain = loss(preds,y.long()) + penalty ltrain.backward() state.optim.step() state.iteration += 1 acctr = sum((preds.argmax(1) == y)).item() / y.shape[0] losstrain += ltrain accytrain += acctr divtrain += 1 #if model != NaiveAttention: # entropytrain = Categorical(probs = attns.squeeze(2).t()).entropy() state.model = state.model.eval() losstest = 0 accytest = 0 divtest = 0 for x, y in test_loader: with torch.no_grad(): y = y.to(device) if model == NaiveAttention : preds = state.model(x) else : preds, attns = state.model(x) ltest = loss(preds,y.long()) accts = sum((preds.argmax(1) == y)).item() / y.shape[0] losstest += ltest accytest += accts divtest += 1 # Saving the loss writer.add_scalars('Attention/Loss'+etiq,{'train':losstrain/divtrain,'test':losstest/divtest},epoch) writer.add_scalars('Attention/Accuracy'+etiq,{'train':accytrain/divtrain,'test':accytest/divtest},epoch) if model != NaiveAttention : entropytest = Categorical(probs = attns.squeeze(2).t()).entropy() writer.add_histogram('Attention/EntropyTest'+etiq,entropytest,epoch) writer.add_histogram('Attention/EntropyTrain'+etiq,entropytrain,epoch) if verbose: print('\nLOSS: \t\ttrain',(losstrain/divtrain).item(),'\t\ttest',(losstest/divtest).item()) print('ACCURACY: \ttrain',accytrain/divtrain,'\t\ttest',accytest/divtest) # Saving the current state after each epoch with savepath.open ("wb") as fp: state.epoch = epoch+1 torch.save(state, fp) print("\n\n\033[1mDone.\033[0m\n") writer.flush() writer.close() # ////////////////////////////////////////////////////////////////////////////////////////////////// </training loop> ////
def train(meta_decoder, decoder_optimizer, fclayers_for_hyper_params): global moving_average global moving_average_alpha decoder_hidden = meta_decoder.initHidden() decoder_optimizer.zero_grad() output = torch.zeros([1, 1, meta_decoder.output_size], device=device) softmax = nn.Softmax(dim=1) softmax_outputs_stored = list() loss = 0 # for i in range(3): output, decoder_hidden = meta_decoder(output, decoder_hidden) #print(hyper_params[i]) softmax_outputs_stored.append( softmax(fclayers_for_hyper_params[hyper_params[i][0]](output))) # output_interaction = softmax_outputs_stored[-1] type_of_interaction = Categorical(output_interaction).sample().tolist()[0] if type_of_interaction == 0: # PairwiseEuDist for i in range(3, 4): output, decoder_hidden = meta_decoder(output, decoder_hidden) softmax_outputs_stored.append( softmax(fclayers_for_hyper_params[hyper_params[i][0]](output))) elif type_of_interaction == 1: # PairwiseLog # no hyper-params for this interaction type pass else: # PointwiseMLPCE for i in range(4, 7): output, decoder_hidden = meta_decoder(output, decoder_hidden) softmax_outputs_stored.append( softmax(fclayers_for_hyper_params[hyper_params[i][0]](output))) # resulted_str = [] for outputs in softmax_outputs_stored: print("softmax_outputs: ", outputs) idx = Categorical(outputs).sample() resulted_str.append(idx.tolist()[0]) resulted_str[ 2] = type_of_interaction # the type of interaction has already been sampled before resulted_idx = resulted_str resulted_str = "_".join(map(str, resulted_str)) print("resulted_str: " + resulted_str) # reward = calc_reward_given_descriptor(resulted_str) if moving_average == -19013: moving_average = reward reward = 0.0 else: tmp = reward reward = reward - moving_average moving_average = moving_average_alpha * tmp + ( 1.0 - moving_average_alpha) * moving_average # print("current reward: " + str(reward)) print("current moving average: " + str(moving_average)) expectedReward = 0 for i in range(len(softmax_outputs_stored)): logprob = torch.log(softmax_outputs_stored[i][0][resulted_idx[i]]) expectedReward += logprob * reward loss = -expectedReward print('loss:', loss) # finally, backpropagate the loss according to the policy loss.backward() decoder_optimizer.step()
def HLAX(surrogate, surrogate2, x, logits, mixtureweights, k=1): B = logits.shape[0] probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) cat_bernoulli = Categorical(probs=probs) net_loss = 0 surr_loss = 0 for jj in range(k): cluster_S = cat.rsample() cluster_H = H(cluster_S) logq_z = cat.log_prob(cluster_S.detach()).view(B,1) logq_b = cat_bernoulli.log_prob(cluster_H.detach()).view(B,1) logpx_given_z = logprob_undercomponent(x, component=cluster_H) logpz = torch.log(mixtureweights[cluster_H]).view(B,1) logpxz = logpx_given_z + logpz #[B,1] f_z = logpxz - logq_z - 1. f_b = logpxz - logq_b - 1. surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] # surr_pred, alpha = surrogate.net(surr_input) surr_pred = surrogate.net(surr_input) alpha = torch.sigmoid(surrogate2.net(x)) net_loss += - torch.mean( alpha.detach()*(f_z.detach() - surr_pred.detach()) * logq_z + alpha.detach()*surr_pred + (1-alpha.detach())*(f_b.detach() ) * logq_b) # surr_loss += torch.mean(torch.abs(f_z.detach() - surr_pred)) grad_logq_z = torch.mean( torch.autograd.grad([torch.mean(logq_z)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) grad_logq_b = torch.mean( torch.autograd.grad([torch.mean(logq_b)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True) # print (alpha.shape, f_z.shape, surr_pred.shape, grad_logq_z.shape, grad_surr.shape) # fsdfa # grad_surr = torch.autograd.grad([surr_pred[0]], [logits], create_graph=True, retain_graph=True)[0] # print (grad_surr) # fsdfasd surr_loss += torch.mean( (alpha*(f_z.detach() - surr_pred) * grad_logq_z + alpha*grad_surr + (1-alpha)*(f_b.detach()) * grad_logq_b )**2 ) surr_dif = torch.mean(torch.abs(f_z.detach() - surr_pred)) # gradd = torch.autograd.grad([surr_loss], [alpha], create_graph=True, retain_graph=True)[0] # print (gradd) # fdsf grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0] grad_score = torch.autograd.grad([torch.mean((f_z.detach() - surr_pred.detach()) * logq_z)], [logits], create_graph=True, retain_graph=True)[0] grad_path = torch.mean(torch.abs(grad_path)) grad_score = torch.mean(torch.abs(grad_score)) net_loss = net_loss/ k surr_loss = surr_loss/ k return net_loss, f_b, logpx_given_z, logpz, logq_b, surr_loss, surr_dif, grad_path, grad_score, torch.mean(alpha)
def _distribution(self, obs): logits = self.logits_net(obs) return Categorical(logits=logits)
def get_pi(self, x): return Categorical(logits=self.actor(self.network(x.permute((0, 3, 1, 2)) / 255.0)))
def get_action(self, x, action=None): logits = self.actor(self.forward(x)) probs = Categorical(logits=logits) if action is None: action = probs.sample() return action, probs.log_prob(action), probs.entropy()
def get_policy(self, observations: np.array): """ Turn an observation into a policy action distribution """ logits = self.policy_net(observations) return Categorical(logits=logits)
def simplax(): def show_surr_preds(): batch_size = 1 rows = 3 cols = 1 fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150) for i in range(rows): x = sample_true(1).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda()) cluster_S = cat.rsample() cluster_H = H(cluster_S) logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1) check_nan(logprob_cluster) z = cluster_S n_evals = 40 x1 = np.linspace(-9,205, n_evals) x = torch.from_numpy(x1).view(n_evals,1).float().cuda() z = z.repeat(n_evals,1) cluster_H = cluster_H.repeat(n_evals,1) xz = torch.cat([z,x], dim=1) logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz #- logprob_cluster surr_pred = surrogate.net(xz) surr_pred = surr_pred.data.cpu().numpy() f = f.data.cpu().numpy() col =0 row = i # print (row) ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1) ax.plot(x1,surr_pred, label='Surr') ax.plot(x1,f, label='f') ax.set_title(str(cluster_H[0])) ax.legend() # save_dir = home+'/Documents/Grad_Estimators/GMM/' plt_path = exp_dir+'gmm_surr.png' plt.savefig(plt_path) print ('saved training plot', plt_path) plt.close() def plot_dist(): mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0) rows = 1 cols = 1 fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150) col =0 row = 0 ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1) xs = np.linspace(-9,205, 300) sum_ = np.zeros(len(xs)) # C = 20 for c in range(n_components): m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float()) ys = [] for x in xs: # component_i = (torch.exp(m.log_prob(x) )* ((c+5.) / 290.)).numpy() component_i = (torch.exp(m.log_prob(x) )* mixture_weights[c]).detach().cpu().numpy() ys.append(component_i) ys = np.reshape(np.array(ys), [-1]) sum_ += ys ax.plot(xs, ys, label='') ax.plot(xs, sum_, label='') # save_dir = home+'/Documents/Grad_Estimators/GMM/' plt_path = exp_dir+'gmm_plot_dist.png' plt.savefig(plt_path) print ('saved training plot', plt_path) plt.close() def get_loss(): x = sample_true(batch_size).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits, dim=1) cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda()) cluster_S = cat.rsample() cluster_H = H(cluster_S) # cluster_onehot = torch.zeros(n_components) # cluster_onehot[cluster_H] = 1. logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1) check_nan(logprob_cluster) logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz - logprob_cluster surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] surr_pred = surrogate.net(surr_input) # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred) # loss = - torch.mean(f) surr_loss = torch.mean(torch.abs(logpxz.detach()-surr_pred)) return surr_loss def plot_posteriors(needsoftmax_mixtureweight, name=''): x = sample_true(1).cuda() trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components) logits = encoder.net(x) probs = torch.softmax(logits, dim=1).view(n_components) trueposterior = trueposterior.data.cpu().numpy() qz = probs.data.cpu().numpy() error = L2_mixtureweights(trueposterior,qz) kl = KL_mixutreweights(p=trueposterior, q=qz) rows = 1 cols = 1 fig = plt.figure(figsize=(8+cols,8+rows), facecolor='white') #, dpi=150) col =0 row = 0 ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1) width = .3 ax.bar(range(len(qz)), trueposterior, width=width, label='True') ax.bar(np.array(range(len(qz)))+width, qz, width=width, label='q') # ax.bar(np.array(range(len(q_b)))+width+width, q_b, width=width) ax.legend() ax.grid(True, alpha=.3) ax.set_title(str(error) + ' kl:' + str(kl)) ax.set_ylim(0.,1.) # save_dir = home+'/Documents/Grad_Estimators/GMM/' plt_path = exp_dir+'posteriors'+name+'.png' plt.savefig(plt_path) print ('saved training plot', plt_path) plt.close() def inference_error(needsoftmax_mixtureweight): error_sum = 0 kl_sum = 0 n=10 for i in range(n): # if x is None: x = sample_true(1).cuda() trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components) logits = encoder.net(x) probs = torch.softmax(logits, dim=1).view(n_components) error = L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy()) kl = KL_mixutreweights(trueposterior.data.cpu().numpy(), probs.data.cpu().numpy()) error_sum+=error kl_sum += kl return error_sum/n, kl_sum/n # fsdfa #SIMPLAX needsoftmax_mixtureweight = torch.randn(n_components, requires_grad=True, device="cuda")#.cuda() print ('current mixuture weights') print (torch.softmax(needsoftmax_mixtureweight, dim=0)) print() encoder = NN3(input_size=1, output_size=n_components).cuda() surrogate = NN3(input_size=1+n_components, output_size=1).cuda() # optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=.00004) # optim_net = torch.optim.Adam(encoder.parameters(), lr=.0004) # optim_surr = torch.optim.Adam(surrogate.parameters(), lr=.004) # optim = torch.optim.Adam([needsoftmax_mixtureweight], lr=.0001) # optim_net = torch.optim.Adam(encoder.parameters(), lr=.0001) optim_net = torch.optim.SGD(encoder.parameters(), lr=.0001) # optim_surr = torch.optim.Adam(surrogate.parameters(), lr=.005) temp = 1. batch_size = 100 n_steps = 300000 surrugate_steps = 0 k = 1 L2_losses = [] inf_losses = [] inf_losses_kl = [] kl_losses_2 = [] surr_losses = [] steps_list =[] grad_reparam_list =[] grad_reinforce_list =[] f_list = [] logpxz_list = [] logprob_cluster_list = [] logpx_list = [] # logprob_cluster_list = [] for step in range(n_steps): for ii in range(surrugate_steps): surr_loss = get_loss() optim_surr.zero_grad() surr_loss.backward() optim_surr.step() x = sample_true(batch_size).cuda() #.view(1,1) logits = encoder.net(x) probs = torch.softmax(logits, dim=1) # print (probs) # fsdafsa # cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda()) cat = Categorical(probs=probs) # cluster = cat.sample() # logprob_cluster = cat.log_prob(cluster.detach()) net_loss = 0 loss = 0 surr_loss = 0 for jj in range(k): # cluster_S = cat.rsample() # cluster_H = H(cluster_S) cluster_H = cat.sample() # print (cluster_H.shape) # print (cluster_H[0]) # fsad # cluster_onehot = torch.zeros(n_components) # cluster_onehot[cluster_H] = 1. # logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1) logprob_cluster = cat.log_prob(cluster_H.detach()).view(batch_size,1) # logprob_cluster = cat.log_prob(cluster_S).view(batch_size,1).detach() # cat = RelaxedOneHotCategorical(probs=probs.detach(), temperature=torch.tensor([temp]).cuda()) # logprob_cluster = cat.log_prob(cluster_S).view(batch_size,1) check_nan(logprob_cluster) logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True) f = logpxz - logprob_cluster # print (f) # surr_input = torch.cat([cluster_S, x], dim=1) #[B,21] surr_input = torch.cat([probs, x], dim=1) #[B,21] surr_pred = surrogate.net(surr_input) # print (f.shape) # print (surr_pred.shape) # print (logprob_cluster.shape) # fsadfsa # net_loss += - torch.mean((logpxz.detach()-surr_pred.detach()) * logprob_cluster + surr_pred - logprob_cluster) # net_loss += - torch.mean((logpxz.detach() - surr_pred.detach() - 1.) * logprob_cluster + surr_pred) net_loss += - torch.mean((logpxz.detach() - 1.) * logprob_cluster) # net_loss += - torch.mean((logpxz.detach()) * logprob_cluster - logprob_cluster) loss += - torch.mean(logpxz) surr_loss += torch.mean(torch.abs(logpxz.detach()-surr_pred)) net_loss = net_loss/ k loss = loss / k surr_loss = surr_loss/ k # if step %2==0: # optim.zero_grad() # loss.backward(retain_graph=True) # optim.step() optim_net.zero_grad() net_loss.backward(retain_graph=True) optim_net.step() # optim_surr.zero_grad() # surr_loss.backward(retain_graph=True) # optim_surr.step() # print (torch.mean(f).cpu().data.numpy()) # plot_posteriors(name=str(step)) # fsdf # kl_batch = compute_kl_batch(x,probs) if step%500==0: print (step, 'f:', torch.mean(f).cpu().data.numpy(), 'surr_loss:', surr_loss.cpu().data.detach().numpy(), 'theta dif:', L2_mixtureweights(true_mixture_weights,torch.softmax( needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy())) # if step %5000==0: # print (torch.softmax(needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy()) # # test_samp, test_cluster = sample_true2() # # print (test_cluster.cpu().data.numpy(), test_samp.cpu().data.numpy(), torch.softmax(encoder.net(test_samp.cuda().view(1,1)), dim=1)) # print () if step > 0: L2_losses.append(L2_mixtureweights(true_mixture_weights,torch.softmax( needsoftmax_mixtureweight, dim=0).cpu().data.detach().numpy())) steps_list.append(step) surr_losses.append(surr_loss.cpu().data.detach().numpy()) inf_error, kl_error = inference_error(needsoftmax_mixtureweight) inf_losses.append(inf_error) inf_losses_kl.append(kl_error) kl_batch = compute_kl_batch(x,probs,needsoftmax_mixtureweight) kl_losses_2.append(kl_batch) logpx = copmute_logpx(x, needsoftmax_mixtureweight) logpx_list.append(logpx) f_list.append(torch.mean(f).cpu().data.detach().numpy()) logpxz_list.append(torch.mean(logpxz).cpu().data.detach().numpy()) logprob_cluster_list.append(torch.mean(logprob_cluster).cpu().data.detach().numpy()) # i_feel_like_it = 1 # if i_feel_like_it: if len(inf_losses) > 0: print ('probs', probs[0]) print('logpxz', logpxz[0]) print('pred', surr_pred[0]) print ('dif', logpxz.detach()[0]-surr_pred.detach()[0]) print ('logq', logprob_cluster[0]) print ('dif*logq', (logpxz.detach()[0]-surr_pred.detach()[0])*logprob_cluster[0]) output= torch.mean((logpxz.detach()-surr_pred.detach()) * logprob_cluster, dim=0)[0] output2 = torch.mean(surr_pred, dim=0)[0] output3 = torch.mean(logprob_cluster, dim=0)[0] # input_ = torch.mean(probs, dim=0) #[0] # print (probs.shape) # print (output.shape) # print (input_.shape) grad_reinforce = torch.autograd.grad(outputs=output, inputs=(probs), retain_graph=True)[0] grad_reparam = torch.autograd.grad(outputs=output2, inputs=(probs), retain_graph=True)[0] grad3 = torch.autograd.grad(outputs=output3, inputs=(probs), retain_graph=True)[0] # print (grad) # print (grad_reinforce.shape) # print (grad_reparam.shape) grad_reinforce = torch.mean(torch.abs(grad_reinforce)) grad_reparam = torch.mean(torch.abs(grad_reparam)) grad3 = torch.mean(torch.abs(grad3)) # print (grad_reinforce) # print (grad_reparam) # dfsfda grad_reparam_list.append(grad_reparam.cpu().data.detach().numpy()) grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy()) # grad_reinforce_list.append(grad_reinforce.cpu().data.detach().numpy()) print ('reparam:', grad_reparam.cpu().data.detach().numpy()) print ('reinforce:', grad_reinforce.cpu().data.detach().numpy()) print ('logqz grad:', grad3.cpu().data.detach().numpy()) print ('current mixuture weights') print (torch.softmax(needsoftmax_mixtureweight, dim=0)) print() # print () else: grad_reparam_list.append(0.) grad_reinforce_list.append(0.) if len(surr_losses) > 3 and step %1000==0: plot_curve(steps=steps_list, thetaloss=L2_losses, infloss=inf_losses, surrloss=surr_losses, grad_reinforce_list=grad_reinforce_list, grad_reparam_list=grad_reparam_list, f_list=f_list, logpxz_list=logpxz_list, logprob_cluster_list=logprob_cluster_list, inf_losses_kl=inf_losses_kl, kl_losses_2=kl_losses_2, logpx_list=logpx_list) plot_posteriors(needsoftmax_mixtureweight) plot_dist() show_surr_preds() # print (f) # print (surr_pred) #Understand surr preds # if step %5000==0: # if step ==0: # fasdf data_dict = {} data_dict['steps'] = steps_list data_dict['losses'] = L2_losses # save_dir = home+'/Documents/Grad_Estimators/GMM/' with open( exp_dir+"data_simplax.p", "wb" ) as f: pickle.dump(data_dict, f) print ('saved data')
def get_policy(obs): logits = logits_net(obs) return Categorical(logits=logits)
def __init__(self, probs=None, logits=None): self._categorical = Categorical(probs, logits) batch_shape = self._categorical.probs.size()[:-1] event_shape = self._categorical.probs.size()[-1:] super(OneHotCategorical, self).__init__(batch_shape, event_shape)
def sample_candidates(model, n_samples: int, encoder_output, masks: Dict[str, Tensor], max_output_length: int, labels: dict = None): """ Sample n_samples sequences from the model In each decoding step, find the k most likely partial hypotheses. :param decoder: :param size: size of the beam :param encoder_output: :param masks: :param max_output_length: :return: - stacked_output: dim?, - scores: dim? """ # init transformer = model.is_transformer any_mask = next(iter(masks.values())) batch_size = any_mask.size(0) att_vectors = None # not used for Transformer device = encoder_output.device masks.pop("trg", None) # mutating one of the inputs is not good # Recurrent models only: initialize RNN hidden state if not transformer and model.decoder.bridge_layer is not None: hidden = model.decoder.bridge_layer(encoder_output.hidden) else: hidden = None # tile encoder states and decoder initial states beam_size times if hidden is not None: # layers x batch*k x dec_hidden_size hidden = tile(hidden, n_samples, dim=1) # encoder_output: batch*k x src_len x enc_hidden_size encoder_output.tile(n_samples, dim=0) masks = {k: tile(v, n_samples, dim=0) for k, v in masks.items()} # Transformer only: create target mask masks["trg"] = any_mask.new_ones([1, 1, 1]) if transformer else None # the structure holding all batch_size * k partial hypotheses alive_seq = torch.full((batch_size * n_samples, 1), model.bos_index, dtype=torch.long, device=device) # the structure indicating, for each hypothesis, whether it has # encountered eos yet (if it has, stop updating the hypothesis # likelihood) is_finished = torch.zeros(batch_size * n_samples, dtype=torch.bool, device=device) # for each (batch x n_samples) sequence, there is a log probability seq_probs = torch.zeros(batch_size * n_samples, device=device) for step in range(1, max_output_length + 1): dec_input = alive_seq if transformer \ else alive_seq[:, -1].view(-1, 1) # decode a step probs, hidden, att_scores, att_vectors = model.decode( trg_input=dec_input, encoder_output=encoder_output, masks=masks, decoder_hidden=hidden, prev_att_vector=att_vectors, unroll_steps=1, labels=labels, generate="true") # batch*k x trg_vocab # probs = model.decoder.gen_func(logits[:, -1], dim=-1).squeeze(1) next_ids = Categorical(probs).sample().unsqueeze(1) # batch*k x 1 next_scores = probs.gather(1, next_ids).squeeze(1) # batch*k seq_probs = torch.where(is_finished, seq_probs, seq_probs + next_scores.log()) # append latest prediction # batch_size*k x hyp_len alive_seq = torch.cat([alive_seq, next_ids], -1) # update which hypotheses are finished is_finished = is_finished | next_ids.eq(model.eos_index).squeeze(1) if is_finished.all(): break # final_outputs: batch x n_samples x len final_outputs = alive_seq.view(batch_size, n_samples, -1) seq_probs = seq_probs.view(batch_size, n_samples) return final_outputs, seq_probs
def loss(actions_logits, action, discounted_reward): distribution = Categorical(logits=actions_logits) return (-distribution.log_prob(action) * discounted_reward).view(-1) # - (distribution.entropy() * (10 ** -2)) # TODO questo iperparametro va messo in un altro modo
class ExpRelaxedCategorical(Distribution): r""" Creates a ExpRelaxedCategorical parameterized by `probs` and `temperature`. Returns the log of a point in the simplex. Based on the interface to OneHotCategorical. Implementation based on [1]. See also: :func:`torch.distributions.OneHotCategorical` Args: temperature (Tensor): relaxation temperature probs (Tensor): event probabilities logits (Tensor): the log probability of each event. [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (Maddison et al, 2017) [2] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017) """ arg_constraints = {'probs': constraints.simplex} support = constraints.real has_rsample = True def __init__(self, temperature, probs=None, logits=None, validate_args=None): self._categorical = Categorical(probs, logits) self.temperature = temperature batch_shape = self._categorical.batch_shape event_shape = self._categorical.param_shape[-1:] super(ExpRelaxedCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args) def _new(self, *args, **kwargs): return self._categorical._new(*args, **kwargs) @property def param_shape(self): return self._categorical.param_shape @property def logits(self): return self._categorical.logits @property def probs(self): return self._categorical.probs def rsample(self, sample_shape=torch.Size()): sample_shape = torch.Size(sample_shape) uniforms = clamp_probs(self.logits.new(self._extended_shape(sample_shape)).uniform_()) gumbels = -((-(uniforms.log())).log()) scores = (self.logits + gumbels) / self.temperature return scores - _log_sum_exp(scores) def log_prob(self, value): K = self._categorical._num_events if self._validate_args: self._validate_sample(value) logits, value = broadcast_all(self.logits, value) log_scale = (self.temperature.new(self.temperature.shape).fill_(K).lgamma() - self.temperature.log().mul(-(K - 1))) score = logits - value.mul(self.temperature) score = (score - _log_sum_exp(score)).sum(-1) return score + log_scale
class DIPTransform: def __init__(self, image_shape, num_iters, stop_iters, input_depth=32, input_noise_std=0, optimizer='adam', lr=1e-2, plot_every=0, device='cpu'): """ :param image_shape: image shape (CxHxW) :param num_iters: number of iterations to overfit :param stop_iters: dict int->float specifying categorical distribution over percentages in (0, 100] representing the probability that the transform runs for KEY/100 * NUM_ITERS iters. :param input_depth: depth of random input. Default value taken from paper. :param input_noise_std: stdev of noise added to base random input at each iter. They say in the paper that this helps sometimes (seemingly using 0.03), but default is 0 (no noise). :param optimizer: supported optimizers are 'adam' and 'LBFGS' :param lr: learning rate. Per paper and code, 1e-2 works best. :param plot_every: how often to save images. Doesn't save images if set to 0 (default). :device: 'cuda' to use GPU if available. For example, with NUM_ITERS = 100 and STOP_ITERS = {10: 0.5, 50: 0.3, 100: 0.2}, there is a 50% chance that the transform runs for 10 iterations, a 30% chance it runs for 50 iterations, and a 20% chance it runs for the full 100 iterations. """ self.iters, probs = zip(*stop_iters.items()) self.probs = Categorical(torch.tensor(probs)) self.num_iters = num_iters self.image_shape = image_shape self.input_depth = input_depth self.input_noise_std = input_noise_std self.plot_every = plot_every self.device = device self.optimizer = optimizer self.lr = lr self.loss = torch.nn.MSELoss() # const strings from provided DIP code self.opt_over = 'net' self.const_input = 'noise' # initialize network self.net = skip(input_depth, 3, num_channels_down=[8, 16, 32], num_channels_up=[8, 16, 32], num_channels_skip=[0, 0, 4], upsample_mode='bilinear', need_sigmoid=True, need_bias=True, pad='zeros', act_fun='LeakyReLU').to(self.device) def sample_iters(self): return self.iters[self.probs.sample()] * self.num_iters // 100 def run(self, image, num_iters): assert image.shape[ 1:] == self.image_shape, 'Wrong shape. Expected {}, got {}.'.format( self.image_shape, image.shape[1:]) # run net for num_iters iterations net_input = get_noise(self.input_depth, self.const_input, self.image_shape[1:]).to(self.device) net_input_saved = net_input.detach().clone() noise = net_input.detach().clone() p = get_params(self.opt_over, self.net, net_input) #def local_closure(iter_num): # self.closure(net_input_saved, image, noise, iter_num) lambda_closure = lambda iter_num: self.closure(net_input_saved, image, noise, iter_num) optimize(self.optimizer, p, lambda_closure, self.lr, num_iters, pass_iter=True) transformed = self.net(net_input) return transformed def closure(self, net_input_saved, image, noise, iter_num): net_input = net_input_saved if self.input_noise_std > 0: net_input = net_input_saved + (noise.normal_() * self.input_noise_std) out = self.net(net_input) total_loss = self.loss(out, image) total_loss.backward() #print(total_loss) if self.plot_every > 0 and iter_num % self.plot_every == 0 and total_loss < 0.01: out_np = torch_to_np(out) plot_image_grid([np.clip(out_np, 0, 1)], factor=4, nrow=1, show=False, save_path=f'results_dip/imgs/{iter_num}.png') # maybe log loss here? def __call__(self, sample): """ Takes in, transforms, and returns PIL image given by SAMPLE. Transformation is a random number of iterations of DIP. Distribution of number of iterations is specified when the transform is initialized. """ torch_img = np_to_torch(tmp := pil_to_np(sample)).to(self.device) plot_image_grid([np.clip(tmp, 0, 1)], factor=4, nrow=1, show=False, save_path='results_dip/imgs/true.png') # TODO remove num_iters = self.sample_iters() transformed = self.run(torch_img, num_iters) return np_to_pil(torch_to_np(transformed))
class OneHotCategorical(Distribution): r""" Creates a one-hot categorical distribution parameterized by :attr:`probs` or :attr:`logits`. Samples are one-hot coded vectors of size ``probs.size(-1)``. .. note:: :attr:`probs` will be normalized to be summing to 1. See also: :func:`torch.distributions.Categorical` for specifications of :attr:`probs` and :attr:`logits`. Example:: >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 tensor([ 0., 0., 0., 1.]) Args: probs (Tensor): event probabilities logits (Tensor): event log probabilities """ arg_constraints = {'probs': constraints.simplex} support = constraints.simplex has_enumerate_support = True def __init__(self, probs=None, logits=None, validate_args=None): self._categorical = Categorical(probs, logits) batch_shape = self._categorical.batch_shape event_shape = self._categorical.param_shape[-1:] super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args) def _new(self, *args, **kwargs): return self._categorical._new(*args, **kwargs) @property def probs(self): return self._categorical.probs @property def logits(self): return self._categorical.logits @property def mean(self): return self._categorical.probs @property def variance(self): return self._categorical.probs * (1 - self._categorical.probs) @property def param_shape(self): return self._categorical.param_shape def sample(self, sample_shape=torch.Size()): sample_shape = torch.Size(sample_shape) probs = self._categorical.probs one_hot = probs.new(self._extended_shape(sample_shape)).zero_() indices = self._categorical.sample(sample_shape) if indices.dim() < one_hot.dim(): indices = indices.unsqueeze(-1) return one_hot.scatter_(-1, indices, 1) def log_prob(self, value): if self._validate_args: self._validate_sample(value) indices = value.max(-1)[1] return self._categorical.log_prob(indices) def entropy(self): return self._categorical.entropy() def enumerate_support(self): n = self.event_shape[0] values = self._new((n, n)) torch.eye(n, out=values) values = values.view((n,) + (1,) * len(self.batch_shape) + (n,)) return values.expand((n,) + self.batch_shape + (n,))
def get_pi_value_and_aux_value(self, x): hidden = self.network(x.permute((0, 3, 1, 2)) / 255.0) return Categorical(logits=self.actor(hidden)), self.critic(hidden.detach()), self.aux_critic(hidden)