Ejemplo n.º 1
0
    def test_resnet101_caffe(self):

        if self.data_dir is None:
            unittest.TestCase.skipTest(self, "DLPY_DATA_DIR is not set in the environment variables")

        model = ResNet101_Caffe(self.s, n_channels=3, height=224, random_flip='HV',
                                pre_trained_weights_file=self.data_dir+'ResNet-101-model.caffemodel.h5',
                                pre_trained_weights=True,
                                include_top=False,
                                n_classes=120,
                                random_crop='unique')
        model.print_summary()

        model = ResNet101_Caffe(self.s, n_channels=3, height=224, random_flip='HV',
                                pre_trained_weights_file=self.data_dir+'ResNet-101-model.caffemodel.h5',
                                pre_trained_weights=True,
                                include_top=False,
                                n_classes=120,
                                random_crop=None,
                                offsets=None)
        model.print_summary()

        self.assertRaises(ValueError,
                          lambda:ResNet101_Caffe(self.s, n_channels=3, height=224, random_flip='HV',
                                                 pre_trained_weights_file=self.data_dir+'ResNet-101-model.caffemodel.h5',
                                                 pre_trained_weights=True,
                                                 include_top=False,
                                                 n_classes=120,
                                                 random_crop='wrong_val'))
Ejemplo n.º 2
0
    def test_resnet101_caffe(self):

        if self.data_dir is None:
            unittest.TestCase.skipTest(self, "DLPY_DATA_DIR is not set in the environment variables")
        file_dependency = self.data_dir + 'ResNet-101-model.caffemodel.h5'
        if not file_exist_on_server(self.s, file_dependency):
            unittest.TestCase.skipTest(self, "File, {}, not found.".format(file_dependency))

        model = ResNet101_Caffe(self.s, n_channels=3, height=224, random_flip='HV',
                                pre_trained_weights_file=self.data_dir + 'ResNet-101-model.caffemodel.h5',
                                pre_trained_weights=True,
                                include_top=False,
                                n_classes=120,
                                random_crop='unique')
        model.print_summary()

        model = ResNet101_Caffe(self.s, n_channels=3, height=224, random_flip='HV',
                                pre_trained_weights_file=self.data_dir + 'ResNet-101-model.caffemodel.h5',
                                pre_trained_weights=True,
                                include_top=False,
                                n_classes=120,
                                random_crop=None,
                                offsets=None)
        model.print_summary()

        self.assertRaises(ValueError,
                          lambda: ResNet101_Caffe(self.s, n_channels=3, height=224, random_flip='HV',
                                                  pre_trained_weights_file=self.data_dir + 'ResNet-101-model.caffemodel.h5',
                                                  pre_trained_weights=True,
                                                  include_top=False,
                                                  n_classes=120,
                                                  random_crop='wrong_val'))

        # test random_mutation and crop on VDMML 8.4
        model = ResNet101_Caffe(self.s, n_channels=3, height=224, random_flip='HV',
                                pre_trained_weights_file=self.data_dir + 'ResNet-101-model.caffemodel.h5',
                                pre_trained_weights=True,
                                include_top=False,
                                n_classes=120,
                                random_crop='RESIZETHENCROP',
                                random_mutation='random',
                                offsets=None)
        model.print_summary()
Ejemplo n.º 3
0
 def test_resnet101_2(self):
     from dlpy.applications import ResNet101_Caffe
     model = ResNet101_Caffe(self.s)
     model.print_summary()