Exemple #1
0
    def test_parse_links(self):
        # Make sure that a validate dataloader is added
        batch_size = 1
        self.links_file = get_data_path('links.txt')
        self.fasta_file = get_data_path('prots.fa')

        res = parse(self.fasta_file, self.links_file,
                    training_column=4,
                    batch_size=batch_size,
                    num_workers=1, arm_the_gpu=False)
        self.assertEqual(len(res), 3)

        train, test, valid = res

        i = 0
        for g, p, n in train:
            i+= 1
        self.assertEqual(len(train), 83)

        i = 0
        for g, p, n in test:
            i+= 1
        self.assertEqual(len(test), 12)

        # Make sure that a validate dataloader is added
        i = 0
        for g, p, n in valid:
            i+= 1
        self.assertEqual(len(valid), 5)
Exemple #2
0
    def setUp(self):
        self.fasta_file = get_data_path('prots.fa')
        self.links_dir = os.path.abspath('data/links_files')
        # TODO:
        # 1. load dummy model
        # 2. fix simple_ppi with dummy model

        # Load dummy model
        input_dim = len(dictionary)
        hidden_size = 10
        self.emb_dimension = 3
        self.pretrained_model = DummyModel(input_dim, hidden_size)

        # freeze the weights of the pre-trained model
        for param in self.pretrained_model.parameters():
            param.requires_grad = False

        self.sampler = NegativeSampler(self.fasta_file)
        self.dataloader = InteractionDataDirectory(self.fasta_file,
                                                   self.links_dir,
                                                   training_column=4)
        self.pos_dataloader = [get_data_path('positives.txt')]
        self.neg_dataloader = [get_data_path('negatives.txt')]

        # setup model.
        self.ppi_model = PPIBinder(hidden_size, self.emb_dimension,
                                   self.pretrained_model)
Exemple #3
0
    def setUp(self):
        self.links_file = get_data_path('links.txt')
        self.fasta_file = get_data_path('prots.fa')

        self.seqs = list(SeqIO.parse(self.fasta_file, format='fasta'))
        links = pd.read_table(self.links_file, header=None)

        truncseqs = list(map(clean, self.seqs))
        seqids = list(map(lambda x: x.id, truncseqs))
        seqdict = dict(zip(seqids, truncseqs))
        self.pairs = preprocess(seqdict, links)
Exemple #4
0
    def test_parse_negative(self):
        batch_size = 1
        self.links_file = get_data_path('negative.txt')
        self.fasta_file = get_data_path('prots.fa')

        res = parse(self.fasta_file, self.links_file,
                    training_column=4,
                    batch_size=batch_size,
                    num_workers=1, arm_the_gpu=False)
        self.assertEqual(len(res), 3)
        self.assertIsNone(res[0])
        self.assertIsNotNone(res[1])
        self.assertIsNotNone(res[2])
        self.assertEqual(len(res[2]), 2)
Exemple #5
0
    def setUp(self):
        self.fasta_file = get_data_path('prots.fa')
        self.links_file = os.path.abspath('data/links_files')
        self.logging1 = 'logging1'
        self.logging2 = 'logging2'
        self.modelpath = 'model.pkt'

        # not ideal :(
        # on popeye
        self.checkpoint = '/simons/scratch/jmorton/mgt/checkpoints/uniref90'
        self.data_dir = '/simons/scratch/jmorton/mgt/data/uniref50'
        # on rusty
        # self.checkpoint = '/simons/scratch/jmorton/mgt/checkpoints/uniref50'
        # self.data_dir = '/simons/scratch/jmorton/mgt/data/uniref50'
        self.checkpoint = '/mnt/home/jmorton/research/gert/data/full/uniref50/checkpoints'
        self.data_dir = '/mnt/home/jmorton/research/gert/data/full/uniref50/pretrain_data'
Exemple #6
0
 def setUp(self):
     self.links_file = get_data_path('links.txt')
     self.fasta_file = get_data_path('prots.fa')