示例#1
0
    def process_output(self, line):
        from digits.webapp import socketio

        self.caffe_log.write('%s\n' % line)
        self.caffe_log.flush()

        # parse caffe header
        timestamp, level, message = self.preprocess_output_caffe(line)

        if not message:
            return True

        match = re.match(r'Memory required for data:\s+(\d+)', message)
        if match:
            pass

        # memory requirement
        match = re.match(r'Memory required for data:\s+(\d+)', message)
        if match:
            bytes_required = int(match.group(1))
            self.logger.debug('memory required: %s' % utils.sizeof_fmt(bytes_required))
            return True

        if level in ['error', 'critical']:
            self.logger.error('%s: %s' % (self.name(), message))
            self.exception = message
            return True

        return True
示例#2
0
    def process_output(self, line):
        from digits.webapp import socketio

        self.caffe_log.write('%s\n' % line)
        self.caffe_log.flush()

        # parse caffe header
        timestamp, level, message = self.preprocess_output_caffe(line)

        if not message:
            return True

        match = re.match(r'Memory required for data:\s+(\d+)', message)
        if match:
            pass

        # memory requirement
        match = re.match(r'Memory required for data:\s+(\d+)', message)
        if match:
            bytes_required = int(match.group(1))
            self.logger.debug('memory required: %s' %
                              utils.sizeof_fmt(bytes_required))
            return True

        if level in ['error', 'critical']:
            self.logger.error('%s: %s' % (self.name(), message))
            self.exception = message
            return True

        return True
示例#3
0
def resize_example():
    """
    Resizes the example image, and returns it as a string of png data
    """
    try:
        example_image_path = os.path.join(os.path.dirname(digits.__file__),
                                          'static', 'images', 'mona_lisa.jpg')
        image = utils.image.load_image(example_image_path)

        width = int(flask.request.form['width'])
        height = int(flask.request.form['height'])
        channels = int(flask.request.form['channels'])
        resize_mode = flask.request.form['resize_mode']
        backend = flask.request.form['backend']
        encoding = flask.request.form['encoding']

        image = utils.image.resize_image(
            image,
            height,
            width,
            channels=channels,
            resize_mode=resize_mode,
        )

        if backend != 'lmdb' or encoding == 'none':
            length = len(image.tostring())
        else:
            s = StringIO()
            if encoding == 'png':
                PIL.Image.fromarray(image).save(s, format='PNG')
            elif encoding == 'jpg':
                PIL.Image.fromarray(image).save(s, format='JPEG', quality=90)
            else:
                raise ValueError('unrecognized encoding "%s"' % encoding)
            s.seek(0)
            image = PIL.Image.open(s)
            length = len(s.getvalue())

        data = utils.image.embed_image_html(image)

        return '<img src=\"' + data + '\" style=\"width:%spx;height=%spx\" />\n<br>\n<i>Image size: %s</i>' % (
            width, height, utils.sizeof_fmt(length))
    except Exception as e:
        return '%s: %s' % (type(e).__name__, e)
示例#4
0
文件: views.py 项目: 5urprise/DIGITS
def resize_example():
    """
    Resizes the example image, and returns it as a string of png data
    """
    try:
        example_image_path = os.path.join(os.path.dirname(digits.__file__), 'static', 'images', 'mona_lisa.jpg')
        image = utils.image.load_image(example_image_path)

        width = int(flask.request.form['width'])
        height = int(flask.request.form['height'])
        channels = int(flask.request.form['channels'])
        resize_mode = flask.request.form['resize_mode']
        backend = flask.request.form['backend']
        encoding = flask.request.form['encoding']

        image = utils.image.resize_image(image, height, width,
                channels=channels,
                resize_mode=resize_mode,
                )

        if backend != 'lmdb' or encoding == 'none':
            length = len(image.tostring())
        else:
            s = StringIO()
            if encoding == 'png':
                PIL.Image.fromarray(image).save(s, format='PNG')
            elif encoding == 'jpg':
                PIL.Image.fromarray(image).save(s, format='JPEG', quality=90)
            else:
                raise ValueError('unrecognized encoding "%s"' % encoding)
            s.seek(0)
            image = PIL.Image.open(s)
            length = len(s.getvalue())

        data = utils.image.embed_image_html(image)

        return '<img src=\"' + data + '\" style=\"width:%spx;height=%spx\" />\n<br>\n<i>Image size: %s</i>' % (
                width,
                height,
                utils.sizeof_fmt(length)
                )
    except Exception as e:
        return '%s: %s' % (type(e).__name__, e)
