def vae_main(live_instrument, model, args): # reset live input clock print("\nUser input\n") live_instrument.reset_sequence() live_instrument.reset_clock() while True: status_played_notes = live_instrument.clock() if status_played_notes: sequence = live_instrument.parse_to_matrix() live_instrument.reset_sequence() break # send live recorded sequence through model and get improvisation with torch.no_grad(): sample = np.array(np.split(sequence, args.bars)) # prepare sample for input sample = cutOctaves(sample) sample = torch.from_numpy(sample).float().to(device) sample = torch.unsqueeze(sample, 1) # model mu, logvar = model.encoder(sample) # TODO reparameterize to get new sequences here with GUI?? #reconstruction, soon ~prediction pred = model.decoder(mu) # reorder prediction pred = pred.squeeze(1) prediction = pred[0] # TODO TEMP for more sequences if pred.size(0) > 1: for p in pred[1:]: prediction = torch.cat((prediction, p), dim=0) prediction = prediction.cpu().numpy() # normalize predictions prediction /= np.abs(np.max(prediction)) # check midi activations to include rests prediction[prediction < (1 - args.temperature)] = 0 prediction = debinarizeMidi(prediction, prediction=True) prediction = addCuttedOctaves(prediction) # play predicted sequence note by note print("\nPrediction\n") live_instrument.computer_play(prediction=prediction) live_instrument.reset_sequence()
def midi_to_embedding(self, sample_path='./utils/midi_files/b- minor_movement21.mid', sample_bar=0): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if self.model.train(): self.model.eval() with torch.no_grad(): sample = getSlicedPianorollMatrixNp(sample_path) sample = transposeNotesHigherLower(sample) sample = cutOctaves(sample[sample_bar]) sample = torch.from_numpy(sample.reshape(1,1,96,60)).float().to(device) embedding, _ = self.model.encoder(sample) return embedding
def reconstruct(file_path, model, start_bar, end_bar, temperature=0.5, smooth_threshold=0): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if model.train(): model.eval() with torch.no_grad(): sample_np = getSlicedPianorollMatrixNp(file_path) sample_np = transposeNotesHigherLower(sample_np) sample_np = cutOctaves(sample_np) sample_np = sample_np[start_bar:end_bar] sample = torch.from_numpy(sample_np).float() recon, embed, logvar = model(sample.view(-1, 1, 96, 60).to(device)) recon = torch.softmax(recon, dim=3) recon = recon.squeeze(1).cpu().numpy() # recon /= np.abs(np.max(recon)) recon[recon < (1 - temperature)] = 0 sample_play = debinarizeMidi(sample_np, prediction=False) sample_play = addCuttedOctaves(sample_play) recon = debinarizeMidi(recon, prediction=True) recon = addCuttedOctaves(recon) recon_out = recon[0] sample_out = sample_play[0] if recon.shape[0] > 1: for i in range(recon.shape[0] - 1): sample_out = np.concatenate((sample_out, sample_play[i + 1]), axis=0) recon_out = np.concatenate((recon_out, recon[i + 1]), axis=0) # plot with pypianoroll sample_plot = ppr.Track(sample_out) ppr.plot(sample_plot) recon_plot = ppr.Track(recon_out) ppr.plot(recon_plot) # smooth output smoother = NoteSmoother(recon_out, threshold=smooth_threshold) smoothed_seq = smoother.smooth() smoother_seq_plot = ppr.Track(smoothed_seq) ppr.plot(smoother_seq_plot)
def vae_interact(gui): live_instrument = gui.live_instrument device = gui.device model = gui.model.to(device) dials = gui.dials while True: print("\nUser input\n") # reset live input clock and prerecorded sequences live_instrument.reset_sequence() live_instrument.reset_clock() while True: status_played_notes = live_instrument.clock() if status_played_notes: sequence = live_instrument.parse_to_matrix() live_instrument.reset_sequence() break if not gui.is_running: break if not gui.is_running: break # send live recorded sequence through model and get response with torch.no_grad(): # prepare sample for input sample = np.array(np.split(sequence, live_instrument.bars)) sample = cutOctaves(sample) sample = torch.from_numpy(sample).float().to(device) sample = torch.unsqueeze(sample,1) # encode mu, logvar = model.encoder(sample) # reparameterize with variance dial_vals = [] for dial in dials: dial_vals.append(dial.value()) dial_tensor = (torch.FloatTensor(dial_vals)/100.).to(device) new = mu + (dial_tensor * 0.5 * logvar.exp()) pred = model.decoder(new).squeeze(1) # for more than 1 sequence prediction = pred[0] if pred.size(0) > 1: for p in pred[1:]: prediction = torch.cat((prediction, p), dim=0) # back to cpu and normalize prediction = prediction.cpu().numpy() prediction /= np.abs(np.max(prediction)) # check midi activations to include rests prediction[prediction < (1 - gui.slider_temperature.value()/100.)] = 0 prediction = debinarizeMidi(prediction, prediction=True) prediction = addCuttedOctaves(prediction) smoother = NoteSmoother(prediction, threshold=2) prediction = smoother.smooth() # sent to robot if gui.chx_simulate_robot.isChecked(): print("\nPublisher\n") note_msg = Int32MultiArray() live_instrument.human = False live_instrument.reset_clock() play_tick = -1 old_midi_on = np.zeros(1) played_notes = [] while True: done = live_instrument.computer_clock() if live_instrument.current_tick > play_tick: play_tick = live_instrument.current_tick midi_on = np.argwhere(prediction[play_tick] > 0) if midi_on.any(): for note in midi_on[0]: if note not in old_midi_on: current_vel = int(prediction[live_instrument.current_tick,note]) mido_msg = mido.Message('note_on', note=note, velocity=current_vel) note_msg.data = mido_msg.bytes() gui.midi_publisher.publish(note_msg) played_notes.append(note) else: for note in played_notes: # self.out_port.send(mido.Message('note_off', # note=note))#, velocity=100)) played_notes.pop(0) if old_midi_on.any(): for note in old_midi_on[0]: if note not in midi_on: # self.out_port.send(mido.Message('note_off', note=note)) continue old_midi_on = midi_on if done: live_instrument.human = True live_instrument.reset_clock() break # or play in software else: print("\nPrediction\n") live_instrument.computer_play(prediction=prediction) live_instrument.reset_sequence() if not gui.is_running: break
def vae_endless(gui): live_instrument = gui.live_instrument device = gui.device model = gui.model.to(device) dials = gui.dials print("\nUser input\n") # reset live input clock and prerecorded sequences live_instrument.reset_sequence() live_instrument.reset_clock() while True: status_played_notes = live_instrument.clock() if status_played_notes: sequence = live_instrument.parse_to_matrix() live_instrument.reset_sequence() break if not gui.is_running: break while True: # send live recorded sequence through model and get response with torch.no_grad(): # prepare sample for input sample = np.array(np.split(sequence, live_instrument.bars)) sample = cutOctaves(sample) sample = torch.from_numpy(sample).float().to(device) sample = torch.unsqueeze(sample,1) # encode mu, logvar = model.encoder(sample) # reparameterize with variance dial_vals = [] for dial in dials: dial_vals.append(dial.value()) dial_tensor = torch.FloatTensor(dial_vals)/100. # print(dial_tensor) new = mu + (dial_tensor * 0.5 * logvar.exp()) pred = model.decoder(new).squeeze(1) # for more than 1 sequence prediction = pred[0] if pred.size(0) > 1: for p in pred[1:]: prediction = torch.cat((prediction, p), dim=0) # back to cpu and normalize prediction = prediction.cpu().numpy() prediction /= np.abs(np.max(prediction)) # check midi activations to include rests prediction[prediction < (1 - gui.slider_temperature.value()/100.)] = 0 prediction = debinarizeMidi(prediction, prediction=True) prediction = addCuttedOctaves(prediction) smoother = NoteSmoother(prediction, threshold=2) prediction = smoother.smooth() # play predicted sequence note by note print("\nPrediction\n") live_instrument.computer_play(prediction=prediction) live_instrument.reset_sequence() sequence = prediction if not gui.is_running: break