示例#1
0
def convert_dataset(dataset_name,
                    new_dataset_name=None,
                    raw_data_name=None,
                    thresh=None,
                    model_name=None,
                    base_path=utils.BASE_PATH):
    """
    datasets assumed to be stored in {base_path}/datasets/{dataset_name}.
    Output will be in {base_path}/datasets/{new_dataset_name}.

    If raw_data_name is points to a directory, then use the sm files for
    bpm estimation.

    If model is provided, then steps will be based off of predictions, and
    threshold will be estimated.
    """

    if new_dataset_name is None:
        new_dataset_name = f"{dataset_name}_gen_{__version__}"
    print(f"New dataset name: {new_dataset_name}")

    new_ds_path = f"{base_path}/datasets/{new_dataset_name}"
    if not os.path.isdir(new_ds_path):
        os.mkdir(new_ds_path)

    smds = SMDUtils.get_dataset_from_file(dataset_name,
                                          'placement',
                                          chunk_size=-1,
                                          concat=False)

    # Compute threshold if model is provided.
    if model_name is not None:
        model = StepPlacement.RegularizedRecurrentStepPlacementModel()
        model.load_state_dict(torch.load(model_name))
        model.cuda()

        thresh_ds = datautils.ConcatDataset(smds[:10])
        thresh = compute_thresh(model, thresh_ds)

    # Get BPM estimations.
    for smd in smds:
        print(smd.song_name)
        n_diffs = len(smd)

        if raw_data_name is None:
            # Compute BPM.
            pos_labels = []
            for d in smd:
                pos_labels.append(d['step_pos_labels'])

            pos_labels = np.concatenate(pos_labels)

            # For training, use ground truth step positions.
            bpm = bpm_estimator.est_bpm(pos_labels)
        else:
            sm = SMData.SMFile(smd.song_name, raw_data_name, base_path)
            try:
                bpm = bpm_estimator.true_bpm(sm)
            except ValueError as e:
                print(e)
                print(f"Skipping song {smd.song_name}")
                continue

        bps = 60 / bpm  # Seconds per beat

        frame_idxs = None
        if model_name is not None:
            predict_loader = datautils.DataLoader(smd)
            outputs_list, labels_list = model.predict(predict_loader,
                                                      return_list=True)
            outputs_list = list(map(lambda l: l[0, :, 0], outputs_list))
            labels_list = list(map(lambda l: l[0, :], labels_list))

            frame_idxs = list(
                map(lambda outputs: np.where(outputs > thresh)[0],
                    outputs_list))

        diff_order, diff_features = get_generation_features(
            smd, bpm, frame_idxs)

        song_path = f'{new_ds_path}/{smd.song_name}'
        fname = f'{song_path}/{smd.song_name}.h5'

        if not os.path.isdir(song_path):
            os.mkdir(song_path)

        with h5py.File(fname, 'w') as hf:
            hf.attrs['song_name'] = smd.song_name
            hf.attrs['diff_names'] = np.array(diff_order).astype('S9')

            for diff in diff_order:
                diff_group = hf.create_group(diff)
                diff_data = diff_features[diff]

                for key in diff_data.keys():
                    diff_group.create_dataset(key, data=diff_data[key])

    return new_dataset_name
示例#2
0
from deepSM import utils

parser = argparse.ArgumentParser()
parser.add_argument('placement_model', type=str)
parser.add_argument('dataset_name', type=str)
parser.add_argument('--n_batches', type=int, default=2000)
parser.add_argument('--chunk_size', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=128)

args = parser.parse_args()

print("Testing model", args.placement_model)
print("Datset name:", args.dataset_name)

test_dataset = SMDUtils.get_dataset_from_file(args.dataset_name + '_test',
                                              'placement',
                                              chunk_size=args.chunk_size)

test_loader = datautils.DataLoader(test_dataset,
                                   num_workers=4,
                                   batch_size=args.batch_size)

model = StepPlacement.RegularizedRecurrentStepPlacementModel()
model.load_state_dict(torch.load(args.placement_model))
model.cuda()

outputs, labels = model.predict(test_loader, max_batches=args.n_batches)

