def _process_after_batch(self, epoch, batch_index, iter_index, packs,
                          **kwargs):
     # Logging
     if chk_d(self._meters, 'counter_log', lambda c: c.check(iter_index)):
         if 'lmd_generate_log' in kwargs.keys():
             kwargs['lmd_generate_log']()
         # (1) Logs
         if 'log_main' in self._logs.keys() and not chk_d(
                 kwargs, 'disable_log'):
             # Update io & optimize timers
             if 'io' in self._meters['timers']:
                 packs['log']['t_io'] = self._meters['timers'][
                     'io'].get_duration_and_reset()
             if 'opt' in self._meters['timers']:
                 packs['log']['t_opt'] = self._meters['timers'][
                     'opt'].get_duration_and_reset()
             # Show information
             log_kwargs = {'items': packs['log']} if 'lmd_process_log' not in kwargs.keys() else \
                 kwargs['lmd_process_log'](packs['log'])
             self._logs['log_main'].info_formatted(
                 [epoch, batch_index, iter_index], **log_kwargs)
         # (2) Tensorboard
         if 'tfboard' in self._logs.keys() and not chk_d(
                 kwargs, 'disable_tfboard'):
             tfboard_add_multi_scalars(self._logs['tfboard'],
                                       packs['tfboard'],
                                       global_step=iter_index)
 def _process_log_after_step(self, packs, **kwargs):
     # Get iteration for log & tfboard
     iter_checker = kwargs['iter_checker'] if 'iter_checker' in kwargs.keys(
     ) else 'iter'
     # Logging
     if chk_d(self._meters, 'counter_log',
              lambda c: c.check(self._meters['i'][iter_checker])):
         if 'lmd_generate_log' in kwargs.keys():
             kwargs['lmd_generate_log']()
         # (1) Logs
         if 'log_main' in self._logs.keys() and 'log' in packs.keys(
         ) and not chk_d(kwargs, 'disable_log'):
             # Update io & optimize timers
             if 'io' in self._meters['timers']:
                 packs['log']['t_io'] = self._meters['timers'][
                     'io'].get_duration_and_reset()
             if 'opt' in self._meters['timers']:
                 packs['log']['t_opt'] = self._meters['timers'][
                     'opt'].get_duration_and_reset()
             # Show information
             log_kwargs = {'items': packs['log']} if 'lmd_process_log' not in kwargs.keys() else \
                 kwargs['lmd_process_log'](packs['log'])
             self._logs['log_main'].info_formatted(
                 fet_d(self._meters['i'],
                       *self._logs['log_main'].formatted_counters),
                 **log_kwargs)
         # (2) Tensorboard
         if 'tfboard' in self._logs.keys() and 'tfboard' in packs.keys(
         ) and not chk_d(kwargs, 'disable_tfboard'):
             tfboard_add_multi_scalars(self._logs['tfboard'],
                                       packs['tfboard'],
                                       self._meters['i'][iter_checker])
    def _init_packs(self, *args, **kwargs):
        """ Init packages. """
        def _init(_k):
            return ValidContainer(
                **(kwargs[_k] if _k in kwargs.keys() else {}))

        # 1. Init
        ret = {}
        # 2. Set packages
        # (1) Log
        if 'log_main' in self._logs.keys() and not chk_d(
                kwargs, 'disable_log'):
            assert 'log' not in args
            ret['log'] = _init('log')
        # (2) TFBoard
        if 'tfboard' in self._logs.keys() and not chk_d(
                kwargs, 'disable_tfboard'):
            assert 'tfboard' not in args
            ret['tfboard'] = _init('tfboard')
        # (3) Others
        if len(args) > 0:
            assert len(set(args)) == len(args)
            for k in args:
                ret[k] = _init(k)
        # Return
        return ret
 def _set_logs(self, **kwargs):
     if not chk_d(kwargs, 'disable_log'):
         self._logs['log_main'] = Logger(
             self._cfg.args.ana_dir,
             'train',
             formatted_prefix=self._cfg.args.desc,
             formatted_counters=kwargs['log_main_counters'],
             append_mode=False if self._cfg.args.load_from == -1 else True)
     if hasattr(self._cfg.args,
                'tfboard_dir') and not chk_d(kwargs, 'disable_tfboard'):
         self._logs['tfboard'] = SummaryWriter(
             os.path.join(self._cfg.args.tfboard_dir, self._cfg.args.desc))
 def _add_tree_args(self, args_dict):
     ################################################################################################################
     # Datasets
     ################################################################################################################
     self.parser.add_argument("--dataset_shuffle",               type=int,   default=1,  choices=[0, 1])
     self.parser.add_argument("--dataset_num_threads",           type=int,   default=0)
     self.parser.add_argument("--dataset_drop_last",             type=bool,  default=True)
     ################################################################################################################
     # Others
     ################################################################################################################
     # Modules
     if args_dict['dataset'] == 'mnist':
         self.parser.set(["input_dim", "num_classes"], [784, 10])
     if args_dict['model'] == 'vib':
         self.parser.add_argument("--vib_softplus_scalar",       type=float, default=-1.0,
                                  help="Set to -1.0 to disable. To make the std initially small by minus a positive scalar.")
     if args_dict['model'] == 'nib':
         self.parser.add_argument("--nib_log_std",               type=float, default=1.0)
         self.parser.add_argument("--nib_log_std_trainable",     type=bool,  default=True)
     self.parser.add_argument("--enc_hidden_dims",               type=str,   default="[1024,1024]")
     self.parser.add_argument("--dec_hidden_dims",               type=str,   default="[]")
     self.parser.add_argument("--emb_dim",                       type=int,   default=16)
     # Optimization & Lambda
     self.parser.add_argument("--hfunc",                         type=str,   default='none', choices=['none', 'exp', 'pow'])
     if chk_d(args_dict, 'hfunc', '!=', 'none'):
         self.parser.add_argument("--hfunc_param",               type=float, default=1.0)
     self.parser.add_argument("--lambda_kl",                     type=float, default=0.01)
     # Evaluating args
     self.parser.add_argument("--eval_batch_size",               type=int,   default=2560)
     self.parser.add_argument("--eval_attack_epsilons",          type=str,   default='[0.1,0.2,0.3]')
     self.parser.add_argument("--eval_odin_out_data",            type=str,   default='gauss')
     self.parser.add_argument("--eval_odin_temper",              type=float, default=1000)
     self.parser.add_argument("--eval_odin_noise_mag",           type=float, default=0.0014)
     self.parser.add_argument("--eval_odin_num_delta",           type=int,   default=10000)
 def _fetch_batch_data(self, **kwargs):
     record = not chk_d(kwargs, 'no_record')
     # Fetch data & update iterations
     with self._meters['timers']('io'):
         # Fetch data
         batch_iters, batch_data_deployed = self._deploy_batch_data(
             next(self._data['train']))
         # Update iterations
         if record:
             # (1) Update iter_index
             self._meters['i']['iter'] += batch_iters
             # (2) Move forward
             if 'epoch' in self._meters['i'].keys():
                 num_cur_epoch = self._meters['i'][
                     'num_cur_epoch'] + batch_iters
                 num_train_samples = self._meters['i']['num_train_samples']
                 if num_cur_epoch >= num_train_samples:
                     self._meters['i'][
                         'num_cur_epoch'] = num_cur_epoch % num_train_samples
                     self._meters['i']['batch'] = 0
                     self._meters['i']['epoch'] += 1
                 else:
                     self._meters['i']['num_cur_epoch'] = num_cur_epoch
                     self._meters['i']['batch'] += 1
     # Return
     return batch_data_deployed
