コード例 #1
0
ファイル: predict.py プロジェクト: delldu/ImageNima
def rate(img_path):
    """
    Returns: Scores, mean, std
    """
    model_setenv()
    device = model_device()
    model = get_model()
    model_name = 'models/ImageNima.pth'
    model_load(model, model_name)
    model = model.to(device)
    model.eval()

    image_filenames = sorted(glob.glob(img_path))

    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])
    weighted_votes = torch.arange(10, dtype=torch.float) + 1
    weighted_votes = weighted_votes.to(device)

    for index, filename in enumerate(image_filenames):
        img = Image.open(filename).convert('RGB')
        img = transform(img).to(device)

        with torch.no_grad():
            scores = model(img.view(1, 3, 224, 224))
            mean = torch.matmul(scores, weighted_votes)
            std = torch.sqrt((scores * torch.pow(
                (weighted_votes - mean.view(-1, 1)), 2)).sum(dim=1))

        print("{:.4f} {:.4f}--- {}".format(mean.item(), std.item(), filename))
コード例 #2
0
def train():
    """
    basic predict function for the API

    the 'mode' give you the ability to toggle between a test version and a production verion of training
    """

    regressor = request.args.get('regressor')
    if regressor is None:
        print(
            "WARNING API (train): received request, but no regressor specified, assuming 'randomforest'"
        )
        regressor = "randomforest"

    print("... training model")
    data_dir = os.path.join(THIS_DIR, "cs-train")
    try:
        model_train(data_dir, test=True, regressor=regressor)
        print("... training complete")
        # reload models and data after re-train
        print("... reloading models in cache")
        global_data, global_models = model_load(training=False)
        return (jsonify(True))
    except Exception as e:
        print("ERROR API (train): model_train returned: {}".format(str(e)))
        return jsonify([]), 400
コード例 #3
0
def main():

    model_train(test=False)

    model = model_load()

    print("model training complete.")
コード例 #4
0
def predict():
    if not request.json:
        print("ERROR: API (predict): did not receive request data")
        return jsonify([])

    if 'country' not in request.json:
        print(
            "ERROR API (predict): received request, but no 'country' found within"
        )
        return jsonify([])
    if 'year' not in request.json:
        print(
            "ERROR API (predict): received request, but no 'year' found within"
        )
        return jsonify([])
    if 'month' not in request.json:
        print(
            "ERROR API (predict): received request, but no 'month' found within"
        )
        return jsonify([])
    if 'day' not in request.json:
        print(
            "ERROR API (predict): received request, but no 'day' found within")
        return jsonify([])

    test = False
    if 'mode' in request.json and request.json['mode'] == 'test':
        test = True

    ## extract the query
    country = request.json['country']
    year = request.json['year']
    month = request.json['month']
    day = request.json['day']

    ## load model
    (all_data, all_models) = model_load()

    if not all_models:
        print("ERROR: models are not available")
        return jsonify([])

    _result = model_predict(country,
                            year,
                            month,
                            day,
                            all_models,
                            all_data,
                            test=test)
    result = {}

    ## convert numpy objects to ensure they are serializable
    for key, item in _result.items():
        if isinstance(item, np.ndarray):
            result[key] = item.tolist()
        else:
            result[key] = item

    return (jsonify(result))