示例#5
0
文件: views.py 项目: chintak/DIGITS
def image_dataset_resize_example():
    """
    Resizes the example image, and returns it as a string of png data
    """
    try:
        example_image_path = os.path.join(os.path.dirname(digits.__file__), "static", "images", "mona_lisa.jpg")
        image = utils.image.load_image(example_image_path)

        width = int(flask.request.form["width"])
        height = int(flask.request.form["height"])
        channels = int(flask.request.form["channels"])
        resize_mode = flask.request.form["resize_mode"]
        backend = flask.request.form["backend"]
        encoding = flask.request.form["encoding"]

        image = utils.image.resize_image(image, height, width, channels=channels, resize_mode=resize_mode)

        if backend != "lmdb" or encoding == "none":
            length = len(image.tostring())
        else:
            s = StringIO()
            if encoding == "png":
                PIL.Image.fromarray(image).save(s, format="PNG")
            elif encoding == "jpg":
                PIL.Image.fromarray(image).save(s, format="JPEG", quality=90)
            else:
                raise ValueError('unrecognized encoding "%s"' % encoding)
            s.seek(0)
            image = PIL.Image.open(s)
            length = len(s.getvalue())

        data = utils.image.embed_image_html(image)

        return (
            '<img src="'
            + data
            + '" style="width:%spx;height=%spx" />\n<br>\n<i>Image size: %s</i>'
            % (width, height, utils.sizeof_fmt(length))
        )
    except Exception as e:
        return "%s: %s" % (type(e).__name__, e)
示例#6
0
    def process_output(self, line):
        from digits.webapp import socketio
        float_exp = '(NaN|[-+]?[0-9]*\.?[0-9]+(e[-+]?[0-9]+)?)'

        self.caffe_log.write('%s\n' % line)
        self.caffe_log.flush()
        # parse caffe output
        timestamp, level, message = self.preprocess_output_caffe(line)
        if not message:
            return True

        # iteration updates
        match = re.match(r'Iteration (\d+)', message)
        if match:
            i = int(match.group(1))
            self.new_iteration(i)

        # net output
        match = re.match(r'(Train|Test) net output #(\d+): (\S*) = %s' % float_exp, message, flags=re.IGNORECASE)
        if match:
            phase = match.group(1)
            index = int(match.group(2))
            name = match.group(3)
            value = match.group(4)
            assert value.lower() != 'nan', 'Network outputted NaN for "%s" (%s phase). Try decreasing your learning rate.' % (name, phase)
            value = float(value)

            # Find the layer type
            kind = '?'
            for layer in self.network.layer:
                if name in layer.top:
                    kind = layer.type
                    break

            if phase.lower() == 'train':
                self.save_train_output(name, kind, value)
            elif phase.lower() == 'test':
                self.save_val_output(name, kind, value)
            return True

        # learning rate updates
        match = re.match(r'Iteration (\d+), lr = %s' % float_exp, message, flags=re.IGNORECASE)
        if match:
            i = int(match.group(1))
            lr = float(match.group(2))
            self.save_train_output('learning_rate', 'LearningRate', lr)
            return True

        # snapshot saved
        if self.saving_snapshot:
            if not message.startswith('Snapshotting solver state'):
                self.logger.warning('caffe output format seems to have changed. Expected "Snapshotting solver state..." after "Snapshotting to..."')
            else:
                self.logger.info('Snapshot saved.')
            self.detect_snapshots()
            self.send_snapshot_update()
            self.saving_snapshot = False
            return True

        # snapshot starting
        match = re.match(r'Snapshotting to (.*)\s*$', message)
        if match:
            self.saving_snapshot = True
            return True

        # memory requirement
        match = re.match(r'Memory required for data:\s+(\d+)', message)
        if match:
            bytes_required = int(match.group(1))
            self.logger.debug('memory required: %s' % utils.sizeof_fmt(bytes_required))
            return True

        if level in ['error', 'critical']:
            self.logger.error('%s: %s' % (self.name(), message))
            self.exception = message
            return True

        return True