Пример #7
0
 def _set_directory_args(self, **kwargs):
     dirs = super(CanonicalConfigTrainPyTorch, self)._set_directory_args()
     # Tensorboard
     if not chk_d(kwargs, 'use_tfboard', 'not'):
         self.args.tfboard_dir = os.path.join(self._exp_dir_path, '../tensorboard')
         dirs.append(self.args.tfboard_dir)
     # Return
     return dirs
 def _process_after_epoch(self, epoch, iter_index, packs):
     """
     :rtype: Whether to early stop training (bool), none by default.
     """
     # Update learning rate
     self._update_learning_rate()
     # Save current epoch
     if chk_d(self._meters, 'counter_chkpt', lambda c: c.check(epoch)):
         self._save_checkpoint(epoch)
 def _process_chkpt_and_lr_after_step(self, **kwargs):
     # Get iteration for chkpt
     iter_checker = kwargs['iter_checker'] if 'iter_checker' in kwargs.keys(
     ) else 'step'
     # 1. Learning rate
     self._update_learning_rate()
     # 2. Chkpt
     if chk_d(self._meters, 'counter_chkpt',
              lambda c: c.check(self._meters['i'][iter_checker])):
         self._save_checkpoint(n=self._meters['i'][iter_checker])
 def parse_logger(*args, **kwargs):
     """
     :param args:
         - log_dir, log_name or
         - log_path
     :return: List of Dict of
         {
             - time:
             - titles:
             - counters:
             - item_name1: item_value1;
             - item_name2: item_value2:
             ...
         }
     """
     # 1. Find logger
     if len(args) == 1:
         logger = args[0]
     else:
         log_dir, log_name = args
         # Search
         logger = []
         for obj in os.listdir(log_dir):
             if obj.endswith("_%s.log" % log_name):
                 logger.append(os.path.join(log_dir, obj))
         assert len(
             logger
         ) == 1, "Too many loggers satisfy the given log name: \n\t%s. " % logger
         logger = logger[0]
     # 2. Parse logger
     # (1) Init result
     r = []
     # (2) Parse each line
     # 1> Open file
     f = open(logger, 'r')
     # 2> Parsing
     lines = f.readlines()
     for index, current_line in enumerate(lines):
         # Show progress
         if chk_d(kwargs, 'verbose'):
             show_progress(title="Parsing log file '%s'" %
                           os.path.split(logger)[1],
                           index=index,
                           maximum=len(lines))
         # 1. Try to parse current line
         try:
             current_result = Logger.parse_logger_line(current_line)
             r.append(current_result)
         # 2. Abort
         except:
             continue
     # 3> Close file
     f.close()
     # Return
     return r