コード例 #5
0
def predict():
    """
    basic predict function for the API
    """

    ## input checking
    if not request.json:
        print("ERROR: API (predict): did not receive request data")
        return jsonify([])

    if 'country' not in request.json:
        print(
            "ERROR API (predict): received request, but no 'country' found within"
        )
        return jsonify([])

    if 'date' not in request.json:
        print(
            "WARNING API (predict): received request, but no 'date' was found within"
        )

    ## set the test flag
    test = False
    if 'mode' in request.json and request.json['mode'] == 'test':
        test = True

    ## extract the country
    country_input = request.json['country']
    date_input = request.json['date']

    ## load model
    model = model_load(test=test)

    if not model:
        print("ERROR: model is not available")
        return jsonify([])

    _result = model_predict(date=date_input,
                            country=country_input,
                            df=None,
                            model=model,
                            test=test)
    result = {}

    ## convert numpy objects to ensure they are serializable
    for key, item in _result.items():
        if isinstance(item, np.ndarray):
            result[key] = item.tolist()
        else:
            result[key] = item

    indexes = pd.Series([])

    for i in range(len(_result["predicted"].index)):
        indexes[i] = str(_result["predicted"].index[i])

    _result["predicted"].index = indexes
    _result["predicted"] = _result["predicted"].to_json()
    return (jsonify(_result))
コード例 #6
0
def main():

    # train the model
    model_train(test=False)

    # load the model
    model = model_load()
    print("model training complete.")
コード例 #7
0
def predict():
    """
    basic predict function for the API
    """

    ## input checking
    if not request.json:
        print("ERROR: API (predict): did not receive request data")
        return jsonify([])

    if 'country' not in request.json:
        print(
            "ERROR API (predict): received request, but no 'country' found within"
        )
        return jsonify([])

    #if 'type' not in request.json:
    #    print("WARNING API (predict): received request, but no 'type' was found assuming 'numpy'")
    #    query_type = 'numpy'

    ## set the test flag
    test = False
    if 'mode' in request.json and request.json['mode'] == 'test':
        test = True

    ## extract the query
    country = request.json['country']
    year = request.json['year']
    month = request.json['month']
    day = request.json['day']

    #if request.json['type'] == 'dict':
    #    pass
    #else:
    #    print("ERROR API (predict): only dict data types have been implemented")
    #    return jsonify([])

    ## load model
    data_dir = os.path.join("data", "cs-train")
    all_data, all_models = model_load(country, data_dir=data_dir)
    model = all_models[country]

    if not model:
        print("ERROR: model is not available")
        return jsonify([])

    _result = model_predict(country, year, month, day, test=test)
    result = {}

    ## convert numpy objects to ensure they are serializable
    for key, item in _result.items():
        if isinstance(item, np.ndarray):
            result[key] = item.tolist()
        else:
            result[key] = item

    return (jsonify(result))
コード例 #8
0
def main():
    
    ## train the model
    model_train()

    ## load the model
    model = model_load()
    
    print("model training complete.")
def main(data_dir):

    ## train the model
    model_train(data_dir, test=False)

    ## load the model
    model = model_load()

    print("model training complete.")
コード例 #10
0
def main():
    # train the model
    data_dir = os.path.join("..", "cs-train")
    model_train(data_dir)

    # load the model
    model = model_load()

    print("model training complete.")
コード例 #11
0
def predict():
    """
    basic predict function for the API
    """

    ## input checking
    if not request.json:
        print("ERROR: API (predict): did not receive request data")
        return jsonify([])

    if 'query' not in request.json:
        print(
            "ERROR API (predict): received request, but no 'query' found within"
        )
        return jsonify([])

    if 'type' not in request.json:
        print(
            "WARNING API (predict): received request, but no 'type' was found assuming 'numpy'"
        )
        query_type = 'numpy'

    ## set the test flag
    test = False
    if 'mode' in request.json and request.json['mode'] == 'test':
        test = True

    ## extract the query

    query = request.json['query']
    print(query)

    if request.json['type'] == 'dict':
        pass
    else:
        print(
            "ERROR API (predict): only dict data types have been implemented")
        return jsonify([])

    ## load model
    model = model_load()

    if not model:
        print("ERROR: model is not available")
        return jsonify([])

    _result = model_predict(query, model, test=test)
    result = {}

    ## convert numpy objects to ensure they are serializable
    for key, item in _result.items():
        if isinstance(item, np.ndarray):
            result[key] = item.tolist()
        else:
            result[key] = item

    return (jsonify(result))
