Ejemplo n.º 1
0
    def compute_loss(self, minibatch):
        (original_aa_string, actual_coords_list, _) = minibatch

        emissions, _backbone_atoms_padded, _batch_sizes = \
            self._get_network_emissions(original_aa_string)
        actual_coords_list_padded = torch.nn.utils.rnn.pad_sequence(actual_coords_list)
        if self.use_gpu:
            actual_coords_list_padded = actual_coords_list_padded.cuda()
        start = time.time()
        if isinstance(_batch_sizes[0], int):
            _batch_sizes = torch.tensor(_batch_sizes)
        emissions_actual, _ = \
            calculate_dihedral_angles_over_minibatch(actual_coords_list_padded,
                                                     _batch_sizes,
                                                     self.use_gpu)
        # drmsd_avg = calc_avg_drmsd_over_minibatch(backbone_atoms_padded,
        #                                           actual_coords_list_padded,
        #                                           batch_sizes)
        write_out("Angle calculation time:", time.time() - start)
        if self.use_gpu:
            emissions_actual = emissions_actual.cuda()
            # drmsd_avg = drmsd_avg.cuda()
        angular_loss = calc_angular_difference(emissions, emissions_actual)

        return angular_loss  # + drmsd_avg
Ejemplo n.º 2
0
    def compute_loss(self, minibatch, processed_minibatches, minimum_updates):
        (original_aa_string, actual_coords_list, _) = minibatch

        emissions, _backbone_atoms_padded, _batch_sizes = \
            self._get_network_emissions(original_aa_string)
        actual_coords_list_padded = torch.nn.utils.rnn.pad_sequence(
            actual_coords_list)
        if self.use_gpu:
            actual_coords_list_padded = actual_coords_list_padded.cuda()
        start = time.time()
        if isinstance(_batch_sizes[0], int):
            _batch_sizes = torch.tensor(_batch_sizes)
        emissions_actual, _ = \
            calculate_dihedral_angles_over_minibatch(actual_coords_list_padded,
                                                    _batch_sizes,
                                                    self.use_gpu)
        drmsd_avg = calc_avg_drmsd_over_minibatch(_backbone_atoms_padded,
                                                  actual_coords_list_padded,
                                                  _batch_sizes)

        write_out("Angle calculation time:", time.time() - start)
        if self.use_gpu:
            emissions_actual = emissions_actual.cuda()
            drmsd_avg = drmsd_avg.cuda()
        angular_loss = calc_angular_difference(emissions, emissions_actual)

        multiplier = 0.4

        if (processed_minibatches < minimum_updates * (40 / 100)):
            multiplier = processed_minibatches / minimum_updates

        normalized_angular_loss = angular_loss / 5
        normalized_drmsd_avg = drmsd_avg / 100
        return (normalized_drmsd_avg * multiplier) + (normalized_angular_loss *
                                                      (1 - multiplier))
Ejemplo n.º 3
0
    def compute_loss(self, minibatch):
        (original_aa_string, actual_coords_list, _) = minibatch

        if any(np.isnan(x.cpu().detach().numpy()).any() for x in original_aa_string) or \
        any(np.isnan(x.cpu().detach().numpy()).any() for x in actual_coords_list):
            return None

        emissions, _backbone_atoms_padded, _batch_sizes = \
            self._get_network_emissions(original_aa_string)
        assert not np.isnan(emissions.cpu().detach().numpy()).any()
        actual_coords_list_padded, batch_sizes_coords = torch.nn.utils.rnn\
            .pad_packed_sequence(
                torch.nn.utils.rnn.pack_sequence(actual_coords_list))
        assert not np.isnan(
            actual_coords_list_padded.cpu().detach().numpy()).any()
        if self.use_gpu:
            actual_coords_list_padded = actual_coords_list_padded.cuda()

        start = time.time()
        emissions_actual, _ = \
            calculate_dihedral_angles_over_minibatch(actual_coords_list_padded,
                                                     batch_sizes_coords,
                                                     self.use_gpu)
        # drmsd_avg = calc_avg_drmsd_over_minibatch(backbone_atoms_padded,
        #                                           actual_coords_list_padded,
        #                                           batch_sizes)
        write_out("Angle calculation time:", time.time() - start)
        if self.use_gpu:
            emissions_actual = emissions_actual.cuda()
            # drmsd_avg = drmsd_avg.cuda()
        angular_loss = calc_angular_difference(emissions, emissions_actual)

        return angular_loss  # + drmsd_avg
Ejemplo n.º 4
0
    def compute_loss(self, minibatch):
        (original_aa_string, actual_coords_list, _, pssms, token) = minibatch

        emissions, _backbone_atoms_padded, _batch_sizes = self._get_network_emissions(
            original_aa_string, pssms, token)
        actual_coords_list_padded, batch_sizes_coords = torch.nn.utils.rnn.pad_packed_sequence(
            torch.nn.utils.rnn.pack_sequence(actual_coords_list))
        if self.use_gpu:
            actual_coords_list_padded = actual_coords_list_padded.cuda()
        start = time.time()
        emissions_actual, _ = calculate_dihedral_angles_over_minibatch(
            actual_coords_list_padded, batch_sizes_coords, self.use_gpu)
        drmsd_avg = calc_avg_drmsd_over_minibatch(_backbone_atoms_padded,
                                                  actual_coords_list_padded,
                                                  _batch_sizes)
        write_out("Angle calculation time:", time.time() - start)
        if self.use_gpu:
            emissions_actual = emissions_actual.cuda()
            drmsd_avg = drmsd_avg.cuda()
        angular_loss = calc_angular_difference(emissions, emissions_actual)

        return angular_loss, drmsd_avg
