示例#1
0
文件: test_mdt.py 项目: salilab/mdt
 def test_chain_span_range(self):
     """Test chain_span_range argument"""
     env = self.get_environ()
     mdl = Model(env)
     mdl.build_sequence('G/G')
     aln = Alignment(env)
     aln.append_model(mdl, align_codes='test')
     mlib = self.get_mdt_library()
     mlib.tuple_classes.read('data/dblcls.lib')
     tuple_dist = mdt.features.TupleDistance(mlib,
                                             bins=mdt.uniform_bins(
                                                 49, 2.0, 0.2))
     # All chain differences should be out of range, so this MDT should
     # end up empty:
     m = mdt.Table(mlib, features=tuple_dist)
     m.add_alignment(aln,
                     chain_span_range=(-999, -999, 999, 999),
                     residue_span_range=(-999, 0, 0, 999))
     self.assertEqual(m.sum(), 0.0)
     # Default chain separation should allow intra-chain interactions, so
     # should yield more (56) than only allowing inter-chain
     # interactions (32)
     m = mdt.Table(mlib, features=tuple_dist)
     m.add_alignment(aln, residue_span_range=(-999, 0, 0, 999))
     self.assertEqual(m.sum(), 56.0)
     m = mdt.Table(mlib, features=tuple_dist)
     m.add_alignment(aln,
                     chain_span_range=(-999, -1, 1, 999),
                     residue_span_range=(-999, 0, 0, 999))
     self.assertEqual(m.sum(), 32.0)
示例#2
0
文件: test_mdt.py 项目: salilab/mdt
    def test_tuple_pair_bond_span_range(self):
        """Test bond_span_range with tuple pair scan"""
        env = self.get_environ()
        mdl = Model(env)
        mdl.build_sequence('A')
        aln = Alignment(env)
        aln.append_model(mdl, align_codes='test')
        mlib = self.get_mdt_library()
        mlib.bond_classes.read('data/bndgrp.lib')
        mlib.tuple_classes.read('data/trpcls.lib')
        _ = mdt.features.TupleType(mlib)
        _ = mdt.features.TupleType(mlib, pos2=True)
        dist = mdt.features.TupleDistance(mlib,
                                          bins=mdt.uniform_bins(9, 2.0, 0.2))

        m = mdt.Table(mlib, features=dist)
        m.add_alignment(aln, residue_span_range=(0, 0, 0, 0))
        self.assertEqual(m.sample_size, 10.0)

        m = mdt.Table(mlib, features=dist)
        m.add_alignment(aln,
                        bond_span_range=(1, 1),
                        residue_span_range=(0, 0, 0, 0))
        # Bond span should restrict interactions to 6
        # (C:CA:CB-CA:C:O, CA:C:O-N:CA:C, CA:C:O-N:CA:CB, and the reverse)
        self.assertEqual(m.sample_size, 6.0)
示例#3
0
文件: test_mdt.py 项目: salilab/mdt
    def test_bond_span_range_disulfide(self):
        """Test bond_span_range argument with disulfides"""
        env = self.get_environ()
        mdl = Model(env)
        mdl.read('1HEL.pdb')
        aln = Alignment(env)
        aln.append_model(mdl, align_codes='test')
        mlib = self.get_mdt_library()
        mlib.bond_classes.read('data/bndgrp.lib')
        dist = mdt.features.AtomDistance(mlib,
                                         bins=mdt.uniform_bins(60, 0, 0.5))
        # Four disulfide bond in this structure
        m = mdt.Table(mlib, features=dist)
        m.add_alignment(aln,
                        bond_span_range=(1, 1),
                        residue_span_range=(-9999, 0, 0, 9999))

        m2 = mdt.Table(mlib, features=dist)
        m2.add_alignment(aln,
                         bond_span_range=(1, 1),
                         residue_span_range=(-9999, 0, 0, 9999),
                         disulfide=True)
        self.assertEqual(m2.sample_size - m.sample_size, 4.0)

        m = mdt.Table(mlib, features=dist)
        m.add_alignment(aln,
                        bond_span_range=(3, 3),
                        residue_span_range=(-9999, 0, 0, 9999))

        m2 = mdt.Table(mlib, features=dist)
        m2.add_alignment(aln,
                         bond_span_range=(3, 3),
                         residue_span_range=(-9999, 0, 0, 9999),
                         disulfide=True)
        self.assertEqual(m2.sample_size - m.sample_size, 12.0)
