def load_model_from_run(backbone_name,
                        load_model_from,
                        load_weights_from=None,
                        skip_mismatch=True):
    b = backbone(backbone_name)
    json_path = get_json_filename(load_model_from)

    if not os.path.exists(json_path):
        h5_path = get_weights_filename(load_model_from)
        if not os.path.exists(h5_path):
            raise ValueError("run with name %s doesn't exist" %
                             load_model_from)

        model = load_model(h5_path,
                           custom_objects=b.custom_objects,
                           compile=False)
    else:
        with open(json_path, 'r') as json_file:
            json_string = json_file.read()
        model = model_from_json(json_string, custom_objects=b.custom_objects)

        if load_weights_from:
            h5_path = get_weights_filename(load_weights_from)
            if os.path.exists(h5_path):
                model.load_weights(h5_path,
                                   by_name=True,
                                   skip_mismatch=skip_mismatch)
    return model
def save_model_to_run(model, run_name):
    json_path = get_json_filename(run_name)
    h5_path = get_weights_filename(run_name)

    with open(json_path, 'w') as json_file:
        json_file.write(model.to_json())

    model.save_weights(h5_path)
示例#3
0
def load_model_weights_from(model, weights, skip_mismatch):
    if weights is None:
        return

    if os.path.exists(weights):
        model.load_weights(weights, by_name=True, skip_mismatch=skip_mismatch)
        return

    weights_path = get_weights_filename(weights)
    if os.path.exists(weights_path):
        model.load_weights(weights_path, by_name=True, skip_mismatch=skip_mismatch)
        return

    raise ValueError('Unknown weights to load from!', weights, weights_path)
示例#4
0
def config_seg_callbacks(run_name=None):
    callbacks = [
        ValidationPrediction(show_confusion_matrix=False),
        ReduceLROnPlateau(monitor='val_loss',
                          factor=0.5,
                          patience=2,
                          verbose=1,
                          mode='auto',
                          min_lr=1e-7),
    ]
    if run_name:
        callbacks.extend([
            ModelCheckpoint(get_weights_filename(run_name),
                            monitor='val_loss',
                            save_best_only=True,
                            save_weights_only=True,
                            verbose=True),
            CSVLogger(filename=get_csv_filename(run_name))
        ])
    return callbacks