示例#1
0
def load_model(filepath):
    """
    Load a MusicTransformer from a saved pytorch state_dict and hparams. The input filepath should point to a .pt
    file in which has been saved a dictionary containing the model state dict and hparams, ex:
    torch.save(filepath, {
        "state_dict": MusicTransformer.state_dict(),
        "hparams": hparams (dict)
    })

    Args:
        filepath (str): path to single .pt file containing the dictionary as described above

    Returns:
        the loaded MusicTransformer model
    """
    from model import MusicTransformer
    from hparams import hparams

    file = torch.load(filepath)
    if "hparams" not in file:
        file["hparams"] = hparams

    model = MusicTransformer(**file["hparams"]).to(device)
    model.load_state_dict(file["state_dict"])
    model.eval()
    return model
示例#2
0
    def __init__(self, hparams_, datapath, batch_size, warmup_steps=4000,
                 ckpt_path="music_transformer_ckpt.pt", load_from_checkpoint=False):
        """
        Args:
            hparams_: hyperparameters of the model
            datapath: path to the data to train on
            batch_size: batch size to batch the data
            warmup_steps: number of warmup steps for transformer learning rate schedule
            ckpt_path: path at which to save checkpoints while training; MUST end in .pt or .pth
            load_from_checkpoint (bool, optional): if true, on instantiating the trainer, this will load a previously
                                                   saved checkpoint at ckpt_path
        """
        # get the data
        self.datapath = datapath
        self.batch_size = batch_size
        data = torch.load(datapath).long().to(device)

        # max absolute position must be able to acount for the largest sequence in the data
        if hparams_["max_abs_position"] > 0:
            hparams_["max_abs_position"] = max(hparams_["max_abs_position"], data.shape[-1])

        # train / validation split: 80 / 20
        train_len = round(data.shape[0] * 0.8)
        train_data = data[:train_len]
        val_data = data[train_len:]
        print(f"There are {data.shape[0]} samples in the data, {len(train_data)} training samples and {len(val_data)} "
              "validation samples")

        # datasets and dataloaders: split data into first (n-1) and last (n-1) tokens
        self.train_ds = TensorDataset(train_data[:, :-1], train_data[:, 1:])
        self.train_dl = DataLoader(dataset=self.train_ds, batch_size=batch_size, shuffle=True)

        self.val_ds = TensorDataset(val_data[:, :-1], val_data[:, 1:])
        self.val_dl = DataLoader(dataset=self.val_ds, batch_size=batch_size, shuffle=True)

        # create model
        self.model = MusicTransformer(**hparams_).to(device)
        self.hparams = hparams_

        # setup training
        self.warmup_steps = warmup_steps
        self.optimizer = optim.Adam(self.model.parameters(), lr=1.0, betas=(0.9, 0.98))
        self.scheduler = optim.lr_scheduler.LambdaLR(
            self.optimizer,
            lambda x: transformer_lr_schedule(self.hparams['d_model'], x, self.warmup_steps)
        )

        # setup checkpointing / saving
        self.ckpt_path = ckpt_path
        self.train_losses = []
        self.val_losses = []

        # load checkpoint if necessesary
        if load_from_checkpoint and os.path.isfile(self.ckpt_path):
            self.load()