示例#4
0
文件: test_mdt.py 项目: salilab/mdt
 def test_atom_tuple_pair_exclusions(self):
     """Test exclusion of atom pairs from atom tuple pair features"""
     env = self.get_environ()
     mdl = Model(env)
     mdl.build_sequence('GG')
     aln = Alignment(env)
     aln.append_model(mdl, align_codes='test')
     mlib = self.get_mdt_library()
     mlib.tuple_classes.read('data/trpcls.lib')
     mlib.bond_classes.read('data/bndgrp.lib')
     mlib.angle_classes.read('data/anggrp.lib')
     mlib.dihedral_classes.read('data/impgrp.lib')
     dist = mdt.features.TupleDistance(mlib,
                                       bins=mdt.uniform_bins(49, 0.0, 0.2))
     m = mdt.Table(mlib, features=dist)
     m.add_alignment(aln,
                     residue_span_range=(-9999, 0, 0, 9999),
                     exclude_bonds=False)
     self.assertEqual(m.sample_size, 38)
     # Exclusions should cut number of sample points
     m = mdt.Table(mlib, features=dist)
     m.add_alignment(aln,
                     residue_span_range=(-9999, 0, 0, 9999),
                     exclude_bonds=True,
                     exclude_angles=True)
     self.assertEqual(m.sample_size, 20)
示例#5
0
文件: test_mdt.py 项目: salilab/mdt
    def test_atom_pair_exclusions(self):
        """Test exclusion of atom pairs from atom pair features"""
        env = self.get_environ()
        mdl = Model(env)
        mdl.build_sequence('G')
        aln = Alignment(env)
        aln.append_model(mdl, align_codes='test')
        mlib = self.get_mdt_library()
        mlib.bond_classes.read('data/bndgrp.lib')
        mlib.angle_classes.read('data/anggrp.lib')
        mlib.dihedral_classes.read('data/impgrp.lib')
        dist = mdt.features.AtomDistance(mlib,
                                         bins=mdt.uniform_bins(49, 0.0, 0.2))
        m = mdt.Table(mlib, features=dist)
        m.add_alignment(aln,
                        residue_span_range=(-9999, 0, 0, 9999),
                        exclude_bonds=False)
        self.assertEqual(m.sample_size, 10)
        # 3 bonds (N-CA, O-C, C-CA) should be excluded in Gly
        m = mdt.Table(mlib, features=dist)
        m.add_alignment(aln,
                        residue_span_range=(-9999, 0, 0, 9999),
                        exclude_bonds=True)
        self.assertEqual(m.sample_size, 7)
        # A further 2 angles (CA-C-O, N-CA-C) should be excluded in Gly
        m = mdt.Table(mlib, features=dist)
        m.add_alignment(aln,
                        residue_span_range=(-9999, 0, 0, 9999),
                        exclude_bonds=True,
                        exclude_angles=True)
        self.assertEqual(m.sample_size, 5)
        # No improper dihedrals
        m = mdt.Table(mlib, features=dist)
        m.add_alignment(aln,
                        residue_span_range=(-9999, 0, 0, 9999),
                        exclude_bonds=True,
                        exclude_angles=True,
                        exclude_dihedrals=True)
        self.assertEqual(m.sample_size, 5)

        mdl = Model(env)
        mdl.build_sequence('GG')
        aln = Alignment(env)
        aln.append_model(mdl, align_codes='test')
        m = mdt.Table(mlib, features=dist)
        m.add_alignment(aln, residue_span_range=(-9999, 0, 0, 9999))
        self.assertEqual(m.sample_size, 36)
        # One dihedral (C:CA:+N:O) in Gly-Gly should be excluded
        m = mdt.Table(mlib, features=dist)
        m.add_alignment(aln,
                        residue_span_range=(-9999, 0, 0, 9999),
                        exclude_dihedrals=True)
        self.assertEqual(m.sample_size, 35)
示例#6
0
文件: test_mdt.py 项目: salilab/mdt
 def test_sources(self):
     """Make sure that alignments and models both work as sources"""
     env = self.get_environ()
     mlib = self.get_mdt_library()
     dist = mdt.features.AtomDistance(mlib,
                                      bins=mdt.uniform_bins(60, 0, 0.5))
     m1 = mdt.Table(mlib, features=dist)
     m2 = mdt.Table(mlib, features=dist)
     a1 = Alignment(env, file='test/data/tiny.ali', align_codes='5fd1')
     m1.add_alignment(a1)
     mdl = Model(env, file='test/data/5fd1.atm', model_segment=('1:', '6:'))
     a2 = Alignment(env)
     # Atom file 'foo' does not exist; all data should be taken from mdl
     a2.append_model(mdl, align_codes='foo', atom_files='foo')
     m2.add_alignment(a2)
     self.assertMDTsEqual(m1, m2)
