class _DataStruct(BaseStruct): def __init__(self): super().__init__() def is_data(self, member): return not self.is_struct(member) def is_struct(self, member): return isinstance(self._members[member], _DataStruct) def make_struct(self, member): if not member in self._members: self[member] = _DataStruct() def to_string(self, indent): s = '' for name, member in self._members.items(): if isinstance(member, np.ndarray): s += '%s np.array: %s %s\n' % (' ' * indent, name, member.shape) elif isinstance(member, _DataStruct): s += '%s struct: %s\n' % (' ' * indent, name) s += member.to_string(indent + 2) else: s += '%s %s = %s\n' % (' ' * indent, name, str(member)) if s == '': s = "%s (empty)\n" % (' ' * indent) return s register_class('DataStruct', _DataStruct)
def shared_batchnorm(self): return self._config['shared_batchnorm'] def correlation_leaky_relu(self): return self._config['correlation_leaky_relu'] def global_step(self): if tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="global_step") != []: return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="global_step")[0] else: return None register_class('Scope', _Scope) bottom_scope = _Scope(None, learn=True, loss_fact=1.0, conv_op=nd.ops.conv, conv_nonlin_op=nd.ops.conv_relu, upconv_op=nd.ops.upconv, upconv_nonlin_op=nd.ops.upconv_relu, weight_decay=0.0, shared_batchnorm=True, correlation_leaky_relu=False) bottom_scope.__enter__()
super().__init__() def get_state(self, id): if isinstance(id, _State): return id elif id is None: return None elif ':' in str(id): evo_name, state_name = id.split(':') state = self.get_evolution(evo_name).get_state(state_name) if state is None: raise KeyError('State <'+str(id)+'> not found') return state else: try: id = int(id) for evo in self.evolutions(): state = evo.get_state(id) if state is not None: return state except ValueError: pass raise Exception('State '+str(id)+' not found.') # if len(self.evolutions())==1: # return self.first_evolution().get_state(id) # else: # raise ValueError('Cannot identify state '+str(id)+' with more than 1 evolution, specify one.') register_class('EvolutionManager', _EvolutionManager)
file_re = re.compile('([^0-9]*)'+str(self.iter())+'{1}(.*)') for file in os.listdir(self.folder()): file_match = file_re.match(file) if file_match: if re.compile('\.index').match(file_match.group(2)): index = os.path.join(self.folder(), file) elif re.compile('\.meta').match(file_match.group(2)): meta = os.path.join(self.folder(), file) elif re.compile('\.data-\d{5}-of-\d{5}').match(file_match.group(2)): data.append(os.path.join(self.folder(), file)) prefix = file_match.group(1) if index is not None and meta is not None and len(data) > 0: #return data + [index] + [meta] return {'data': data, 'index': index, 'meta': meta, 'prefix': prefix} def folder(self): return os.path.join(nd.evo_manager.training_dir(), self.evo_name(), 'checkpoints') def clean(self): files = self.files() if files is not None: os.remove(files['index']) os.remove(files['meta']) [os.remove(data_file) for data_file in files['data']] def path(self): return os.path.join(self.folder(), self.files()['prefix']+str(self.iter())) register_class('State', _State)
self.prefix() + str(last_state.iter()) + '"') line = checkpoint_log.readline() match = log_line_re.match(line) if not match: return nd.status.CONFIGURATION_ERROR if not len(states): if os.path.isfile( os.path.join(self._snapshots_path(), 'checkpoint')) and os.path.getsize( os.path.join(self._snapshots_path(), 'checkpoint')) > 0: return nd.status.CONFIGURATION_ERROR return nd.status.SUCCESS def update_states_log(self): checkpoint_log = open( os.path.join(self._snapshots_path(), 'checkpoint'), 'w') last_state = self.last_state() if last_state is not None: checkpoint_log.write('model_checkpoint_path: \"' + self.prefix() + str(last_state.iter()) + '\"\n') def get_state_path(self, id): return os.path.join(self._snapshots_path(), self.prefix() + str(id)) register_class('Evolution', _Evolution)
from netdef_slim.core.base_struct import BaseStruct from netdef_slim.core.register import register_class nothing = None class _Struct(BaseStruct): def __init__(self): super().__init__() def is_data(self, member): return not self.is_struct(member) def is_struct(self, member): return isinstance(self._members[member], _Struct) def to_string(self, indent): s = str(self._members) return s def make_sibling(self, other): pass register_class('Struct', _Struct)
from netdef_slim.schedules.fixed_step_schedule import _FixedStepSchedule from netdef_slim.core.register import register_class import tensorflow as tf nothing = None class FixedStepSchedule(_FixedStepSchedule): def __init__(self, name, base_lr, max_iter, steps, gamma=0.5, stretch=1.0): super().__init__(name, base_lr, max_iter, steps, gamma, stretch) def get_schedule(self, global_step): lr = self.base_lr() lr_steps = [lr] for step in self.step_iters(): lr = lr * self.gamma() lr_steps.append(lr) return tf.train.piecewise_constant(global_step, self.step_iters(), lr_steps) register_class('FixedStepSchedule', FixedStepSchedule)