def sample(self, features, distances, structure): # encode distance_data = self.relative(structure, self.rbf) relative_data = distance_data.message( distances, distances ) relative_structure = OrientationStructure(structure, relative_data) encoding = self.encoder(features, relative_structure) # sampling sampled = torch.zeros(encoding.size(0), dtype=torch.long, device=features.device) sequence = self.prepare_sequence(sampled) masked_structure = MaskedStructure( structure, relative_data, sequence, encoding ) result = self.decoder(encoding, masked_structure) hard = hard_one_hot(result) sampled[0] = hard[0].argmax(dim=0) for idx in range(1, len(sampled)): sequence = self.prepare_sequence(sampled) masked_structure = MaskedStructure( structure, relative_data, sequence, encoding ) result = self.decoder(encoding, masked_structure) hard = hard_one_hot(result) sequence = self.sequence_embedding(hard) sampled[idx] = hard[idx].argmax(dim=0) return sampled
def integrate(self, score, data, *args): data = data.clone() current_energy, *_ = score(data, *args) for idx in range(self.steps): make_differentiable(data) make_differentiable(args) energy = score(data, *args) if isinstance(energy, (list, tuple)): energy, *_ = energy gradient = ag.grad(energy, data.tensor, torch.ones_like(energy))[0] if self.max_norm: gradient = clip_grad_by_norm(gradient, self.max_norm) # attempt at gradient based local update of discrete variables: grad_prob = (-500 * gradient).softmax(dim=1) new_prob = self.noise + self.rate * grad_prob + ( 1 - self.noise - self.rate) * data.tensor new_val = hard_one_hot(new_prob.log()) data.tensor = new_val data = data.detach() return data
def forward(self, features, sequence, distances, structure): distance_data = self.relative(structure, self.rbf) relative_data = distance_data.message( distances, distances ) relative_structure = OrientationStructure(structure, relative_data) encoding = self.encoder(features, relative_structure) # initial evaluation sequence = self.prepare_sequence(sequence) masked_structure = MaskedStructure( structure, relative_data, sequence, encoding ) result = self.decoder(encoding, masked_structure) # differentiable scheduled sampling for idx in range(self.schedule): samples = self.sequence_embedding(hard_one_hot(result)) mask = torch.rand(samples.size(0), device=samples.device) < 0.25 mask = mask.unsqueeze(1).float() sequence = mask * samples + (1 - mask) * sequence masked_structure = MaskedStructure( structure, relative_data, sequence, encoding ) result = self.decoder(encoding, masked_structure) return result
def forward(self, features, sequence, pair_features, structure, protein=None): # featurize distances distance = pair_features[:, :, 0] distance = gaussian_rbf(distance.view(-1, 1), *self.rbf).reshape(distance.size(0), distance.size(1), -1) pair_features = torch.cat((distance, pair_features[:, :, 1:]), dim=2) relative_structure = OrientationStructure(structure, pair_features) encoding = self.encoder(features, relative_structure) # initial evaluation sequence = self.prepare_sequence(sequence) masked_structure = MaskedStructure( structure, pair_features, sequence, encoding ) result = self.decoder(encoding, masked_structure) # differentiable scheduled sampling for idx in range(self.schedule): samples = self.sequence_embedding(hard_one_hot(result)) mask = torch.rand(samples.size(0), device=samples.device) < 0.25 mask = mask.unsqueeze(1).float() sequence = mask * samples + (1 - mask) * sequence masked_structure = MaskedStructure( structure, relative_data, sequence, encoding ) result = self.decoder(encoding, masked_structure) return result
def sample(self, primary, mode, gt_ignore, angles, orientation, neighbours, protein): inputs = primary.clone() positions = torch.randint(0, inputs.tensor.size(0), (inputs.tensor.size(0) // 20, )) values = torch.randint(0, 20, (inputs.tensor.size(0) // 20, )) inputs.tensor[positions] = 0 inputs.tensor[positions, values] = 1 logits = self.sampler(inputs, mode, gt_ignore, angles, orientation, neighbours, protein) positions = torch.randint(0, logits.size(0), (logits.size(0) // 10, )) sample = hard_one_hot(logits) primary.tensor[positions] = sample[positions] return primary
def integrate(self, score, data, *args): data = data.clone() result = data.clone() current_energy = score(data, *args) for idx in range(self.steps): energy, deltas = score(data, *args, return_deltas=True) # attempt at gradient based local update of discrete variables: grad_prob = torch.zeros_like(deltas) grad_prob[torch.arange(deltas.size(0)), deltas.argmax(dim=1)] = 1 if self.scale is not None: grad_prob = (self.scale * deltas).softmax(dim=1) access = torch.rand(deltas.size(0), dtype=torch.float, device=deltas.device) access = access < self.rate data.tensor[access] = hard_one_hot(grad_prob[access].log()) data = data.detach() return data
def integrate(self, score, data, *args): data = data.clone() result = data.clone() current_energy = score(data, *args) access_cache = [] for idx in range(self.steps): data = self.perturb(data) energy, deltas = score(data, *args, return_deltas=True) # attempt at gradient based local update of discrete variables: grad_prob = torch.zeros_like(deltas) grad_prob[torch.arange(deltas.size(0)), deltas.argmax(dim=1)] = 1 if self.scale is not None: grad_prob = (self.scale * deltas).softmax(dim=1) access = self.pick_positions(args[-2].connections, args[-1].counts) data.tensor[access] = hard_one_hot(grad_prob[access].log()) data = data.detach() return data
def integrate(self, score, data, *args): data = data.clone() result = data.clone() current_energy = score(data, *args) for idx in range(self.steps): make_differentiable(data) make_differentiable(args) energy, deltas = score(data, *args, return_deltas=True) # attempt at gradient based local update of discrete variables: grad_prob = torch.zeros_like(deltas) grad_prob[torch.arange(deltas.size(0)), deltas.argmax(dim=1)] = 1 if self.scale is not None: grad_prob = (self.scale * deltas).softmax(dim=1) new_prob = self.noise + self.rate * grad_prob + ( 1 - self.noise - self.rate) * data.tensor new_val = hard_one_hot(new_prob.log()) data.tensor = new_val data = data.detach() return data