예제 #1
0
def main(args):
    # process input file
    input_file = util.ensure_local_file(args['train_file'])
    user_map, item_map, tr_sparse, test_sparse = model.create_test_and_train_sets(
        args, input_file, args['data_type'])

    # train model
    output_row, output_col = model.train_model(args, tr_sparse)

    # save trained model to job directory
    if args['data_type'] == 'user_ratings':
        model.save_model_json(args, user_map, item_map, output_row, output_col)
        user_items_w = model.get_user_items_w(input_file)
        model.save_user_items_w(args, user_items_w)
    else:
        model.save_model(args, user_map, item_map, output_row, output_col)

    # log results
    train_rmse = wals.get_rmse(output_row, output_col, tr_sparse)
    test_rmse = wals.get_rmse(output_row, output_col, test_sparse)

    if args['hypertune']:
        # write test_rmse metric for hyperparam tuning
        util.write_hptuning_metric(args, test_rmse)

    tf.logging.info('train RMSE = %.2f' % train_rmse)
    tf.logging.info('test RMSE = %.2f' % test_rmse)
예제 #2
0
    def __iter__(self):
        local_data = ensure_local_file(self._url, self.path)

        json_data = None
        with tarfile.open(local_data) as tar:
            for info in tar:

                if fnmatch.fnmatch(info.name, '*.json'):
                    flo = tar.extractfile(member=info)
                    json_data = json.load(flo)

        with tarfile.open(local_data) as tar:
            for info in tar:
                if fnmatch.fnmatch(info.name, '*.json'):
                    continue
                if not info.isfile():
                    continue
                path_segments = os.path.split(info.name)
                _id = os.path.splitext(path_segments[1])[0]
                wav_flo = tar.extractfile(member=info)
                url = \
                    'https://magenta.tensorflow.org/datasets/nsynth/{_id}' \
                        .format(**locals())
                pdl = PreDownload(wav_flo.read(), url)
                yield AudioMetaData(
                    uri=pdl,
                    web_url='https://magenta.tensorflow.org/datasets/nsynth',
                    **json_data[_id])
예제 #3
0
def main(args):

    tf.logging.set_verbosity(tf.logging.INFO)

    # input files
    input_file = util.ensure_local_file(args.train_file)
    user_map, item_map, tr_sparse, test_sparse = model.create_test_and_train_sets(
        input_file)

    # train model
    output_row, output_col = model.train_model(args, tr_sparse)

    # save trained model to job directory
    model.save_model(args, user_map, item_map, output_row, output_col)

    # log results
    test_rmse = wals.get_rmse(output_row, output_col, test_sparse)
    util.write_hptuning_metric(args, test_rmse)
def main(args):
  # process input file
  input_file = util.ensure_local_file(args['train_files'][0])
  user_map, item_map, tr_sparse, test_sparse = model.create_test_and_train_sets(
      args, input_file, args['data_type'])

  # train model
  output_row, output_col = model.train_model(args, tr_sparse)

  # save trained model to job directory
  model.save_model(args, user_map, item_map, output_row, output_col)

  # log results
  train_rmse = wals.get_rmse(output_row, output_col, tr_sparse)
  test_rmse = wals.get_rmse(output_row, output_col, test_sparse)

  if args['hypertune']:
    # write test_rmse metric for hyperparam tuning
    util.write_hptuning_metric(args, test_rmse)

  tf.logging.info('train RMSE = %.2f' % train_rmse)
  tf.logging.info('test RMSE = %.2f' % test_rmse)
예제 #5
0
    def __iter__(self):
        local_metadata = ensure_local_file(self._metadata, self.path)

        metadata = dict()
        with open(local_metadata, 'rb') as f:
            reader = csv.DictReader(f)
            for row in reader:
                metadata[row['id']] = row

        train_audio_path = os.path.join(self.path, 'train_data')

        for filename in os.listdir(train_audio_path):
            full_path = os.path.join(train_audio_path, filename)
            _id, ext = os.path.splitext(filename)
            url = \
                'https://homes.cs.washington.edu/~thickstn/media/{_id}'\
                    .format(**locals())
            meta = metadata[_id]
            samples = AudioSamples.from_file(full_path)
            uri = PreDownload(samples.encode().read(), url)
            yield AudioMetaData(uri=uri,
                                samplerate=int(self._samplerate),
                                **meta)