예제 #1
0
class _DataStruct(BaseStruct):
    def __init__(self):
        super().__init__()

    def is_data(self, member):
        return not self.is_struct(member)

    def is_struct(self, member):
        return isinstance(self._members[member], _DataStruct)

    def make_struct(self, member):
        if not member in self._members: self[member] = _DataStruct()

    def to_string(self, indent):
        s = ''
        for name, member in self._members.items():
            if isinstance(member, np.ndarray):
                s += '%s np.array: %s %s\n' % (' ' * indent, name,
                                               member.shape)
            elif isinstance(member, _DataStruct):
                s += '%s struct: %s\n' % (' ' * indent, name)
                s += member.to_string(indent + 2)
            else:
                s += '%s %s = %s\n' % (' ' * indent, name, str(member))
        if s == '': s = "%s (empty)\n" % (' ' * indent)
        return s


register_class('DataStruct', _DataStruct)
예제 #2
0
    def shared_batchnorm(self):
        return self._config['shared_batchnorm']

    def correlation_leaky_relu(self):
        return self._config['correlation_leaky_relu']

    def global_step(self):
        if tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                             scope="global_step") != []:
            return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope="global_step")[0]
        else:
            return None


register_class('Scope', _Scope)

bottom_scope = _Scope(None,
                      learn=True,
                      loss_fact=1.0,
                      conv_op=nd.ops.conv,
                      conv_nonlin_op=nd.ops.conv_relu,
                      upconv_op=nd.ops.upconv,
                      upconv_nonlin_op=nd.ops.upconv_relu,
                      weight_decay=0.0,
                      shared_batchnorm=True,
                      correlation_leaky_relu=False)

bottom_scope.__enter__()
예제 #3
0
        super().__init__()

    def get_state(self, id):
        if isinstance(id, _State):
            return id
        elif id is None:
            return None
        elif ':' in str(id):
            evo_name, state_name = id.split(':')
            state = self.get_evolution(evo_name).get_state(state_name)
            if state is None:
                raise KeyError('State <'+str(id)+'> not found')
            return state
        else:
            try:
                id = int(id)
                for evo in self.evolutions():
                    state = evo.get_state(id)
                    if state is not None:
                        return state
            except ValueError:
                pass

            raise Exception('State '+str(id)+' not found.')
            # if len(self.evolutions())==1:
            #     return self.first_evolution().get_state(id)
            # else:
            #     raise ValueError('Cannot identify state '+str(id)+' with more than 1 evolution, specify one.')

register_class('EvolutionManager', _EvolutionManager)
예제 #4
0
파일: state.py 프로젝트: wpfhtl/netdef_slim
        file_re = re.compile('([^0-9]*)'+str(self.iter())+'{1}(.*)')
        for file in os.listdir(self.folder()):
            file_match = file_re.match(file)
            if file_match:
                if re.compile('\.index').match(file_match.group(2)):
                    index = os.path.join(self.folder(), file)
                elif re.compile('\.meta').match(file_match.group(2)):
                    meta = os.path.join(self.folder(), file)
                elif re.compile('\.data-\d{5}-of-\d{5}').match(file_match.group(2)):
                    data.append(os.path.join(self.folder(), file))
                prefix = file_match.group(1)
        if index is not None and meta is not None and len(data) > 0:
            #return data + [index] + [meta]
            return {'data': data, 'index': index, 'meta': meta, 'prefix': prefix}

    def folder(self):
        return os.path.join(nd.evo_manager.training_dir(), self.evo_name(), 'checkpoints')

    def clean(self):
        files = self.files()
        if files is not None:
            os.remove(files['index'])
            os.remove(files['meta'])
            [os.remove(data_file) for data_file in files['data']]

    def path(self):
        return os.path.join(self.folder(), self.files()['prefix']+str(self.iter()))

register_class('State', _State)
예제 #5
0
                                     self.prefix() + str(last_state.iter()) +
                                     '"')

            line = checkpoint_log.readline()
            match = log_line_re.match(line)
            if not match:
                return nd.status.CONFIGURATION_ERROR
        if not len(states):
            if os.path.isfile(
                    os.path.join(self._snapshots_path(),
                                 'checkpoint')) and os.path.getsize(
                                     os.path.join(self._snapshots_path(),
                                                  'checkpoint')) > 0:
                return nd.status.CONFIGURATION_ERROR

        return nd.status.SUCCESS

    def update_states_log(self):
        checkpoint_log = open(
            os.path.join(self._snapshots_path(), 'checkpoint'), 'w')
        last_state = self.last_state()
        if last_state is not None:
            checkpoint_log.write('model_checkpoint_path: \"' + self.prefix() +
                                 str(last_state.iter()) + '\"\n')

    def get_state_path(self, id):
        return os.path.join(self._snapshots_path(), self.prefix() + str(id))


register_class('Evolution', _Evolution)
예제 #6
0
from netdef_slim.core.base_struct import BaseStruct
from netdef_slim.core.register import register_class

nothing = None


class _Struct(BaseStruct):
    def __init__(self):
        super().__init__()

    def is_data(self, member):
        return not self.is_struct(member)

    def is_struct(self, member):
        return isinstance(self._members[member], _Struct)

    def to_string(self, indent):
        s = str(self._members)
        return s

    def make_sibling(self, other):
        pass


register_class('Struct', _Struct)
예제 #7
0
from netdef_slim.schedules.fixed_step_schedule import _FixedStepSchedule
from netdef_slim.core.register import register_class
import tensorflow as tf

nothing = None

class FixedStepSchedule(_FixedStepSchedule):
    def __init__(self, name, base_lr, max_iter, steps, gamma=0.5, stretch=1.0):
        super().__init__(name, base_lr, max_iter, steps, gamma, stretch)

    def get_schedule(self, global_step):
        lr = self.base_lr()
        lr_steps = [lr]
        for step in self.step_iters():
            lr = lr * self.gamma()
            lr_steps.append(lr)
        return  tf.train.piecewise_constant(global_step, self.step_iters(), lr_steps)

register_class('FixedStepSchedule', FixedStepSchedule)