示例#3
0
def main():

    input_dir = "../data/test"
    training_val_split = 0.7
    #defines, in Hz, the smallest timestep preserved in quantizing MIDIs
    #determines number of timeshift events
    sampling_rate = 125
    #determines number of velocity events
    n_velocity_bins = 32

    #set up data pipeline
    seq_length = 128
    padded_length = 128
    pipeline = PreprocessingPipeline(input_dir=input_dir,
                                     stretch_factors=[0.975, 1, 1.025],
                                     split_size=30,
                                     sampling_rate=sampling_rate,
                                     n_velocity_bins=n_velocity_bins,
                                     transpositions=range(-2, 3),
                                     training_val_split=training_val_split,
                                     max_encoded_length=seq_length + 1,
                                     min_encoded_length=33)

    pipeline.run()
    training_sequences = pipeline.encoded_sequences['training'][:1000]
    validation_sequences = pipeline.encoded_sequences['validation'][:100]
    n_tokens = 256 + sampling_rate + n_velocity_bins
    batch_size = 10
    optim = "adam"
    transformer = MusicTransformer(n_tokens,
                                   seq_length=padded_length,
                                   d_model=4,
                                   d_feedforward=32,
                                   n_heads=4,
                                   positional_encoding=True,
                                   relative_pos=True)

    train(transformer,
          training_sequences,
          validation_sequences,
          epochs=2,
          evaluate_per=1,
          batch_size=batch_size,
          batches_per_print=20,
          padding_index=0,
          checkpoint_path="../saved_models/test_save",
          custom_schedule=True)

    print(sample(transformer, 10))
示例#4
0
文件: run.py 项目: BShennette/Pno-ai
def main():
    parser = argparse.ArgumentParser("Script to train model on a GPU")
    parser.add_argument("--checkpoint", type=str, default=None,
            help="Optional path to saved model, if none provided, the model is trained from scratch.")
    parser.add_argument("--n_epochs", type=int, default=5,
            help="Number of training epochs.")
    args = parser.parse_args()
    
    sampling_rate = 125
    n_velocity_bins = 32
    seq_length = 1024
    n_tokens = 256 + sampling_rate + n_velocity_bins
    transformer = MusicTransformer(n_tokens, seq_length, 
            d_model = 64, n_heads = 8, d_feedforward=256, 
            depth = 4, positional_encoding=True, relative_pos=True)

    if args.checkpoint is not None:
        state = torch.load(args.checkpoint)
        transformer.load_state_dict(state)
        print(f"Successfully loaded checkpoint at {args.checkpoint}")
    #rule of thumb: 1 minute is roughly 2k tokens
    
    pipeline = PreprocessingPipeline(input_dir="data", stretch_factors=[0.975, 1, 1.025],
            split_size=30, sampling_rate=sampling_rate, n_velocity_bins=n_velocity_bins,
            transpositions=range(-2,3), training_val_split=0.9, max_encoded_length=seq_length+1,
                                    min_encoded_length=257)
    pipeline_start = time.time()
    pipeline.run()
    runtime = time.time() - pipeline_start
    print(f"MIDI pipeline runtime: {runtime / 60 : .1f}m")


    today = datetime.date.today().strftime('%m%d%Y')
    checkpoint = f"saved_models/tf_{today}"

    training_sequences = pipeline.encoded_sequences['training']
    validation_sequences = pipeline.encoded_sequences['validation']
    
    batch_size = 2
    
    train(transformer, training_sequences, validation_sequences,
               epochs = args.n_epochs, evaluate_per = 1,
               batch_size = batch_size, batches_per_print=100,
               padding_index=0, checkpoint_path=checkpoint)
示例#5
0
    def load(self, ckpt_path=None):
        """
        Loads a checkpoint from ckpt_path
        NOTE: OVERWRITES THE MODEL STATE DICT, OPTIMIZER STATE DICT, SCHEDULER STATE DICT, AND HISTORY OF LOSSES

        Args:
            ckpt_path (str, optional): if None, loads the checkpoint at the previously stored self.ckpt_path
                                       else loads the checkpoints from the new passed-in path, and stores this new path
                                       at the member variable self.ckpt_path
        """
        if ckpt_path is not None:
            self.ckpt_path = ckpt_path

        ckpt = torch.load(self.ckpt_path)

        del self.model, self.optimizer, self.scheduler

        # create and load model
        self.model = MusicTransformer(**ckpt["hparams"]).to(device)
        self.hparams = ckpt["hparams"]
        print("Loading the model...", end="")
        print(self.model.load_state_dict(ckpt["model_state_dict"]))

        # create and load load optimizer and scheduler
        self.warmup_steps = ckpt["warmup_steps"]
        self.optimizer = optim.Adam(self.model.parameters(), lr=1.0, betas=(0.9, 0.98))
        self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        self.scheduler = optim.lr_scheduler.LambdaLR(
            self.optimizer,
            lambda x: transformer_lr_schedule(self.hparams['d_model'], x, self.warmup_steps)
        )
        self.scheduler.load_state_dict(ckpt["scheduler_state_dict"])

        # load loss histories
        self.train_losses = ckpt["train_losses"]
        self.val_losses = ckpt["validation_losses"]

        return