示例#7
0
文件: job.py 项目: maotong/DIGITS
 def disk_size_fmt(self):
     """
     return string representing job disk size
     """
     size = fs.get_tree_size(self._dir)
     return sizeof_fmt(size)
示例#8
0
文件: forms.py 项目: joakandr/DIGITS
class ModelForm(Form):

    ### Methods

    def selection_exists_in_choices(form, field):
        found = False
        for choice in field.choices:
            if choice[0] == field.data:
                found = True
        if not found:
            raise validators.ValidationError(
                "Selected job doesn't exist. Maybe it was deleted by another user."
            )

    def validate_NetParameter(form, field):
        fw = frameworks.get_framework_by_id(form['framework'].data)
        try:
            # below function raises a BadNetworkException in case of validation error
            fw.validate_network(field.data)
        except frameworks.errors.BadNetworkError as e:
            raise validators.ValidationError('Bad network: %s' % e.message)

    def validate_file_exists(form, field):
        from_client = bool(form.python_layer_from_client.data)

        filename = ''
        if not from_client and field.type == 'StringField':
            filename = field.data

        if filename == '': return

        if not os.path.isfile(filename):
            raise validators.ValidationError(
                'Server side file, %s, does not exist.' % filename)

    def validate_py_ext(form, field):
        from_client = bool(form.python_layer_from_client.data)

        filename = ''
        if from_client and field.type == 'FileField':
            filename = flask.request.files[field.name].filename
        elif not from_client and field.type == 'StringField':
            filename = field.data

        if filename == '': return

        (root, ext) = os.path.splitext(filename)
        if ext != '.py' and ext != '.pyc':
            raise validators.ValidationError(
                'Python file, %s, needs .py or .pyc extension.' % filename)

    ### Fields

    # The options for this get set in the view (since they are dynamic)
    dataset = utils.forms.SelectField(
        'Select Dataset',
        choices=[],
        tooltip="Choose the dataset to use for this model.")

    python_layer_from_client = utils.forms.BooleanField(
        u'Use client-side file', default=False)

    python_layer_client_file = utils.forms.FileField(
        u'Client-side file',
        validators=[validate_py_ext],
        tooltip=
        "Choose a Python file on the client containing layer definitions.")
    python_layer_server_file = utils.forms.StringField(
        u'Server-side file',
        validators=[validate_file_exists, validate_py_ext],
        tooltip=
        "Choose a Python file on the server containing layer definitions.")

    train_epochs = utils.forms.IntegerField(
        'Training epochs',
        validators=[validators.NumberRange(min=1)],
        default=30,
        tooltip="How many passes through the training data?")

    snapshot_interval = utils.forms.FloatField(
        'Snapshot interval (in epochs)',
        default=1,
        validators=[
            validators.NumberRange(min=0),
        ],
        tooltip="How many epochs of training between taking a snapshot?")

    val_interval = utils.forms.FloatField(
        'Validation interval (in epochs)',
        default=1,
        validators=[validators.NumberRange(min=0)],
        tooltip=
        "How many epochs of training between running through one pass of the validation data?"
    )

    random_seed = utils.forms.IntegerField(
        'Random seed',
        validators=[
            validators.NumberRange(min=0),
            validators.Optional(),
        ],
        tooltip=
        "If you provide a random seed, then back-to-back runs with the same model and dataset should give identical results."
    )

    batch_size = utils.forms.MultiIntegerField(
        'Batch size',
        validators=[
            utils.forms.MultiNumberRange(min=1),
            utils.forms.MultiOptional(),
        ],
        tooltip=
        "How many images to process at once. If blank, values are used from the network definition."
    )

    batch_accumulation = utils.forms.IntegerField(
        'Batch Accumulation',
        validators=[
            validators.NumberRange(min=1),
            validators.Optional(),
        ],
        tooltip=
        "Accumulate gradients over multiple batches (useful when you need a bigger batch size for training but it doesn't fit in memory)."
    )

    ### Solver types

    solver_type = utils.forms.SelectField(
        'Solver type',
        choices=[
            ('SGD', 'Stochastic gradient descent (SGD)'),
            ('NESTEROV', "Nesterov's accelerated gradient (NAG)"),
            ('ADAGRAD', 'Adaptive gradient (AdaGrad)'),
            ('RMSPROP', 'RMSprop'),
            ('ADADELTA', 'AdaDelta'),
            ('ADAM', 'Adam'),
        ],
        default='SGD',
        tooltip="What type of solver will be used?",
    )

    def validate_solver_type(form, field):
        fw = frameworks.get_framework_by_id(form.framework)
        if fw is not None:
            if not fw.supports_solver_type(field.data):
                raise validators.ValidationError(
                    'Solver type not supported by this framework')

    ### Learning rate

    learning_rate = utils.forms.MultiFloatField(
        'Base Learning Rate',
        default=0.01,
        validators=[
            utils.forms.MultiNumberRange(min=0),
        ],
        tooltip=
        "Affects how quickly the network learns. If you are getting NaN for your loss, you probably need to lower this value."
    )

    lr_policy = wtforms.SelectField('Policy',
                                    choices=[
                                        ('fixed', 'Fixed'),
                                        ('step', 'Step Down'),
                                        ('multistep',
                                         'Step Down (arbitrary steps)'),
                                        ('exp', 'Exponential Decay'),
                                        ('inv', 'Inverse Decay'),
                                        ('poly', 'Polynomial Decay'),
                                        ('sigmoid', 'Sigmoid Decay'),
                                    ],
                                    default='step')

    lr_step_size = wtforms.FloatField('Step Size', default=33)
    lr_step_gamma = wtforms.FloatField('Gamma', default=0.1)
    lr_multistep_values = wtforms.StringField('Step Values', default="50,85")

    def validate_lr_multistep_values(form, field):
        if form.lr_policy.data == 'multistep':
            for value in field.data.split(','):
                try:
                    float(value)
                except ValueError:
                    raise validators.ValidationError('invalid value')

    lr_multistep_gamma = wtforms.FloatField('Gamma', default=0.5)
    lr_exp_gamma = wtforms.FloatField('Gamma', default=0.95)
    lr_inv_gamma = wtforms.FloatField('Gamma', default=0.1)
    lr_inv_power = wtforms.FloatField('Power', default=0.5)
    lr_poly_power = wtforms.FloatField('Power', default=3)
    lr_sigmoid_step = wtforms.FloatField('Step', default=50)
    lr_sigmoid_gamma = wtforms.FloatField('Gamma', default=0.1)

    ### Network

    # Use a SelectField instead of a HiddenField so that the default value
    # is used when nothing is provided (through the REST API)
    method = wtforms.SelectField(
        u'Network type',
        choices=[
            ('standard', 'Standard network'),
            ('previous', 'Previous network'),
            ('pretrained', 'Pretrained network'),
            ('custom', 'Custom network'),
        ],
        default='standard',
    )

    ## framework - hidden field, set by Javascript to the selected framework ID
    framework = wtforms.HiddenField(
        'framework',
        validators=[
            validators.AnyOf(
                [fw.get_id() for fw in frameworks.get_frameworks()],
                message='The framework you choose is not currently supported.')
        ],
        default=frameworks.get_frameworks()[0].get_id())

    # The options for this get set in the view (since they are dependent on the data type)
    standard_networks = wtforms.RadioField(
        'Standard Networks',
        validators=[
            validate_required_iff(method='standard'),
        ],
    )

    previous_networks = wtforms.RadioField(
        'Previous Networks',
        choices=[],
        validators=[
            validate_required_iff(method='previous'),
            selection_exists_in_choices,
        ],
    )

    pretrained_networks = wtforms.RadioField(
        'Pretrained Networks',
        choices=[],
        validators=[
            validate_required_iff(method='pretrained'),
            selection_exists_in_choices,
        ],
    )

    custom_network = utils.forms.TextAreaField(
        'Custom Network',
        validators=[
            validate_required_iff(method='custom'),
            validate_NetParameter,
        ],
    )

    custom_network_snapshot = utils.forms.TextField(
        'Pretrained model(s)',
        tooltip=
        "Paths to pretrained model files, separated by '%s'. Only edit this field if you understand how fine-tuning works in caffe or torch."
        % os.path.pathsep)

    def validate_custom_network_snapshot(form, field):
        if form.method.data == 'custom':
            for filename in field.data.strip().split(os.path.pathsep):
                if filename and not os.path.exists(filename):
                    raise validators.ValidationError(
                        'File "%s" does not exist' % filename)

    # Select one of several GPUs
    select_gpu = wtforms.RadioField(
        'Select which GPU you would like to use',
        choices=[('next', 'Next available')] + [(
            index,
            '#%s - %s (%s memory)' %
            (index, get_device(index).name,
             sizeof_fmt(
                 get_nvml_info(index)['memory']['total']
                 if get_nvml_info(index) and 'memory' in get_nvml_info(index)
                 else get_device(index).totalGlobalMem)),
        ) for index in config_value('gpu_list').split(',') if index],
        default='next',
    )

    # Select N of several GPUs
    select_gpus = utils.forms.SelectMultipleField(
        'Select which GPU[s] you would like to use',
        choices=[(
            index,
            '#%s - %s (%s memory)' %
            (index, get_device(index).name,
             sizeof_fmt(
                 get_nvml_info(index)['memory']['total']
                 if get_nvml_info(index) and 'memory' in get_nvml_info(index)
                 else get_device(index).totalGlobalMem)),
        ) for index in config_value('gpu_list').split(',') if index],
        tooltip=
        "The job won't start until all of the chosen GPUs are available.")

    # XXX For testing
    # The Flask test framework can't handle SelectMultipleFields correctly
    select_gpus_list = wtforms.StringField(
        'Select which GPU[s] you would like to use (comma separated)')

    def validate_select_gpus(form, field):
        if form.select_gpus_list.data:
            field.data = form.select_gpus_list.data.split(',')

    # Use next available N GPUs
    select_gpu_count = wtforms.IntegerField(
        'Use this many GPUs (next available)',
        validators=[
            validators.NumberRange(min=1,
                                   max=len(
                                       config_value('gpu_list').split(',')))
        ],
        default=1,
    )

    def validate_select_gpu_count(form, field):
        if field.data is None:
            if form.select_gpus.data:
                # Make this field optional
                field.errors[:] = []
                raise validators.StopValidation()

    model_name = utils.forms.StringField(
        'Model Name',
        validators=[validators.DataRequired()],
        tooltip=
        "An identifier, later used to refer to this model in the Application.")

    # allows shuffling data during training (for frameworks that support this, as indicated by
    # their Framework.can_shuffle_data() method)
    shuffle = utils.forms.BooleanField(
        'Shuffle Train Data',
        default=True,
        tooltip='For every epoch, shuffle the data before training.')
