コード例 #1
0
  def sample(self, features, distances, structure):
    # encode
    distance_data = self.relative(structure, self.rbf)
    relative_data = distance_data.message(
      distances, distances
    )
    relative_structure = OrientationStructure(structure, relative_data)
    encoding = self.encoder(features, relative_structure)

    # sampling
    sampled = torch.zeros(encoding.size(0), dtype=torch.long, device=features.device)
    sequence = self.prepare_sequence(sampled)
    masked_structure = MaskedStructure(
      structure, relative_data, sequence, encoding
    )
    result = self.decoder(encoding, masked_structure)
    hard = hard_one_hot(result)
    sampled[0] = hard[0].argmax(dim=0)
    for idx in range(1, len(sampled)):
      sequence = self.prepare_sequence(sampled)
      masked_structure = MaskedStructure(
        structure, relative_data, sequence, encoding
      )
      result = self.decoder(encoding, masked_structure)
      hard = hard_one_hot(result)
      sequence = self.sequence_embedding(hard)
      sampled[idx] = hard[idx].argmax(dim=0)
    return sampled
コード例 #2
0
    def integrate(self, score, data, *args):
        data = data.clone()
        current_energy, *_ = score(data, *args)
        for idx in range(self.steps):
            make_differentiable(data)
            make_differentiable(args)

            energy = score(data, *args)
            if isinstance(energy, (list, tuple)):
                energy, *_ = energy

            gradient = ag.grad(energy, data.tensor, torch.ones_like(energy))[0]
            if self.max_norm:
                gradient = clip_grad_by_norm(gradient, self.max_norm)

            # attempt at gradient based local update of discrete variables:
            grad_prob = (-500 * gradient).softmax(dim=1)
            new_prob = self.noise + self.rate * grad_prob + (
                1 - self.noise - self.rate) * data.tensor
            new_val = hard_one_hot(new_prob.log())
            data.tensor = new_val

            data = data.detach()

        return data
コード例 #3
0
  def forward(self, features, sequence, distances, structure):
    distance_data = self.relative(structure, self.rbf)
    relative_data = distance_data.message(
      distances, distances
    )
    relative_structure = OrientationStructure(structure, relative_data)
    encoding = self.encoder(features, relative_structure)

    # initial evaluation
    sequence = self.prepare_sequence(sequence)
    masked_structure = MaskedStructure(
      structure, relative_data, sequence, encoding
    )
    result = self.decoder(encoding, masked_structure)

    # differentiable scheduled sampling
    for idx in range(self.schedule):
      samples = self.sequence_embedding(hard_one_hot(result))
      mask = torch.rand(samples.size(0), device=samples.device) < 0.25
      mask = mask.unsqueeze(1).float()
      sequence = mask * samples + (1 - mask) * sequence
      masked_structure = MaskedStructure(
        structure, relative_data, sequence, encoding
      )
      result = self.decoder(encoding, masked_structure)

    return result
コード例 #4
0
  def forward(self, features, sequence, pair_features, structure, protein=None):
    # featurize distances
    distance = pair_features[:, :, 0]
    distance = gaussian_rbf(distance.view(-1, 1), *self.rbf).reshape(distance.size(0), distance.size(1), -1)

    pair_features = torch.cat((distance, pair_features[:, :, 1:]), dim=2)

    relative_structure = OrientationStructure(structure, pair_features)
    encoding = self.encoder(features, relative_structure)

    # initial evaluation
    sequence = self.prepare_sequence(sequence)
    masked_structure = MaskedStructure(
      structure, pair_features, sequence, encoding
    )
    result = self.decoder(encoding, masked_structure)

    # differentiable scheduled sampling
    for idx in range(self.schedule):
      samples = self.sequence_embedding(hard_one_hot(result))
      mask = torch.rand(samples.size(0), device=samples.device) < 0.25
      mask = mask.unsqueeze(1).float()
      sequence = mask * samples + (1 - mask) * sequence
      masked_structure = MaskedStructure(
        structure, relative_data, sequence, encoding
      )
      result = self.decoder(encoding, masked_structure)

    return result
コード例 #5
0
 def sample(self, primary, mode, gt_ignore, angles, orientation, neighbours,
            protein):
     inputs = primary.clone()
     positions = torch.randint(0, inputs.tensor.size(0),
                               (inputs.tensor.size(0) // 20, ))
     values = torch.randint(0, 20, (inputs.tensor.size(0) // 20, ))
     inputs.tensor[positions] = 0
     inputs.tensor[positions, values] = 1
     logits = self.sampler(inputs, mode, gt_ignore, angles, orientation,
                           neighbours, protein)
     positions = torch.randint(0, logits.size(0), (logits.size(0) // 10, ))
     sample = hard_one_hot(logits)
     primary.tensor[positions] = sample[positions]
     return primary
コード例 #6
0
    def integrate(self, score, data, *args):
        data = data.clone()
        result = data.clone()
        current_energy = score(data, *args)
        for idx in range(self.steps):
            energy, deltas = score(data, *args, return_deltas=True)

            # attempt at gradient based local update of discrete variables:
            grad_prob = torch.zeros_like(deltas)
            grad_prob[torch.arange(deltas.size(0)), deltas.argmax(dim=1)] = 1
            if self.scale is not None:
                grad_prob = (self.scale * deltas).softmax(dim=1)
            access = torch.rand(deltas.size(0),
                                dtype=torch.float,
                                device=deltas.device)
            access = access < self.rate
            data.tensor[access] = hard_one_hot(grad_prob[access].log())

            data = data.detach()

        return data
コード例 #7
0
    def integrate(self, score, data, *args):
        data = data.clone()
        result = data.clone()
        current_energy = score(data, *args)
        access_cache = []
        for idx in range(self.steps):
            data = self.perturb(data)

            energy, deltas = score(data, *args, return_deltas=True)

            # attempt at gradient based local update of discrete variables:
            grad_prob = torch.zeros_like(deltas)
            grad_prob[torch.arange(deltas.size(0)), deltas.argmax(dim=1)] = 1
            if self.scale is not None:
                grad_prob = (self.scale * deltas).softmax(dim=1)

            access = self.pick_positions(args[-2].connections, args[-1].counts)
            data.tensor[access] = hard_one_hot(grad_prob[access].log())

            data = data.detach()

        return data
コード例 #8
0
    def integrate(self, score, data, *args):
        data = data.clone()
        result = data.clone()
        current_energy = score(data, *args)
        for idx in range(self.steps):
            make_differentiable(data)
            make_differentiable(args)

            energy, deltas = score(data, *args, return_deltas=True)

            # attempt at gradient based local update of discrete variables:
            grad_prob = torch.zeros_like(deltas)
            grad_prob[torch.arange(deltas.size(0)), deltas.argmax(dim=1)] = 1
            if self.scale is not None:
                grad_prob = (self.scale * deltas).softmax(dim=1)
            new_prob = self.noise + self.rate * grad_prob + (
                1 - self.noise - self.rate) * data.tensor
            new_val = hard_one_hot(new_prob.log())
            data.tensor = new_val

            data = data.detach()

        return data