示例#7
0
文件: test_mdt.py 项目: salilab/mdt
    def test_bond_span_range(self):
        """Test bond_span_range argument"""
        env = self.get_environ()
        mdl = Model(env)
        mdl.build_sequence('A')
        aln = Alignment(env)
        aln.append_model(mdl, align_codes='test')
        mlib = self.get_mdt_library()
        mlib.bond_classes.read('data/bndgrp.lib')
        dist = mdt.features.AtomDistance(mlib,
                                         bins=mdt.uniform_bins(60, 0, 0.5))

        # Only 4 direct chemical bonds (N-CA, CA-CB, CA-C, C-O) in ALA; note
        # that bond library does not include OXT so C-OXT interaction is
        # excluded
        m = mdt.Table(mlib, features=dist)
        m.add_alignment(aln,
                        bond_span_range=(1, 1),
                        residue_span_range=(0, 0, 0, 0))
        self.assertEqual(m.sample_size, 4.0)

        # Only 2 dihedrals (N-CA-C-O, O-C-CA-CB)
        m = mdt.Table(mlib, features=dist)
        m.add_alignment(aln,
                        bond_span_range=(3, 3),
                        residue_span_range=(0, 0, 0, 0))
        self.assertEqual(m.sample_size, 2.0)

        # 4 bonds, 4 angles and 2 dihedrals: 10 in total
        m = mdt.Table(mlib, features=dist)
        m.add_alignment(aln,
                        bond_span_range=(1, 3),
                        residue_span_range=(0, 0, 0, 0))
        self.assertEqual(m.sample_size, 10.0)

        # Check for bonds between residues (just the N-C bond here)
        mdl = Model(env)
        mdl.build_sequence('AA')
        aln = Alignment(env)
        aln.append_model(mdl, align_codes='test')

        # Force a non-symmetric scan (to check handling of bond separation
        # regardless of which order atom indices are in)
        diff = mdt.features.ResidueIndexDifference(mlib,
                                                   bins=mdt.uniform_bins(
                                                       21, -10, 1))
        m = mdt.Table(mlib, features=(dist, diff))
        m.add_alignment(aln,
                        bond_span_range=(0, 1),
                        residue_span_range=(-10, -1, 1, 10))
        self.assertEqual(m.sample_size, 2.0)

        # Bonds never span chains
        mdl = Model(env)
        mdl.build_sequence('A/A')
        aln = Alignment(env)
        aln.append_model(mdl, align_codes='test')
        m = mdt.Table(mlib, features=dist)
        m.add_alignment(aln,
                        bond_span_range=(0, 99999),
                        residue_span_range=(-10, -1, 1, 10))
        self.assertEqual(m.sample_size, 0.0)
示例#8
0
import segmentation_models_pytorch as smp
import torch
from losses.BCEJaccard import LossBinary
from data.final_transforms import hardcore_aug, crazy_custom_aug
from modeller import Model

checkpoint = torch.load("last.pth")
weights = checkpoint["model_state_dict"]
model = smp.Unet("se_resnext50_32x4d")
model.load_state_dict(weights)
model.eval()
model.cuda()
modeler = Model(transforms=crazy_custom_aug, criterion=LossBinary(0.3))
modeler.train(model, 10000)
示例#9
0
# from models.AlbuNet.AlbuNet import AlbuNet
from modeller import Model
from data.final_transforms import very_light_aug, light_aug, hardcore_aug
from losses.BCEJaccard import LossBinary
from catalyst.contrib.criterion import FocalLossMultiClass
import segmentation_models_pytorch as smp

model1_stage = Model(transforms=very_light_aug,
                     criterion=FocalLossMultiClass())
model2_stage = Model(transforms=light_aug, criterion=FocalLossMultiClass())
model3_stage = Model(transforms=hardcore_aug, criterion=FocalLossMultiClass())
# net = AlbuNet()
net = smp.Unet("se_resnext50_32x4d", encoder_weights="imagenet", classes=6)
net.cuda()
print("No Augmentations")
model1_stage.train(net, 300)
print("Light Augmentations")
model2_stage.train(net, 400)
print("Hardcore Augmentations")
model3_stage.train(net, 50000)
示例#10
0
import segmentation_models_pytorch as smp
import torch
from losses.BCEJaccard import LossBinary
from catalyst.contrib.criterion import FocalLossMultiClass
from data.final_transforms import hardcore_aug, crazy_custom_aug
from modeller import Model

checkpoint = torch.load("instance.pth")
weights = checkpoint["model_state_dict"]
model = smp.Unet("se_resnext50_32x4d", classes=6)
model.load_state_dict(weights)
model.eval()
model.cuda()
modeler = Model(transforms=hardcore_aug, criterion=FocalLossMultiClass())
modeler.train(model, 10000)
示例#11
0
from models.AlbuNet.AlbuNet import AlbuNet
from modeller import Model
from data.final_transforms import very_light_aug, light_aug, hardcore_aug
from losses.BCEJaccard import LossBinary
import segmentation_models_pytorch as smp

model1_stage = Model(transforms=very_light_aug,)
model2_stage = Model(transforms=light_aug, criterion=LossBinary(0.3))
model3_stage = Model(transforms=hardcore_aug, criterion=LossBinary(0.3))
# net = AlbuNet()
net = smp.Unet("se_resnext50_32x4d", encoder_weights="imagenet")
net.cuda()
print("No Augmentations")
model1_stage.train(net, 30)
print("Light Augmentations")
model2_stage.train(net, 40)
print("Hardcore Augmentations")
model3_stage.train(net, 50)