コード例 #12
0
    def test_02_load(self):
        """
        test the train functionality
        """

        ## load the model
        all_data, all_models = model_load()

        self.assertTrue(all_data)
        self.assertTrue(all_models)
コード例 #13
0
def main():

    ## train the model
    data_dir = os.path.join(".", "data", "cs-train")
    model_train(data_dir, test=False)

    ## load the model
    all_data, all_models = model_load()

    print("model training complete.")
コード例 #14
0
def predict():
    """
    basic predict function for the API
    """

    ## input checking
    if not request.json:
        print("ERROR: API (predict): did not receive request data")
        return jsonify([])

    if 'query' not in request.json:
        print(
            "ERROR API (predict): received request, but no 'query' found within"
        )
        return jsonify([])

    if 'type' not in request.json:
        print(
            "WARNING API (predict): received request, but no 'type' was found assuming 'numpy'"
        )
        query_type = 'numpy'

    query = request.json['query']

    idx = query['idx']
    query = np.array(query['data'])

    ## load model
    model = model_load()

    if not model:
        print("ERROR: model is not navailable")
        return jsonify([])

    _result = {}

    if idx == 1:
        _result[idx] = (model_predict(query, model))

    else:
        for i, j in list(zip(query, idx)):
            _result[j] = (model_predict(i, model))

    result = {}

    result['y_pred'] = {}

    ## convert numpy objects so ensure they are serializable
    for key, item in _result.items():
        if isinstance(item, np.ndarray):
            result['y_pred'][key] = item.tolist()
        else:
            result['y_pred'][key] = item

    return (jsonify(result))
コード例 #15
0
def main():

    data_dir = os.path.join(DATA_DIR, "cs-train")

    ## train the model
    model_train(data_dir, test=False)

    ## load the model
    model = model_load()

    print("model training complete.")
コード例 #16
0
def main():

    ## train the model
    model_train(data_dir=DATA_DIR, prefix='sl', test=False)

    ## load the model
    all_data, all_models = model_load(country='all',
                                      prefix='sl',
                                      data_dir=DATA_DIR,
                                      training=False)
    print("... models loaded: ", ",".join(all_models.keys()))

    print("model training complete.")
コード例 #17
0
def main():

    ## train the model
    print("TRAINING MODELS")
    data_dir = os.path.join(".", "data", "cs-train")
    model_train(data_dir, test=True)

    ## load the model
    print("LOADING MODELS")
    all_data, all_models = model_load()
    print("... models loaded: ", ",".join(all_models.keys()))

    print("model training complete.")
コード例 #18
0
ファイル: ModelTests.py プロジェクト: abhivp/ai_wf_capstone
    def test_02_load(self):
        """
        test the train functionality
        """

        # Load the model
        model_data, models = model_load(country='united_kingdom',
                                        prefix='test',
                                        data_dir=data_dir,
                                        training=False)
        model = list(models.values())[0]
        self.assertTrue('predict' in dir(model))
        self.assertTrue('fit' in dir(model))
コード例 #19
0
def predict():
    """
    basic predict function for the API
    """

    ## input checking
    if not request.json:
        print("ERROR: API (predict): did not receive request data")
        return jsonify([])

    print(request.json)
    if 'country' not in request.json:
        print(
            "ERROR API (predict): received request, but no 'country' found within"
        )
        return jsonify(False)

    ## set the test flag
    test = False
    if 'mode' in request.json and request.json['mode'] == 'test':
        test = True

    ## extract the query parameters
    country = request.json['country']
    year = request.json['year']
    month = request.json['month']
    day = request.json['day']

    print(country, year, month, day)

    ## load model
    data_dir = os.path.join("data", "cs-train")
    all_data, all_models = model_load(data_dir=data_dir)
    model = all_models[country]

    if not model:
        print("ERROR: model is not available")
        return jsonify([])

    ## predict
    _result = model_predict(country, year, month, day, test=test)
    result = {}

    ## convert numpy objects to ensure they are serializable
    for key, item in _result.items():
        if isinstance(item, np.ndarray):
            result[key] = item.tolist()
        else:
            result[key] = item

    return (jsonify(result))
