Exemplo n.º 1
0
    def __init__(self, hparams_, datapath, batch_size, warmup_steps=4000,
                 ckpt_path="music_transformer_ckpt.pt", load_from_checkpoint=False):
        """
        Args:
            hparams_: hyperparameters of the model
            datapath: path to the data to train on
            batch_size: batch size to batch the data
            warmup_steps: number of warmup steps for transformer learning rate schedule
            ckpt_path: path at which to save checkpoints while training; MUST end in .pt or .pth
            load_from_checkpoint (bool, optional): if true, on instantiating the trainer, this will load a previously
                                                   saved checkpoint at ckpt_path
        """
        # get the data
        self.datapath = datapath
        self.batch_size = batch_size
        data = torch.load(datapath).long().to(device)

        # max absolute position must be able to acount for the largest sequence in the data
        if hparams_["max_abs_position"] > 0:
            hparams_["max_abs_position"] = max(hparams_["max_abs_position"], data.shape[-1])

        # train / validation split: 80 / 20
        train_len = round(data.shape[0] * 0.8)
        train_data = data[:train_len]
        val_data = data[train_len:]
        print(f"There are {data.shape[0]} samples in the data, {len(train_data)} training samples and {len(val_data)} "
              "validation samples")

        # datasets and dataloaders: split data into first (n-1) and last (n-1) tokens
        self.train_ds = TensorDataset(train_data[:, :-1], train_data[:, 1:])
        self.train_dl = DataLoader(dataset=self.train_ds, batch_size=batch_size, shuffle=True)

        self.val_ds = TensorDataset(val_data[:, :-1], val_data[:, 1:])
        self.val_dl = DataLoader(dataset=self.val_ds, batch_size=batch_size, shuffle=True)

        # create model
        self.model = MusicTransformer(**hparams_).to(device)
        self.hparams = hparams_

        # setup training
        self.warmup_steps = warmup_steps
        self.optimizer = optim.Adam(self.model.parameters(), lr=1.0, betas=(0.9, 0.98))
        self.scheduler = optim.lr_scheduler.LambdaLR(
            self.optimizer,
            lambda x: transformer_lr_schedule(self.hparams['d_model'], x, self.warmup_steps)
        )

        # setup checkpointing / saving
        self.ckpt_path = ckpt_path
        self.train_losses = []
        self.val_losses = []

        # load checkpoint if necessesary
        if load_from_checkpoint and os.path.isfile(self.ckpt_path):
            self.load()
Exemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser("Script to train model on a GPU")
    parser.add_argument("--checkpoint", type=str, default=None,
            help="Optional path to saved model, if none provided, the model is trained from scratch.")
    parser.add_argument("--n_epochs", type=int, default=5,
            help="Number of training epochs.")
    args = parser.parse_args()
    
    sampling_rate = 125
    n_velocity_bins = 32
    seq_length = 1024
    n_tokens = 256 + sampling_rate + n_velocity_bins
    transformer = MusicTransformer(n_tokens, seq_length, 
            d_model = 64, n_heads = 8, d_feedforward=256, 
            depth = 4, positional_encoding=True, relative_pos=True)

    if args.checkpoint is not None:
        state = torch.load(args.checkpoint)
        transformer.load_state_dict(state)
        print(f"Successfully loaded checkpoint at {args.checkpoint}")
    #rule of thumb: 1 minute is roughly 2k tokens
    
    pipeline = PreprocessingPipeline(input_dir="data", stretch_factors=[0.975, 1, 1.025],
            split_size=30, sampling_rate=sampling_rate, n_velocity_bins=n_velocity_bins,
            transpositions=range(-2,3), training_val_split=0.9, max_encoded_length=seq_length+1,
                                    min_encoded_length=257)
    pipeline_start = time.time()
    pipeline.run()
    runtime = time.time() - pipeline_start
    print(f"MIDI pipeline runtime: {runtime / 60 : .1f}m")


    today = datetime.date.today().strftime('%m%d%Y')
    checkpoint = f"saved_models/tf_{today}"

    training_sequences = pipeline.encoded_sequences['training']
    validation_sequences = pipeline.encoded_sequences['validation']
    
    batch_size = 2
    
    train(transformer, training_sequences, validation_sequences,
               epochs = args.n_epochs, evaluate_per = 1,
               batch_size = batch_size, batches_per_print=100,
               padding_index=0, checkpoint_path=checkpoint)