示例#6
0
    config.device = torch.device('cuda')
else:
    config.device = torch.device('cpu')

# load data
dataset = Data(config.pickle_dir)
print(dataset)

# load model
learning_rate = config.l_r

# define model
mt = MusicTransformer(embedding_dim=config.embedding_dim,
                      vocab_size=config.vocab_size,
                      num_layer=config.num_layers,
                      max_seq=config.max_seq,
                      dropout=config.dropout,
                      debug=config.debug,
                      loader_path=config.load_path)
mt.to(config.device)
opt = optim.Adam(mt.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
scheduler = CustomSchedule(config.embedding_dim, optimizer=opt)

# multi-GPU set
if torch.cuda.device_count() > 1:
    single_mt = mt
    mt = torch.nn.DataParallel(mt, output_device=torch.cuda.device_count() - 1)
else:
    single_mt = mt

# init metric set
示例#7
0
# load data
dataset = Data('dataset/processed')
print(dataset)

# load model
learning_rate = callback.CustomSchedule(par.embedding_dim)
opt = Adam(l_r, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

strategy = tf.distribute.MirroredStrategy()

# define model
with strategy.scope():
    mt = MusicTransformer(embedding_dim=256,
                          vocab_size=par.vocab_size,
                          num_layer=6,
                          max_seq=max_seq,
                          dropout=0.2,
                          debug=False,
                          loader_path=load_path)
    mt.compile(optimizer=opt, loss=callback.transformer_dist_train_loss)

    # Train Start
    for e in range(epochs):
        mt.reset_metrics()
        for b in range(len(dataset.files) // batch_size):
            try:
                batch_x, batch_y = dataset.seq2seq_batch(batch_size, max_seq)
            except:
                continue
            result_metrics = mt.train_on_batch(batch_x, batch_y)
            if b % 100 == 0:
示例#8
0
args = parser.parse_args()
config.load(args.model_dir, args.configs, initialize=True)

# check cuda
if torch.cuda.is_available():
    config.device = torch.device('cuda')
else:
    config.device = torch.device('cpu')

current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
gen_log_dir = 'logs/mt_decoder/generate_' + current_time + '/generate'
gen_summary_writer = SummaryWriter(gen_log_dir)

mt = MusicTransformer(embedding_dim=config.embedding_dim,
                      vocab_size=config.vocab_size,
                      num_layer=config.num_layers,
                      max_seq=config.max_seq,
                      dropout=0,
                      debug=False)
mt.load_state_dict(torch.load(args.model_dir + '/final.pth'))
mt.test()

print(config.condition_file)
if config.condition_file is not None:
    inputs = np.array([encode_midi('dataset/midi/BENABD10.mid')[:500]])
else:
    inputs = np.array([[24, 28, 31]])
inputs = torch.from_numpy(inputs)
result = mt(inputs, config.length, gen_summary_writer)

for i in result:
    print(i)
示例#9
0
def main():
    parser = argparse.ArgumentParser("Script to generate MIDI tracks by sampling from a trained model.")

    parser.add_argument("--model", type=str, 
            help="Key in saved_models/model.yaml, helps look up model arguments and path to saved checkpoint.")
    parser.add_argument("--sample_length", type=int, default=512,
            help="number of events to generate")
    parser.add_argument("--temps", nargs="+", type=float, 
            default=[1.0],
            help="space-separated list of temperatures to use when sampling")
    parser.add_argument("--n_trials", type=int, default=3,
            help="number of MIDI samples to generate per experiment")
    parser.add_argument("--live_input", action='store_true', default = False,
            help="if true, take in a seed from a MIDI input controller")

    parser.add_argument("--play_live", action='store_true', default=False,
            help="play sample(s) at end of script if true")
    parser.add_argument("--keep_ghosts", action='store_true', default=True)
    parser.add_argument("--stuck_note_duration", type=int, default=1)

    args=parser.parse_args()

    model = args.model
    '''
    try:
        model_dict = yaml.safe_load(open('saved_models/model.yaml'))[model_key]
    except:
        raise GeneratorError(f"could not find yaml information for key {model_key}")
    '''
    #model_path = model_dict["path"]
    #model_args = model_dict["args"]

    #Change the value here to the model you want to run
    model_path = 'saved_models/'+model

    try:
        state = torch.load(model_path)
    except RuntimeError:
        state = torch.load(model_path, map_location="cpu")
    
    n_velocity_events = 32
    n_time_shift_events = 125

    decoder = SequenceEncoder(n_time_shift_events, n_velocity_events,
           min_events=0)

    if args.live_input:
        pretty_midis = []
        m = 'twinkle.midi'
        with open(m, "rb") as f:
                try:
                    midi_str = six.BytesIO(f.read())
                    pretty_midis.append(pretty_midi.PrettyMIDI(midi_str))
                    #print("Successfully parsed {}".format(m))
                except:
                    print("Could not parse {}".format(m))
        pipeline = PreprocessingPipeline(input_dir="data")
        note_sequence = pipeline.get_note_sequences(pretty_midis)
        note_sequence = [vectorize(ns) for ns in note_sequence]
        prime_sequence = decoder.encode_sequences(note_sequence)
        prime_sequence = prime_sequence[1:6]


    else:
        prime_sequence = []

   #model = MusicTransformer(**model_args)
    model = MusicTransformer(256+125+32, 1024, 
            d_model = 64, n_heads = 8, d_feedforward=256, 
            depth = 4, positional_encoding=True, relative_pos=True)

    model.load_state_dict(state, strict=False)

    temps = args.temps

    trial_key = str(uuid.uuid4())[:6]
    n_trials = args.n_trials

    keep_ghosts = args.keep_ghosts
    stuck_note_duration = None if args.stuck_note_duration == 0 else args.stuck_note_duration

    for temp in temps:
        print(f"sampling temp={temp}")
        note_sequence = []
        for i in range(n_trials):
            print("generating sequence")
            output_sequence = sample(model, prime_sequence = prime_sequence, sample_length=args.sample_length, temperature=temp)
            note_sequence = decoder.decode_sequence(output_sequence, 
                verbose=True, stuck_note_duration=0.5, keep_ghosts=True)

            output_dir = f"output/midis/{trial_key}/"
            file_name = f"sample{i+1}_{temp}"
            write_midi(note_sequence, output_dir, file_name)
    '''
示例#10
0
# load data
dataset = Data('dataset/processed')
print(dataset)

# load model
learning_rate = callback.CustomSchedule(par.embedding_dim)
opt = Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

strategy = tf.distribute.MirroredStrategy()

# define model
with strategy.scope():
    mt = MusicTransformer(embedding_dim=256,
                          vocab_size=par.vocab_size,
                          num_layer=6,
                          max_seq=max_seq,
                          dropout=0.1,
                          debug=False)
    mt.compile(optimizer=opt, loss=callback.transformer_dist_train_loss)

    # Train Start
    for e in range(epochs):
        mt.reset_metrics()
        for b in range(len(dataset.files) // batch_size):
            try:
                batch_x, batch_y = dataset.seq2seq_batch(batch_size, max_seq)
            except:
                continue
            result_metrics = mt.train_on_batch(batch_x, batch_y)
            if b % 100 == 0:
                eval_x, eval_y = dataset.seq2seq_batch(batch_size, max_seq,
示例#11
0
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
gen_log_dir = 'logs/mt_decoder/generate_' + current_time + '/generate'
gen_summary_writer = SummaryWriter(gen_log_dir)

# mt = MusicTransformer(
#     embedding_dim=config.embedding_dim,
#     vocab_size=config.vocab_size,
#     num_layer=config.num_layers,
#     max_seq=config.max_seq,
#     dropout=0,
#     debug=False)

transformer = MusicTransformer(n_tokens,
                               seq_length,
                               d_model=64,
                               n_heads=8,
                               d_feedforward=256,
                               depth=4,
                               positional_encoding=True,
                               relative_pos=True)

# mt.load_state_dict(torch.load(args.model_dir+'/final.pth'))
# mt.test()

if args.checkpoint is not None:
    state = torch.load(args.checkpoint)
    transformer.load_state_dict(state)
    print(f"Successfully loaded checkpoint at {args.checkpoint}")
else:
    print(f"NOT FOUND checkpoint")

if config.condition_file is not None:
示例#12
0
def main():
    parser = argparse.ArgumentParser(
        "Script to generate MIDI tracks by sampling from a trained model.")

    parser.add_argument(
        "--model_key",
        type=str,
        help=
        "Key in saved_models/model.yaml, helps look up model arguments and path to saved checkpoint."
    )
    parser.add_argument("--sample_length",
                        type=int,
                        default=512,
                        help="number of events to generate")
    parser.add_argument(
        "--temps",
        nargs="+",
        type=float,
        default=[1.0],
        help="space-separated list of temperatures to use when sampling")
    parser.add_argument(
        "--n_trials",
        type=int,
        default=3,
        help="number of MIDI samples to generate per experiment")
    parser.add_argument(
        "--live_input",
        action='store_true',
        default=False,
        help="if true, take in a seed from a MIDI input controller")

    parser.add_argument("--play_live",
                        action='store_true',
                        default=False,
                        help="play sample(s) at end of script if true")
    parser.add_argument("--keep_ghosts", action='store_true', default=False)
    parser.add_argument("--stuck_note_duration", type=int, default=0)

    args = parser.parse_args()

    model_key = args.model_key

    try:
        model_dict = yaml.safe_load(open('saved_models/model.yaml'))[model_key]
    except:
        raise GeneratorError(
            f"could not find yaml information for key {model_key}")

    model_path = model_dict["path"]
    model_args = model_dict["args"]
    try:
        state = torch.load(model_path)
    except RuntimeError:
        state = torch.load(model_path, map_location="cpu")

    n_velocity_events = 32
    n_time_shift_events = 125

    decoder = SequenceEncoder(n_time_shift_events,
                              n_velocity_events,
                              min_events=0)

    if args.live_input:
        print("Expecting a midi input...")
        note_sequence = midi_input.read(n_velocity_events, n_time_shift_events)
        prime_sequence = decoder.encode_sequences([note_sequence])[0]

    else:
        prime_sequence = []

    model = MusicTransformer(**model_args)
    model.load_state_dict(state, strict=False)

    temps = args.temps

    trial_key = str(uuid.uuid4())[:6]
    n_trials = args.n_trials

    keep_ghosts = args.keep_ghosts
    stuck_note_duration = None if args.stuck_note_duration == 0 else args.stuck_note_duration

    for temp in temps:
        print(f"sampling temp={temp}")
        note_sequence = []
        for i in range(n_trials):
            print("generating sequence")
            output_sequence = sample(model,
                                     prime_sequence=prime_sequence,
                                     sample_length=args.sample_length,
                                     temperature=temp)
            note_sequence = decoder.decode_sequence(output_sequence,
                                                    verbose=True,
                                                    stuck_note_duration=None)

            output_dir = f"output/{model_key}/{trial_key}/"
            file_name = f"sample{i+1}_{temp}"
            write_midi(note_sequence, output_dir, file_name)

    for temp in temps:
        try:
            subprocess.run([
                'timidity',
                f"output/{model_key}/{trial_key}/sample{i+1}_{temp}.midi"
            ])
        except KeyboardInterrupt:
            continue
示例#13
0
def main():
    parser = argparse.ArgumentParser(
        "Script to generate MIDI tracks by sampling from a trained model.")

    parser.add_argument(
        "--model_key",
        type=str,
        help=
        "Key in saved_models/model.yaml, helps look up model arguments and path to saved checkpoint."
    )
    parser.add_argument("--sample_length",
                        type=int,
                        default=512,
                        help="number of events to generate")
    parser.add_argument(
        "--temps",
        nargs="+",
        type=float,
        default=[1.0],
        help="space-separated list of temperatures to use when sampling")
    parser.add_argument(
        "--n_trials",
        type=int,
        default=3,
        help="number of MIDI samples to generate per experiment")
    parser.add_argument("--primer",
                        type=str,
                        default=None,
                        help="Path to the primer")

    parser.add_argument("--play_live",
                        action='store_true',
                        default=False,
                        help="play sample(s) at end of script if true")
    parser.add_argument("--keep_ghosts", action='store_true', default=False)
    parser.add_argument("--stuck_note_duration", type=int, default=0)

    args = parser.parse_args()

    model_key = args.model_key

    try:
        model_dict = yaml.safe_load(open('saved_models/model.yaml'))[model_key]
    except:
        raise GeneratorError(
            f"could not find yaml information for key {model_key}")

    model_path = model_dict["path"]
    model_args = model_dict["args"]
    try:
        state = torch.load(model_path)
    except RuntimeError:
        state = torch.load(model_path, map_location="cpu")

    n_velocity_events = 32
    n_time_shift_events = 125

    decoder = SequenceEncoder(n_time_shift_events,
                              n_velocity_events,
                              min_events=0)

    if args.primer:
        # Read midi primer
        midi_str = six.BytesIO(open(args.primer, 'rb').read())
        p = pretty_midi.PrettyMIDI(midi_str)
        piano_data = p.instruments[0]

        notes = apply_sustain(piano_data)
        note_sequence = sorted(notes, key=lambda x: (x.start, x.pitch))
        ns = vectorize(note_sequence)

        prime_sequence = decoder.encode_sequences([ns])[0]
    else:
        prime_sequence = []

    model = MusicTransformer(**model_args)
    model.load_state_dict(state, strict=False)

    temps = args.temps

    trial_key = str(uuid.uuid4())[:6]
    n_trials = args.n_trials

    keep_ghosts = args.keep_ghosts
    stuck_note_duration = None if args.stuck_note_duration == 0 else args.stuck_note_duration

    for temp in temps:
        print(f"sampling temp={temp}")
        note_sequence = []
        for i in range(n_trials):
            print("generating sequence")
            output_sequence = sample(model,
                                     prime_sequence=prime_sequence,
                                     sample_length=args.sample_length,
                                     temperature=temp)
            note_sequence = decoder.decode_sequence(output_sequence,
                                                    verbose=True,
                                                    stuck_note_duration=None)

            output_dir = f"output/{model_key}/{trial_key}/"
            file_name = f"sample{i+1}_{temp}"
            write_midi(note_sequence, output_dir, file_name)
示例#14
0
import params as par
from tensorflow.python import enable_eager_execution
from tensorflow.python.keras.optimizer_v2.adam import Adam
from tensorflow.python.keras.optimizer_v2.gradient_descent import SGD
from data import Data
import utils
enable_eager_execution()

tf.executing_eagerly()

if __name__ == '__main__':
    epoch = 100
    batch = 1000

    dataset = Data('dataset/processed/')
    opt = Adam(0.0001)
    # opt = SGD(lr=0.0001, momentum=0.0, decay=0.0, nesterov=False)
    mt = MusicTransformer(embedding_dim=par.embedding_dim,
                          vocab_size=par.vocab_size,
                          num_layer=6,
                          max_seq=100,
                          debug=True)
    mt.compile(optimizer=opt, loss=callback.TransformerLoss())

    for e in range(epoch):
        for b in range(batch):
            batch_x, batch_y = dataset.seq2seq_batch(2, 100)
            result_metrics = mt.train_on_batch(batch_x, batch_y)
            print('Loss: {:6.6}, Accuracy: {:3.2}'.format(
                result_metrics[0], result_metrics[1]))