def main(args): ## ドレミ def easyTones(): training_seq_lengths = torch.tensor([8]*1) training_data_sequences = torch.zeros(1,8,88) for i in range(1): training_data_sequences[i][0][int(70-i*10) ] = 1 training_data_sequences[i][1][int(70-i*10)+2] = 1 training_data_sequences[i][2][int(70-i*10)+4] = 1 training_data_sequences[i][3][int(70-i*10)+5] = 1 training_data_sequences[i][4][int(70-i*10)+7] = 1 training_data_sequences[i][5][int(70-i*10)+9] = 1 training_data_sequences[i][6][int(70-i*10)+11] = 1 training_data_sequences[i][7][int(70-i*10)+12] = 1 return training_seq_lengths, training_data_sequences def superEasyTones(): training_seq_lengths = torch.tensor([8]*10) training_data_sequences = torch.zeros(10,8,88) for i in range(10): for j in range(8): training_data_sequences[i][j][int(30+i*5)] = 1 return training_seq_lengths, training_data_sequences ## ドドド、ドドド、ドドド def easiestTones(): training_seq_lengths = torch.tensor([8]*10) training_data_sequences = torch.zeros(10,8,88) for i in range(10): for j in range(8): training_data_sequences[i][j][int(70)] = 1 return training_seq_lengths, training_data_sequences # setup logging logging.basicConfig(level=logging.DEBUG, format='%(message)s', filename=args.log, filemode='w') console = logging.StreamHandler() console.setLevel(logging.INFO) logging.getLogger('').addHandler(console) logging.info(args) data = poly.load_data(poly.JSB_CHORALES) training_seq_lengths = data['train']['sequence_lengths'] training_data_sequences = data['train']['sequences'] training_seq_lengths, training_data_sequences = easiestTones() test_seq_lengths = data['test']['sequence_lengths'] test_data_sequences = data['test']['sequences'] test_seq_lengths, test_data_sequences = easiestTones() val_seq_lengths = data['valid']['sequence_lengths'] val_data_sequences = data['valid']['sequences'] val_seq_lengths, val_data_sequences = easiestTones() N_train_data = len(training_seq_lengths) N_train_time_slices = float(torch.sum(training_seq_lengths)) N_mini_batches = int(N_train_data / args.mini_batch_size + int(N_train_data % args.mini_batch_size > 0)) logging.info("N_train_data: %d avg. training seq. length: %.2f N_mini_batches: %d" % (N_train_data, training_seq_lengths.float().mean(), N_mini_batches)) # how often we do validation/test evaluation during training val_test_frequency = 50 # the number of samples we use to do the evaluation n_eval_samples = 1 # package repeated copies of val/test data for faster evaluation # (i.e. set us up for vectorization) def rep(x): rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:] repeat_dims = [1] * len(x.size()) repeat_dims[0] = n_eval_samples return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(1, 0).reshape(rep_shape) # get the validation/test data ready for the dmm: pack into sequences, etc. val_seq_lengths = rep(val_seq_lengths) test_seq_lengths = rep(test_seq_lengths) val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch( torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences), val_seq_lengths, cuda=args.cuda) test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch( torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences), test_seq_lengths, cuda=args.cuda) # instantiate the dmm dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate, num_iafs=args.num_iafs, iaf_dim=args.iaf_dim, use_cuda=args.cuda) # setup optimizer adam_params = {"lr": args.learning_rate, "betas": (args.beta1, args.beta2), "clip_norm": args.clip_norm, "lrd": args.lr_decay, "weight_decay": args.weight_decay} adam = ClippedAdam(adam_params) # setup inference algorithm if args.tmc: if args.jit: raise NotImplementedError("no JIT support yet for TMC") tmc_loss = TraceTMC_ELBO() dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False) svi = SVI(dmm.model, dmm_guide, adam, loss=tmc_loss) elif args.tmcelbo: if args.jit: raise NotImplementedError("no JIT support yet for TMC ELBO") elbo = TraceEnum_ELBO() dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False) svi = SVI(dmm.model, dmm_guide, adam, loss=elbo) else: elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() svi = SVI(dmm.model, dmm.guide, adam, loss=elbo) # now we're going to define some functions we need to form the main training loop # saves the model and optimizer states to disk def save_checkpoint(): logging.info("saving model to %s..." % args.save_model) torch.save(dmm.state_dict(), args.save_model) # logging.info("saving optimizer states to %s..." % args.save_opt) # adam.save(args.save_opt) logging.info("done saving model and optimizer checkpoints to disk.") # loads the model and optimizer states from disk def load_checkpoint(): assert exists(args.load_opt) and exists(args.load_model), \ "--load-model and/or --load-opt misspecified" logging.info("loading model from %s..." % args.load_model) dmm.load_state_dict(torch.load(args.load_model)) logging.info("loading optimizer states from %s..." % args.load_opt) adam.load(args.load_opt) logging.info("done loading model and optimizer states.") # prepare a mini-batch and take a gradient step to minimize -elbo def process_minibatch(epoch, which_mini_batch, shuffled_indices): if args.annealing_epochs > 0 and epoch < args.annealing_epochs: # compute the KL annealing factor approriate for the current mini-batch in the current epoch min_af = args.minimum_annealing_factor annealing_factor = min_af + (1.0 - min_af) * \ (float(which_mini_batch + epoch * N_mini_batches + 1) / float(args.annealing_epochs * N_mini_batches)) else: # by default the KL annealing factor is unity annealing_factor = 1.0 # compute which sequences in the training set we should grab mini_batch_start = (which_mini_batch * args.mini_batch_size) mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data]) mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end] # grab a fully prepped mini-batch using the helper function in the data loader mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \ = poly.get_mini_batch(mini_batch_indices, training_data_sequences, training_seq_lengths, cuda=args.cuda) # do an actual gradient step loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor) # keep track of the training loss return loss # helper function for doing evaluation def do_evaluation(): # put the RNN into evaluation mode (i.e. turn off drop-out if applicable) dmm.rnn.eval() # compute the validation and test loss n_samples many times val_nll = svi.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths) / float(torch.sum(val_seq_lengths)) test_nll = svi.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths) / float(torch.sum(test_seq_lengths)) # put the RNN back into training mode (i.e. turn on drop-out if applicable) dmm.rnn.train() return val_nll, test_nll # if checkpoint files provided, load model and optimizer states from disk before we start training if args.load_opt != '' and args.load_model != '': load_checkpoint() ################# # TRAINING LOOP # ################# times = [time.time()] for epoch in range(args.num_epochs): # if specified, save model and optimizer states to disk every checkpoint_freq epochs if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0: save_checkpoint() # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch epoch_nll = 0.0 # prepare mini-batch subsampling indices for this epoch shuffled_indices = torch.randperm(N_train_data) # process each mini-batch; this is where we take gradient steps for which_mini_batch in range(N_mini_batches): epoch_nll += process_minibatch(epoch, which_mini_batch, shuffled_indices) # report training diagnostics times.append(time.time()) epoch_time = times[-1] - times[-2] logging.info("[training epoch %04d] %.4f \t\t\t\t(dt = %.3f sec)" % (epoch, epoch_nll / N_train_time_slices, epoch_time)) # do evaluation on test and validation data and report results if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0: val_nll, test_nll = do_evaluation() logging.info("[val/test epoch %04d] %.4f %.4f" % (epoch, val_nll, test_nll))
def main(args): ## This func is for saving losses at final def saveGraph(loss_list, sub_error_list, now): FS = 10 fig = plt.figure() # plt.rcParams["font.size"] = FS plt.plot(loss_list, label="LOSS") plt.plot(sub_error_list, label="Reconstruction Error") plt.ylim(bottom=0) plt.title("Loss") plt.xlabel("epoch", fontsize=FS) plt.ylabel("loss", fontsize=FS) plt.legend() fig.savefig(os.path.join("saveData", now, "LOSS.png")) plt.close() def saveReconSinGraph(train_data, recon_data, length, path, number): FS = 10 fig = plt.figure() ax = fig.add_subplot() # plt.plot(train_data[:,0], train_data[:,1], label="Training data") plt.plot(train_data[:, 0], train_data[:, 1], label="Training data") # plt.plot(recon_data.detach().numpy(), label="Reconstructed data") plt.plot(recon_data.detach().numpy()[:, 0], recon_data.detach().numpy()[:, 1], label="Reconstructed data") # plt.plot(recon_data.detach().numpy(), label="Reconstructed data") plt.title("Reconstruction") # plt.ylim(top=10, bottom=-10) # plt.xlabel("time", fontsize=FS) # plt.ylabel("y", fontsize=FS) plt.legend() fig.savefig(os.path.join(path, "Reconstruction" + str(number) + ".png")) plt.close() def saveGeneSinGraph(gene_data, length, path, number): FS = 10 fig = plt.figure() ax = fig.add_subplot() # plt.rcParams["font.size"] = FS plt.plot(gene_data.detach().numpy()[:, 0], gene_data.detach().numpy()[:, 1], label="Generated data") # plt.plot(gene_data.detach().numpy(), label="Generated data") # plt.plot(gene_data.detach().numpy(), label="Generated data") plt.title("Genetration") # plt.ylim(top=10, bottom=-10) # plt.xlabel("time", fontsize=FS) # plt.ylabel("y", fontsize=FS) plt.legend() fig.savefig(os.path.join(path, "Generation" + str(number) + ".png")) plt.close() FS = 10 plt.rcParams["font.size"] = FS ## 長さ最長129、例えば長さが60のやつは61~129はすべて0データ data = poly.load_data(poly.JSB_CHORALES) training_seq_lengths = data['train']['sequence_lengths'][:args.N_songs] training_data_sequences = data['train']['sequences'][:args.N_songs, :8] training_seq_lengths = torch.tensor([8] * args.N_songs) traing_data = [] for i in range(args.N_songs): traing_data.append( musics.VanDelPol(torch.randn(2), args.length, args.T)) # traing_data.append(musics.Nonlinear(torch.randn(1), args.length, args.T)) training_data_sequences = torch.stack(traing_data) training_seq_lengths = torch.tensor([args.length] * args.N_songs) data_dim = training_data_sequences.size(-1) N_train_data = len(training_seq_lengths) N_train_time_slices = float(torch.sum(training_seq_lengths)) N_mini_batches = int(N_train_data / args.mini_batch_size + int(N_train_data % args.mini_batch_size > 0)) z_dim = args.z_dim #100 #2 rnn_dim = args.rnn_dim #200 #30 transition_dim = args.transition_dim #200 #30 emission_dim = args.emission_dim #100 #30 encoder = Encoder(input_dim=training_data_sequences.size(-1), z_dim=z_dim, rnn_dim=rnn_dim, N_z0=args.N_songs) # encoder = Encoder(z_dim=100, rnn_dim=200) prior = Prior(z_dim=z_dim, transition_dim=transition_dim, N_z0=args.N_songs, use_cuda=args.cuda) decoder = Emitter(input_dim=training_data_sequences.size(-1), z_dim=z_dim, emission_dim=emission_dim, use_cuda=args.cuda) # Create optimizer algorithm # optimizer = optim.SGD(dmm.parameters(), lr=args.learning_rate) # optimizer = optim.Adam(dmm.parameters(), lr=args.learning_rate, betas=(0.96, 0.999), weight_decay=2.0) params = list(encoder.parameters()) + list(prior.parameters()) + list( decoder.parameters()) optimizer = optim.SGD(params, lr=args.learning_rate) # optimizer = optim.Adam(params, lr=args.learning_rate) # Add learning rate scheduler scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999) # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 1.) #1でもやってみる # make directory for Data now = datetime.datetime.now().strftime('%Y%m%d_%H_%M') os.makedirs(os.path.join("saveData", now), exist_ok=True) # save args as TEXT f = open(os.path.join("saveData", now, 'args.txt'), 'w') # 書き込みモードで開く for name, site in vars(args).items(): f.write(name + " = ") f.write(str(site) + "\n") f.close() # ファイルを閉じる ####################### #### TRAINING LOOP #### ####################### times = [time.time()] losses = [] recon_errors = [] pbar = tqdm.tqdm(range(args.num_epochs)) for epoch in pbar: epoch_nll = 0 epoch_recon_error = 0 shuffled_indices = torch.randperm(N_train_data) # print("Proceeding: %.2f " % (epoch*100/5000) + "%") # process each mini-batch; this is where we take gradient steps for which_mini_batch in range(N_mini_batches): # compute which sequences in the training set we should grab mini_batch_start = (which_mini_batch * args.mini_batch_size) mini_batch_end = np.min([ (which_mini_batch + 1) * args.mini_batch_size, N_train_data ]) mini_batch_indices = shuffled_indices[ mini_batch_start:mini_batch_end] # grab a fully prepped mini-batch using the helper function in the data loader mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \ = poly.get_mini_batch(mini_batch_indices, training_data_sequences, training_seq_lengths, cuda=args.cuda) # reset gradients optimizer.zero_grad() loss = 0 N_repeat = 1 for i in range(N_repeat): seiki = 1 / N_repeat # generate mini batch from training mini batch pos_z, pos_z_loc, pos_z_scale = encoder( mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths) pri_z, pri_z_loc, pri_z_scale = prior( length=mini_batch.size(1), N_generate=mini_batch.size(0)) # Regularizer (KL-diveergence(pos||pri)) regularizer = seiki * D_Wass(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale) # regularizer = seiki * D_KL(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale) # regularizer = seiki * D_JS_Monte(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale, pos_z, pri_z) reconstruction_error = seiki * torch.norm( mini_batch - decoder(pos_z), dim=2).sum() / mini_batch.size(0) / mini_batch.size( 1) / mini_batch.size(2) loss += args.lam * regularizer + reconstruction_error reconed_x = decoder(pos_z) loss.backward() optimizer.step() scheduler.step() # if args.rnn_clip != None: # for p in dmm.rnn.parameters(): # p.data.clamp_(-args.rnn_clip, args.rnn_clip) epoch_nll += loss epoch_recon_error += reconstruction_error # report training diagnostics times.append(time.time()) losses.append(epoch_nll) recon_errors.append(epoch_recon_error) # epoch_time = times[-1] - times[-2] pbar.set_description("LOSS = %f " % epoch_nll) if epoch % args.checkpoint_freq == 0: path = os.path.join("saveData", now, "Epoch%d" % epoch) os.makedirs(path, exist_ok=True) for i in range(len(mini_batch)): saveReconSinGraph(mini_batch[i], reconed_x[i], args.length, path, i) saveGeneSinGraph(decoder(pri_z)[i], args.length, path, i) if epoch == args.num_epochs - 1: path = os.path.join("saveData", now, "Epoch%d" % args.num_epochs) os.makedirs(path, exist_ok=True) for i in range(len(mini_batch)): saveReconSinGraph(mini_batch[i], reconed_x[i], args.length, path, i) saveGeneSinGraph(decoder(pri_z)[i], args.length, path, i) saveGraph(losses, recon_errors, now) saveDic = { "optimizer": optimizer, "mini_batch": mini_batch, "Encoder_dic": encoder.state_dict, "Prior_dic": prior.state_dict, "Emitter_dic": decoder.state_dict, "epoch_times": times, "losses": losses } # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1))) torch.save(saveDic, os.path.join("saveData", now, "DMM_dic"))
def main(args): ## This func is for saving losses at final def saveGraph(loss_list, sub_error_list, now): FS = 10 fig = plt.figure() # plt.rcParams["font.size"] = FS plt.plot(loss_list, label="LOSS") plt.plot(sub_error_list, label="Reconstruction Error") plt.ylim(bottom=0) plt.title("Loss") plt.xlabel("epoch", fontsize=FS) plt.ylabel("loss", fontsize=FS) plt.legend() fig.savefig(os.path.join("saveData", now, "LOSS.png")) plt.close() def saveReconSinGraph(train_data, recon_data, length, path, number): FS = 10 fig = plt.figure() # plt.rcParams["font.size"] = FS x = np.linspace(0, 2 * np.pi, length) plt.plot(x, train_data, label="Training data") plt.plot(x, recon_data.detach().numpy(), label="Reconstructed data") # plt.ylim(bottom=0) plt.title("Sin Curves") # plt.ylim(top=10, bottom=-10) plt.xlabel("time", fontsize=FS) plt.ylabel("y", fontsize=FS) plt.legend() fig.savefig(os.path.join(path, "Reconstruction" + str(number) + ".png")) plt.close() def saveGeneSinGraph(gene_data, length, path, number): FS = 10 fig = plt.figure() # plt.rcParams["font.size"] = FS x = np.linspace(0, 2 * np.pi, length) plt.plot(x, gene_data.detach().numpy(), label="Generated data") # plt.ylim(bottom=0) plt.title("Sin Curves") # plt.ylim(top=10, bottom=-10) plt.xlabel("time", fontsize=FS) plt.ylabel("y", fontsize=FS) plt.legend() fig.savefig(os.path.join(path, "Generation" + str(number) + ".png")) plt.close() ## This func is for save generatedTones and trainingTones as MIDI def save_as_midi(song, path="", name="default.mid", BPM=120, velocity=100): pm = pretty_midi.PrettyMIDI(resolution=960, initial_tempo=BPM) #pretty_midiオブジェクトを作ります instrument = pretty_midi.Instrument(0) #instrumentはトラックみたいなものです。 for i, tones in enumerate(song): which_tone = torch.nonzero((tones == 1), as_tuple=False).reshape(-1) if len(which_tone) == 0: note = pretty_midi.Note( velocity=0, pitch=0, start=i, end=i + 1) #noteはNoteOnEventとNoteOffEventに相当します。 instrument.notes.append(note) else: for which in which_tone: note = pretty_midi.Note( velocity=velocity, pitch=int(which), start=i, end=i + 1) #noteはNoteOnEventとNoteOffEventに相当します。 instrument.notes.append(note) pm.instruments.append(instrument) pm.write(os.path.join(path, name)) #midiファイルを書き込みます。 # generate Xs def generate_Xs(batch_data): songs_list = [] for i, song in enumerate(batch_data): tones_container = [] for time in range(batch_data.size(1)): p = dist.Bernoulli(probs=song[time]) tone = p.sample() tones_container.append(tone) tones_container = torch.stack(tones_container) songs_list.append(tones_container) return songs_list def saveSongs(songs_list, mini_batch, N_songs, path): # print(len(songs_list[0][0])) if len(songs_list) != len(mini_batch): assert False if N_songs <= len(songs_list): song_No = random.sample(range(len(songs_list)), k=N_songs) else: song_No = random.sample(range(len(songs_list)), k=len(songs_list)) for i, Number in enumerate(song_No): save_as_midi(song=songs_list[Number], path=path, name="No%d_Gene.midi" % i) save_as_midi(song=mini_batch[Number], path=path, name="No%d_Tran.midi" % i) FS = 10 plt.rcParams["font.size"] = FS ## 長さ最長129、例えば長さが60のやつは61~129はすべて0データ data = poly.load_data(poly.JSB_CHORALES) training_seq_lengths = data['train']['sequence_lengths'][:args.N_songs] training_data_sequences = data['train']['sequences'][:args.N_songs, :8] training_seq_lengths = torch.tensor([8] * args.N_songs) if args.doremifasorasido: # ## ドドド、レレレ、ミミミ、ドレミ training_seq_lengths, training_data_sequences = musics.doremifasorasido( args.N_songs) if args.dododorerere: # ドドド、レレレ training_seq_lengths, training_data_sequences = musics.dododorerere( args.N_songs) if args.dodododododo: ## ドドドのみ training_seq_lengths, training_data_sequences = musics.dodododododo( args.N_songs) if args.sin: training_data_sequences = musics.createSin_allChanged( args.N_songs, args.length) training_seq_lengths = torch.tensor([args.length] * args.N_songs) training_data_sequences = musics.createNewTrainingData( args.N_songs, args.length) # traing_data =[] # for i in range(args.N_songs): # traing_data.append(musics.Nonlinear(torch.randn(1), args.length, T = args.T)) # training_data_sequences = torch.stack(traing_data) training_seq_lengths = torch.tensor([args.length] * args.N_songs) data_dim = training_data_sequences.size(-1) N_train_data = len(training_seq_lengths) N_train_time_slices = float(torch.sum(training_seq_lengths)) N_mini_batches = int(N_train_data / args.mini_batch_size + int(N_train_data % args.mini_batch_size > 0)) z_dim = args.z_dim #100 #2 rnn_dim = args.rnn_dim #200 #30 transition_dim = args.transition_dim #200 #30 emission_dim = args.emission_dim #100 #30 encoder = Encoder(input_dim=training_data_sequences.size(-1), z_dim=z_dim, rnn_dim=rnn_dim, N_z0=args.N_songs) # encoder = Encoder(z_dim=100, rnn_dim=200) prior = Prior(z_dim=z_dim, transition_dim=transition_dim, N_z0=args.N_songs, use_cuda=args.cuda, R=args.R) decoder = Emitter(input_dim=training_data_sequences.size(-1), z_dim=z_dim, emission_dim=emission_dim, use_cuda=args.cuda) # Create optimizer algorithm # optimizer = optim.SGD(dmm.parameters(), lr=args.learning_rate) # optimizer = optim.Adam(dmm.parameters(), lr=args.learning_rate, betas=(0.96, 0.999), weight_decay=2.0) params = list(encoder.parameters()) + list(prior.parameters()) + list( decoder.parameters()) optimizer = optim.Adam(params, lr=args.learning_rate) # Add learning rate scheduler scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999) # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 1.) #1でもやってみる # make directory for Data now = datetime.datetime.now().strftime('%Y%m%d_%H_%M') os.makedirs(os.path.join("saveData", now), exist_ok=True) # save args as TEXT f = open(os.path.join("saveData", now, 'args.txt'), 'w') # 書き込みモードで開く for name, site in vars(args).items(): f.write(name + " = ") f.write(str(site) + "\n") f.close() # ファイルを閉じる ####################### #### TRAINING LOOP #### ####################### times = [time.time()] losses = [] recon_errors = [] pbar = tqdm.tqdm(range(args.num_epochs)) for epoch in pbar: epoch_nll = 0 epoch_recon_error = 0 shuffled_indices = torch.randperm(N_train_data) # print("Proceeding: %.2f " % (epoch*100/5000) + "%") # process each mini-batch; this is where we take gradient steps for which_mini_batch in range(N_mini_batches): # compute which sequences in the training set we should grab mini_batch_start = (which_mini_batch * args.mini_batch_size) mini_batch_end = np.min([ (which_mini_batch + 1) * args.mini_batch_size, N_train_data ]) mini_batch_indices = shuffled_indices[ mini_batch_start:mini_batch_end] # grab a fully prepped mini-batch using the helper function in the data loader mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \ = poly.get_mini_batch(mini_batch_indices, training_data_sequences, training_seq_lengths, cuda=args.cuda) # reset gradients optimizer.zero_grad() loss = 0 N_repeat = 1 for i in range(N_repeat): seiki = 1 / N_repeat # generate mini batch from training mini batch pos_z, pos_z_loc, pos_z_scale = encoder( mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths) pri_z, pri_z_loc, pri_z_scale = prior( length=mini_batch.size(1), N_generate=mini_batch.size(0)) # Regularizer (KL-diveergence(pos||pri)) # if any(torch.isnan(pos_z_loc.reshape(-1))): # print("pos_z_loc") # assert False # if any(torch.isnan(pos_z_scale.reshape(-1))): # print("pos_z_scale") # assert False # if any(torch.isnan(pri_z_loc.reshape(-1))): # print("pri_z_loc") # assert False # if any(torch.isnan(pri_z_scale.reshape(-1))): # print("pri_z_scale") # assert False regularizer = seiki * D_Wass(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale) # if any(torch.isnan(regularizer.reshape(-1))): # print("regula") # assert False # print(regularizer) # regularizer = seiki * D_KL(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale) # regularizer = seiki * D_JS_Monte(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale, pos_z, pri_z) reconstruction_error = seiki * torch.norm( mini_batch - decoder(pos_z), dim=2).sum() / mini_batch.size(0) / mini_batch.size( 1) / mini_batch.size(2) loss += args.lam * regularizer + reconstruction_error reconed_x = decoder(pos_z) # Reconstruction Error # reconstruction_error = torch.norm(mini_batch - reconed_x, dim=2).sum()/mini_batch.size(0)/mini_batch.size(1)/mini_batch.size(2) # loss += reconstruction_error # # Covariance Penalty # sumOfCovariance = torch.sum(pos_z_scale + pri_z_scale) # loss -= 0.00007 * sumOfCovariance # # generate mini batch from training mini batch # pos_z, pos_z_loc, pos_z_scale = encoder(mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths) # pri_z, pri_z_loc, pri_z_scale = prior(length = mini_batch.size(1), N_generate= mini_batch.size(0)) # reconed_x = decoder(pos_z) # # Reconstruction Error # reconstruction_error = torch.norm(mini_batch - reconed_x, dim=2).sum()/mini_batch.size(0)/mini_batch.size(1)/mini_batch.size(2) # # Regularizer (KL-diveergence(pos||pri)) # regularizer = D_Wass(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale) # loss += reconstruction_error + args.lam * regularizer # # NaN detector # if torch.isnan(loss): # saveDic = { # "optimizer":optimizer, # "mini_batch":mini_batch, # "Encoder_dic": encoder.state_dict, # "Prior_dic": prior.state_dict, # "Emitter_dic": decoder.state_dict, # } # torch.save(saveDic,os.path.join("saveData", now, "fail_DMM_dic")) # assert False, ("LOSS ia NaN!") # saveDic = { # "optimizer":optimizer, # "mini_batch":mini_batch, # "Encoder_dic": encoder.state_dict, # "Prior_dic": prior.state_dict, # "Emitter_dic": decoder.state_dict, # } # torch.save(saveDic,os.path.join("saveData", now, "DMM_dic")) # do an actual gradient step loss.backward() optimizer.step() scheduler.step() # if args.rnn_clip != None: # for p in dmm.rnn.parameters(): # p.data.clamp_(-args.rnn_clip, args.rnn_clip) epoch_nll += loss epoch_recon_error += reconstruction_error # report training diagnostics times.append(time.time()) losses.append(epoch_nll) recon_errors.append(epoch_recon_error) # epoch_time = times[-1] - times[-2] pbar.set_description("LOSS = %f " % epoch_nll) if epoch % args.checkpoint_freq == 0: path = os.path.join("saveData", now, "Epoch%d" % epoch) os.makedirs(path, exist_ok=True) if data_dim != 1: reco_list = generate_Xs(reconed_x) gene_list = generate_Xs(decoder(pri_z)) for i in range(args.N_generate): save_as_midi(song=reco_list[i], path=path, name="No%d_Reco.midi" % i) save_as_midi(song=mini_batch[i], path=path, name="No%d_Real.midi" % i) save_as_midi(song=gene_list[i], path=path, name="No%d_Gene.midi" % i) else: for i in range(len(mini_batch)): saveReconSinGraph(mini_batch[i], reconed_x[i], args.length, path, i) saveGeneSinGraph(decoder(pri_z)[i], args.length, path, i) saveDic = { "optimizer": optimizer, "mini_batch": mini_batch, "Encoder_dic": encoder.state_dict, "Prior_dic": prior.state_dict, "Emitter_dic": decoder.state_dict, "epoch_times": times, "losses": losses } # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1))) torch.save(saveDic, os.path.join(path, "DMM_dic")) if epoch == args.num_epochs - 1: path = os.path.join("saveData", now, "Epoch%d" % args.num_epochs) os.makedirs(path, exist_ok=True) if data_dim != 1: reco_list = generate_Xs(reconed_x) gene_list = generate_Xs(decoder(pri_z)) for i in range(args.N_generate): save_as_midi(song=reco_list[i], path=path, name="No%d_Reco.midi" % i) save_as_midi(song=mini_batch[i], path=path, name="No%d_Real.midi" % i) save_as_midi(song=gene_list[i], path=path, name="No%d_Gene.midi" % i) else: for i in range(len(mini_batch)): saveReconSinGraph(mini_batch[i], reconed_x[i], args.length, path, i) saveGeneSinGraph(decoder(pri_z)[i], args.length, path, i) saveDic = { "optimizer": optimizer, "mini_batch": mini_batch, "Encoder_dic": encoder.state_dict, "Prior_dic": prior.state_dict, "Emitter_dic": decoder.state_dict, "epoch_times": times, "losses": losses } # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1))) torch.save(saveDic, os.path.join(path, "DMM_dic")) saveGraph(losses, recon_errors, now) saveDic = { "optimizer": optimizer, "mini_batch": mini_batch, "Encoder_dic": encoder.state_dict, "Prior_dic": prior.state_dict, "Emitter_dic": decoder.state_dict, "epoch_times": times, "losses": losses } # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1))) torch.save(saveDic, os.path.join("saveData", now, "DMM_dic"))
def main(args): if args.cuda: torch.set_default_tensor_type("torch.cuda.FloatTensor") logging.info("Loading data") data = poly.load_data(poly.JSB_CHORALES) logging.info("-" * 40) model = models[args.model] logging.info("Training {} on {} sequences".format( model.__name__, len(data["train"]["sequences"]))) sequences = data["train"]["sequences"] lengths = data["train"]["sequence_lengths"] # find all the notes that are present at least once in the training set present_notes = (sequences == 1).sum(0).sum(0) > 0 # remove notes that are never played (we remove 37/88 notes) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clamp(max=args.truncate) sequences = sequences[:, :args.truncate] num_observations = float(lengths.sum()) pyro.set_rng_seed(args.seed) pyro.clear_param_store() # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing # out the hidden state x. This is accomplished via an automatic guide that # learns point estimates of all of our conditional probability tables, # named probs_*. guide = AutoDelta( poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))) # To help debug our tensor shapes, let's print the shape of each site's # distribution, value, and log_prob tensor. Note this information is # automatically printed on most errors inside SVI. if args.print_shapes: first_available_dim = -2 if model is model_0 else -3 guide_trace = poutine.trace(guide).get_trace( sequences, lengths, args=args, batch_size=args.batch_size) model_trace = poutine.trace( poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(sequences, lengths, args=args, batch_size=args.batch_size) logging.info(model_trace.format_shapes()) # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting. # All of our models have two plates: "data" and "tones". optim = Adam({"lr": args.learning_rate}) if args.tmc: if args.jit: raise NotImplementedError( "jit support not yet added for TraceTMC_ELBO") elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2) tmc_model = poutine.infer_config( model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {}, ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO elbo = Elbo( max_plate_nesting=1 if model is model_0 else 2, strict_enumeration_warning=(model is not model_7), jit_options={"time_compilation": args.time_compilation}, ) svi = SVI(model, guide, optim, elbo) # We'll train on small minibatches. logging.info("Step\tLoss") for step in range(args.num_steps): loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size) logging.info("{: >5d}\t{}".format(step, loss / num_observations)) if args.jit and args.time_compilation: logging.debug("time to compile: {} s.".format( elbo._differentiable_loss.compile_time)) # We evaluate on the entire training dataset, # excluding the prior term so our results are comparable across models. train_loss = elbo.loss(model, guide, sequences, lengths, args, include_prior=False) logging.info("training loss = {}".format(train_loss / num_observations)) # Finally we evaluate on the test dataset. logging.info("-" * 40) logging.info("Evaluating on {} test sequences".format( len(data["test"]["sequences"]))) sequences = data["test"]["sequences"][..., present_notes] lengths = data["test"]["sequence_lengths"] if args.truncate: lengths = lengths.clamp(max=args.truncate) num_observations = float(lengths.sum()) # note that since we removed unseen notes above (to make the problem a bit easier and for # numerical stability) this test loss may not be directly comparable to numbers # reported on this dataset elsewhere. test_loss = elbo.loss(model, guide, sequences, lengths, args=args, include_prior=False) logging.info("test loss = {}".format(test_loss / num_observations)) # We expect models with higher capacity to perform better, # but eventually overfit to the training set. capacity = sum( value.reshape(-1).size(0) for value in pyro.get_param_store().values()) logging.info("{} capacity = {} parameters".format(model.__name__, capacity))
def main(args): ## ドレミ def easyTones(): max = 70 interval = 10 N = 10 training_seq_lengths = torch.tensor([8] * N) training_data_sequences = torch.zeros(N, 8, 88) for i in range(N): training_data_sequences[i][0][int(max - i * interval)] = 1 training_data_sequences[i][1][int(max - i * interval) + 2] = 1 training_data_sequences[i][2][int(max - i * interval) + 4] = 1 training_data_sequences[i][3][int(max - i * interval) + 5] = 1 training_data_sequences[i][4][int(max - i * interval) + 7] = 1 training_data_sequences[i][5][int(max - i * interval) + 9] = 1 training_data_sequences[i][6][int(max - i * interval) + 11] = 1 training_data_sequences[i][7][int(max - i * interval) + 12] = 1 return training_seq_lengths, training_data_sequences ## ドドド、レレレ def superEasyTones(): training_seq_lengths = torch.tensor([8] * 10) training_data_sequences = torch.zeros(10, 8, 88) for i in range(10): for j in range(8): training_data_sequences[i][j][int(30 + i * 5)] = 1 return training_seq_lengths, training_data_sequences ## ドドド、ドドド、ドドド def easiestTones(): training_seq_lengths = torch.tensor([8] * 10) training_data_sequences = torch.zeros(10, 8, 88) for i in range(10): for j in range(8): training_data_sequences[i][j][int(70)] = 1 return training_seq_lengths, training_data_sequences # rep is short for "repeat" # which means how many times we use certain sample to do validation/test evaluation during training def rep(x): rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:] repeat_dims = [1] * len(x.size()) repeat_dims[0] = n_eval_samples return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose( 1, 0).reshape(rep_shape) ## This func is for saving losses at final def saveGraph(loss_list, now): FS = 10 fig = plt.figure() plt.rcParams["font.size"] = FS plt.plot(loss_list) plt.ylim(bottom=0) plt.title("Wasserstein Loss") plt.xlabel("epoch", fontsize=FS) plt.ylabel("loss", fontsize=FS) fig.savefig(os.path.join("saveData", now, "LOSS.png")) ## This func is for save generatedTones and trainingTones as MIDI def save_as_midi(song, path="", name="default.mid", BPM=120, velocity=100): pm = pretty_midi.PrettyMIDI(resolution=960, initial_tempo=BPM) #pretty_midiオブジェクトを作ります instrument = pretty_midi.Instrument(0) #instrumentはトラックみたいなものです。 for i, tones in enumerate(song): which_tone = torch.nonzero((tones == 1), as_tuple=False).reshape(-1) if len(which_tone) == 0: note = pretty_midi.Note( velocity=0, pitch=0, start=i, end=i + 1) #noteはNoteOnEventとNoteOffEventに相当します。 instrument.notes.append(note) else: for which in which_tone: note = pretty_midi.Note( velocity=velocity, pitch=int(which), start=i, end=i + 1) #noteはNoteOnEventとNoteOffEventに相当します。 instrument.notes.append(note) pm.instruments.append(instrument) pm.write(os.path.join(path, name)) #midiファイルを書き込みます。 # generate Xs def generate_Xs(dmm, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths): songs_list = [] generated_mini_batch = dmm(mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths) for i, song in enumerate(generated_mini_batch): tones_container = [] for time in range(mini_batch_seq_lengths[i]): p = dist.Bernoulli(probs=song[time]) tone = p.sample() tones_container.append(tone) tones_container = torch.stack(tones_container) songs_list.append(tones_container) return songs_list def saveSongs(songs_list, mini_batch, N_songs, path): # print(len(songs_list[0][0])) if len(songs_list) != len(mini_batch): assert False if N_songs <= len(songs_list): song_No = random.sample(range(len(songs_list)), k=N_songs) else: song_No = random.sample(range(len(songs_list)), k=len(songs_list)) for i, Number in enumerate(song_No): save_as_midi(song=songs_list[Number], path=path, name="No%d_Gene.midi" % i) save_as_midi(song=mini_batch[Number], path=path, name="No%d_Tran.midi" % i) ## 長さ最長129、例えば長さが60のやつは61~129はすべて0データ data = poly.load_data(poly.JSB_CHORALES) training_seq_lengths = data['train']['sequence_lengths'][:3] training_data_sequences = data['train']['sequences'][:3, :8] training_seq_lengths = torch.tensor([8] * 10) # training_seq_lengths = torch.tensor([8]) if args.eas: # ## ドドド、レレレ、ミミミ、ドレミ training_seq_lengths, training_data_sequences = easyTones() if args.sup: # ドドド、レレレ training_seq_lengths, training_data_sequences = superEasyTones() if args.est: ## ドドドのみ training_seq_lengths, training_data_sequences = easiestTones() test_seq_lengths = data['test']['sequence_lengths'][:3] test_seq_lengths = torch.tensor([8, 8, 8]) test_data_sequences = data['test']['sequences'][:3, :8] val_seq_lengths = data['valid']['sequence_lengths'] val_data_sequences = data['valid']['sequences'] N_train_data = len(training_seq_lengths) # N_train_data = len(test_seq_lengths) N_train_time_slices = float(torch.sum(training_seq_lengths)) N_mini_batches = int(N_train_data / args.mini_batch_size + int(N_train_data % args.mini_batch_size > 0)) # how often we do validation/test evaluation during training val_test_frequency = 50 # the number of samples we use to do the evaluation n_eval_samples = 1 # get the validation/test data ready for the dmm: pack into sequences, etc. # val_seq_lengths = rep(val_seq_lengths) # # test_seq_lengths = rep(test_seq_lengths) # val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch( # torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences), # val_seq_lengths, cuda=args.cuda) # test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch( # torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences), # test_seq_lengths, cuda=args.cuda) dmm = DMM(use_cuda=args.cuda, rnn_check=args.rck, rpt=args.rpt) WN = WGAN_network(use_cuda=args.cuda) W = WassersteinLoss(WN, N_loops=args.N_loops, lr=args.w_learning_rate, use_cuda=args.cuda) # Create optimizer algorithm # optimizer = optim.SGD(dmm.parameters(), lr=args.learning_rate) # optimizer = optim.Adam(dmm.parameters(), lr=args.learning_rate, betas=(0.96, 0.999), weight_decay=2.0) optimizer = optim.Adam(dmm.parameters(), lr=args.learning_rate) # Add learning rate scheduler # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 1.) #1でもやってみる # make directory for Data now = datetime.datetime.now().strftime('%Y%m%d_%H_%M') os.makedirs(os.path.join("saveData", now), exist_ok=True) # save args as TEXT f = open(os.path.join("saveData", now, 'args.txt'), 'w') # 書き込みモードで開く for name, site in vars(args).items(): f.write(name + " = ") f.write(str(site) + "\n") f.close() # ファイルを閉じる ####################### #### TRAINING LOOP #### ####################### times = [time.time()] losses = [] pbar = tqdm.tqdm(range(args.num_epochs)) for epoch in pbar: epoch_nll = 0 # shuffled_indices = torch.randperm(N_train_data) shuffled_indices = torch.arange(N_train_data) # print("Proceeding: %.2f " % (epoch*100/5000) + "%") # process each mini-batch; this is where we take gradient steps for which_mini_batch in range(N_mini_batches): # compute which sequences in the training set we should grab mini_batch_start = (which_mini_batch * args.mini_batch_size) mini_batch_end = np.min([ (which_mini_batch + 1) * args.mini_batch_size, N_train_data ]) mini_batch_indices = shuffled_indices[ mini_batch_start:mini_batch_end] # grab a fully prepped mini-batch using the helper function in the data loader mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \ = poly.get_mini_batch(mini_batch_indices, training_data_sequences, training_seq_lengths, cuda=args.cuda) # mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \ # = poly.get_mini_batch(mini_batch_indices, test_data_sequences, # test_seq_lengths, cuda=args.cuda) # reset gradients optimizer.zero_grad() # generate mini batch from training mini batch # NaN_detect(dmm, epoch, message="Before Generate") generated_mini_batch = dmm(mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths) # NaN_detect(dmm, epoch, message="After Generate") # if any(torch.isnan(generated_mini_batch.reshape(-1))): # print("GENERATED") # assert False # NaN_detect(WN, epoch, message="calc_Before") # calculate loss if args.wass: loss = EMD(mini_batch, generated_mini_batch) if args.wgan: loss = W.calc(mini_batch, generated_mini_batch) # NaN_detect(WN, epoch, message="calc_After") # NaN_detect(dmm, epoch, message="step_Before") # do an actual gradient step loss.backward() optimizer.step() scheduler.step() if args.rnn_clip != None: for p in dmm.rnn.parameters(): p.data.clamp_(-args.rnn_clip, args.rnn_clip) # p.data.clamp_(-0.01, 0.01) # NaN_detect(dmm, epoch, message="step_After") epoch_nll += loss # report training diagnostics times.append(time.time()) losses.append(epoch_nll) epoch_time = times[-1] - times[-2] # logging.info("[training epoch %04d] %.4f \t\t\t\t(dt = %.3f sec)" % # (epoch, epoch_nll / N_train_time_slices, epoch_time)) # if epoch % 10 == 0: # print("epoch %d time : %d sec" % (epoch, int(epoch_time))) pbar.set_description("LOSS = %f " % epoch_nll) # tqdm.write("\r"+" loss : %f " % epoch_nll, end="") if epoch % args.checkpoint_freq == 0: path = os.path.join("saveData", now, "Epoch%d" % epoch) os.makedirs(path, exist_ok=True) songs_list = generate_Xs(dmm, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths) saveSongs(songs_list, mini_batch, N_songs=2, path=path) if epoch == args.num_epochs - 1: path = os.path.join("saveData", now, "Epoch%d" % args.num_epochs) os.makedirs(path, exist_ok=True) songs_list = generate_Xs(dmm, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths) saveSongs(songs_list, mini_batch, N_songs=2, path=path) saveGraph(losses, now) saveDic = { "DMM_dic": dmm.state_dict, "WGAN_Network_dic": WN.state_dict, "epoch_times": times, "losses": losses } # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1))) torch.save(saveDic, os.path.join("saveData", now, "DMM_dic"))