コード例 #20
0
def predict():
    """
    basic predict function for the API
    """
    print(f"Request: {request.json}")
    ## input checking
    if not request.json:
        print("ERROR: API (predict): did not receive request data")
        return jsonify([])

    if 'query' not in request.json:
        print(
            "ERROR API (predict): received request, but no 'query' found within"
        )
        return jsonify([])

    # if 'type' not in request.json:
    #     print("WARNING API (predict): received request, but no 'type' was found assuming 'numpy'")
    #     query_type = 'numpy'

    ## set the test flag
    test = False
    if 'mode' in request.json and request.json['mode'] == 'test':
        test = True

    ## extract the query
    query = request.json['query']

    ## load model
    data_dir = os.path.join(".", "data", "cs-production")
    all_data, all_models = model_load(data_dir=data_dir, training=False)

    if not all_models:
        print("ERROR: model is not available")
        return jsonify([])

    _result = model_predict(**query, test=test)
    result = {}

    ## convert numpy objects to ensure they are serializable
    for key, item in _result.items():
        if isinstance(item, np.ndarray):
            result[key] = item.tolist()
        else:
            result[key] = item
    print(f"Result: {result}")
    print(f"JSON: {jsonify(result)}")
    return (jsonify(result))
コード例 #21
0
def predict():
    """
    basic predict function for the API
    """

    print(request.json)

    ## input checking
    if not request.json:
        print("ERROR: API (predict): did not receive request data")
        return jsonify([])

    if 'query' not in request.json:
        print("ERROR API (predict): received request, but no 'query' found within")
        return jsonify([])

    if 'type' not in request.json:
        print("WARNING API (predict): received request, but no 'type' was found assuming 'numpy'")
        query_type = 'numpy'

    query = request.json['query']
        
    if request.json['type'] == 'numpy':
        query = np.array(query)
    else:
        print("ERROR API (predict): only numpy data types have been implemented")
        return jsonify([])
        
    ## load model
    model = model_load()
    
    if not model:
        print("ERROR: model is not available")
        return jsonify([])
    
    _result = model_predict(query,model)
    result = {}

    ## convert numpy objects so ensure they are serializable
    for key,item in _result.items():
        if isinstance(item,np.ndarray):
            result[key] = item.tolist()
        else:
            result[key] = item

    return(jsonify(result))
コード例 #22
0
def setting(model_config, checkpoint_path):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    _, test_dataloader, _ = data_loader.load_path_loss_with_detail_dataset(
        input_dir=model_config['input_dir'],
        model_type=model_config['model'],
        num_workers=model_config['num_workers'],
        batch_size=model_config['batch_size'],
        shuffle=model_config['shuffle'],
        input_size=model_config['input_size']
        if model_config['model'] == 'DNN' else model_config['sequence_length'])
    nn_model = model.model_load(model_config)

    checkpoint = torch.load(checkpoint_path)
    nn_model.load_state_dict(checkpoint['model_state_dict'])
    test(model_config=model_config,
         nn_model=nn_model,
         dataloader=test_dataloader,
         device=device)
コード例 #23
0
ファイル: app.py プロジェクト: richardcure/capstone
def predict():
    """
    basic predict function for the API
    """

    ## input checking
    if not request.json:
        print("ERROR: API (predict): did not receive request data")
        return jsonify([])
    country = request.json['country']
    year = request.json['year']
    month = request.json['month']
    day = request.json['day']
    model = model_load(country)
    if not model:
        print("ERROR: model is not available")
        return jsonify([])

    print("... predicting")
    result = model_predict(country, year, month, day, model)
    print("... prediction complete")
    return (jsonify(result))
