예제 #1
0
파일: tcav.py 프로젝트: justcho5/tcav
  def run(self, num_workers=10, run_parallel=False, overwrite=False, return_proto=False):
    """Run TCAV for all parameters (concept and random), write results to html.

    Args:
      num_workers: number of workers to parallelize
      run_parallel: run this parallel.
      overwrite: if True, overwrite any saved CAV files.
      return_proto: if True, returns results as a tcav.Results object; else,
        return as a list of dicts.

    Returns:
      results: an object (either a Results proto object or a list of
        dictionaries) containing metrics for TCAV results.
    """
    self.run_parallel = run_parallel
    # for random exp,  a machine with cpu = 30, ram = 300G, disk = 10G and
    # pool worker 50 seems to work.
    tf.logging.info('running %s params' % len(self.params))
    now = time.time()
    if run_parallel:
      pool = multiprocessing.Pool(num_workers)
      results = pool.map(lambda param: self._run_single_set(param, overwrite=overwrite), self.params)
    else:
      results = []
      for i, param in enumerate(self.params):
        tf.logging.info('Running param %s of %s' % (i, len(self.params)))
        results.append(self._run_single_set(param, overwrite=overwrite))
    tf.logging.info('Done running %s params. Took %s seconds...' % (len(
        self.params), time.time() - now))
    if return_proto:
      return utils.results_to_proto(results)
    else:
      return results
    def test_results_to_proto(self):
        results = [{
            'cav_key': 'c1-c2-b-model-0.1',
            'cav_concept': 'c1',
            'negative_concept': 'c2',
            'target_class': 0,
            'cav_accuracies': {
                'c1': .526,
                'c2': .414,
                'overall': .47
            },
            'i_up': 0.342,
            'val_directional_dirs_abs_mean': 0.25,
            'val_directional_dirs_mean': 0.25,
            'val_directional_dirs_std': 0.144,
            'val_directional_dirs': [0, .25, .5],
            'note': 'alpha_' + str(0.1),
            'alpha': 0.1,
            'bottleneck': 'b'
        }]

        result_proto = Result()
        result_proto.cav_key = 'c1-c2-b-model-0.1'
        result_proto.cav_concept = 'c1'
        result_proto.negative_concept = 'c2'
        result_proto.target_class = 0
        result_proto.i_up = 0.342
        result_proto.val_directional_dirs_abs_mean = 0.25
        result_proto.val_directional_dirs_mean = 0.25
        result_proto.val_directional_dirs_std = 0.144
        for val in [0, .25, .5]:
            result_proto.val_directional_dirs.append(val)
        result_proto.note = 'alpha_' + str(0.1)
        result_proto.alpha = 0.1
        result_proto.bottleneck = 'b'
        result_proto.cav_accuracies.positive_set_accuracy = .526
        result_proto.cav_accuracies.negative_set_accuracy = .414
        result_proto.cav_accuracies.overall_accuracy = .47
        results_proto = Results()
        results_proto.results.append(result_proto)

        self.assertEqual(results_proto, results_to_proto(results))
예제 #3
0
    def run(self,
            num_workers=10,
            run_parallel=False,
            overwrite=False,
            return_proto=False,
            save_interval=100,
            existing_results=None):
        """Run TCAV for all parameters (concept and random), write results to html.

    Args:
      num_workers: number of workers to parallelize
      run_parallel: run this parallel.
      overwrite: if True, overwrite any saved CAV files.
      return_proto: if True, returns results as a tcav.Results object; else,
        return as a list of dicts.

    Returns:
      results: an object (either a Results proto object or a list of
        dictionaries) containing metrics for TCAV results.
    """
        # for random exp,  a machine with cpu = 30, ram = 300G, disk = 10G and
        # pool worker 50 seems to work.
        tf.logging.info('running %s params' % len(self.params))
        print('params num: ', len(self.params))
        results = []
        now = time.time()
        if run_parallel:
            pool = multiprocessing.Pool(num_workers)
            for i, res in enumerate(
                    pool.imap(
                        lambda p: self._run_single_set(
                            p, overwrite=overwrite, run_parallel=run_parallel),
                        self.params), 1):

                results.append(res)
                if i % save_interval == 0:
                    with open('result_' + str(i) + '.pickle', 'wb') as handle:
                        pickle.dump(results,
                                    handle,
                                    protocol=pickle.HIGHEST_PROTOCOL)
                    print('Finished running param %s of %s' %
                          (i, len(self.params)))
        else:
            for i, param in enumerate(self.params):
                if not existing_results or i > existing_results:
                    tf.logging.info('Running param %s of %s' %
                                    (i, len(self.params)))
                    results.append(
                        self._run_single_set(param,
                                             overwrite=overwrite,
                                             run_parallel=run_parallel))
                    if i % save_interval == 0:
                        with open('result_' + str(i) + '.pickle',
                                  'wb') as handle:
                            pickle.dump(results,
                                        handle,
                                        protocol=pickle.HIGHEST_PROTOCOL)
                        print('Finished running param %s of %s' %
                              (i, len(self.params)))
                        print(param.bottleneck, param.concepts,
                              param.target_class)
        tf.logging.info('Done running %s params. Took %s seconds...' %
                        (len(self.params), time.time() - now))
        if return_proto:
            return utils.results_to_proto(results)
        else:
            return results