pmodel_str = args.placement_model.split('/')[-1][:-3]
torch.save(outputs, f'outputs_{args.dataset_name}_{pmodel_str}.torch')
示例#3
0
def train(args):
    print("Begin training.")
    train_dataset = SMDUtils.get_dataset_from_file(args.train)
    test_dataset = SMDUtils.get_dataset_from_file(args.test)
    
    train_loader = datautils.DataLoader(
        train_dataset,
        num_workers=args.num_workers,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True)
    
    test_loader = datautils.DataLoader(
        test_dataset,
        num_workers=args.num_workers,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True)
    
    print('N train:', len(train_loader))
    print('N test:', len(test_loader))
    
    train_ts = utils.timestamp()
    
    model = StepPlacement.RecurrentStepPlacementModel()
    
    checkpoint_path = f'{args.output_data_dir}/model.cpt'
    if os.path.exists(checkpoint_path):
        print("Loading weights from" checkpoint_path)
        model.load_state_dict(torch.load(checkpoint_path))
        
    model.cuda()
    
    model.fit(train_loader, 
              args.epochs, 
              args.batch_size, 
              args.checkpoint_freq,
              args.output_data_dir)
    
    torch.save(model.state_dict(), f'{args.model_dir}/StepPlacement.torch')
    
    outputs, labels = model.predict(test_loader, max_batches = 2000)
    
    s3_bucket = 'sagemaker-us-west-1-284801879240'
    sm_env = json.loads(os.environ['SM_TRAINING_ENV'])
    s3_path = sm_env['job_name']
    
    def sigmoid(x):
        return 1 / (1 + np.exp(-x))
    
    preds = sigmoid(outputs) > 0.1
    
    accuracy = 1 - np.mean(np.abs(preds - labels))
    print("Accuracy:", accuracy)
    
    percent_pos = np.mean(preds)
    print("Percent positive:", percent_pos)
    
    fpr, tpr, _ = roc_curve(labels, outputs)
    roc_auc = auc(fpr, tpr)
    
    plt.plot(fpr, tpr)
    roc_buf = io.BytesIO()
    plt.savefig(roc_buf, format='png')
    roc_buf.seek(0)
    utils.upload_image_obj(roc_buf, s3_bucket, f'{s3_path}/roc_auc.png')
    
    print("AUC ROC:", roc_auc)
    
    precision, recall, _ = precision_recall_curve(labels, outputs)
    prauc = auc(recall, precision)
    print("AUC PR:", prauc)
    
    plt.plot(recall, precision)
    pr_buf = io.BytesIO()
    plt.savefig(pr_buf, format='png')
    pr_buf.seek(0)
    utils.upload_image_obj(pr_buf, s3_bucket, f'{s3_path}/pr_auc.png')
    
    f1 = f1_score(labels, preds)
    print("F1 score:", f1)
    
    output_metrics = [
        'Training done.',
        f'Accuracy: {accuracy}',
        f'Percent pos: {percent_pos}',
        f'ROC AUC: {roc_auc}',
        f'PRAUC: {prauc}',
        f'F1 score: {f1_score}'
    ]
    
    output_str = output_metrics.join('\\n')
    
    utils.notify(output_str)
示例#4
0
n_songs = args.n_songs
train_ts = utils.timestamp()

print("Training on dataset", dataset_name)
print("Batch size:", batch_size)
print("Chunk size:", chunk_size)
print("N epochs:", n_epochs)
print("Timestamp:", train_ts)
print("Output dir:", output_dir)

st_time = time.time()

print("Loading data...")

train_dataset = SMDUtils.get_dataset_from_file(dataset_name + '_train',
                                               'gen',
                                               n_songs=n_songs,
                                               chunk_size=chunk_size)

test_dataset = SMDUtils.get_dataset_from_file(dataset_name + '_test',
                                              'gen',
                                              n_songs=n_songs,
                                              chunk_size=chunk_size)

# Train/test sets are pre-generated.
train_loader = datautils.DataLoader(train_dataset,
                                    num_workers=8,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    pin_memory=True)
print("Train dataset size:", len(train_dataset))
示例#5
0
from deepSM import SMDUtils

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Convert raw dataset into step placement dataset.")

    # Looks for the folder data/{dataset}.
    parser.add_argument('dataset', type=str, help='The raw data to process.')
    # Outputs to datasets/{output_name}.
    parser.add_argument('output_name', type=str)
    parser.add_argument('--drop_diffs',
                        type=str,
                        nargs='+',
                        help="Exclude difficulty from processing.")
    parser.add_argument(
        '--test',
        type=float,
        default=-1,
        help="Percent of data in test dataset, if splitting dataset.")

    args = parser.parse_args()

    SMDUtils.save_generated_datasets(args.dataset,
                                     dataset_name=args.output_name,
                                     test_split=None,
                                     drop_diffs=args.drop_diffs)

    if args.test >= 0:
        SMDUtils.train_test_split_dataset(args.output_name, args.test)
示例#6
0
import argparse

from deepSM import SMDUtils

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Splits dataset into train/test splits.")

    parser.add_argument('dataset', type=str)
    parser.add_argument('--test',
                        type=float,
                        default=0.25,
                        help="Percent of data in test dataset.")

    args = parser.parse_args()

    SMDUtils.train_test_split_dataset(args.dataset, test_split=args.test)
示例#7
0
from deepSM import utils
from deepSM import StepPlacement
from deepSM import post_processing

import torch
import torch.utils.data as datautils

parser = argparse.ArgumentParser()

parser.add_argument('dataset_name', type=str)
parser.add_argument('placement_model', type=str)

args = parser.parse_args()

smds = SMDUtils.get_dataset_from_file(args.dataset_name,
                                      'placement',
                                      chunk_size=-1,
                                      concat=False)

diffs = ['Beginner', 'Easy', 'Medium', 'Hard', 'Challenge']
thresholds = {}
for diff in diffs:
    thresholds[diff] = []

targets = dict(list(zip(diffs, [50, 66, 130, 220, 380])))

model = StepPlacement.RegularizedRecurrentStepPlacementModel()
model.load_state_dict(torch.load(args.placement_model))
model.cuda()

for smd in smds:
    print("Loading song", smd.song_name)