コード例 #24
0
def main():
    print("LOADING MODELS")
    production_data_dir = os.path.join("data", "cs-production")
    all_data, all_models = model_load(data_dir=production_data_dir)

    print("... models loaded: ", ",".join(all_models.keys()))

    count = 0

    for country in all_data.keys():

        if all_data[country]['X'].shape[0] > 0:

            for date in all_data[country]['dates']:

                dt = datetime.strptime(date, '%Y-%m-%d')

                query = {
                    'country': country,
                    'year': str(dt.year),
                    'month': str(dt.month),
                    'day': str(dt.day)
                }

                # input checks
                if country not in all_models.keys():
                    result = "ERROR (model_predict) - model for country '{}' could not be found".format(
                        country)
                else:
                    result = model_predict(query,
                                           data=all_data[country],
                                           model=all_models[country],
                                           test=True)

                count += 1

                print('result[', count, ']: ', result)

    print("model test predict complete.")
コード例 #25
0
def predict():
    if not request.json:
        print("No request data")
        return jsonify([])
    if 'country' not in request.json:
        print("Please provide the country name")
        return jsonify([])
    if 'day' not in request.json:
        print("Please provide the day")
        return jsonify([])
    if 'month' not in request.json:
        print("Please provide the month")
        return jsonify([])
    if 'year' not in request.json:
        print("Please provide the year")
        return jsonify([])
    test = False
    if 'mode' in request.json and request.json['mode'] == 'test':
        test = True
    country = request.json['country']
    day = request.json['day']
    month = request.json['month']
    year = request.json['year']
    data_dir = os.path.join(".", "data", "cs-train")
    all_data, all_models = model_load(data_dir=data_dir)
    model = all_models[country]
    if not model:
        print("Mo models avaliable")
        return jsonify([])
    _result = model_predict(country, year, month, day, test=test)
    result = {}
    for key, item in _result.items():
        if isinstance(item, np.ndarray):
            result[key] = item.tolist()
        else:
            result[key] = item
    return (jsonify(result))
コード例 #26
0
    parser.add_argument('--checkpoint',
                        type=str,
                        default="models/VideoZoom.pth",
                        help="checkpint file")
    parser.add_argument('--input',
                        type=str,
                        default="dataset/predict/input",
                        help="input folder")
    parser.add_argument('--output',
                        type=str,
                        default="dataset/predict/output",
                        help="output folder")
    args = parser.parse_args()

    model = get_model()
    model_load(model, args.checkpoint)
    device = model_device()
    model.to(device)
    model.eval()

    enable_amp(model)

    totensor = transforms.ToTensor()
    toimage = transforms.ToPILImage()

    video = Video()
    video.reset(args.input)
    progress_bar = tqdm(total=len(video))

    h = video.height
    w = video.width
コード例 #27
0
def model_load(downloaded_path):
    # Load "vggnet_model_acc_0.69.hdf5"
    return model.model_load(downloaded_path)