Пример #11
0
 def add_argument(self, key, **kwargs):
     assert key.startswith("--"), "Argument key must start with '--'. "
     if key[2:] not in self._args_dict.keys():
         # Check duplicate
         self._check_duplicate(key[2:])
         # Set command
         if chk_d(kwargs, 'type', '==', bool):
             kwargs['type'] = str2bool
             if 'default' in kwargs.keys(): kwargs['default'] = str(kwargs['default'])
         self._parser.add_argument(key, **kwargs)
         # Save
         self._parser_names.append(key[2:])
Пример #12
0
 def __call__(self, *args, **kwargs):
     # Calculate cache
     if not chk_d(kwargs, 'void'):
         # 1. On
         on = list(filter(lambda _k: _k in self._timers, args))
         for k in on:
             assert self._timers[k].stat == 'off'
         # 2. Off
         off = [
             k for k in filter(lambda _k: self._timers[_k].stat == 'on',
                               self._timers.keys())
         ]
         # Result
         cache = {'on': on, 'off': off}
     else:
         cache = None
     # Return
     return _TimersManager(self._timers, cache=cache)
 def _train_procedure(self, **kwargs):
     """
     Training procedure
     """
     # 1. Preliminaries
     iter_marker, iter_max = self._set_iterations()
     packs = self._init_packs()
     # 2. Main
     while self._meters['i'][iter_marker] < iter_max:
         # 1. Train
         with self._meters['timers']('opt', void=chk_d(kwargs,
                                                       'dis_t_opt')):
             self._set_to_train_mode()
             self._train_step(packs)
         # 2. Process after train
         early_stop = self._process_after_step(packs)
         if early_stop: return
         # Move forward
         self._meters['i']['step'] += 1
     # Save final result
     self._save_checkpoint(self._meters['i'][iter_marker] - 1)
 def _train_procedure(self, **kwargs):
     """
     Training procedure. 
     """
     # 1. Initialize packs used in training
     packs = self._init_packs()
     # 2. Training
     iter_index = -1
     for epoch in range(self._cfg.args.load_from + 1,
                        self._cfg.args.epochs):
         # 1. Train each batch
         # Start recording io time
         if 'io' in self._meters['timers']:
             self._meters['timers']['io'].resume()
         # Read batch data
         for batch_index, batch_data in enumerate(self._data['train']):
             # Deploy
             batch_iters, batch_data = self._deploy_batch_data(batch_data)
             iter_index += batch_iters
             # End recording io time & start recording optimize time
             if 'io' in self._meters['timers']:
                 self._meters['timers']['io'].pause()
             # Batch optimization
             with self._meters['timers']('opt',
                                         void=chk_d(kwargs, 'dis_t_opt')):
                 self._set_to_train_mode()
                 self._train_batch(epoch, batch_index, iter_index,
                                   batch_data, packs)
             ########################################################################################################
             # After-batch operations
             ########################################################################################################
             self._process_after_batch(epoch, batch_index, iter_index,
                                       packs)
         # 2. Process after epoch
         early_stop = self._process_after_epoch(epoch, iter_index, packs)
         if early_stop: return
     # Save final results
     self._save_checkpoint(self._cfg.args.epochs - 1)
 def _set_iterations(self, **kwargs):
     # Init iterations
     self._meters['i'] = {}
     # (1) Involving dataset
     if not chk_d(kwargs, 'disable_epoch_batch'):
         self._meters['i']['num_train_samples'] = self._data['train'].num_samples \
             if 'num_train_samples' not in kwargs.keys() else kwargs.pop('num_train_samples')
         self._meters['i'].update({
             'epoch': 0,
             'batch': -1,
             'num_cur_epoch': 0
         })
     # (2) Step & iter
     self._meters['i'].update({'step': 0, 'iter': -1})
     # Generate
     if hasattr(self._cfg.args, 'steps'):
         assert not hasattr(self._cfg.args, 'iters')
         return 'step', self._cfg.args.steps
     elif hasattr(self._cfg.args, 'iters'):
         assert not hasattr(self._cfg.args, 'steps')
         return 'iter', self._cfg.args.iters
     else:
         raise ValueError
 def save_images_to_the_disk(self, visuals, iter_count=0, **kwargs):
     # 1. Save to disk
     for label, image in visuals.items():
         # Path
         img_path = os.path.join(
             self._webpage.image_dir,
             '%s.png' % self._get_image_name(iter_count, label))
         # Save
         try:
             self._lmd_save_image(image, img_path)
         except FileNotFoundError:
             os.makedirs(os.path.split(img_path)[0])
             self._lmd_save_image(image, img_path)
     # 2. Update for website
     if isinstance(self._webpage, HTML):
         # Update key
         for key in list(visuals.keys()):
             if key not in self._visual_labels:
                 self._visual_labels.append(key)
         # Update iter
         self._iter_container.append(iter_count)
         # Flush website
         if chk_d(kwargs, 'flush_website'):
             self.save_website()
 def _set_meters(self, **kwargs):
     # Timers
     self._meters['timers'] = TimersController()
     if not chk_d(kwargs, 'disable_timers'):
         self._meters['timers']['io'] = StopWatch()
         self._meters['timers']['opt'] = StopWatch()