示例#9
0
class ModelForm(Form):

    ### Methods

    def selection_exists_in_choices(form, field):
        found = False
        for choice in field.choices:
            if choice[0] == field.data:
                found = True
        if not found:
            raise validators.ValidationError(
                "Selected job doesn't exist. Maybe it was deleted by another user."
            )

    def validate_NetParameter(form, field):
        pb = caffe_pb2.NetParameter()
        try:
            text_format.Merge(field.data, pb)
        except text_format.ParseError as e:
            raise validators.ValidationError('Not a valid NetParameter: %s' %
                                             e)

    ### Fields

    # The options for this get set in the view (since they are dynamic)
    dataset = wtforms.SelectField('Select Dataset', choices=[])

    train_epochs = wtforms.IntegerField(
        'Training epochs',
        validators=[validators.NumberRange(min=1)],
        default=30,
    )

    snapshot_interval = wtforms.FloatField(
        'Snapshot interval (in epochs)',
        default=1,
        validators=[
            validators.NumberRange(min=0),
        ],
    )

    val_interval = wtforms.FloatField(
        'Validation interval (in epochs)',
        default=1,
        validators=[validators.NumberRange(min=0)],
    )

    random_seed = wtforms.IntegerField(
        'Random seed',
        validators=[
            validators.NumberRange(min=0),
            validators.Optional(),
        ],
    )

    batch_size = wtforms.IntegerField(
        'Batch size',
        validators=[
            validators.NumberRange(min=1),
            validators.Optional(),
        ],
    )

    ### Solver types

    solver_type = wtforms.SelectField(
        'Solver type',
        choices=[
            ('SGD', 'Stochastic gradient descent (SGD)'),
            ('ADAGRAD', 'Adaptive gradient (AdaGrad)'),
            ('NESTEROV', "Nesterov's accelerated gradient (NAG)"),
        ],
        default='SGD')

    ### Learning rate

    learning_rate = wtforms.FloatField('Base Learning Rate',
                                       default=0.01,
                                       validators=[
                                           validators.NumberRange(min=0),
                                       ])

    lr_policy = wtforms.SelectField('Policy',
                                    choices=[
                                        ('fixed', 'Fixed'),
                                        ('step', 'Step Down'),
                                        ('multistep',
                                         'Step Down (arbitrary steps)'),
                                        ('exp', 'Exponential Decay'),
                                        ('inv', 'Inverse Decay'),
                                        ('poly', 'Polynomial Decay'),
                                        ('sigmoid', 'Sigmoid Decay'),
                                    ],
                                    default='step')

    lr_step_size = wtforms.FloatField('Step Size', default=33)
    lr_step_gamma = wtforms.FloatField('Gamma', default=0.1)
    lr_multistep_values = wtforms.StringField('Step Values', default="50,85")

    def validate_lr_multistep_values(form, field):
        if form.lr_policy.data == 'multistep':
            for value in field.data.split(','):
                try:
                    float(value)
                except ValueError:
                    raise validators.ValidationError('invalid value')

    lr_multistep_gamma = wtforms.FloatField('Gamma', default=0.5)
    lr_exp_gamma = wtforms.FloatField('Gamma', default=0.95)
    lr_inv_gamma = wtforms.FloatField('Gamma', default=0.1)
    lr_inv_power = wtforms.FloatField('Power', default=0.5)
    lr_poly_power = wtforms.FloatField('Power', default=3)
    lr_sigmoid_step = wtforms.FloatField('Step', default=50)
    lr_sigmoid_gamma = wtforms.FloatField('Gamma', default=0.1)

    ### Network

    # Use a SelectField instead of a HiddenField so that the default value
    # is used when nothing is provided (through the REST API)
    method = wtforms.SelectField(
        u'Network type',
        choices=[
            ('standard', 'Standard network'),
            ('previous', 'Previous network'),
            ('custom', 'Custom network'),
        ],
        default='standard',
    )

    # The options for this get set in the view (since they are dependent on the data type)
    standard_networks = wtforms.RadioField(
        'Standard Networks',
        validators=[
            validate_required_iff(method='standard'),
        ],
    )

    previous_networks = wtforms.RadioField(
        'Previous Networks',
        choices=[],
        validators=[
            validate_required_iff(method='previous'),
            selection_exists_in_choices,
        ],
    )

    custom_network = wtforms.TextAreaField(
        'Custom Network',
        validators=[
            validate_required_iff(method='custom'),
            validate_NetParameter,
        ])

    custom_network_snapshot = wtforms.TextField('Pretrained model')

    def validate_custom_network_snapshot(form, field):
        if form.method.data == 'custom':
            snapshot = field.data.strip()
            if snapshot:
                if not os.path.exists(snapshot):
                    raise validators.ValidationError('File does not exist')

    # Select one of several GPUs
    select_gpu = wtforms.RadioField(
        'Select which GPU you would like to use',
        choices=[('next', 'Next available')] + [(
            index,
            '#%s - %s%s' % (
                index,
                get_device(index).name,
                ' (%s memory)' %
                sizeof_fmt(get_nvml_info(index)['memory']['total'])
                if get_nvml_info(index) and 'memory' in get_nvml_info(index)
                else '',
            ),
        ) for index in config_value('gpu_list').split(',') if index],
        default='next',
    )

    # Select N of several GPUs
    select_gpus = wtforms.SelectMultipleField(
        'Select which GPU[s] you would like to use',
        choices=[(
            index,
            '#%s - %s%s' % (
                index,
                get_device(index).name,
                ' (%s memory)' %
                sizeof_fmt(get_nvml_info(index)['memory']['total'])
                if get_nvml_info(index) and 'memory' in get_nvml_info(index)
                else '',
            ),
        ) for index in config_value('gpu_list').split(',') if index])

    # XXX For testing
    # The Flask test framework can't handle SelectMultipleFields correctly
    select_gpus_list = wtforms.StringField(
        'Select which GPU[s] you would like to use (comma separated)')

    def validate_select_gpus(form, field):
        if form.select_gpus_list.data:
            field.data = form.select_gpus_list.data.split(',')

    # Use next available N GPUs
    select_gpu_count = wtforms.IntegerField(
        'Use this many GPUs (next available)',
        validators=[
            validators.NumberRange(min=1,
                                   max=len(
                                       config_value('gpu_list').split(',')))
        ],
        default=1,
    )

    def validate_select_gpu_count(form, field):
        if field.data is None:
            if form.select_gpus.data:
                # Make this field optional
                field.errors[:] = []
                raise validators.StopValidation()

    model_name = wtforms.StringField('Model Name',
                                     validators=[validators.DataRequired()])