Exemplo n.º 3
0
def load_model(filepath):
    """
    Load a MusicTransformer from a saved pytorch state_dict and hparams. The input filepath should point to a .pt
    file in which has been saved a dictionary containing the model state dict and hparams, ex:
    torch.save(filepath, {
        "state_dict": MusicTransformer.state_dict(),
        "hparams": hparams (dict)
    })

    Args:
        filepath (str): path to single .pt file containing the dictionary as described above

    Returns:
        the loaded MusicTransformer model
    """
    from model import MusicTransformer
    from hparams import hparams

    file = torch.load(filepath)
    if "hparams" not in file:
        file["hparams"] = hparams

    model = MusicTransformer(**file["hparams"]).to(device)
    model.load_state_dict(file["state_dict"])
    model.eval()
    return model
Exemplo n.º 4
0
    def load(self, ckpt_path=None):
        """
        Loads a checkpoint from ckpt_path
        NOTE: OVERWRITES THE MODEL STATE DICT, OPTIMIZER STATE DICT, SCHEDULER STATE DICT, AND HISTORY OF LOSSES

        Args:
            ckpt_path (str, optional): if None, loads the checkpoint at the previously stored self.ckpt_path
                                       else loads the checkpoints from the new passed-in path, and stores this new path
                                       at the member variable self.ckpt_path
        """
        if ckpt_path is not None:
            self.ckpt_path = ckpt_path

        ckpt = torch.load(self.ckpt_path)

        del self.model, self.optimizer, self.scheduler

        # create and load model
        self.model = MusicTransformer(**ckpt["hparams"]).to(device)
        self.hparams = ckpt["hparams"]
        print("Loading the model...", end="")
        print(self.model.load_state_dict(ckpt["model_state_dict"]))

        # create and load load optimizer and scheduler
        self.warmup_steps = ckpt["warmup_steps"]
        self.optimizer = optim.Adam(self.model.parameters(), lr=1.0, betas=(0.9, 0.98))
        self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        self.scheduler = optim.lr_scheduler.LambdaLR(
            self.optimizer,
            lambda x: transformer_lr_schedule(self.hparams['d_model'], x, self.warmup_steps)
        )
        self.scheduler.load_state_dict(ckpt["scheduler_state_dict"])

        # load loss histories
        self.train_losses = ckpt["train_losses"]
        self.val_losses = ckpt["validation_losses"]

        return
Exemplo n.º 5
0
def main():

    input_dir = "../data/test"
    training_val_split = 0.7
    #defines, in Hz, the smallest timestep preserved in quantizing MIDIs
    #determines number of timeshift events
    sampling_rate = 125
    #determines number of velocity events
    n_velocity_bins = 32

    #set up data pipeline
    seq_length = 128
    padded_length = 128
    pipeline = PreprocessingPipeline(input_dir=input_dir,
                                     stretch_factors=[0.975, 1, 1.025],
                                     split_size=30,
                                     sampling_rate=sampling_rate,
                                     n_velocity_bins=n_velocity_bins,
                                     transpositions=range(-2, 3),
                                     training_val_split=training_val_split,
                                     max_encoded_length=seq_length + 1,
                                     min_encoded_length=33)

    pipeline.run()
    training_sequences = pipeline.encoded_sequences['training'][:1000]
    validation_sequences = pipeline.encoded_sequences['validation'][:100]
    n_tokens = 256 + sampling_rate + n_velocity_bins
    batch_size = 10
    optim = "adam"
    transformer = MusicTransformer(n_tokens,
                                   seq_length=padded_length,
                                   d_model=4,
                                   d_feedforward=32,
                                   n_heads=4,
                                   positional_encoding=True,
                                   relative_pos=True)

    train(transformer,
          training_sequences,
          validation_sequences,
          epochs=2,
          evaluate_per=1,
          batch_size=batch_size,
          batches_per_print=20,
          padding_index=0,
          checkpoint_path="../saved_models/test_save",
          custom_schedule=True)

    print(sample(transformer, 10))