예제 #4
0
파일: tcav.py 프로젝트: munema/tcav
    def run(self,
            num_workers=10,
            run_parallel=False,
            overwrite=False,
            return_proto=False):
        """Run TCAV for all parameters (concept and random), write results to html.

    Args:
      num_workers: number of workers to parallelize
      run_parallel: run this parallel.
      overwrite: if True, overwrite any saved CAV files.
      return_proto: if True, returns results as a tcav.Results object; else,
        return as a list of dicts.

    Returns:
      results: an object (either a Results proto object or a list of
        dictionaries) containing metrics for TCAV results.
    """
        # for random exp,  a machine with cpu = 30, ram = 300G, disk = 10G and
        # pool worker 50 seems to work.
        tf.logging.info('running %s params' % len(self.params))
        tf.logging.info('training with alpha={}'.format(self.alphas))
        results = []
        if self.true_cav:
            concept_lst = self.concepts
            bottleneck_lst = self.bottlenecks
            concept_dct = {}
            for c in self.concepts:
                concept_dct[c] = {}
                for b in self.bottlenecks:
                    concept_dct[c][b] = 0

        now = time.time()
        if run_parallel:
            pool = multiprocessing.Pool(num_workers)
            for i, res in enumerate(
                    pool.imap(
                        lambda p: self._run_single_set(
                            p, overwrite=overwrite, run_parallel=run_parallel),
                        self.params), 1):
                tf.logging.info('Finished running param %s of %s' %
                                (i, len(self.params)))
                results.append(res)
        else:
            keyword = ''
            if self.logit_grad:
                keyword += ':logit_grad'
            if self.grad_nomalize:
                keyword += ':grad_nomalize'
            for i, param in enumerate(self.params):
                tf.logging.info('Running param %s of %s' %
                                (i, len(self.params)))
                # randomをスキップ
                if 'random' in param.concepts[0] and self.make_random == False:
                    continue
                # randomのみ計算
                elif self.make_random == True and (
                        'random' not in param.concepts[0] or os.path.
                        exists(self.tcav_dir + '{}:{}:{}:{}_{}{}'.format(
                            param.bottleneck, param.target_class, param.alpha,
                            param.concepts[0], param.concepts[1], keyword))):
                    continue
                # 真のCAVで計算
                elif self.true_cav:
                    if param.concepts[
                            0] not in concept_lst and param.bottleneck not in bottleneck_lst:
                        continue
                    elif concept_dct[param.concepts[0]][param.bottleneck] == 1:
                        continue

                    concept_dct[param.concepts[0]][param.bottleneck] = 1
                results.append(
                    self._run_single_set(param,
                                         overwrite=overwrite,
                                         run_parallel=run_parallel))
        tf.logging.info('Done running %s params. Took %s seconds...' %
                        (len(self.params), time.time() - now))

        keyword = ''
        is_keyword = False
        if self.logit_grad:
            keyword += ':logit_grad'
            is_keyword = True
        if self.grad_nomalize:
            keyword += ':grad_nomalize'
            is_keyword = True

        if return_proto:
            return utils.results_to_proto(results)
        elif self.make_random == False and self.true_cav == False:
            pickle_dump(results, self.tcav_dir + self.project_name + keyword)
        elif self.make_random == False and self.true_cav:
            pickle_dump(
                results,
                self.tcav_dir + 'trueCAV-' + self.project_name + keyword)
        return results