示例#10
0
 def disk_size_fmt(self):
     """
     return string representing job disk size
     """
     size = fs.get_tree_size(self._dir)
     return sizeof_fmt(size)
示例#11
0
    def process_output(self, line):
        from digits.webapp import socketio

        self.caffe_log.write('%s\n' % line)
        self.caffe_log.flush()

        # parse caffe header
        timestamp, level, message = self.preprocess_output_caffe(line)

        if not message:
            return True

        float_exp = '(NaN|[-+]?[0-9]*\.?[0-9]+(e[-+]?[0-9]+)?)'

        # snapshot saved
        if self.saving_snapshot:
            self.logger.info('Snapshot saved.')
            self.detect_snapshots()
            self.send_snapshot_update()
            self.saving_snapshot = False
            return True

        # loss updates
        match = re.match(r'Iteration (\d+), \w*loss\w* = %s' % float_exp, message)
        if match:
            i = int(match.group(1))
            l = match.group(2)
            assert l.lower() != 'nan', 'Network reported NaN for training loss. Try decreasing your learning rate.'
            l = float(l)
            self.train_loss_updates.append((self.iteration_to_epoch(i), l))
            self.logger.debug('Iteration %d/%d, loss=%s' % (i, self.solver.max_iter, l))
            self.send_iteration_update(i)
            self.send_data_update()
            return True

        # learning rate updates
        match = re.match(r'Iteration (\d+), lr = %s' % float_exp, message)
        if match:
            i = int(match.group(1))
            lr = match.group(2)
            if lr.lower() != 'nan':
                lr = float(lr)
                self.lr_updates.append((self.iteration_to_epoch(i), lr))
            self.send_iteration_update(i)
            return True

        # other iteration updates
        match = re.match(r'Iteration (\d+)', message)
        if match:
            i = int(match.group(1))
            self.send_iteration_update(i)
            return True

        # testing loss updates
        match = re.match(r'Test net output #\d+: \w*loss\w* = %s' % float_exp, message, flags=re.IGNORECASE)
        if match:
            l = match.group(1)
            if l.lower() != 'nan':
                l = float(l)
                self.val_loss_updates.append( (self.iteration_to_epoch(self.current_iteration), l) )
                self.send_data_update()
            return True

        # testing accuracy updates
        match = re.match(r'Test net output #(\d+): \w*acc\w* = %s' % float_exp, message, flags=re.IGNORECASE)
        if match:
            index = int(match.group(1))
            a = match.group(2)
            if a.lower() != 'nan':
                a = float(a) * 100
                self.logger.debug('Network accuracy #%d: %s' % (index, a))
                self.val_accuracy_updates.append( (self.iteration_to_epoch(self.current_iteration), a, index) )
                self.send_data_update(important=True)
            return True

        # snapshot starting
        match = re.match(r'Snapshotting to (.*)\s*$', message)
        if match:
            self.saving_snapshot = True
            return True

        # memory requirement
        match = re.match(r'Memory required for data:\s+(\d+)', message)
        if match:
            bytes_required = int(match.group(1))
            self.logger.debug('memory required: %s' % utils.sizeof_fmt(bytes_required))
            return True

        if level in ['error', 'critical']:
            self.logger.error('%s: %s' % (self.name(), message))
            self.exception = message
            return True

        return True