Exemplo n.º 6
0
# load data
dataset = Data('dataset/processed')
print(dataset)

# load model
learning_rate = callback.CustomSchedule(par.embedding_dim)
opt = Adam(l_r, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

strategy = tf.distribute.MirroredStrategy()

# define model
with strategy.scope():
    mt = MusicTransformer(embedding_dim=256,
                          vocab_size=par.vocab_size,
                          num_layer=6,
                          max_seq=max_seq,
                          dropout=0.2,
                          debug=False,
                          loader_path=load_path)
    mt.compile(optimizer=opt, loss=callback.transformer_dist_train_loss)

    # Train Start
    for e in range(epochs):
        mt.reset_metrics()
        for b in range(len(dataset.files) // batch_size):
            try:
                batch_x, batch_y = dataset.seq2seq_batch(batch_size, max_seq)
            except:
                continue
            result_metrics = mt.train_on_batch(batch_x, batch_y)
            if b % 100 == 0:
Exemplo n.º 7
0
args = parser.parse_args()
config.load(args.model_dir, args.configs, initialize=True)

# check cuda
if torch.cuda.is_available():
    config.device = torch.device('cuda')
else:
    config.device = torch.device('cpu')

current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
gen_log_dir = 'logs/mt_decoder/generate_' + current_time + '/generate'
gen_summary_writer = SummaryWriter(gen_log_dir)

mt = MusicTransformer(embedding_dim=config.embedding_dim,
                      vocab_size=config.vocab_size,
                      num_layer=config.num_layers,
                      max_seq=config.max_seq,
                      dropout=0,
                      debug=False)
mt.load_state_dict(torch.load(args.model_dir + '/final.pth'))
mt.test()

print(config.condition_file)
if config.condition_file is not None:
    inputs = np.array([encode_midi('dataset/midi/BENABD10.mid')[:500]])
else:
    inputs = np.array([[24, 28, 31]])
inputs = torch.from_numpy(inputs)
result = mt(inputs, config.length, gen_summary_writer)

for i in result:
    print(i)
Exemplo n.º 8
0
else:
    config.device = torch.device('cpu')
    device_ids = None

# load data
dataset = Data(config.pickle_dir)
print(dataset)

# load model
learning_rate = config.l_r

# define model
mt = MusicTransformer(embedding_dim=config.embedding_dim,
                      vocab_size=config.vocab_size,
                      num_layer=config.num_layers,
                      max_seq=config.max_seq,
                      dropout=config.dropout,
                      debug=config.debug,
                      loader_path=config.load_path)
opt = optim.Adam(mt.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
scheduler = CustomSchedule(config.embedding_dim, optimizer=opt)

# multi-GPU set
if torch.cuda.device_count() > 1:
    single_mt = mt
    mt = DataParallelModel(mt)
    mt.cuda()
else:
    single_mt = mt

# init metric set
Exemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser("Script to generate MIDI tracks by sampling from a trained model.")

    parser.add_argument("--model", type=str, 
            help="Key in saved_models/model.yaml, helps look up model arguments and path to saved checkpoint.")
    parser.add_argument("--sample_length", type=int, default=512,
            help="number of events to generate")
    parser.add_argument("--temps", nargs="+", type=float, 
            default=[1.0],
            help="space-separated list of temperatures to use when sampling")
    parser.add_argument("--n_trials", type=int, default=3,
            help="number of MIDI samples to generate per experiment")
    parser.add_argument("--live_input", action='store_true', default = False,
            help="if true, take in a seed from a MIDI input controller")

    parser.add_argument("--play_live", action='store_true', default=False,
            help="play sample(s) at end of script if true")
    parser.add_argument("--keep_ghosts", action='store_true', default=True)
    parser.add_argument("--stuck_note_duration", type=int, default=1)

    args=parser.parse_args()

    model = args.model
    '''
    try:
        model_dict = yaml.safe_load(open('saved_models/model.yaml'))[model_key]
    except:
        raise GeneratorError(f"could not find yaml information for key {model_key}")
    '''
    #model_path = model_dict["path"]
    #model_args = model_dict["args"]

    #Change the value here to the model you want to run
    model_path = 'saved_models/'+model

    try:
        state = torch.load(model_path)
    except RuntimeError:
        state = torch.load(model_path, map_location="cpu")
    
    n_velocity_events = 32
    n_time_shift_events = 125

    decoder = SequenceEncoder(n_time_shift_events, n_velocity_events,
           min_events=0)

    if args.live_input:
        pretty_midis = []
        m = 'twinkle.midi'
        with open(m, "rb") as f:
                try:
                    midi_str = six.BytesIO(f.read())
                    pretty_midis.append(pretty_midi.PrettyMIDI(midi_str))
                    #print("Successfully parsed {}".format(m))
                except:
                    print("Could not parse {}".format(m))
        pipeline = PreprocessingPipeline(input_dir="data")
        note_sequence = pipeline.get_note_sequences(pretty_midis)
        note_sequence = [vectorize(ns) for ns in note_sequence]
        prime_sequence = decoder.encode_sequences(note_sequence)
        prime_sequence = prime_sequence[1:6]


    else:
        prime_sequence = []

   #model = MusicTransformer(**model_args)
    model = MusicTransformer(256+125+32, 1024, 
            d_model = 64, n_heads = 8, d_feedforward=256, 
            depth = 4, positional_encoding=True, relative_pos=True)

    model.load_state_dict(state, strict=False)

    temps = args.temps

    trial_key = str(uuid.uuid4())[:6]
    n_trials = args.n_trials

    keep_ghosts = args.keep_ghosts
    stuck_note_duration = None if args.stuck_note_duration == 0 else args.stuck_note_duration

    for temp in temps:
        print(f"sampling temp={temp}")
        note_sequence = []
        for i in range(n_trials):
            print("generating sequence")
            output_sequence = sample(model, prime_sequence = prime_sequence, sample_length=args.sample_length, temperature=temp)
            note_sequence = decoder.decode_sequence(output_sequence, 
                verbose=True, stuck_note_duration=0.5, keep_ghosts=True)

            output_dir = f"output/midis/{trial_key}/"
            file_name = f"sample{i+1}_{temp}"
            write_midi(note_sequence, output_dir, file_name)
    '''
Exemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser(
        "Script to generate MIDI tracks by sampling from a trained model.")

    parser.add_argument(
        "--model_key",
        type=str,
        help=
        "Key in saved_models/model.yaml, helps look up model arguments and path to saved checkpoint."
    )
    parser.add_argument("--sample_length",
                        type=int,
                        default=512,
                        help="number of events to generate")
    parser.add_argument(
        "--temps",
        nargs="+",
        type=float,
        default=[1.0],
        help="space-separated list of temperatures to use when sampling")
    parser.add_argument(
        "--n_trials",
        type=int,
        default=3,
        help="number of MIDI samples to generate per experiment")
    parser.add_argument("--primer",
                        type=str,
                        default=None,
                        help="Path to the primer")

    parser.add_argument("--play_live",
                        action='store_true',
                        default=False,
                        help="play sample(s) at end of script if true")
    parser.add_argument("--keep_ghosts", action='store_true', default=False)
    parser.add_argument("--stuck_note_duration", type=int, default=0)

    args = parser.parse_args()

    model_key = args.model_key

    try:
        model_dict = yaml.safe_load(open('saved_models/model.yaml'))[model_key]
    except:
        raise GeneratorError(
            f"could not find yaml information for key {model_key}")

    model_path = model_dict["path"]
    model_args = model_dict["args"]
    try:
        state = torch.load(model_path)
    except RuntimeError:
        state = torch.load(model_path, map_location="cpu")

    n_velocity_events = 32
    n_time_shift_events = 125

    decoder = SequenceEncoder(n_time_shift_events,
                              n_velocity_events,
                              min_events=0)

    if args.primer:
        # Read midi primer
        midi_str = six.BytesIO(open(args.primer, 'rb').read())
        p = pretty_midi.PrettyMIDI(midi_str)
        piano_data = p.instruments[0]

        notes = apply_sustain(piano_data)
        note_sequence = sorted(notes, key=lambda x: (x.start, x.pitch))
        ns = vectorize(note_sequence)

        prime_sequence = decoder.encode_sequences([ns])[0]
    else:
        prime_sequence = []

    model = MusicTransformer(**model_args)
    model.load_state_dict(state, strict=False)

    temps = args.temps

    trial_key = str(uuid.uuid4())[:6]
    n_trials = args.n_trials

    keep_ghosts = args.keep_ghosts
    stuck_note_duration = None if args.stuck_note_duration == 0 else args.stuck_note_duration

    for temp in temps:
        print(f"sampling temp={temp}")
        note_sequence = []
        for i in range(n_trials):
            print("generating sequence")
            output_sequence = sample(model,
                                     prime_sequence=prime_sequence,
                                     sample_length=args.sample_length,
                                     temperature=temp)
            note_sequence = decoder.decode_sequence(output_sequence,
                                                    verbose=True,
                                                    stuck_note_duration=None)

            output_dir = f"output/{model_key}/{trial_key}/"
            file_name = f"sample{i+1}_{temp}"
            write_midi(note_sequence, output_dir, file_name)
Exemplo n.º 11
0
mode = args.mode
beam = args.beam
length = args.length
save_path = args.save_path

current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
gen_log_dir = 'logs/mt_decoder/generate_' + current_time + '/generate'
gen_summary_writer = tf.summary.create_file_writer(gen_log_dir)

if mode == 'enc-dec':
    print(">> generate with original seq2seq wise... beam size is {}".format(
        beam))
    mt = MusicTransformer(embedding_dim=256,
                          vocab_size=par.vocab_size,
                          num_layer=6,
                          max_seq=2048,
                          dropout=0.2,
                          debug=False,
                          loader_path=load_path)
else:
    print(">> generate with decoder wise... beam size is {}".format(beam))
    mt = MusicTransformerDecoder(loader_path=load_path)

inputs = encode_midi('dataset/midi/BENABD10.mid')

with gen_summary_writer.as_default():
    result = mt.generate(inputs[:10], beam=beam, length=length, tf_board=True)

for i in result:
    print(i)
Exemplo n.º 12
0
gen_log_dir = 'logs/mt_decoder/generate_'+current_time+'/generate'
gen_summary_writer = SummaryWriter(gen_log_dir)

# import param
# mt = MusicTransformer(
#     embedding_dim=param.embedding_dim,
#     vocab_size=param.vocab_size,
#     num_layer=param.num_attention_layer,
#     max_seq=param.max_seq,
#     dropout=0,
#     debug=False)

mt = MusicTransformer(
    embedding_dim=config.embedding_dim,
    vocab_size=config.vocab_size,
    num_layer=config.num_layers,
    max_seq=config.max_seq,
    dropout=0,
    debug=False)
# mt.load_state_dict(torch.load(args.model_dir+'/final.pth'))
mt.test()

# def model_size_summary(model):
#     param_num = 0
#     for param in model.parameters():
#         param_num += np.prod(list(param.shape))
#         print(param.shape, "num %d" % np.prod(list(param.shape)))
#     print(param_num, " in total, %.2fmb"%(param_num * 4 / 1024**2))
#
# model_size_summary(mt)
mt.cuda()
Exemplo n.º 13
0
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
gen_log_dir = 'logs/mt_decoder/generate_' + current_time + '/generate'
gen_summary_writer = SummaryWriter(gen_log_dir)

# mt = MusicTransformer(
#     embedding_dim=config.embedding_dim,
#     vocab_size=config.vocab_size,
#     num_layer=config.num_layers,
#     max_seq=config.max_seq,
#     dropout=0,
#     debug=False)

transformer = MusicTransformer(n_tokens,
                               seq_length,
                               d_model=64,
                               n_heads=8,
                               d_feedforward=256,
                               depth=4,
                               positional_encoding=True,
                               relative_pos=True)

# mt.load_state_dict(torch.load(args.model_dir+'/final.pth'))
# mt.test()

if args.checkpoint is not None:
    state = torch.load(args.checkpoint)
    transformer.load_state_dict(state)
    print(f"Successfully loaded checkpoint at {args.checkpoint}")
else:
    print(f"NOT FOUND checkpoint")

if config.condition_file is not None:
Exemplo n.º 14
0
def main():
    parser = argparse.ArgumentParser(
        "Script to generate MIDI tracks by sampling from a trained model.")

    parser.add_argument(
        "--model_key",
        type=str,
        help=
        "Key in saved_models/model.yaml, helps look up model arguments and path to saved checkpoint."
    )
    parser.add_argument("--sample_length",
                        type=int,
                        default=512,
                        help="number of events to generate")
    parser.add_argument(
        "--temps",
        nargs="+",
        type=float,
        default=[1.0],
        help="space-separated list of temperatures to use when sampling")
    parser.add_argument(
        "--n_trials",
        type=int,
        default=3,
        help="number of MIDI samples to generate per experiment")
    parser.add_argument(
        "--live_input",
        action='store_true',
        default=False,
        help="if true, take in a seed from a MIDI input controller")

    parser.add_argument("--play_live",
                        action='store_true',
                        default=False,
                        help="play sample(s) at end of script if true")
    parser.add_argument("--keep_ghosts", action='store_true', default=False)
    parser.add_argument("--stuck_note_duration", type=int, default=0)

    args = parser.parse_args()

    model_key = args.model_key

    try:
        model_dict = yaml.safe_load(open('saved_models/model.yaml'))[model_key]
    except:
        raise GeneratorError(
            f"could not find yaml information for key {model_key}")

    model_path = model_dict["path"]
    model_args = model_dict["args"]
    try:
        state = torch.load(model_path)
    except RuntimeError:
        state = torch.load(model_path, map_location="cpu")

    n_velocity_events = 32
    n_time_shift_events = 125

    decoder = SequenceEncoder(n_time_shift_events,
                              n_velocity_events,
                              min_events=0)

    if args.live_input:
        print("Expecting a midi input...")
        note_sequence = midi_input.read(n_velocity_events, n_time_shift_events)
        prime_sequence = decoder.encode_sequences([note_sequence])[0]

    else:
        prime_sequence = []

    model = MusicTransformer(**model_args)
    model.load_state_dict(state, strict=False)

    temps = args.temps

    trial_key = str(uuid.uuid4())[:6]
    n_trials = args.n_trials

    keep_ghosts = args.keep_ghosts
    stuck_note_duration = None if args.stuck_note_duration == 0 else args.stuck_note_duration

    for temp in temps:
        print(f"sampling temp={temp}")
        note_sequence = []
        for i in range(n_trials):
            print("generating sequence")
            output_sequence = sample(model,
                                     prime_sequence=prime_sequence,
                                     sample_length=args.sample_length,
                                     temperature=temp)
            note_sequence = decoder.decode_sequence(output_sequence,
                                                    verbose=True,
                                                    stuck_note_duration=None)

            output_dir = f"output/{model_key}/{trial_key}/"
            file_name = f"sample{i+1}_{temp}"
            write_midi(note_sequence, output_dir, file_name)

    for temp in temps:
        try:
            subprocess.run([
                'timidity',
                f"output/{model_key}/{trial_key}/sample{i+1}_{temp}.midi"
            ])
        except KeyboardInterrupt:
            continue
Exemplo n.º 15
0
    config.device = torch.device('cuda')
else:
    config.device = torch.device('cpu')

# load data
dataset = Data(config.pickle_dir)
print(dataset)

# load model
learning_rate = config.l_r

# define model
mt = MusicTransformer(embedding_dim=config.embedding_dim,
                      vocab_size=config.vocab_size,
                      num_layer=config.num_layers,
                      max_seq=config.max_seq,
                      dropout=config.dropout,
                      debug=config.debug,
                      loader_path=config.load_path)
mt.to(config.device)
opt = optim.Adam(mt.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
scheduler = CustomSchedule(config.embedding_dim, optimizer=opt)

# multi-GPU set
if torch.cuda.device_count() > 1:
    single_mt = mt
    mt = torch.nn.DataParallel(mt, output_device=torch.cuda.device_count() - 1)
else:
    single_mt = mt

# init metric set
Exemplo n.º 16
0
class MusicTransformerTrainer:
    """
    As the transformer is a large model and takes a while to train on a GPU, or even a TPU, I wrote this Trainer
    class to make it easier to load and save checkpoints with the model. The way I've designed it instantiates the
    model, optimizer, and scheduler within the class itself, as there are some problems with passing them in. But,
    to get these objects back just call:
        trainer.model
        trainer.optimizer
        trainer.scheduler

    This class also tracks the cumulative losses while training, which you can get back with:
        trainer.train_losses
        trainer.val_losses
    as lists of floats

    To save a checkpoint, call trainer.save()
    To load a checkpoint, call trainer.load( (optional) ckpt_path)
    """

    def __init__(self, hparams_, datapath, batch_size, warmup_steps=4000,
                 ckpt_path="music_transformer_ckpt.pt", load_from_checkpoint=False):
        """
        Args:
            hparams_: hyperparameters of the model
            datapath: path to the data to train on
            batch_size: batch size to batch the data
            warmup_steps: number of warmup steps for transformer learning rate schedule
            ckpt_path: path at which to save checkpoints while training; MUST end in .pt or .pth
            load_from_checkpoint (bool, optional): if true, on instantiating the trainer, this will load a previously
                                                   saved checkpoint at ckpt_path
        """
        # get the data
        self.datapath = datapath
        self.batch_size = batch_size
        data = torch.load(datapath).long().to(device)

        # max absolute position must be able to acount for the largest sequence in the data
        if hparams_["max_abs_position"] > 0:
            hparams_["max_abs_position"] = max(hparams_["max_abs_position"], data.shape[-1])

        # train / validation split: 80 / 20
        train_len = round(data.shape[0] * 0.8)
        train_data = data[:train_len]
        val_data = data[train_len:]
        print(f"There are {data.shape[0]} samples in the data, {len(train_data)} training samples and {len(val_data)} "
              "validation samples")

        # datasets and dataloaders: split data into first (n-1) and last (n-1) tokens
        self.train_ds = TensorDataset(train_data[:, :-1], train_data[:, 1:])
        self.train_dl = DataLoader(dataset=self.train_ds, batch_size=batch_size, shuffle=True)

        self.val_ds = TensorDataset(val_data[:, :-1], val_data[:, 1:])
        self.val_dl = DataLoader(dataset=self.val_ds, batch_size=batch_size, shuffle=True)

        # create model
        self.model = MusicTransformer(**hparams_).to(device)
        self.hparams = hparams_

        # setup training
        self.warmup_steps = warmup_steps
        self.optimizer = optim.Adam(self.model.parameters(), lr=1.0, betas=(0.9, 0.98))
        self.scheduler = optim.lr_scheduler.LambdaLR(
            self.optimizer,
            lambda x: transformer_lr_schedule(self.hparams['d_model'], x, self.warmup_steps)
        )

        # setup checkpointing / saving
        self.ckpt_path = ckpt_path
        self.train_losses = []
        self.val_losses = []

        # load checkpoint if necessesary
        if load_from_checkpoint and os.path.isfile(self.ckpt_path):
            self.load()

    def save(self, ckpt_path=None):
        """
        Saves a checkpoint at ckpt_path

        Args:
            ckpt_path (str, optional): if None, saves the checkpoint at the previously stored self.ckpt_path
                                       else saves the checkpoints at the new passed-in path, and stores this new path at
                                       the member variable self.ckpt_path
        """
        if ckpt_path is not None:
            self.ckpt_path = ckpt_path

        ckpt = {
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
            "train_losses": self.train_losses,
            "validation_losses": self.val_losses,
            "warmup_steps": self.warmup_steps,
            "hparams": self.hparams
        }

        torch.save(ckpt, self.ckpt_path)
        return

    def load(self, ckpt_path=None):
        """
        Loads a checkpoint from ckpt_path
        NOTE: OVERWRITES THE MODEL STATE DICT, OPTIMIZER STATE DICT, SCHEDULER STATE DICT, AND HISTORY OF LOSSES

        Args:
            ckpt_path (str, optional): if None, loads the checkpoint at the previously stored self.ckpt_path
                                       else loads the checkpoints from the new passed-in path, and stores this new path
                                       at the member variable self.ckpt_path
        """
        if ckpt_path is not None:
            self.ckpt_path = ckpt_path

        ckpt = torch.load(self.ckpt_path)

        del self.model, self.optimizer, self.scheduler

        # create and load model
        self.model = MusicTransformer(**ckpt["hparams"]).to(device)
        self.hparams = ckpt["hparams"]
        print("Loading the model...", end="")
        print(self.model.load_state_dict(ckpt["model_state_dict"]))

        # create and load load optimizer and scheduler
        self.warmup_steps = ckpt["warmup_steps"]
        self.optimizer = optim.Adam(self.model.parameters(), lr=1.0, betas=(0.9, 0.98))
        self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        self.scheduler = optim.lr_scheduler.LambdaLR(
            self.optimizer,
            lambda x: transformer_lr_schedule(self.hparams['d_model'], x, self.warmup_steps)
        )
        self.scheduler.load_state_dict(ckpt["scheduler_state_dict"])

        # load loss histories
        self.train_losses = ckpt["train_losses"]
        self.val_losses = ckpt["validation_losses"]

        return

    def fit(self, epochs):
        """
        Training loop to fit the model to the data stored at the passed in datapath. If KeyboardInterrupt at anytime
        during the training loop, and if progresss being printed, this method will save a checkpoint at the 
        passed-in ckpt_path

        Args:
            epochs: number of epochs to train for.

        Returns:
            history of training and validation losses for this training session
        """
        print_interval = epochs // 10 + int(epochs < 10)
        train_losses = []
        val_losses = []
        start = time.time()

        print("Beginning training...")

        try:
            for epoch in range(epochs):
                train_epoch_losses = []
                val_epoch_losses = []

                self.model.train()
                for train_inp, train_tar in self.train_dl:
                    loss = train_step(self.model, self.optimizer, self.scheduler, train_inp, train_tar)
                    train_epoch_losses.append(loss)

                self.model.eval()
                for val_inp, val_tar in self.val_dl:
                    loss = val_step(self.model, val_inp, val_tar)
                    val_epoch_losses.append(loss)

                # mean losses for the epoch
                train_mean = sum(train_epoch_losses) / len(train_epoch_losses)
                val_mean = sum(val_epoch_losses) / len(val_epoch_losses)

                # store complete history of losses in member lists and relative history for this session in output lists
                self.train_losses.append(train_mean)
                train_losses.append(train_mean)
                self.val_losses.append(val_mean)
                val_losses.append(val_mean)

                if ((epoch + 1) % print_interval) == 0:
                    print(f"Epoch {epoch + 1} Time taken {round(time.time() - start, 2)} seconds "
                          f"Train Loss {train_losses[-1]} Val Loss {val_losses[-1]}")
                    # print("Checkpointing...")
                    # self.save()
                    # print("Done")
                    start = time.time()

        except KeyboardInterrupt:
            pass

        print("Checkpointing...")
        self.save()
        print("Done")

        return train_losses, val_losses
Exemplo n.º 17
0
import params as par
from tensorflow.python import enable_eager_execution
from tensorflow.python.keras.optimizer_v2.adam import Adam
from tensorflow.python.keras.optimizer_v2.gradient_descent import SGD
from data import Data
import utils
enable_eager_execution()

tf.executing_eagerly()

if __name__ == '__main__':
    epoch = 100
    batch = 1000

    dataset = Data('dataset/processed/')
    opt = Adam(0.0001)
    # opt = SGD(lr=0.0001, momentum=0.0, decay=0.0, nesterov=False)
    mt = MusicTransformer(embedding_dim=par.embedding_dim,
                          vocab_size=par.vocab_size,
                          num_layer=6,
                          max_seq=100,
                          debug=True)
    mt.compile(optimizer=opt, loss=callback.TransformerLoss())

    for e in range(epoch):
        for b in range(batch):
            batch_x, batch_y = dataset.seq2seq_batch(2, 100)
            result_metrics = mt.train_on_batch(batch_x, batch_y)
            print('Loss: {:6.6}, Accuracy: {:3.2}'.format(
                result_metrics[0], result_metrics[1]))