Пример #1
0
    def setUp(self):
        self.acts = np.array([[0, 1., 2.]])
        self.examples = [None, None, None]
        self.concepts = ['c1', 'c2']
        self.target = 't1'
        self.class_id = 0
        self.bottleneck = 'bn'
        self.cav_dir = None
        self.hparams = {'model_type': 'linear', 'alpha': .01}
        self.cav = CAV(self.concepts, self.bottleneck, self.hparams)
        self.cav.cavs = [[
            1.,
            2.,
            3.,
        ]]
        self.activation_generator = None
        self.mymodel = TcavTest_model()
        self.act_gen = TcavTest_ActGen(self.mymodel)
        self.random_counterpart = 'random500_1'

        self.mytcav = TCAV(None, self.target, self.concepts, [self.bottleneck],
                           self.act_gen, [self.hparams['alpha']])

        self.mytcav_random_counterpart = TCAV(None, self.target, self.concepts,
                                              [self.bottleneck], self.act_gen,
                                              [self.hparams['alpha']],
                                              self.random_counterpart)
Пример #2
0
    def setUp(self):
        self.acts = np.array([[0, 1., 2.]])
        self.concepts = ['c1', 'c2']
        self.target = 't1'
        self.class_id = 0
        self.bottleneck = 'bn'
        self.cav_dir = None
        self.hparams = tf.contrib.training.HParams(model_type='linear',
                                                   alpha=.01)
        self.cav = CAV(self.concepts, self.bottleneck, self.hparams)
        self.cav.cavs = [[
            1.,
            2.,
            3.,
        ]]
        self.activation_generator = None
        self.mymodel = TcavTest_model()
        self.act_gen = TcavTest_ActGen(self.mymodel)
        self.random_counterpart = 'random500_1'

        self.mytcav = TCAV(None, self.target, self.concepts, [self.bottleneck],
                           self.act_gen, [self.hparams.alpha])

        self.mytcav_random_counterpart = TCAV(None, self.target, self.concepts,
                                              [self.bottleneck], self.act_gen,
                                              [self.hparams.alpha],
                                              self.random_counterpart)
Пример #3
0
    def setUp(self):
        """Makes a cav instance and writes it to tmp direcotry.

    The cav instance uses preset values.
    """
        self.hparams = {
            'model_type': 'linear',
            'alpha': .01,
            'max_iter': 1000,
            'tol': 1e-3
        }
        self.concepts = ['concept1', 'concept2']
        self.bottleneck = 'bottleneck'
        self.accuracies = {'concept1': 0.8, 'concept2': 0.5, 'overall': 0.65}
        self.cav_vecs = [[1, 2, 3], [4, 5, 6]]

        self.test_subdirectory = os.path.join(FLAGS.tcav_test_tmpdir, 'test')
        self.cav_dir = self.test_subdirectory
        self.cav_file_name = CAV.cav_key(self.concepts, self.bottleneck,
                                         self.hparams['model_type'],
                                         self.hparams['alpha']) + '.pkl'
        self.save_path = os.path.join(self.cav_dir, self.cav_file_name)
        self.cav = CAV(self.concepts, self.bottleneck, self.hparams)
        # pretend that it was trained and cavs are stored
        self.cav.cavs = np.array(self.cav_vecs)
        shape = (1, 3)
        self.acts = {
            concept: {
                self.bottleneck: np.tile(i * np.ones(shape), (4, 1))
            }
            for i, concept in enumerate(self.concepts)
        }

        if os.path.exists(self.cav_dir):
            shutil.rmtree(self.cav_dir)
        os.mkdir(self.cav_dir)
        with tf.io.gfile.GFile(self.save_path, 'w') as pkl_file:
            pickle.dump(
                {
                    'concepts': self.concepts,
                    'bottleneck': self.bottleneck,
                    'hparams': self.hparams,
                    'accuracies': self.accuracies,
                    'cavs': self.cav_vecs,
                    'saved_path': self.save_path
                }, pkl_file)