示例#12
0
    def process_output(self, line):
        from digits.webapp import socketio

        self.caffe_log.write('%s\n' % line)
        self.caffe_log.flush()

        # parse caffe header
        timestamp, level, message = self.preprocess_output_caffe(line)

        if not message:
            return True

        float_exp = '(NaN|[-+]?[0-9]*\.?[0-9]+(e[-+]?[0-9]+)?)'

        # snapshot saved
        if self.saving_snapshot:
            self.logger.info('Snapshot saved.')
            self.detect_snapshots()
            self.send_snapshot_update()
            self.saving_snapshot = False
            return True

        # loss updates
        match = re.match(r'Iteration (\d+), \w*loss\w* = %s' % float_exp, message)
        if match:
            i = int(match.group(1))
            l = match.group(2)
            assert l.lower() != 'nan', 'Network reported NaN for training loss. Try decreasing your learning rate.'
            l = float(l)
            self.train_loss_updates.append((self.iteration_to_epoch(i), l))
            self.logger.debug('Iteration %d/%d, loss=%s' % (i, self.solver.max_iter, l))
            self.send_iteration_update(i)
            self.send_data_update()
            return True

        # learning rate updates
        match = re.match(r'Iteration (\d+), lr = %s' % float_exp, message)
        if match:
            i = int(match.group(1))
            lr = match.group(2)
            if lr.lower() != 'nan':
                lr = float(lr)
                self.lr_updates.append((self.iteration_to_epoch(i), lr))
            self.send_iteration_update(i)
            return True

        # other iteration updates
        match = re.match(r'Iteration (\d+)', message)
        if match:
            i = int(match.group(1))
            self.send_iteration_update(i)
            return True

        # testing loss updates
        match = re.match(r'Test net output #\d+: \w*loss\w* = %s' % float_exp, message, flags=re.IGNORECASE)
        if match:
            l = match.group(1)
            if l.lower() != 'nan':
                l = float(l)
                self.val_loss_updates.append( (self.iteration_to_epoch(self.current_iteration), l) )
                self.send_data_update()
            return True

        # testing accuracy updates
        match = re.match(r'Test net output #(\d+): \w*acc\w* = %s' % float_exp, message, flags=re.IGNORECASE)
        if match:
            index = int(match.group(1))
            a = match.group(2)
            if a.lower() != 'nan':
                a = float(a) * 100
                self.logger.debug('Network accuracy #%d: %s' % (index, a))
                self.val_accuracy_updates.append( (self.iteration_to_epoch(self.current_iteration), a, index) )
                self.send_data_update(important=True)
            return True

        # snapshot starting
        match = re.match(r'Snapshotting to (.*)\s*$', message)
        if match:
            self.saving_snapshot = True
            return True

        # memory requirement
        match = re.match(r'Memory required for data:\s+(\d+)', message)
        if match:
            bytes_required = int(match.group(1))
            self.logger.debug('memory required: %s' % utils.sizeof_fmt(bytes_required))
            return True

        if level in ['error', 'critical']:
            self.logger.error('%s: %s' % (self.name(), message))
            self.exception = message
            return True

        return True