コード例 #28
0
def test_and_train(model_name='speech2speech', retrain=True):
    """ 
    Test and/or train on given dataset 

    @param model_name name of model to save.
    @param retrain True if retrain, False if load from pretrained model
  """

    (X_train, y_train), (X_val, y_val), (X_test, y_test) = load_dataset(
        'raw',
        nfft=NFFT,
        hop_len=HOP_LENGTH,
        fs=FS,
        stacked_frames=STACKED_FRAMES,
        chunk=CHUNK)
    model = None

    X_train_norm = normalize_sample(X_train)
    X_val_norm = normalize_sample(X_val)
    X_test_norm = normalize_sample(X_test)

    y_train_norm = normalize_sample(y_train)
    y_val_norm = normalize_sample(y_val)
    y_test_norm = normalize_sample(y_test)

    print("X shape:", X_train_norm.shape, "y shape:", y_train_norm.shape)

    # Xtn_strided = stride_over(X_train_norm)
    # Xvn_strided = stride_over(X_val_norm)
    # Xten_strided = stride_over(X_test_norm)

    # Xtn_reshape = Xtn_strided
    # Xvn_reshape = Xvn_strided
    # Xten_reshape = Xten_strided

    # ytn_reshape = y_train_norm.reshape(-1, NFFT//2 + 1, 1, 1)
    # yvn_reshape = y_val_norm.reshape(-1, NFFT//2 + 1, 1, 1)
    # yten_reshape = y_test_norm.reshape(-1, NFFT//2 + 1, 1, 1)

    # train_dataset = tf.data.Dataset.from_tensor_slices((Xtn_reshape,
    #                                                     ytn_reshape)).batch(X_train_norm.shape[1]).shuffle(X_train.shape[0]).repeat()
    # val_dataset = tf.data.Dataset.from_tensor_slices((Xvn_reshape, yvn_reshape)).batch(X_val_norm.shape[1]).repeat(1)

    # train_dataset = tf.data.Dataset.from_tensor_slices((X_train_norm, y_train_norm)).batch(BATCH_SIZE).shuffle(BATCH_SIZE).repeat()
    # val_dataset = tf.data.Dataset.from_tensor_slices((X_val_norm, y_val_norm)).batch(BATCH_SIZE).repeat(1)

    # print(list(train_dataset.as_numpy_iterator())[0])

    # Scale the sample X and get the scaler
    # scaler = scale_sample(X)

    # Check if model already exists and retrain is not being called again
    if (os.path.isfile(os.path.join(MODEL_DIR, model_name, 'model.json'))
            and not retrain):
        model = model_load(model_name)
        # Compile the model
        model.compile(loss=LOSS, optimizer=OPTIMIZER, metrics=METRICS)
    else:
        if not os.path.isdir(os.path.join(MODEL_DIR, model_name)):
            create_model_directory(model_name)

        baseline_val_loss = None

        model = None

        # model = gen_model(tuple(Xtn_reshape.shape[1:]))
        model = gen_model(tuple(X_train_norm.shape[1:]))
        print('Created Model...')

        model.compile(loss=LOSS, optimizer=OPTIMIZER, metrics=METRICS)
        print('Metrics for Model...')

        # print(list(train_dataset.as_numpy_iterator())[0])

        tf.keras.utils.plot_model(model,
                                  show_shapes=True,
                                  dpi=96,
                                  to_file=os.path.join(MODEL_DIR, model_name,
                                                       'model.png'))
        print(model.metrics_names)

        early_stopping_callback = tf.keras.callbacks.EarlyStopping(
            monitor='val_loss', patience=5, restore_best_weights=True)

        if (os.path.isfile(path(MODEL_DIR, model_name))):
            model.load_weights(path(MODEL_DIR, model_name))
            baseline_val_loss = model.evaluate(X_val_norm, y_val_norm)[0]
            print(baseline_val_loss)
            early_stopping_callback = tf.keras.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=5,
                restore_best_weights=True,
                baseline=baseline_val_loss)

        log_dir = os.path.join(
            LOGS_DIR, 'files',
            datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=log_dir, update_freq='batch')

        model_checkpoint_callback = ModelCheckpoint(monitor='val_loss',
                                                    filepath=os.path.join(
                                                        MODEL_DIR, model_name,
                                                        'model.h5'),
                                                    save_best_only=True,
                                                    save_weights_only=False,
                                                    mode='min')

        # fit the keras model on the dataset
        cbs = [early_stopping_callback,
               model_checkpoint_callback]  # tensorboard_callback

        model.fit(X_train_norm,
                  y_train_norm,
                  epochs=EPOCHS,
                  validation_data=(X_val_norm, y_val_norm),
                  verbose=1,
                  callbacks=cbs,
                  batch_size=BATCH_SIZE)
        print('Model Fit...')

        model = save_model(model, model_name)

    # model = model_load(model_name)

    [loss, mse, accuracy,
     rmse] = model.evaluate(X_test_norm, y_test_norm,
                            verbose=0)  # _, mse, accuracy =
    print(
        'Testing accuracy: {}, Testing MSE: {}, Testing Loss: {}, Testing RMSE: {}'
        .format(accuracy * 100, mse, loss, rmse))

    # # Randomly pick 1 test
    idx = 32
    print(idx)
    X = X_test_norm[idx]
    y = y_test_norm[idx]
    # y = y_test_norm[idx].reshape(-1, NFFT//2 + 1)
    # min_y, max_y = np.min(y_test_norm[idx]), np.max(y_test_norm[idx])
    # min_x, max_x = np.min(y_test_norm[idx]), np.max(y_test_norm[idx])
    # print("MinY: {}\tMaxY{}".format(min_y, max_y))
    # print("MinX: {}\tMaxX{}".format(min_x, max_x))

    X = np.expand_dims(X, axis=0)
    # X = stride_over(X)

    # mean = np.mean(X)
    # std = np.std(X)
    # X = (X-mean) / std

    print(X.shape)

    # y_pred = model.predict(X)
    y_pred = np.squeeze(model.predict(X), axis=0)
    # y_pred = y_pred.reshape(-1, NFFT//2 + 1)

    print(y.shape)
    print(y_pred.shape)

    y = y.T
    y_pred = y_pred.T
    X_test_norm = X_test_norm[idx].T

    # GriffinLim Vocoder
    output_sound = librosa.core.griffinlim(y_pred)
    input_sound = librosa.core.griffinlim(X_test_norm)
    target_sound = librosa.core.griffinlim(y)

    # Play and plot all
    play_sound(input_sound, output_sound, target_sound, FS)

    if not os.path.isdir(os.path.join(MODEL_DIR, model_name, 'audio_output')):
        create_model_directory(os.path.join(model_name, 'audio_output'))

    librosa.output.write_wav(path(MODEL_DIR, model_name, 'audio_output',
                                  'input.wav'),
                             input_sound,
                             sr=FS,
                             norm=True)
    librosa.output.write_wav(path(MODEL_DIR, model_name, 'audio_output',
                                  'target.wav'),
                             target_sound,
                             sr=FS,
                             norm=True)
    librosa.output.write_wav(path(MODEL_DIR, model_name, 'audio_output',
                                  'predicted.wav'),
                             output_sound,
                             sr=FS,
                             norm=True)

    return
