def run_classes(self, **kwargs): results = defaultdict(lambda: []) kwargs = dict(kwargs) it = kwargs.pop('it') logdir = kwargs['logdir'] classes = range(no_classes(kwargs['dataset'])) if 'cls_restrictions' in kwargs: classes = kwargs['cls_restrictions'] if kwargs[ 'cls_restrictions'] is not None else classes del kwargs['cls_restrictions'] for c in classes: cls_logdir = pt.join(logdir, 'normal_{}'.format(c)) kwargs['logdir'] = cls_logdir kwargs['normal_class'] = c try: res = self.run_seeds(it, **kwargs) for key in res: results[key].append(res[key]) finally: print( 'Plotting ROC for completed classes up to {}...'.format(c)) for key in results: plot_many_roc(logdir.replace('{t}', time_format(self.start)), results[key], labels=str_labels(kwargs['dataset']), mean=True, name=key) for key in results: plot_many_roc(logdir.replace('{t}', time_format(self.start)), results[key], labels=str_labels(kwargs['dataset']), mean=True, name=key) return {key: mean_roc(results[key]) for key in results}
def run_seeds(self, it: int, **kwargs): results = defaultdict(lambda: []) kwargs = dict(kwargs) logdir = kwargs.pop('logdir') viz_ids = kwargs.pop('viz_ids') its = range(it) if 'its_restrictions' in kwargs: its = kwargs['its_restrictions'] if kwargs[ 'its_restrictions'] is not None else its del kwargs['its_restrictions'] for i in its: kwargs['logdir'] = pt.join(logdir, 'it_{}'.format(i)) this_viz_ids = extract_viz_ids(viz_ids, kwargs['normal_class'], i) res = self.run_one(this_viz_ids, **kwargs) for key in res: results[key].append(res[key]) for key in results: plot_many_roc(logdir.replace('{t}', time_format(self.start)), results[key], labels=its, mean=True, name=key) return {key: mean_roc(results[key]) for key in results}