Ejemplo n.º 5
0
def seq_and_angle_loss(pred_seqs,
                       padded_seqs,
                       pred_dihedrals,
                       padded_dihedrals,
                       mask,
                       device,
                       VOCAB_SIZE=21,
                       use_mask=False):
    # Everything is padded here!
    if not use_mask:
        mask = torch.ones(mask.shape).byte().to(device)
    '''criterion = torch.nn.NLLLoss(size_average=True, ignore_index=-1)
    loss=  criterion(pred_seqs.permute([0,2,1]).contiguous(),padded_seqs.max(dim=2)[1])'''
    #get cross entropy just padding at the end!
    criterion = torch.nn.NLLLoss(size_average=True, ignore_index=0)
    seq_cross_ent_loss = criterion(
        pred_seqs.permute([0, 2, 1]).contiguous(), padded_seqs)
    # flatten all the labels
    padded_seqs = padded_seqs.flatten()
    pred_seqs = pred_seqs.view(-1, VOCAB_SIZE)
    seq_mask = (padded_seqs >
                0).float()  # this is just the padding at the end!
    nb_tokens = int(torch.sum(seq_mask).item())
    #get accuracy
    top_preds = pred_seqs.max(dim=1)[1]
    seq_acc = torch.sum(
        torch.eq(top_preds, padded_seqs).type(torch.float) *
        seq_mask) / nb_tokens

    #loss for the angles, apply the mask to padding and uncertain coordinates!
    mask = mask.view(mask.shape[0], mask.shape[1], 1).expand(-1, -1, 3)
    angular_loss = calc_angular_difference(
        torch.masked_select(pred_dihedrals, mask),
        torch.masked_select(padded_dihedrals, mask))

    return seq_cross_ent_loss, seq_acc, angular_loss
Ejemplo n.º 6
0
    def evaluate_model(self, data_loader):
        loss = 0
        angular_loss = 0
        data_total = []
        dRMSD_list = []
        RMSD_list = []
        for _, data in enumerate(data_loader, 0):
            primary_sequence, tertiary_positions, _mask, pssm, token = data
            start = time.time()
            predicted_angles, backbone_atoms, _batch_sizes = self(
                primary_sequence, pssm, token)
            write_out("Apply model to validation minibatch:",
                      time.time() - start)
            cpu_predicted_angles = predicted_angles.transpose(
                0, 1).cpu().detach()
            cpu_predicted_backbone_atoms = backbone_atoms.transpose(
                0, 1).cpu().detach()
            minibatch_data = list(
                zip(primary_sequence, tertiary_positions, cpu_predicted_angles,
                    cpu_predicted_backbone_atoms))
            data_total.extend(minibatch_data)
            actual_coords_list_padded, batch_sizes_coords = torch.nn.utils.rnn.pad_packed_sequence(
                torch.nn.utils.rnn.pack_sequence(tertiary_positions))
            if self.use_gpu:
                actual_coords_list_padded = actual_coords_list_padded.cuda()
            emissions_actual, _ = calculate_dihedral_angles_over_minibatch(
                actual_coords_list_padded, batch_sizes_coords, self.use_gpu)

            start = time.time()
            for primary_sequence, tertiary_positions, _predicted_pos, predicted_backbone_atoms \
                    in minibatch_data:
                actual_coords = tertiary_positions.transpose(
                    0, 1).contiguous().view(-1, 3)
                predicted_coords = predicted_backbone_atoms[:len(primary_sequence)] \
                    .transpose(0, 1).contiguous().view(-1, 3).detach()

                if self.use_gpu:
                    emissions_actual = emissions_actual.cuda()
                angular_loss += float(
                    calc_angular_difference(predicted_angles,
                                            emissions_actual))

                rmsd = calc_rmsd(predicted_coords, actual_coords)
                drmsd = calc_drmsd(predicted_coords, actual_coords)
                RMSD_list.append(rmsd)
                dRMSD_list.append(drmsd)
                error = float(drmsd)
                loss += error
                end = time.time()
            write_out("Calculate validation loss for minibatch took:",
                      end - start)
        loss /= data_loader.dataset.__len__()
        angular_loss /= data_loader.dataset.__len__()
        self.historical_rmsd_avg_values.append(
            float(torch.Tensor(RMSD_list).mean()))
        self.historical_drmsd_avg_values.append(
            float(torch.Tensor(dRMSD_list).mean()))

        prim = data_total[0][0]
        pos = data_total[0][1]
        pos_pred = data_total[0][3]
        if self.use_gpu:
            pos = pos.cuda()
            pos_pred = pos_pred.cuda()
        angles = calculate_dihedral_angles(pos, self.use_gpu)
        angles_pred = calculate_dihedral_angles(pos_pred, self.use_gpu)

        write_to_pdb(get_structure_from_angles(prim, angles), "test")
        write_to_pdb(get_structure_from_angles(prim, angles_pred), "test_pred")

        data = {}
        data["pdb_data_pred"] = open("output/protein_test_pred.pdb",
                                     "r").read()
        data["pdb_data_true"] = open("output/protein_test.pdb", "r").read()
        data["phi_actual"] = list(
            [math.degrees(float(v)) for v in angles[1:, 1]])
        data["psi_actual"] = list(
            [math.degrees(float(v)) for v in angles[:-1, 2]])
        data["phi_predicted"] = list(
            [math.degrees(float(v)) for v in angles_pred[1:, 1]])
        data["psi_predicted"] = list(
            [math.degrees(float(v)) for v in angles_pred[:-1, 2]])
        data["rmsd_avg"] = self.historical_rmsd_avg_values
        data["drmsd_avg"] = self.historical_drmsd_avg_values

        prediction_data = None

        return loss, data, prediction_data, angular_loss