def __init__(self, hparams_, datapath, batch_size, warmup_steps=4000, ckpt_path="music_transformer_ckpt.pt", load_from_checkpoint=False): """ Args: hparams_: hyperparameters of the model datapath: path to the data to train on batch_size: batch size to batch the data warmup_steps: number of warmup steps for transformer learning rate schedule ckpt_path: path at which to save checkpoints while training; MUST end in .pt or .pth load_from_checkpoint (bool, optional): if true, on instantiating the trainer, this will load a previously saved checkpoint at ckpt_path """ # get the data self.datapath = datapath self.batch_size = batch_size data = torch.load(datapath).long().to(device) # max absolute position must be able to acount for the largest sequence in the data if hparams_["max_abs_position"] > 0: hparams_["max_abs_position"] = max(hparams_["max_abs_position"], data.shape[-1]) # train / validation split: 80 / 20 train_len = round(data.shape[0] * 0.8) train_data = data[:train_len] val_data = data[train_len:] print(f"There are {data.shape[0]} samples in the data, {len(train_data)} training samples and {len(val_data)} " "validation samples") # datasets and dataloaders: split data into first (n-1) and last (n-1) tokens self.train_ds = TensorDataset(train_data[:, :-1], train_data[:, 1:]) self.train_dl = DataLoader(dataset=self.train_ds, batch_size=batch_size, shuffle=True) self.val_ds = TensorDataset(val_data[:, :-1], val_data[:, 1:]) self.val_dl = DataLoader(dataset=self.val_ds, batch_size=batch_size, shuffle=True) # create model self.model = MusicTransformer(**hparams_).to(device) self.hparams = hparams_ # setup training self.warmup_steps = warmup_steps self.optimizer = optim.Adam(self.model.parameters(), lr=1.0, betas=(0.9, 0.98)) self.scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, lambda x: transformer_lr_schedule(self.hparams['d_model'], x, self.warmup_steps) ) # setup checkpointing / saving self.ckpt_path = ckpt_path self.train_losses = [] self.val_losses = [] # load checkpoint if necessesary if load_from_checkpoint and os.path.isfile(self.ckpt_path): self.load()
def main(): parser = argparse.ArgumentParser("Script to train model on a GPU") parser.add_argument("--checkpoint", type=str, default=None, help="Optional path to saved model, if none provided, the model is trained from scratch.") parser.add_argument("--n_epochs", type=int, default=5, help="Number of training epochs.") args = parser.parse_args() sampling_rate = 125 n_velocity_bins = 32 seq_length = 1024 n_tokens = 256 + sampling_rate + n_velocity_bins transformer = MusicTransformer(n_tokens, seq_length, d_model = 64, n_heads = 8, d_feedforward=256, depth = 4, positional_encoding=True, relative_pos=True) if args.checkpoint is not None: state = torch.load(args.checkpoint) transformer.load_state_dict(state) print(f"Successfully loaded checkpoint at {args.checkpoint}") #rule of thumb: 1 minute is roughly 2k tokens pipeline = PreprocessingPipeline(input_dir="data", stretch_factors=[0.975, 1, 1.025], split_size=30, sampling_rate=sampling_rate, n_velocity_bins=n_velocity_bins, transpositions=range(-2,3), training_val_split=0.9, max_encoded_length=seq_length+1, min_encoded_length=257) pipeline_start = time.time() pipeline.run() runtime = time.time() - pipeline_start print(f"MIDI pipeline runtime: {runtime / 60 : .1f}m") today = datetime.date.today().strftime('%m%d%Y') checkpoint = f"saved_models/tf_{today}" training_sequences = pipeline.encoded_sequences['training'] validation_sequences = pipeline.encoded_sequences['validation'] batch_size = 2 train(transformer, training_sequences, validation_sequences, epochs = args.n_epochs, evaluate_per = 1, batch_size = batch_size, batches_per_print=100, padding_index=0, checkpoint_path=checkpoint)
def load_model(filepath): """ Load a MusicTransformer from a saved pytorch state_dict and hparams. The input filepath should point to a .pt file in which has been saved a dictionary containing the model state dict and hparams, ex: torch.save(filepath, { "state_dict": MusicTransformer.state_dict(), "hparams": hparams (dict) }) Args: filepath (str): path to single .pt file containing the dictionary as described above Returns: the loaded MusicTransformer model """ from model import MusicTransformer from hparams import hparams file = torch.load(filepath) if "hparams" not in file: file["hparams"] = hparams model = MusicTransformer(**file["hparams"]).to(device) model.load_state_dict(file["state_dict"]) model.eval() return model
def load(self, ckpt_path=None): """ Loads a checkpoint from ckpt_path NOTE: OVERWRITES THE MODEL STATE DICT, OPTIMIZER STATE DICT, SCHEDULER STATE DICT, AND HISTORY OF LOSSES Args: ckpt_path (str, optional): if None, loads the checkpoint at the previously stored self.ckpt_path else loads the checkpoints from the new passed-in path, and stores this new path at the member variable self.ckpt_path """ if ckpt_path is not None: self.ckpt_path = ckpt_path ckpt = torch.load(self.ckpt_path) del self.model, self.optimizer, self.scheduler # create and load model self.model = MusicTransformer(**ckpt["hparams"]).to(device) self.hparams = ckpt["hparams"] print("Loading the model...", end="") print(self.model.load_state_dict(ckpt["model_state_dict"])) # create and load load optimizer and scheduler self.warmup_steps = ckpt["warmup_steps"] self.optimizer = optim.Adam(self.model.parameters(), lr=1.0, betas=(0.9, 0.98)) self.optimizer.load_state_dict(ckpt["optimizer_state_dict"]) self.scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, lambda x: transformer_lr_schedule(self.hparams['d_model'], x, self.warmup_steps) ) self.scheduler.load_state_dict(ckpt["scheduler_state_dict"]) # load loss histories self.train_losses = ckpt["train_losses"] self.val_losses = ckpt["validation_losses"] return
def main(): input_dir = "../data/test" training_val_split = 0.7 #defines, in Hz, the smallest timestep preserved in quantizing MIDIs #determines number of timeshift events sampling_rate = 125 #determines number of velocity events n_velocity_bins = 32 #set up data pipeline seq_length = 128 padded_length = 128 pipeline = PreprocessingPipeline(input_dir=input_dir, stretch_factors=[0.975, 1, 1.025], split_size=30, sampling_rate=sampling_rate, n_velocity_bins=n_velocity_bins, transpositions=range(-2, 3), training_val_split=training_val_split, max_encoded_length=seq_length + 1, min_encoded_length=33) pipeline.run() training_sequences = pipeline.encoded_sequences['training'][:1000] validation_sequences = pipeline.encoded_sequences['validation'][:100] n_tokens = 256 + sampling_rate + n_velocity_bins batch_size = 10 optim = "adam" transformer = MusicTransformer(n_tokens, seq_length=padded_length, d_model=4, d_feedforward=32, n_heads=4, positional_encoding=True, relative_pos=True) train(transformer, training_sequences, validation_sequences, epochs=2, evaluate_per=1, batch_size=batch_size, batches_per_print=20, padding_index=0, checkpoint_path="../saved_models/test_save", custom_schedule=True) print(sample(transformer, 10))
# load data dataset = Data('dataset/processed') print(dataset) # load model learning_rate = callback.CustomSchedule(par.embedding_dim) opt = Adam(l_r, beta_1=0.9, beta_2=0.98, epsilon=1e-9) strategy = tf.distribute.MirroredStrategy() # define model with strategy.scope(): mt = MusicTransformer(embedding_dim=256, vocab_size=par.vocab_size, num_layer=6, max_seq=max_seq, dropout=0.2, debug=False, loader_path=load_path) mt.compile(optimizer=opt, loss=callback.transformer_dist_train_loss) # Train Start for e in range(epochs): mt.reset_metrics() for b in range(len(dataset.files) // batch_size): try: batch_x, batch_y = dataset.seq2seq_batch(batch_size, max_seq) except: continue result_metrics = mt.train_on_batch(batch_x, batch_y) if b % 100 == 0:
args = parser.parse_args() config.load(args.model_dir, args.configs, initialize=True) # check cuda if torch.cuda.is_available(): config.device = torch.device('cuda') else: config.device = torch.device('cpu') current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') gen_log_dir = 'logs/mt_decoder/generate_' + current_time + '/generate' gen_summary_writer = SummaryWriter(gen_log_dir) mt = MusicTransformer(embedding_dim=config.embedding_dim, vocab_size=config.vocab_size, num_layer=config.num_layers, max_seq=config.max_seq, dropout=0, debug=False) mt.load_state_dict(torch.load(args.model_dir + '/final.pth')) mt.test() print(config.condition_file) if config.condition_file is not None: inputs = np.array([encode_midi('dataset/midi/BENABD10.mid')[:500]]) else: inputs = np.array([[24, 28, 31]]) inputs = torch.from_numpy(inputs) result = mt(inputs, config.length, gen_summary_writer) for i in result: print(i)
else: config.device = torch.device('cpu') device_ids = None # load data dataset = Data(config.pickle_dir) print(dataset) # load model learning_rate = config.l_r # define model mt = MusicTransformer(embedding_dim=config.embedding_dim, vocab_size=config.vocab_size, num_layer=config.num_layers, max_seq=config.max_seq, dropout=config.dropout, debug=config.debug, loader_path=config.load_path) opt = optim.Adam(mt.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9) scheduler = CustomSchedule(config.embedding_dim, optimizer=opt) # multi-GPU set if torch.cuda.device_count() > 1: single_mt = mt mt = DataParallelModel(mt) mt.cuda() else: single_mt = mt # init metric set
def main(): parser = argparse.ArgumentParser("Script to generate MIDI tracks by sampling from a trained model.") parser.add_argument("--model", type=str, help="Key in saved_models/model.yaml, helps look up model arguments and path to saved checkpoint.") parser.add_argument("--sample_length", type=int, default=512, help="number of events to generate") parser.add_argument("--temps", nargs="+", type=float, default=[1.0], help="space-separated list of temperatures to use when sampling") parser.add_argument("--n_trials", type=int, default=3, help="number of MIDI samples to generate per experiment") parser.add_argument("--live_input", action='store_true', default = False, help="if true, take in a seed from a MIDI input controller") parser.add_argument("--play_live", action='store_true', default=False, help="play sample(s) at end of script if true") parser.add_argument("--keep_ghosts", action='store_true', default=True) parser.add_argument("--stuck_note_duration", type=int, default=1) args=parser.parse_args() model = args.model ''' try: model_dict = yaml.safe_load(open('saved_models/model.yaml'))[model_key] except: raise GeneratorError(f"could not find yaml information for key {model_key}") ''' #model_path = model_dict["path"] #model_args = model_dict["args"] #Change the value here to the model you want to run model_path = 'saved_models/'+model try: state = torch.load(model_path) except RuntimeError: state = torch.load(model_path, map_location="cpu") n_velocity_events = 32 n_time_shift_events = 125 decoder = SequenceEncoder(n_time_shift_events, n_velocity_events, min_events=0) if args.live_input: pretty_midis = [] m = 'twinkle.midi' with open(m, "rb") as f: try: midi_str = six.BytesIO(f.read()) pretty_midis.append(pretty_midi.PrettyMIDI(midi_str)) #print("Successfully parsed {}".format(m)) except: print("Could not parse {}".format(m)) pipeline = PreprocessingPipeline(input_dir="data") note_sequence = pipeline.get_note_sequences(pretty_midis) note_sequence = [vectorize(ns) for ns in note_sequence] prime_sequence = decoder.encode_sequences(note_sequence) prime_sequence = prime_sequence[1:6] else: prime_sequence = [] #model = MusicTransformer(**model_args) model = MusicTransformer(256+125+32, 1024, d_model = 64, n_heads = 8, d_feedforward=256, depth = 4, positional_encoding=True, relative_pos=True) model.load_state_dict(state, strict=False) temps = args.temps trial_key = str(uuid.uuid4())[:6] n_trials = args.n_trials keep_ghosts = args.keep_ghosts stuck_note_duration = None if args.stuck_note_duration == 0 else args.stuck_note_duration for temp in temps: print(f"sampling temp={temp}") note_sequence = [] for i in range(n_trials): print("generating sequence") output_sequence = sample(model, prime_sequence = prime_sequence, sample_length=args.sample_length, temperature=temp) note_sequence = decoder.decode_sequence(output_sequence, verbose=True, stuck_note_duration=0.5, keep_ghosts=True) output_dir = f"output/midis/{trial_key}/" file_name = f"sample{i+1}_{temp}" write_midi(note_sequence, output_dir, file_name) '''
def main(): parser = argparse.ArgumentParser( "Script to generate MIDI tracks by sampling from a trained model.") parser.add_argument( "--model_key", type=str, help= "Key in saved_models/model.yaml, helps look up model arguments and path to saved checkpoint." ) parser.add_argument("--sample_length", type=int, default=512, help="number of events to generate") parser.add_argument( "--temps", nargs="+", type=float, default=[1.0], help="space-separated list of temperatures to use when sampling") parser.add_argument( "--n_trials", type=int, default=3, help="number of MIDI samples to generate per experiment") parser.add_argument("--primer", type=str, default=None, help="Path to the primer") parser.add_argument("--play_live", action='store_true', default=False, help="play sample(s) at end of script if true") parser.add_argument("--keep_ghosts", action='store_true', default=False) parser.add_argument("--stuck_note_duration", type=int, default=0) args = parser.parse_args() model_key = args.model_key try: model_dict = yaml.safe_load(open('saved_models/model.yaml'))[model_key] except: raise GeneratorError( f"could not find yaml information for key {model_key}") model_path = model_dict["path"] model_args = model_dict["args"] try: state = torch.load(model_path) except RuntimeError: state = torch.load(model_path, map_location="cpu") n_velocity_events = 32 n_time_shift_events = 125 decoder = SequenceEncoder(n_time_shift_events, n_velocity_events, min_events=0) if args.primer: # Read midi primer midi_str = six.BytesIO(open(args.primer, 'rb').read()) p = pretty_midi.PrettyMIDI(midi_str) piano_data = p.instruments[0] notes = apply_sustain(piano_data) note_sequence = sorted(notes, key=lambda x: (x.start, x.pitch)) ns = vectorize(note_sequence) prime_sequence = decoder.encode_sequences([ns])[0] else: prime_sequence = [] model = MusicTransformer(**model_args) model.load_state_dict(state, strict=False) temps = args.temps trial_key = str(uuid.uuid4())[:6] n_trials = args.n_trials keep_ghosts = args.keep_ghosts stuck_note_duration = None if args.stuck_note_duration == 0 else args.stuck_note_duration for temp in temps: print(f"sampling temp={temp}") note_sequence = [] for i in range(n_trials): print("generating sequence") output_sequence = sample(model, prime_sequence=prime_sequence, sample_length=args.sample_length, temperature=temp) note_sequence = decoder.decode_sequence(output_sequence, verbose=True, stuck_note_duration=None) output_dir = f"output/{model_key}/{trial_key}/" file_name = f"sample{i+1}_{temp}" write_midi(note_sequence, output_dir, file_name)
mode = args.mode beam = args.beam length = args.length save_path = args.save_path current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') gen_log_dir = 'logs/mt_decoder/generate_' + current_time + '/generate' gen_summary_writer = tf.summary.create_file_writer(gen_log_dir) if mode == 'enc-dec': print(">> generate with original seq2seq wise... beam size is {}".format( beam)) mt = MusicTransformer(embedding_dim=256, vocab_size=par.vocab_size, num_layer=6, max_seq=2048, dropout=0.2, debug=False, loader_path=load_path) else: print(">> generate with decoder wise... beam size is {}".format(beam)) mt = MusicTransformerDecoder(loader_path=load_path) inputs = encode_midi('dataset/midi/BENABD10.mid') with gen_summary_writer.as_default(): result = mt.generate(inputs[:10], beam=beam, length=length, tf_board=True) for i in result: print(i)
gen_log_dir = 'logs/mt_decoder/generate_'+current_time+'/generate' gen_summary_writer = SummaryWriter(gen_log_dir) # import param # mt = MusicTransformer( # embedding_dim=param.embedding_dim, # vocab_size=param.vocab_size, # num_layer=param.num_attention_layer, # max_seq=param.max_seq, # dropout=0, # debug=False) mt = MusicTransformer( embedding_dim=config.embedding_dim, vocab_size=config.vocab_size, num_layer=config.num_layers, max_seq=config.max_seq, dropout=0, debug=False) # mt.load_state_dict(torch.load(args.model_dir+'/final.pth')) mt.test() # def model_size_summary(model): # param_num = 0 # for param in model.parameters(): # param_num += np.prod(list(param.shape)) # print(param.shape, "num %d" % np.prod(list(param.shape))) # print(param_num, " in total, %.2fmb"%(param_num * 4 / 1024**2)) # # model_size_summary(mt) mt.cuda()
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') gen_log_dir = 'logs/mt_decoder/generate_' + current_time + '/generate' gen_summary_writer = SummaryWriter(gen_log_dir) # mt = MusicTransformer( # embedding_dim=config.embedding_dim, # vocab_size=config.vocab_size, # num_layer=config.num_layers, # max_seq=config.max_seq, # dropout=0, # debug=False) transformer = MusicTransformer(n_tokens, seq_length, d_model=64, n_heads=8, d_feedforward=256, depth=4, positional_encoding=True, relative_pos=True) # mt.load_state_dict(torch.load(args.model_dir+'/final.pth')) # mt.test() if args.checkpoint is not None: state = torch.load(args.checkpoint) transformer.load_state_dict(state) print(f"Successfully loaded checkpoint at {args.checkpoint}") else: print(f"NOT FOUND checkpoint") if config.condition_file is not None:
def main(): parser = argparse.ArgumentParser( "Script to generate MIDI tracks by sampling from a trained model.") parser.add_argument( "--model_key", type=str, help= "Key in saved_models/model.yaml, helps look up model arguments and path to saved checkpoint." ) parser.add_argument("--sample_length", type=int, default=512, help="number of events to generate") parser.add_argument( "--temps", nargs="+", type=float, default=[1.0], help="space-separated list of temperatures to use when sampling") parser.add_argument( "--n_trials", type=int, default=3, help="number of MIDI samples to generate per experiment") parser.add_argument( "--live_input", action='store_true', default=False, help="if true, take in a seed from a MIDI input controller") parser.add_argument("--play_live", action='store_true', default=False, help="play sample(s) at end of script if true") parser.add_argument("--keep_ghosts", action='store_true', default=False) parser.add_argument("--stuck_note_duration", type=int, default=0) args = parser.parse_args() model_key = args.model_key try: model_dict = yaml.safe_load(open('saved_models/model.yaml'))[model_key] except: raise GeneratorError( f"could not find yaml information for key {model_key}") model_path = model_dict["path"] model_args = model_dict["args"] try: state = torch.load(model_path) except RuntimeError: state = torch.load(model_path, map_location="cpu") n_velocity_events = 32 n_time_shift_events = 125 decoder = SequenceEncoder(n_time_shift_events, n_velocity_events, min_events=0) if args.live_input: print("Expecting a midi input...") note_sequence = midi_input.read(n_velocity_events, n_time_shift_events) prime_sequence = decoder.encode_sequences([note_sequence])[0] else: prime_sequence = [] model = MusicTransformer(**model_args) model.load_state_dict(state, strict=False) temps = args.temps trial_key = str(uuid.uuid4())[:6] n_trials = args.n_trials keep_ghosts = args.keep_ghosts stuck_note_duration = None if args.stuck_note_duration == 0 else args.stuck_note_duration for temp in temps: print(f"sampling temp={temp}") note_sequence = [] for i in range(n_trials): print("generating sequence") output_sequence = sample(model, prime_sequence=prime_sequence, sample_length=args.sample_length, temperature=temp) note_sequence = decoder.decode_sequence(output_sequence, verbose=True, stuck_note_duration=None) output_dir = f"output/{model_key}/{trial_key}/" file_name = f"sample{i+1}_{temp}" write_midi(note_sequence, output_dir, file_name) for temp in temps: try: subprocess.run([ 'timidity', f"output/{model_key}/{trial_key}/sample{i+1}_{temp}.midi" ]) except KeyboardInterrupt: continue
config.device = torch.device('cuda') else: config.device = torch.device('cpu') # load data dataset = Data(config.pickle_dir) print(dataset) # load model learning_rate = config.l_r # define model mt = MusicTransformer(embedding_dim=config.embedding_dim, vocab_size=config.vocab_size, num_layer=config.num_layers, max_seq=config.max_seq, dropout=config.dropout, debug=config.debug, loader_path=config.load_path) mt.to(config.device) opt = optim.Adam(mt.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9) scheduler = CustomSchedule(config.embedding_dim, optimizer=opt) # multi-GPU set if torch.cuda.device_count() > 1: single_mt = mt mt = torch.nn.DataParallel(mt, output_device=torch.cuda.device_count() - 1) else: single_mt = mt # init metric set
class MusicTransformerTrainer: """ As the transformer is a large model and takes a while to train on a GPU, or even a TPU, I wrote this Trainer class to make it easier to load and save checkpoints with the model. The way I've designed it instantiates the model, optimizer, and scheduler within the class itself, as there are some problems with passing them in. But, to get these objects back just call: trainer.model trainer.optimizer trainer.scheduler This class also tracks the cumulative losses while training, which you can get back with: trainer.train_losses trainer.val_losses as lists of floats To save a checkpoint, call trainer.save() To load a checkpoint, call trainer.load( (optional) ckpt_path) """ def __init__(self, hparams_, datapath, batch_size, warmup_steps=4000, ckpt_path="music_transformer_ckpt.pt", load_from_checkpoint=False): """ Args: hparams_: hyperparameters of the model datapath: path to the data to train on batch_size: batch size to batch the data warmup_steps: number of warmup steps for transformer learning rate schedule ckpt_path: path at which to save checkpoints while training; MUST end in .pt or .pth load_from_checkpoint (bool, optional): if true, on instantiating the trainer, this will load a previously saved checkpoint at ckpt_path """ # get the data self.datapath = datapath self.batch_size = batch_size data = torch.load(datapath).long().to(device) # max absolute position must be able to acount for the largest sequence in the data if hparams_["max_abs_position"] > 0: hparams_["max_abs_position"] = max(hparams_["max_abs_position"], data.shape[-1]) # train / validation split: 80 / 20 train_len = round(data.shape[0] * 0.8) train_data = data[:train_len] val_data = data[train_len:] print(f"There are {data.shape[0]} samples in the data, {len(train_data)} training samples and {len(val_data)} " "validation samples") # datasets and dataloaders: split data into first (n-1) and last (n-1) tokens self.train_ds = TensorDataset(train_data[:, :-1], train_data[:, 1:]) self.train_dl = DataLoader(dataset=self.train_ds, batch_size=batch_size, shuffle=True) self.val_ds = TensorDataset(val_data[:, :-1], val_data[:, 1:]) self.val_dl = DataLoader(dataset=self.val_ds, batch_size=batch_size, shuffle=True) # create model self.model = MusicTransformer(**hparams_).to(device) self.hparams = hparams_ # setup training self.warmup_steps = warmup_steps self.optimizer = optim.Adam(self.model.parameters(), lr=1.0, betas=(0.9, 0.98)) self.scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, lambda x: transformer_lr_schedule(self.hparams['d_model'], x, self.warmup_steps) ) # setup checkpointing / saving self.ckpt_path = ckpt_path self.train_losses = [] self.val_losses = [] # load checkpoint if necessesary if load_from_checkpoint and os.path.isfile(self.ckpt_path): self.load() def save(self, ckpt_path=None): """ Saves a checkpoint at ckpt_path Args: ckpt_path (str, optional): if None, saves the checkpoint at the previously stored self.ckpt_path else saves the checkpoints at the new passed-in path, and stores this new path at the member variable self.ckpt_path """ if ckpt_path is not None: self.ckpt_path = ckpt_path ckpt = { "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "scheduler_state_dict": self.scheduler.state_dict(), "train_losses": self.train_losses, "validation_losses": self.val_losses, "warmup_steps": self.warmup_steps, "hparams": self.hparams } torch.save(ckpt, self.ckpt_path) return def load(self, ckpt_path=None): """ Loads a checkpoint from ckpt_path NOTE: OVERWRITES THE MODEL STATE DICT, OPTIMIZER STATE DICT, SCHEDULER STATE DICT, AND HISTORY OF LOSSES Args: ckpt_path (str, optional): if None, loads the checkpoint at the previously stored self.ckpt_path else loads the checkpoints from the new passed-in path, and stores this new path at the member variable self.ckpt_path """ if ckpt_path is not None: self.ckpt_path = ckpt_path ckpt = torch.load(self.ckpt_path) del self.model, self.optimizer, self.scheduler # create and load model self.model = MusicTransformer(**ckpt["hparams"]).to(device) self.hparams = ckpt["hparams"] print("Loading the model...", end="") print(self.model.load_state_dict(ckpt["model_state_dict"])) # create and load load optimizer and scheduler self.warmup_steps = ckpt["warmup_steps"] self.optimizer = optim.Adam(self.model.parameters(), lr=1.0, betas=(0.9, 0.98)) self.optimizer.load_state_dict(ckpt["optimizer_state_dict"]) self.scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, lambda x: transformer_lr_schedule(self.hparams['d_model'], x, self.warmup_steps) ) self.scheduler.load_state_dict(ckpt["scheduler_state_dict"]) # load loss histories self.train_losses = ckpt["train_losses"] self.val_losses = ckpt["validation_losses"] return def fit(self, epochs): """ Training loop to fit the model to the data stored at the passed in datapath. If KeyboardInterrupt at anytime during the training loop, and if progresss being printed, this method will save a checkpoint at the passed-in ckpt_path Args: epochs: number of epochs to train for. Returns: history of training and validation losses for this training session """ print_interval = epochs // 10 + int(epochs < 10) train_losses = [] val_losses = [] start = time.time() print("Beginning training...") try: for epoch in range(epochs): train_epoch_losses = [] val_epoch_losses = [] self.model.train() for train_inp, train_tar in self.train_dl: loss = train_step(self.model, self.optimizer, self.scheduler, train_inp, train_tar) train_epoch_losses.append(loss) self.model.eval() for val_inp, val_tar in self.val_dl: loss = val_step(self.model, val_inp, val_tar) val_epoch_losses.append(loss) # mean losses for the epoch train_mean = sum(train_epoch_losses) / len(train_epoch_losses) val_mean = sum(val_epoch_losses) / len(val_epoch_losses) # store complete history of losses in member lists and relative history for this session in output lists self.train_losses.append(train_mean) train_losses.append(train_mean) self.val_losses.append(val_mean) val_losses.append(val_mean) if ((epoch + 1) % print_interval) == 0: print(f"Epoch {epoch + 1} Time taken {round(time.time() - start, 2)} seconds " f"Train Loss {train_losses[-1]} Val Loss {val_losses[-1]}") # print("Checkpointing...") # self.save() # print("Done") start = time.time() except KeyboardInterrupt: pass print("Checkpointing...") self.save() print("Done") return train_losses, val_losses
import params as par from tensorflow.python import enable_eager_execution from tensorflow.python.keras.optimizer_v2.adam import Adam from tensorflow.python.keras.optimizer_v2.gradient_descent import SGD from data import Data import utils enable_eager_execution() tf.executing_eagerly() if __name__ == '__main__': epoch = 100 batch = 1000 dataset = Data('dataset/processed/') opt = Adam(0.0001) # opt = SGD(lr=0.0001, momentum=0.0, decay=0.0, nesterov=False) mt = MusicTransformer(embedding_dim=par.embedding_dim, vocab_size=par.vocab_size, num_layer=6, max_seq=100, debug=True) mt.compile(optimizer=opt, loss=callback.TransformerLoss()) for e in range(epoch): for b in range(batch): batch_x, batch_y = dataset.seq2seq_batch(2, 100) result_metrics = mt.train_on_batch(batch_x, batch_y) print('Loss: {:6.6}, Accuracy: {:3.2}'.format( result_metrics[0], result_metrics[1]))