コード例 #29
0
def startup():
    global global_data, global_models
    print(".. loading models")
    global_data, global_models = model_load(training=False)
    print(".. all models loaded")
コード例 #30
0
ファイル: train.py プロジェクト: delldu/ImageColor
                        default="output/ImageColor_D.pth", help="checkpoint D file")
    parser.add_argument('--bs', type=int, default=16, help="batch size")
    parser.add_argument('--lr', type=float, default=1e-4, help="learning rate")
    parser.add_argument('--epochs', type=int, default=1000)
    args = parser.parse_args()

    # Create directory to store weights
    if not os.path.exists(args.outputdir):
        os.makedirs(args.outputdir)

    # get model
    model = get_model(trainning=True)
    device = model_device()
    model.set_optimizer(args.lr)

    model_load(model.net_G, args.checkpoint_g)
    if model.use_D:
        model_load(model.net_D, args.checkpoint_d)

    model.to(device)

    lr_scheduler_G = optim.lr_scheduler.StepLR(
        model.optimizer_G, step_size=100, gamma=0.1)
    if model.use_D:
        lr_scheduler_D = optim.lr_scheduler.StepLR(
            model.optimizer_D, step_size=100, gamma=0.1)

    # get data loader
    train_dl, valid_dl = get_data(trainning=True, bs=args.bs)

    for epoch in range(args.epochs):