コード例 #1
0
ファイル: nutszebra_log_model.py プロジェクト: rzel/trainer
 def generate_stat_figure(self):
     files = utility.reg_extract(
         utility.find_files_recursively(self.stat_path), utility.reg_json)
     files = sorted(files,
                    key=lambda s: int(s.split('/')[-1].split('.')[0]))
     record_mean = defaultdict(list)
     record_max = defaultdict(list)
     record_min = defaultdict(list)
     record_var = defaultdict(list)
     for f in files:
         tmp = utility.load_json(f)
         for key1 in tmp:
             for key2 in tmp[key1]:
                 key = '{}_{}'.format(key1, key2)
                 record_mean[key].append(tmp[key1][key2]['mean'])
                 record_max[key].append(tmp[key1][key2]['max'])
                 record_min[key].append(tmp[key1][key2]['min'])
                 record_var[key].append(tmp[key1][key2]['var'])
     for key in record_mean:
         plt.clf()
         plt.plot(record_mean[key], label='mean')
         plt.plot(record_max[key], label='max')
         plt.plot(record_min[key], label='min')
         plt.draw()
         plt.title(key)
         plt.legend(loc='upper left')
         plt.savefig('{}model_stat_mean_{}.jpg'.format(
             self.figure_path, key))
     for key in record_var:
         plt.clf()
         plt.plot(record_var[key])
         plt.draw()
         plt.title(key)
         plt.savefig('{}model_stat_var_{}.jpg'.format(
             self.figure_path, key))
     plt.clf()
     for key in record_mean:
         plt.plot(record_mean[key])
     plt.title('all mean')
     plt.savefig('{}model_stat_all_mean.jpg'.format(self.figure_path))
     plt.clf()
     for key in record_var:
         plt.plot(record_var[key])
     plt.title('all var')
     plt.savefig('{}model_stat_all_var.jpg'.format(self.figure_path))
     plt.clf()
     for key in record_max:
         plt.plot(record_max[key])
     plt.title('all max')
     plt.savefig('{}model_stat_all_max.jpg'.format(self.figure_path))
     plt.clf()
     for key in record_min:
         plt.plot(record_min[key])
     plt.title('all min')
     plt.savefig('{}model_stat_all_min.jpg'.format(self.figure_path))
コード例 #2
0
ファイル: nutszebra_log_model.py プロジェクト: rzel/trainer
 def __init__(self, model, save_path='./'):
     self.model = model
     self.links = LogModel._track_link(model)
     save_path = save_path if save_path[-1] == '/' else save_path + '/'
     self.stat_path = save_path + 'log/model_stat/'
     self.grad_path = save_path + 'log/grad_stat/'
     self.figure_path = save_path + 'log/'
     self.stat_count = 0
     self.grad_count = 0
     utility.make_dir(self.stat_path)
     utility.make_dir(self.grad_path)
コード例 #3
0
 def _train(path, prefix):
     train = defaultdict(list)
     for line in utility.yield_text(path):
         category, name = line.split(' ')[0].split('/')
         train[category].append('{}/{}/{}.JPEG'.format(
             prefix, category, name))
     return train
コード例 #4
0
 def _val(path, prefix1, prefix2):
     val = defaultdict(list)
     for line in utility.yield_text(path):
         name = line.split(' ')[0]
         for obj in ET.parse('{}/{}.xml'.format(
                 prefix2, name)).getroot().findall('object'):
             key = obj.find('name').text
             break
         val[key].append('{}/{}.JPEG'.format(prefix1, name))
     return val
コード例 #5
0
ファイル: nutszebra_log2.py プロジェクト: rzel/trainer
    def save(self, path):
        """Save log

        Edited date:
            161014

        Examples:

        ::

            self.save('./log.json')

        Args:
            path (str): it has to be json

        Returns:
            True if successful
        """

        utility.save_json(self.log, path)
コード例 #6
0
ファイル: nutszebra_log2.py プロジェクト: rzel/trainer
    def load(self, path):
        """Load log

        Edited date:
            161014

        Examples:

        ::

            self.load('./log.json')

        Args:
            path (str): it has to be json
        """
        self.log = utility.load_json(path)
コード例 #7
0
 def __init__(self, ilsvrc_path, flag_debug=False):
     super(LoadDataset, self).__init__()
     ilsvrc_path = ilsvrc_path[:-1] if ilsvrc_path[
         -1] == '/' else ilsvrc_path
     print('loading ILSVRC dataset')
     for f, key in six.moves.zip(self.filename, self.keys):
         print('    {}'.format(f))
         if '{}.pkl'.format(key) in utility.find_files(ilsvrc_path,
                                                       affix_flag=True):
             print('        Already {} were loaded before'.format(f))
             self[key] = utility.load_pickle('{}/{}.pkl'.format(
                 ilsvrc_path, key))
         elif 'train' in f:
             print('        Loading')
             self[key] = self._train(
                 '{}/ImageSets/CLS-LOC/{}'.format(ilsvrc_path, f),
                 '{}/Data/CLS-LOC/train'.format(ilsvrc_path))
             utility.save_pickle(self[key],
                                 '{}/{}.pkl'.format(ilsvrc_path, key))
             print('        Done')
         elif 'val' in f:
             print('        Loading')
             self[key] = self._val(
                 '{}/ImageSets/CLS-LOC/{}'.format(ilsvrc_path, f),
                 '{}/Data/CLS-LOC/val'.format(ilsvrc_path),
                 '{}/Annotations/CLS-LOC/val'.format(ilsvrc_path))
             utility.save_pickle(self[key],
                                 '{}/{}.pkl'.format(ilsvrc_path, key))
             print('        Done')
         elif 'test' in f:
             print('        Loading')
             self[key] = self._test(
                 '{}/ImageSets/CLS-LOC/{}'.format(ilsvrc_path, f),
                 '{}/Data/CLS-LOC/test'.format(ilsvrc_path))
             utility.save_pickle(self[key],
                                 '{}/{}.pkl'.format(ilsvrc_path, key))
             print('        Done')
     self.debug(flag_debug)
コード例 #8
0
 def _test(path, prefix):
     not_train = []
     for line in utility.yield_text(path):
         name = line.split(' ')[0]
         not_train.append('{}/{}.JPEG'.format(prefix, name))
     return not_train
コード例 #9
0
ファイル: nutszebra_log_model.py プロジェクト: rzel/trainer
 def save(info, path):
     utility.save_json(info, path)