def _process_after_batch(self, epoch, batch_index, iter_index, packs, **kwargs): # Logging if chk_d(self._meters, 'counter_log', lambda c: c.check(iter_index)): if 'lmd_generate_log' in kwargs.keys(): kwargs['lmd_generate_log']() # (1) Logs if 'log_main' in self._logs.keys() and not chk_d( kwargs, 'disable_log'): # Update io & optimize timers if 'io' in self._meters['timers']: packs['log']['t_io'] = self._meters['timers'][ 'io'].get_duration_and_reset() if 'opt' in self._meters['timers']: packs['log']['t_opt'] = self._meters['timers'][ 'opt'].get_duration_and_reset() # Show information log_kwargs = {'items': packs['log']} if 'lmd_process_log' not in kwargs.keys() else \ kwargs['lmd_process_log'](packs['log']) self._logs['log_main'].info_formatted( [epoch, batch_index, iter_index], **log_kwargs) # (2) Tensorboard if 'tfboard' in self._logs.keys() and not chk_d( kwargs, 'disable_tfboard'): tfboard_add_multi_scalars(self._logs['tfboard'], packs['tfboard'], global_step=iter_index)
def _process_log_after_step(self, packs, **kwargs): # Get iteration for log & tfboard iter_checker = kwargs['iter_checker'] if 'iter_checker' in kwargs.keys( ) else 'iter' # Logging if chk_d(self._meters, 'counter_log', lambda c: c.check(self._meters['i'][iter_checker])): if 'lmd_generate_log' in kwargs.keys(): kwargs['lmd_generate_log']() # (1) Logs if 'log_main' in self._logs.keys() and 'log' in packs.keys( ) and not chk_d(kwargs, 'disable_log'): # Update io & optimize timers if 'io' in self._meters['timers']: packs['log']['t_io'] = self._meters['timers'][ 'io'].get_duration_and_reset() if 'opt' in self._meters['timers']: packs['log']['t_opt'] = self._meters['timers'][ 'opt'].get_duration_and_reset() # Show information log_kwargs = {'items': packs['log']} if 'lmd_process_log' not in kwargs.keys() else \ kwargs['lmd_process_log'](packs['log']) self._logs['log_main'].info_formatted( fet_d(self._meters['i'], *self._logs['log_main'].formatted_counters), **log_kwargs) # (2) Tensorboard if 'tfboard' in self._logs.keys() and 'tfboard' in packs.keys( ) and not chk_d(kwargs, 'disable_tfboard'): tfboard_add_multi_scalars(self._logs['tfboard'], packs['tfboard'], self._meters['i'][iter_checker])
def _init_packs(self, *args, **kwargs): """ Init packages. """ def _init(_k): return ValidContainer( **(kwargs[_k] if _k in kwargs.keys() else {})) # 1. Init ret = {} # 2. Set packages # (1) Log if 'log_main' in self._logs.keys() and not chk_d( kwargs, 'disable_log'): assert 'log' not in args ret['log'] = _init('log') # (2) TFBoard if 'tfboard' in self._logs.keys() and not chk_d( kwargs, 'disable_tfboard'): assert 'tfboard' not in args ret['tfboard'] = _init('tfboard') # (3) Others if len(args) > 0: assert len(set(args)) == len(args) for k in args: ret[k] = _init(k) # Return return ret
def _set_logs(self, **kwargs): if not chk_d(kwargs, 'disable_log'): self._logs['log_main'] = Logger( self._cfg.args.ana_dir, 'train', formatted_prefix=self._cfg.args.desc, formatted_counters=kwargs['log_main_counters'], append_mode=False if self._cfg.args.load_from == -1 else True) if hasattr(self._cfg.args, 'tfboard_dir') and not chk_d(kwargs, 'disable_tfboard'): self._logs['tfboard'] = SummaryWriter( os.path.join(self._cfg.args.tfboard_dir, self._cfg.args.desc))
def _add_tree_args(self, args_dict): ################################################################################################################ # Datasets ################################################################################################################ self.parser.add_argument("--dataset_shuffle", type=int, default=1, choices=[0, 1]) self.parser.add_argument("--dataset_num_threads", type=int, default=0) self.parser.add_argument("--dataset_drop_last", type=bool, default=True) ################################################################################################################ # Others ################################################################################################################ # Modules if args_dict['dataset'] == 'mnist': self.parser.set(["input_dim", "num_classes"], [784, 10]) if args_dict['model'] == 'vib': self.parser.add_argument("--vib_softplus_scalar", type=float, default=-1.0, help="Set to -1.0 to disable. To make the std initially small by minus a positive scalar.") if args_dict['model'] == 'nib': self.parser.add_argument("--nib_log_std", type=float, default=1.0) self.parser.add_argument("--nib_log_std_trainable", type=bool, default=True) self.parser.add_argument("--enc_hidden_dims", type=str, default="[1024,1024]") self.parser.add_argument("--dec_hidden_dims", type=str, default="[]") self.parser.add_argument("--emb_dim", type=int, default=16) # Optimization & Lambda self.parser.add_argument("--hfunc", type=str, default='none', choices=['none', 'exp', 'pow']) if chk_d(args_dict, 'hfunc', '!=', 'none'): self.parser.add_argument("--hfunc_param", type=float, default=1.0) self.parser.add_argument("--lambda_kl", type=float, default=0.01) # Evaluating args self.parser.add_argument("--eval_batch_size", type=int, default=2560) self.parser.add_argument("--eval_attack_epsilons", type=str, default='[0.1,0.2,0.3]') self.parser.add_argument("--eval_odin_out_data", type=str, default='gauss') self.parser.add_argument("--eval_odin_temper", type=float, default=1000) self.parser.add_argument("--eval_odin_noise_mag", type=float, default=0.0014) self.parser.add_argument("--eval_odin_num_delta", type=int, default=10000)
def _fetch_batch_data(self, **kwargs): record = not chk_d(kwargs, 'no_record') # Fetch data & update iterations with self._meters['timers']('io'): # Fetch data batch_iters, batch_data_deployed = self._deploy_batch_data( next(self._data['train'])) # Update iterations if record: # (1) Update iter_index self._meters['i']['iter'] += batch_iters # (2) Move forward if 'epoch' in self._meters['i'].keys(): num_cur_epoch = self._meters['i'][ 'num_cur_epoch'] + batch_iters num_train_samples = self._meters['i']['num_train_samples'] if num_cur_epoch >= num_train_samples: self._meters['i'][ 'num_cur_epoch'] = num_cur_epoch % num_train_samples self._meters['i']['batch'] = 0 self._meters['i']['epoch'] += 1 else: self._meters['i']['num_cur_epoch'] = num_cur_epoch self._meters['i']['batch'] += 1 # Return return batch_data_deployed
def _set_directory_args(self, **kwargs): dirs = super(CanonicalConfigTrainPyTorch, self)._set_directory_args() # Tensorboard if not chk_d(kwargs, 'use_tfboard', 'not'): self.args.tfboard_dir = os.path.join(self._exp_dir_path, '../tensorboard') dirs.append(self.args.tfboard_dir) # Return return dirs
def _process_after_epoch(self, epoch, iter_index, packs): """ :rtype: Whether to early stop training (bool), none by default. """ # Update learning rate self._update_learning_rate() # Save current epoch if chk_d(self._meters, 'counter_chkpt', lambda c: c.check(epoch)): self._save_checkpoint(epoch)
def _process_chkpt_and_lr_after_step(self, **kwargs): # Get iteration for chkpt iter_checker = kwargs['iter_checker'] if 'iter_checker' in kwargs.keys( ) else 'step' # 1. Learning rate self._update_learning_rate() # 2. Chkpt if chk_d(self._meters, 'counter_chkpt', lambda c: c.check(self._meters['i'][iter_checker])): self._save_checkpoint(n=self._meters['i'][iter_checker])
def parse_logger(*args, **kwargs): """ :param args: - log_dir, log_name or - log_path :return: List of Dict of { - time: - titles: - counters: - item_name1: item_value1; - item_name2: item_value2: ... } """ # 1. Find logger if len(args) == 1: logger = args[0] else: log_dir, log_name = args # Search logger = [] for obj in os.listdir(log_dir): if obj.endswith("_%s.log" % log_name): logger.append(os.path.join(log_dir, obj)) assert len( logger ) == 1, "Too many loggers satisfy the given log name: \n\t%s. " % logger logger = logger[0] # 2. Parse logger # (1) Init result r = [] # (2) Parse each line # 1> Open file f = open(logger, 'r') # 2> Parsing lines = f.readlines() for index, current_line in enumerate(lines): # Show progress if chk_d(kwargs, 'verbose'): show_progress(title="Parsing log file '%s'" % os.path.split(logger)[1], index=index, maximum=len(lines)) # 1. Try to parse current line try: current_result = Logger.parse_logger_line(current_line) r.append(current_result) # 2. Abort except: continue # 3> Close file f.close() # Return return r
def add_argument(self, key, **kwargs): assert key.startswith("--"), "Argument key must start with '--'. " if key[2:] not in self._args_dict.keys(): # Check duplicate self._check_duplicate(key[2:]) # Set command if chk_d(kwargs, 'type', '==', bool): kwargs['type'] = str2bool if 'default' in kwargs.keys(): kwargs['default'] = str(kwargs['default']) self._parser.add_argument(key, **kwargs) # Save self._parser_names.append(key[2:])
def __call__(self, *args, **kwargs): # Calculate cache if not chk_d(kwargs, 'void'): # 1. On on = list(filter(lambda _k: _k in self._timers, args)) for k in on: assert self._timers[k].stat == 'off' # 2. Off off = [ k for k in filter(lambda _k: self._timers[_k].stat == 'on', self._timers.keys()) ] # Result cache = {'on': on, 'off': off} else: cache = None # Return return _TimersManager(self._timers, cache=cache)
def _train_procedure(self, **kwargs): """ Training procedure """ # 1. Preliminaries iter_marker, iter_max = self._set_iterations() packs = self._init_packs() # 2. Main while self._meters['i'][iter_marker] < iter_max: # 1. Train with self._meters['timers']('opt', void=chk_d(kwargs, 'dis_t_opt')): self._set_to_train_mode() self._train_step(packs) # 2. Process after train early_stop = self._process_after_step(packs) if early_stop: return # Move forward self._meters['i']['step'] += 1 # Save final result self._save_checkpoint(self._meters['i'][iter_marker] - 1)
def _train_procedure(self, **kwargs): """ Training procedure. """ # 1. Initialize packs used in training packs = self._init_packs() # 2. Training iter_index = -1 for epoch in range(self._cfg.args.load_from + 1, self._cfg.args.epochs): # 1. Train each batch # Start recording io time if 'io' in self._meters['timers']: self._meters['timers']['io'].resume() # Read batch data for batch_index, batch_data in enumerate(self._data['train']): # Deploy batch_iters, batch_data = self._deploy_batch_data(batch_data) iter_index += batch_iters # End recording io time & start recording optimize time if 'io' in self._meters['timers']: self._meters['timers']['io'].pause() # Batch optimization with self._meters['timers']('opt', void=chk_d(kwargs, 'dis_t_opt')): self._set_to_train_mode() self._train_batch(epoch, batch_index, iter_index, batch_data, packs) ######################################################################################################## # After-batch operations ######################################################################################################## self._process_after_batch(epoch, batch_index, iter_index, packs) # 2. Process after epoch early_stop = self._process_after_epoch(epoch, iter_index, packs) if early_stop: return # Save final results self._save_checkpoint(self._cfg.args.epochs - 1)
def _set_iterations(self, **kwargs): # Init iterations self._meters['i'] = {} # (1) Involving dataset if not chk_d(kwargs, 'disable_epoch_batch'): self._meters['i']['num_train_samples'] = self._data['train'].num_samples \ if 'num_train_samples' not in kwargs.keys() else kwargs.pop('num_train_samples') self._meters['i'].update({ 'epoch': 0, 'batch': -1, 'num_cur_epoch': 0 }) # (2) Step & iter self._meters['i'].update({'step': 0, 'iter': -1}) # Generate if hasattr(self._cfg.args, 'steps'): assert not hasattr(self._cfg.args, 'iters') return 'step', self._cfg.args.steps elif hasattr(self._cfg.args, 'iters'): assert not hasattr(self._cfg.args, 'steps') return 'iter', self._cfg.args.iters else: raise ValueError
def save_images_to_the_disk(self, visuals, iter_count=0, **kwargs): # 1. Save to disk for label, image in visuals.items(): # Path img_path = os.path.join( self._webpage.image_dir, '%s.png' % self._get_image_name(iter_count, label)) # Save try: self._lmd_save_image(image, img_path) except FileNotFoundError: os.makedirs(os.path.split(img_path)[0]) self._lmd_save_image(image, img_path) # 2. Update for website if isinstance(self._webpage, HTML): # Update key for key in list(visuals.keys()): if key not in self._visual_labels: self._visual_labels.append(key) # Update iter self._iter_container.append(iter_count) # Flush website if chk_d(kwargs, 'flush_website'): self.save_website()
def _set_meters(self, **kwargs): # Timers self._meters['timers'] = TimersController() if not chk_d(kwargs, 'disable_timers'): self._meters['timers']['io'] = StopWatch() self._meters['timers']['opt'] = StopWatch()