Ejemplo n.º 1
0
def generate():
    ''' Sample MIDI from trained generator model
    '''
    # prepare model
    dataloader = music_data_utils.MusicDataLoader(datadir=None)
    num_feats = dataloader.get_num_song_features()

    use_gpu = torch.cuda.is_available()
    g_model = Generator(num_feats, use_cuda=use_gpu)

    ckpt = torch.load(os.path.join(CKPT_DIR, G_FN))
    g_model.load_state_dict(ckpt)

    # generate from model then save to MIDI file
    g_states = g_model.init_hidden(1)
    z = torch.empty([1, MAX_SEQ_LEN, num_feats]).uniform_()  # random vector
    if use_gpu:
        z = z.cuda()
        g_model.cuda()

    g_model.eval()
    g_feats, _ = g_model(z, g_states)
    song_data = g_feats.squeeze().cpu()
    song_data = song_data.detach().numpy()

    dataloader.save_data(FILENAME, song_data)
    print('Generated {}'.format(FILENAME))
Ejemplo n.º 2
0
def main(_):
  if not FLAGS.datadir:
    raise ValueError("Must set --datadir to midi music dir.")
  if not FLAGS.traindir:
    raise ValueError("Must set --traindir to dir where I can save model and plots.")
  restore_flags()
  generated_data_dir = os.path.join(FLAGS.traindir, 'generated_data')
  try:
    os.makedirs(FLAGS.traindir)
  except:
    pass
  try:
    os.makedirs(generated_data_dir)
  except:
    pass
  directorynames = FLAGS.traindir.split('/')
  songfeatures_filename = os.path.join(FLAGS.traindir, 'num_song_features.pkl')
  metafeatures_filename = os.path.join(FLAGS.traindir, 'num_meta_features.pkl')
  loader = music_data_utils.MusicDataLoader(FLAGS.datadir, FLAGS.select_validation_percentage, FLAGS.select_test_percentage, FLAGS.works_per_composer, FLAGS.pace_events, synthetic=None, tones_per_cell=FLAGS.tones_per_cell, single_composer=FLAGS.composer)
  num_song_features = loader.get_num_song_features()
  num_meta_features = loader.get_num_meta_features()
  songlength_ceiling = FLAGS.songlength
  songlength = 0
  with tf.Graph().as_default(), tf.Session() as session:
    with tf.variable_scope("model", reuse=None) as scope:
      scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))
      m = RNNGAN(is_training=True, num_song_features=num_song_features, num_meta_features=num_meta_features)
    print("Created model with fresh parameters.")
    session.run(tf.global_variables_initializer())
    for i in range(FLAGS.max_epoch):
      if songlength < songlength_ceiling:
        songlength += 4
        print('Changing songlength, now training on {} events from songs.'.format(songlength))
        FLAGS.songlength = songlength
        with tf.variable_scope("model", reuse=True) as scope:
          scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))
          m = RNNGAN(is_training=True, num_song_features=num_song_features, num_meta_features=num_meta_features)
      m.assign_lr(session, FLAGS.learning_rate * FLAGS.lr_decay ** max(i - FLAGS.epochs_before_decay, 0.0))
      print("Epoch: {} Learning rate: {:.3f}".format(i, session.run(m.lr)))
      run_epoch(session, m, loader, 'train', m.opt_pretraining, pretraining = True, verbose=True)
      song_data = sample(session, m, batch=True)
      midi_patterns = []
      midi_time = time.time()
      for d in song_data:
        midi_patterns.append(loader.get_midi_pattern(d))
      print('done. time: {}'.format(time.time()-midi_time))
      filename = os.path.join(generated_data_dir, 'out-{}-{}.mid'.format(i, datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')))
      loader.save_midi_pattern(filename, midi_patterns[0])
      sys.stdout.flush()
    run_epoch(session, m, loader, 'test', tf.no_op())
    song_data = sample(session, m)
    filename = os.path.join(generated_data_dir, 'out-{}-{}.mid'.format(i, datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')))
    loader.save_data(filename, song_data)
    print('Saved {}.'.format(filename))
Ejemplo n.º 3
0
def generate(n):
    ''' Sample MIDI from trained generator model
    '''
    # prepare model
    dataloader = music_data_utils.MusicDataLoader(datadir=None)
    num_feats = dataloader.get_num_song_features()

    use_gpu = torch.cuda.is_available()
    g_model = Generator(num_feats, use_cuda=use_gpu)

    if not use_gpu:
        ckpt = torch.load(os.path.join(CKPT_DIR, G_FN), map_location='cpu')
    else:
        ckpt = torch.load(os.path.join(CKPT_DIR, G_FN))

    g_model.load_state_dict(ckpt)

    # generate from model then save to MIDI file
    g_states = g_model.init_hidden(1)
    z = torch.empty([1, MAX_SEQ_LEN, num_feats]).uniform_() # random vector
    if use_gpu:
        z = z.cuda()
        g_model.cuda()

    g_model.eval()

    full_song_data = []
    for i in range(n):
        g_feats, g_states = g_model(z, g_states)
        song_data = g_feats.squeeze().cpu()
        song_data = song_data.detach().numpy() 
        full_song_data.append(song_data)

    if len(full_song_data) > 1:
        full_song_data = np.concatenate(full_song_data, axis=0)
    else:
        full_song_data = full_song_data[0]

    dataloader.save_data(FILENAME, song_data)
    print('Full sequence shape: ', full_song_data.shape)
    print('Generated {}'.format(FILENAME))
Ejemplo n.º 4
0
def main(_):
  if not FLAGS.datadir:
    raise ValueError("Must set --datadir to midi music dir.")
  if not FLAGS.traindir:
    raise ValueError("Must set --traindir to dir where I can save model and plots.")
 
  restore_flags()
 
  summaries_dir = None
  plots_dir = None
  generated_data_dir = None
  summaries_dir = os.path.join(FLAGS.traindir, 'summaries')
  plots_dir = os.path.join(FLAGS.traindir, 'plots')
  generated_data_dir = os.path.join(FLAGS.traindir, 'generated_data')
  try: os.makedirs(FLAGS.traindir)
  except: pass
  try: os.makedirs(summaries_dir)
  except: pass
  try: os.makedirs(plots_dir)
  except: pass
  try: os.makedirs(generated_data_dir)
  except: pass
  directorynames = FLAGS.traindir.split('/')
  experiment_label = ''
  while not experiment_label:
    experiment_label = directorynames.pop()
  
  global_step = -1
  if os.path.exists(os.path.join(FLAGS.traindir, 'global_step.pkl')):
    with open(os.path.join(FLAGS.traindir, 'global_step.pkl'), 'r') as f:
      global_step = pkl.load(f)
  global_step += 1

  songfeatures_filename = os.path.join(FLAGS.traindir, 'num_song_features.pkl')
  metafeatures_filename = os.path.join(FLAGS.traindir, 'num_meta_features.pkl')
  synthetic=None
  if FLAGS.synthetic_chords:
    synthetic = 'chords'
    print('Training on synthetic chords!')
  if FLAGS.composer is not None:
    print('Single composer: {}'.format(FLAGS.composer))
  loader = music_data_utils.MusicDataLoader(FLAGS.datadir, FLAGS.select_validation_percentage, FLAGS.select_test_percentage, FLAGS.works_per_composer, FLAGS.pace_events, synthetic=synthetic, tones_per_cell=FLAGS.tones_per_cell, single_composer=FLAGS.composer)
  if FLAGS.synthetic_chords:
    # This is just a print out, to check the generated data.
    batch = loader.get_batch(batchsize=1, songlength=400)
    loader.get_midi_pattern([batch[1][0][i] for i in xrange(batch[1].shape[1])])

  num_song_features = loader.get_num_song_features()
  print('num_song_features:{}'.format(num_song_features))
  num_meta_features = loader.get_num_meta_features()
  print('num_meta_features:{}'.format(num_meta_features))

  train_start_time = time.time()
  checkpoint_path = os.path.join(FLAGS.traindir, "model.ckpt")

  songlength_ceiling = FLAGS.songlength

  if global_step < FLAGS.pretraining_epochs:
    FLAGS.songlength = int(min(((global_step+10)/10)*10,songlength_ceiling))
    FLAGS.songlength = int(min((global_step+1)*4,songlength_ceiling))
 
  with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) as session:
    with tf.variable_scope("model", reuse=None) as scope:
      scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))
      m = RNNGAN(is_training=True, num_song_features=num_song_features, num_meta_features=num_meta_features)


    if FLAGS.initialize_d:
      vars_to_restore = {}
      for v in tf.trainable_variables():
        if v.name.startswith('model/G/'):
          print(v.name[:-2])
          vars_to_restore[v.name[:-2]] = v
      saver = tf.train.Saver(vars_to_restore)
      ckpt = tf.train.get_checkpoint_state(FLAGS.traindir)
      if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path,end=" ")
        saver.restore(session, ckpt.model_checkpoint_path)
        session.run(tf.initialize_variables([v for v in tf.trainable_variables() if v.name.startswith('model/D/')]))
      else:
        print("Created model with fresh parameters.")
        session.run(tf.initialize_all_variables())
      saver = tf.train.Saver(tf.all_variables())
    else:
      saver = tf.train.Saver(tf.all_variables())
      ckpt = tf.train.get_checkpoint_state(FLAGS.traindir)
      if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        saver.restore(session, ckpt.model_checkpoint_path)
      else:
        print("Created model with fresh parameters.")
        session.run(tf.initialize_all_variables())

    run_metadata = None
    if FLAGS.profiling:
      run_metadata = tf.RunMetadata()
    if not FLAGS.sample:
      train_g_loss,train_d_loss = 1.0,1.0
      for i in range(global_step, FLAGS.max_epoch):
        lr_decay = FLAGS.lr_decay ** max(i - FLAGS.epochs_before_decay, 0.0)

        if global_step < FLAGS.pretraining_epochs:
          #new_songlength = int(min(((i+10)/10)*10,songlength_ceiling))
          new_songlength = int(min((i+1)*4,songlength_ceiling))
        else:
          new_songlength = songlength_ceiling
        if new_songlength != FLAGS.songlength:
          print('Changing songlength, now training on {} events from songs.'.format(new_songlength))
          FLAGS.songlength = new_songlength
          with tf.variable_scope("model", reuse=True) as scope:
            scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))
            m = RNNGAN(is_training=True, num_song_features=num_song_features, num_meta_features=num_meta_features)

        if not FLAGS.adam:
          m.assign_lr(session, FLAGS.learning_rate * lr_decay)

        save = False
        do_exit = False

        print("Epoch: {} Learning rate: {:.3f}, pretraining: {}".format(i, session.run(m.lr), (i<FLAGS.pretraining_epochs)))
        if i<FLAGS.pretraining_epochs:
          opt_d = tf.no_op()
          if FLAGS.pretraining_d:
            opt_d = m.opt_d
          train_g_loss,train_d_loss = run_epoch(session, m, loader, 'train', m.opt_pretraining, opt_d, pretraining = True, verbose=True, run_metadata=run_metadata, pretraining_d=FLAGS.pretraining_d)
          if FLAGS.pretraining_d:
            try:
              print("Epoch: {} Pretraining loss: G: {:.3f}, D: {:.3f}".format(i, train_g_loss, train_d_loss))
            except:
              print(train_g_loss)
              print(train_d_loss)
          else:
            print("Epoch: {} Pretraining loss: G: {:.3f}".format(i, train_g_loss))
        else:
          train_g_loss,train_d_loss = run_epoch(session, m, loader, 'train', m.opt_d, m.opt_g, verbose=True, run_metadata=run_metadata)
          try:
            print("Epoch: {} Train loss: G: {:.3f}, D: {:.3f}".format(i, train_g_loss, train_d_loss))
          except:
            print("Epoch: {} Train loss: G: {}, D: {}".format(i, train_g_loss, train_d_loss))
        valid_g_loss,valid_d_loss = run_epoch(session, m, loader, 'validation', tf.no_op(), tf.no_op())
        try:
          print("Epoch: {} Valid loss: G: {:.3f}, D: {:.3f}".format(i, valid_g_loss, valid_d_loss))
        except:
          print("Epoch: {} Valid loss: G: {}, D: {}".format(i, valid_g_loss, valid_d_loss))
        
        if train_d_loss == 0.0 and train_g_loss == 0.0:
          print('Both G and D train loss are zero. Exiting.')
          save = True
          do_exit = True
        if i % FLAGS.epochs_per_checkpoint == 0:
          save = True
        if FLAGS.exit_after > 0 and time.time() - train_start_time > FLAGS.exit_after*60:
          print("%s: Has been running for %d seconds. Will exit (exiting after %d minutes)."%(datetime.datetime.today().strftime('%Y-%m-%d %H:%M:%S'), (int)(time.time() - train_start_time), FLAGS.exit_after))
          save = True
          do_exit = True

        if save:
          saver.save(session, checkpoint_path, global_step=i)
          with open(os.path.join(FLAGS.traindir, 'global_step.pkl'), 'wb') as f:
            pkl.dump(i, f)
          if FLAGS.profiling:
            # Create the Timeline object, and write it to a json
            tl = timeline.Timeline(run_metadata.step_stats)
            ctf = tl.generate_chrome_trace_format()
            with open(os.path.join(plots_dir, 'timeline.json'), 'w') as f:
              f.write(ctf)
          print('{}: Saving done!'.format(i))

        step_time, loss = 0.0, 0.0
        if train_d_loss is None: #pretraining
          train_d_loss = 0.0
          valid_d_loss = 0.0
          valid_g_loss = 0.0
        if not os.path.exists(os.path.join(plots_dir, 'gnuplot-input.txt')):
          with open(os.path.join(plots_dir, 'gnuplot-input.txt'), 'w') as f:
            f.write('# global-step learning-rate train-g-loss train-d-loss valid-g-loss valid-d-loss\n')
        with open(os.path.join(plots_dir, 'gnuplot-input.txt'), 'a') as f:
          try:
            f.write('{} {:.4f} {:.2f} {:.2f} {:.3} {:.3f}\n'.format(i, m.lr.eval(), train_g_loss, train_d_loss, valid_g_loss, valid_d_loss))
          except:
            f.write('{} {} {} {} {} {}\n'.format(i, m.lr.eval(), train_g_loss, train_d_loss, valid_g_loss, valid_d_loss))
        if not os.path.exists(os.path.join(plots_dir, 'gnuplot-commands-loss.txt')):
          with open(os.path.join(plots_dir, 'gnuplot-commands-loss.txt'), 'a') as f:
            f.write('set terminal postscript eps color butt "Times" 14\nset yrange [0:400]\nset output "loss.eps"\nplot \'gnuplot-input.txt\' using ($1):($3) title \'train G\' with linespoints, \'gnuplot-input.txt\' using ($1):($4) title \'train D\' with linespoints, \'gnuplot-input.txt\' using ($1):($5) title \'valid G\' with linespoints, \'gnuplot-input.txt\' using ($1):($6) title \'valid D\' with linespoints, \n')
        if not os.path.exists(os.path.join(plots_dir, 'gnuplot-commands-midistats.txt')):
          with open(os.path.join(plots_dir, 'gnuplot-commands-midistats.txt'), 'a') as f:
            f.write('set terminal postscript eps color butt "Times" 14\nset yrange [0:127]\nset xrange [0:70]\nset output "midistats.eps"\nplot \'midi_stats.gnuplot\' using ($1):(100*$3) title \'Scale consistency, %\' with linespoints, \'midi_stats.gnuplot\' using ($1):($6) title \'Tone span, halftones\' with linespoints, \'midi_stats.gnuplot\' using ($1):($10) title \'Unique tones\' with linespoints, \'midi_stats.gnuplot\' using ($1):($23) title \'Intensity span, units\' with linespoints, \'midi_stats.gnuplot\' using ($1):(100*$24) title \'Polyphony, %\' with linespoints, \'midi_stats.gnuplot\' using ($1):($12) title \'3-tone repetitions\' with linespoints\n')
        try:
          Popen(['gnuplot','gnuplot-commands-loss.txt'], cwd=plots_dir)
          Popen(['gnuplot','gnuplot-commands-midistats.txt'], cwd=plots_dir)
        except:
          print('failed to run gnuplot. Please do so yourself: gnuplot gnuplot-commands.txt cwd={}'.format(plots_dir))
        
        song_data = sample(session, m, batch=True)
        midi_patterns = []
        print('formatting midi...')
        midi_time = time.time()
        for d in song_data:
          midi_patterns.append(loader.get_midi_pattern(d))
        print('done. time: {}'.format(time.time()-midi_time))
        
        filename = os.path.join(generated_data_dir, 'out-{}-{}-{}.mid'.format(experiment_label, i, datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')))
        loader.save_midi_pattern(filename, midi_patterns[0])
  
        stats = []
        print('getting stats...')
        stats_time = time.time()
        for p in midi_patterns:
          stats.append(get_all_stats(p))
        print('done. time: {}'.format(time.time()-stats_time))
        #print(stats)
        stats = [stat for stat in stats if stat is not None]
        if len(stats):
          stats_keys_string = ['scale']
          stats_keys = ['scale_score', 'tone_min', 'tone_max', 'tone_span', 'freq_min', 'freq_max', 'freq_span', 'tones_unique', 'repetitions_2', 'repetitions_3', 'repetitions_4', 'repetitions_5', 'repetitions_6', 'repetitions_7', 'repetitions_8', 'repetitions_9', 'estimated_beat', 'estimated_beat_avg_ticks_off', 'intensity_min', 'intensity_max', 'intensity_span', 'polyphony_score', 'top_2_interval_difference', 'top_3_interval_difference', 'num_tones']
          statsfilename = os.path.join(plots_dir, 'midi_stats.gnuplot')
          if not os.path.exists(statsfilename):
            with open(statsfilename, 'a') as f:
              f.write('# Average numers over one minibatch of size {}.\n'.format(FLAGS.batch_size))
              f.write('# global-step {} {}\n'.format(' '.join([s.replace(' ', '_') for s in stats_keys_string]), ' '.join(stats_keys)))
          with open(statsfilename, 'a') as f:
            f.write('{} {} {}\n'.format(i, ' '.join(['{}'.format(stats[0][key].replace(' ', '_')) for key in stats_keys_string]), ' '.join(['{:.3f}'.format(sum([s[key] for s in stats])/float(len(stats))) for key in stats_keys])))
          print('Saved {}.'.format(filename))
          
        if do_exit:
          if FLAGS.call_after is not None:
            print("%s: Will call \"%s\" before exiting."%(datetime.datetime.today().strftime('%Y-%m-%d %H:%M:%S'), FLAGS.call_after))
            res = call(FLAGS.call_after.split(" "))
            print ('{}: call returned {}.'.format(datetime.datetime.today().strftime('%Y-%m-%d %H:%M:%S'), res))
          exit()
        sys.stdout.flush()


      test_g_loss,test_d_loss = run_epoch(session, m, loader, 'test', tf.no_op(), tf.no_op())
      print("Test loss G: %.3f, D: %.3f" %(test_g_loss, test_d_loss))

    song_data = sample(session, m)
    filename = os.path.join(generated_data_dir, 'out-{}-{}-{}.mid'.format(experiment_label, i, datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')))
    loader.save_data(filename, song_data)
    print('Saved {}.'.format(filename))
Ejemplo n.º 5
0
def main(args):
    ''' Training sequence
    '''
    dataloader = music_data_utils.MusicDataLoader(DATA_DIR,
                                                  single_composer=COMPOSER)
    num_feats = dataloader.get_num_song_features()

    # First checking if GPU is available
    train_on_gpu = torch.cuda.is_available()
    if train_on_gpu:
        print('Training on GPU.')
    else:
        print('No GPU available, training on CPU.')

    model = {
        'g': Generator(num_feats, use_cuda=train_on_gpu),
        'd': Discriminator(num_feats, use_cuda=train_on_gpu)
    }

    if args.use_sgd:
        optimizer = {
            'g':
            optim.SGD(model['g'].parameters(),
                      lr=args.g_lrn_rate,
                      momentum=0.9),
            'd':
            optim.SGD(model['d'].parameters(),
                      lr=args.d_lrn_rate,
                      momentum=0.9)
        }
    else:
        optimizer = {
            'g': optim.Adam(model['g'].parameters(), args.g_lrn_rate),
            'd': optim.Adam(model['d'].parameters(), args.d_lrn_rate)
        }

    criterion = {
        'g': nn.MSELoss(reduction='sum') if args.feature_matching else GLoss(),
        'd': DLoss(args.label_smoothing)
    }

    if args.load_g:
        ckpt = torch.load(os.path.join(CKPT_DIR, G_FN))
        model['g'].load_state_dict(ckpt)
        print("Continue training of %s" % os.path.join(CKPT_DIR, G_FN))

    if args.load_d:
        ckpt = torch.load(os.path.join(CKPT_DIR, D_FN))
        model['d'].load_state_dict(ckpt)
        print("Continue training of %s" % os.path.join(CKPT_DIR, D_FN))

    if train_on_gpu:
        model['g'].cuda()
        model['d'].cuda()

    if not args.no_pretraining:
        for ep in range(args.d_pretraining_epochs):
            model, _ = run_epoch(model,
                                 optimizer,
                                 criterion,
                                 dataloader,
                                 ep,
                                 args.d_pretraining_epochs,
                                 freeze_g=True,
                                 pretraining=True)

        for ep in range(args.g_pretraining_epochs):
            model, _ = run_epoch(model,
                                 optimizer,
                                 criterion,
                                 dataloader,
                                 ep,
                                 args.g_pretraining_epochs,
                                 freeze_d=True,
                                 pretraining=True)

    freeze_d = False
    for ep in range(args.num_epochs):
        # if ep % args.freeze_d_every == 0:
        #     freeze_d = not freeze_d

        model, trn_acc = run_epoch(model,
                                   optimizer,
                                   criterion,
                                   dataloader,
                                   ep,
                                   args.num_epochs,
                                   freeze_d=freeze_d)
        if args.conditional_freezing:
            # conditional freezing
            freeze_d = False
            if trn_acc >= 95.0:
                freeze_d = True

    if not args.no_save_g:
        torch.save(model['g'].state_dict(), os.path.join(CKPT_DIR, G_FN))
        print("Saved generator: %s" % os.path.join(CKPT_DIR, G_FN))

    if not args.no_save_d:
        torch.save(model['d'].state_dict(), os.path.join(CKPT_DIR, D_FN))
        print("Saved discriminator: %s" % os.path.join(CKPT_DIR, D_FN))
Ejemplo n.º 6
0
def print_stats(file_dir="", more_than_one=False, filename=""):
    """
    method prints out statistics of one midi file/several midi files in
    :param file_dir: (string) directory where files are saved
    :param more_than_one: (boolean) true if more than one file
    :param filename: (string) directory + name of single midi file
    :return: none
    """
    if not more_than_one:
        filename = filename
        print('File: {}'.format(filename))
        dl = music_data_utils.MusicDataLoader(datadir=None,
                                              select_validation_percentage=0.0,
                                              select_test_percentage=0.0)
        song_data = dl.read_one_file(filename=filename, pace_events=True)
        midi_pattern = dl.get_midi_pattern(song_data)
        stats = get_all_stats(midi_pattern)
        if stats is None:
            print('Could not extract stats.')
        else:
            print('ML scale estimate: {}: {:.2f}'.format(
                stats['scale'], stats['scale_score']))
            print('Min tone: {}, {:.1f} Hz.'.format(
                tone_to_tone_name(stats['tone_min']), stats['freq_min']))
            print('Max tone: {}, {:.1f} Hz.'.format(
                tone_to_tone_name(stats['tone_max']), stats['freq_max']))
            print('Span: {} tones, {:.1f} Hz.'.format(stats['tone_span'],
                                                      stats['freq_span']))
            print('Overall number of tones: {}'.format(stats['num_tones']))
            print('Unique tones: {}'.format(stats['tones_unique']))
            for r in range(2, 10):  # xrange in Python2, range in Python3
                print('Repetitions of len {}: {}'.format(
                    r, stats['repetitions_{}'.format(r)]))
            print('Estimated beat: {}. Avg ticks off: {:.2f}.'.format(
                stats['estimated_beat'],
                stats['estimated_beat_avg_ticks_off']))
            print('Intensity: min: {}, max: {}.'.format(
                stats['intensity_min'], stats['intensity_max']))
            print('Polyphonous events: {:.2f}.'.format(
                stats['polyphony_score']))
            print('Top intervals:')
            for interval, score in stats['top_10_intervals']:
                print('{}: {:.2f}.'.format(interval, score))
            print('Top 2 interval difference: {}.'.format(
                stats['top_2_interval_difference']))
            print('Top 3 interval difference: {}.'.format(
                stats['top_3_interval_difference']))

    else:
        for file in glob.glob(file_dir + "/*.mid"):
            filename = file
            print('File: {}'.format(filename))
            dl = music_data_utils.MusicDataLoader(
                datadir=None,
                select_validation_percentage=0.0,
                select_test_percentage=0.0)
            song_data = dl.read_one_file(filename=filename, pace_events=True)
            midi_pattern = dl.get_midi_pattern(song_data)
            stats = get_all_stats(midi_pattern)
            if stats is None:
                print('Could not extract stats.')
            else:
                print('ML scale estimate: {}: {:.2f}'.format(
                    stats['scale'], stats['scale_score']))
                print('Min tone: {}, {:.1f} Hz.'.format(
                    tone_to_tone_name(stats['tone_min']), stats['freq_min']))
                print('Max tone: {}, {:.1f} Hz.'.format(
                    tone_to_tone_name(stats['tone_max']), stats['freq_max']))
                print('Span: {} tones, {:.1f} Hz.'.format(
                    stats['tone_span'], stats['freq_span']))
                print('Overall number of tones: {}'.format(stats['num_tones']))
                print('Unique tones: {}'.format(stats['tones_unique']))
                for r in range(2, 10):  # xrange in Python2, range in Python3
                    print('Repetitions of len {}: {}'.format(
                        r, stats['repetitions_{}'.format(r)]))
                print('Estimated beat: {}. Avg ticks off: {:.2f}.'.format(
                    stats['estimated_beat'],
                    stats['estimated_beat_avg_ticks_off']))
                print('Intensity: min: {}, max: {}.'.format(
                    stats['intensity_min'], stats['intensity_max']))
                print('Polyphonous events: {:.2f}.'.format(
                    stats['polyphony_score']))
                print('Top intervals:')
                for interval, score in stats['top_10_intervals']:
                    print('{}: {:.2f}.'.format(interval, score))
                print('Top 2 interval difference: {}.'.format(
                    stats['top_2_interval_difference']))
                print('Top 3 interval difference: {}.'.format(
                    stats['top_3_interval_difference']))
Ejemplo n.º 7
0
def save_stats_in_gnuplot_format(plots_dir, midi_dir):
    """
    method creates files in directory
        gnuplot-commands-midistads.txt: commands for settings in gnuplot program
        midi_stats.gnuplot: file with header information and formatted statistics
    :param plots_dir: (string) directory name where the plots are saved
    :param midi_dir: (string) directory name where the midi files are
    :return: none
    """

    if not os.path.exists(
            os.path.join(plots_dir, 'gnuplot-commands-midistats.txt')):
        with open(os.path.join(plots_dir, 'gnuplot-commands-midistats.txt'),
                  'a') as f:
            f.write(
                #'set terminal postscript eps color butt "Times" 14\nset yrange [0:127]\nset xrange [0:20]\nset output "midistats.eps"\nplot \'midi_stats.gnuplot\' using ($1):(100*$3) title \'Scale consistency, %\' with linespoints, \'midi_stats.gnuplot\' using ($1):($6) title \'Tone span, halftones\' with linespoints, \'midi_stats.gnuplot\' using ($1):($10) title \'Unique tones\' with linespoints, \'midi_stats.gnuplot\' using ($1):($23) title \'Intensity span, units\' with linespoints, \'midi_stats.gnuplot\' using ($1):(100*$24) title \'Polyphony, %\' with linespoints, \'midi_stats.gnuplot\' using ($1):($12) title \'3-tone repetitions\' with linespoints\n')
                'set terminal postscript eps color butt "Times" 14\nset yrange [0:400]\nset xrange [0:26]\nset output "midistats.eps"\nplot \'midi_stats.gnuplot\' using ($1):(100*$3) title \'Scale consistency, %\' with points ps 2, \'midi_stats.gnuplot\' using ($1):($6) title \'Tone span, halftones\' with points ps 2, \'midi_stats.gnuplot\' using ($1):($10) title \'Unique tones\' with points ps 2, \'midi_stats.gnuplot\' using ($1):($13) title \'4-tone-repetitions\' with points ps 2, \'midi_stats.gnuplot\' using ($1):($15) title \'6-tone repetitions\' with points ps 2, \'midi_stats.gnuplot\' using ($1):($17) title \'8-tone repetitions\' with points ps 2, \'midi_stats.gnuplot\' using ($1):($27) title \'Number of tones\' with points ps 2\n'
            )

    try:
        Popen(['gnuplot', 'gnuplot-commands-midistats.txt'], cwd=plots_dir)
    except:
        print(
            'failed to run gnuplot. Please do so yourself: gnuplot gnuplot-commands.txt cwd={}'
            .format(plots_dir))

    ### Getting stats ###
    stats = []
    print('getting stats...')
    stats_time = time.time()
    patterns = []
    for file in glob.glob(midi_dir + "/*.mid"):
        filename = file
        print('File: {}'.format(filename))
        dl = music_data_utils.MusicDataLoader(datadir=None,
                                              select_validation_percentage=0.0,
                                              select_test_percentage=0.0)
        song_data = dl.read_one_file(filename=filename, pace_events=True)
        midi_pattern = dl.get_midi_pattern(song_data)
        patterns.append(midi_pattern)
    for p in patterns:
        stats.append(get_all_stats(p))
    print('done. time: {}'.format(time.time() - stats_time))
    print(stats)
    stats = [stat for stat in stats if stat is not None]

    if len(stats):
        stats_keys_string = ['scale']
        stats_keys = [
            'scale_score', 'tone_min', 'tone_max', 'tone_span', 'freq_min',
            'freq_max', 'freq_span', 'tones_unique', 'repetitions_2',
            'repetitions_3', 'repetitions_4', 'repetitions_5', 'repetitions_6',
            'repetitions_7', 'repetitions_8', 'repetitions_9',
            'estimated_beat', 'estimated_beat_avg_ticks_off', 'intensity_min',
            'intensity_max', 'intensity_span', 'polyphony_score',
            'top_2_interval_difference', 'top_3_interval_difference',
            'num_tones'
        ]
        statsfilename = os.path.join(plots_dir, 'midi_stats.gnuplot')

        if not os.path.exists(statsfilename):
            with open(statsfilename, 'a') as f:
                f.write('# global-step {} {}\n'.format(
                    ' '.join([s.replace(' ', '_') for s in stats_keys_string]),
                    ' '.join(stats_keys)))

        all_stats_string = ''
        i = 0
        for s in stats:
            all_stats_string += '%i ' % i
            for key in stats_keys_string:
                all_stats_string += stats[i][key].replace(' ', '_')
            for key in stats_keys:
                all_stats_string += (' %.3f' % s[key])
            all_stats_string += '\n'
            i += 1

        with open(statsfilename, 'a') as f:
            # Get summary line and write it into file
            f.write(get_gnuplot_line(patterns, len(patterns),
                                     showheader=False))
            # Write all statistics as gnuplot_lines into file
            f.write(all_stats_string)
        print("Saved stats")
Ejemplo n.º 8
0
def main(args):
    ''' Training sequence
    '''
    dataloader = music_data_utils.MusicDataLoader(args.data_dir, composers=args.composers, redo_split=args.redo_split)
    num_feats = dataloader.get_num_song_features()

    # First checking if GPU is available
    train_on_gpu = torch.cuda.is_available()
    if train_on_gpu:
        print('Training on GPU.')
    else:
        print('No GPU available, training on CPU.')

    model = {
        'g': Generator(num_feats, use_cuda=train_on_gpu),
        'd': Discriminator(num_feats, use_cuda=train_on_gpu)
    }

    if args.use_sgd:
        optimizer = {
            'g': optim.SGD(model['g'].parameters(), lr=args.g_lrn_rate, momentum=0.9),
            'd': optim.SGD(model['d'].parameters(), lr=args.d_lrn_rate, momentum=0.9)
        }
    else:
        optimizer = {
            'g': optim.Adam(model['g'].parameters(), args.g_lrn_rate),
            'd': optim.Adam(model['d'].parameters(), args.d_lrn_rate)
        }

    criterion = {
        'g': nn.MSELoss(reduction='sum') if args.feature_matching else GLoss(),
        'd': DLoss(args.label_smoothing)
    }

    if args.load_g:
        ckpt = torch.load(os.path.join(CKPT_DIR, G_FN))
        model['g'].load_state_dict(ckpt)
        print("Continue training of %s" % os.path.join(CKPT_DIR, G_FN))

    if args.load_d:
        ckpt = torch.load(os.path.join(CKPT_DIR, D_FN))
        model['d'].load_state_dict(ckpt)
        print("Continue training of %s" % os.path.join(CKPT_DIR, D_FN))

    if train_on_gpu:
        model['g'].cuda()
        model['d'].cuda()

    if not args.no_pretraining:
        for ep in range(args.d_pretraining_epochs):
            model, _, _, _, _, _ = run_epoch(model, optimizer, criterion, dataloader,
                                             ep, args.d_pretraining_epochs, freeze_g=True, pretraining=True)

        for ep in range(args.g_pretraining_epochs):
            model, _, _, _, _, _ = run_epoch(model, optimizer, criterion, dataloader,
                                             ep, args.g_pretraining_epochs, freeze_d=True, pretraining=True)

    freeze_d = False
    losses = []

    for ep in range(args.num_epochs):
        model, trn_acc,  trn_g_loss, trn_d_loss, val_g_loss, val_d_loss = run_epoch(
            model, optimizer, criterion, dataloader, ep, args.num_epochs, freeze_d=freeze_d)

        losses.append([trn_g_loss, trn_d_loss, val_g_loss,
                       val_d_loss])  # store losses

        if args.conditional_freezing:
            # conditional freezing
            freeze_d = False
            if trn_acc >= 95.0:
                freeze_d = True

    if not args.no_save_g:
        torch.save(model['g'].state_dict(), os.path.join(CKPT_DIR, G_FN))
        print("Saved generator: %s" % os.path.join(CKPT_DIR, G_FN))

    if not args.no_save_d:
        torch.save(model['d'].state_dict(), os.path.join(CKPT_DIR, D_FN))
        print("Saved discriminator: %s" % os.path.join(CKPT_DIR, D_FN))

    if args.plot_loss:
        _, ax = plt.subplots()
        ax.plot([loss[0] for loss in losses], label='G Training Loss')
        ax.plot([loss[1] for loss in losses], label='D Training Loss')
        ax.plot([loss[2] for loss in losses], label='G Validation Loss')
        ax.plot([loss[3] for loss in losses], label='D Validation Loss')
        plt.legend()
        # plt.show()
        plt.savefig('loss_' + str(args.num_epochs) + '_' + time.strftime("%m%d%Y_%H%M%S") + '.png')
Ejemplo n.º 9
0
def main(_):
    if not FLAGS.datadir:
        raise ValueError("Must set --datadir to midi music dir.")
    if not FLAGS.traindir:
        raise ValueError(
            "Must set --traindir to dir where I can save model and plots.")

    restore_flags()

    summaries_dir = None
    plots_dir = None
    generated_data_dir = None
    summaries_dir = os.path.join(FLAGS.traindir, 'summaries')
    plots_dir = os.path.join(FLAGS.traindir, 'plots')
    generated_data_dir = os.path.join(FLAGS.traindir, 'generated_data')
    try:
        os.makedirs(FLAGS.traindir)
    except:
        pass
    try:
        os.makedirs(summaries_dir)
    except:
        pass
    try:
        os.makedirs(plots_dir)
    except:
        pass
    try:
        os.makedirs(generated_data_dir)
    except:
        pass
    directorynames = FLAGS.traindir.split('/')
    experiment_label = ''
    while not experiment_label:
        experiment_label = directorynames.pop()

    global_step = -1
    if os.path.exists(os.path.join(FLAGS.traindir, 'global_step.pkl')):
        with open(os.path.join(FLAGS.traindir, 'global_step.pkl'), 'r') as f:
            global_step = pkl.load(f)
    global_step += 1

    xfeatures_filename = os.path.join(FLAGS.traindir, 'num_x_features.pkl')
    zfeatures_filename = os.path.join(FLAGS.traindir, 'num_z_features.pkl')

    loader = music_data_utils.MusicDataLoader(FLAGS.datadir)
    num_x_features = loader.get_num_x_features()
    print('num_x_features:{}'.format(num_x_features))
    num_z_features = loader.get_num_z_features()
    print('num_z_features:{}'.format(num_z_features))
    num_meta_features = loader.get_num_meta_features()
    print('num_meta_features:{}'.format(num_meta_features))

    train_start_time = time.time()
    checkpoint_path = os.path.join(FLAGS.traindir, "model.ckpt")

    songlength_ceiling = FLAGS.songlength

    if global_step < FLAGS.pretraining_epochs:
        #FLAGS.songlength = int(min(((global_step+10)/10)*10,songlength_ceiling))
        FLAGS.songlength = int(min((global_step + 1) * 4, songlength_ceiling))

    with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as session:
        with tf.variable_scope("model", reuse=None) as scope:
            scope.set_regularizer(
                tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))
            m = RNNGAN(is_training=True,
                       num_x_features=num_x_features,
                       num_z_features=num_z_features,
                       num_meta_features=num_meta_features)

        if FLAGS.initialize_d:
            vars_to_restore = {}
            for v in tf.trainable_variables():
                if v.name.startswith('model/G/'):
                    print(v.name[:-2])
                    vars_to_restore[v.name[:-2]] = v
            saver = tf.train.Saver(vars_to_restore)
            ckpt = tf.train.get_checkpoint_state(FLAGS.traindir)
            if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
                print("Reading model parameters from %s" %
                      ckpt.model_checkpoint_path)
                saver.restore(session, ckpt.model_checkpoint_path)
                session.run(
                    tf.initialize_variables([
                        v for v in tf.trainable_variables()
                        if v.name.startswith('model/D/')
                    ]))
            else:
                print("Created model with fresh parameters.")
                session.run(tf.initialize_all_variables())
            saver = tf.train.Saver(tf.all_variables())
        else:
            saver = tf.train.Saver(tf.all_variables())
            ckpt = tf.train.get_checkpoint_state(FLAGS.traindir)
            if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
                print("Reading model parameters from %s" %
                      ckpt.model_checkpoint_path)
                saver.restore(session, ckpt.model_checkpoint_path)
            else:
                print("Created model with fresh parameters.")
                session.run(tf.initialize_all_variables())

        summary_op = tf.merge_all_summaries()
        summary_writer = tf.train.SummaryWriter(FLAGS.traindir,
                                                graph_def=session.graph_def)

        run_metadata = None
        if FLAGS.profiling:
            run_metadata = tf.RunMetadata()
        if not FLAGS.sample:
            train_g_loss, train_d_loss = 1.0, 1.0
            for i in range(global_step, FLAGS.max_epoch):
                lr_decay = FLAGS.lr_decay**max(i - FLAGS.epochs_before_decay,
                                               0.0)

                if global_step < FLAGS.pretraining_epochs:
                    #new_songlength = int(min(((i+10)/10)*10,songlength_ceiling))
                    new_songlength = int(min((i + 1) * 4, songlength_ceiling))
                else:
                    new_songlength = songlength_ceiling
                if new_songlength != FLAGS.songlength:
                    print(
                        'Changing songlength, now training on {} events from songs.'
                        .format(new_songlength))
                    FLAGS.songlength = new_songlength
                    with tf.variable_scope("model", reuse=True) as scope:
                        scope.set_regularizer(
                            tf.contrib.layers.l2_regularizer(
                                scale=FLAGS.reg_scale))
                        m = RNNGAN(is_training=True,
                                   num_x_features=num_x_features,
                                   num_z_features=num_z_features,
                                   num_meta_features=num_meta_features)

                if not FLAGS.adam:
                    m.assign_lr(session, FLAGS.learning_rate * lr_decay)

                save = False
                do_exit = False

                print(
                    "Epoch: {} Learning rate: {:.3f}, pretraining: {}".format(
                        i, session.run(m.lr), (i < FLAGS.pretraining_epochs)))
                if i < FLAGS.pretraining_epochs:
                    opt_d = tf.no_op()
                    if FLAGS.pretraining_d:
                        opt_d = m.opt_d
                    train_g_loss, train_d_loss = run_epoch(
                        session,
                        m,
                        loader,
                        'train',
                        m.opt_pretraining,
                        opt_d,
                        pretraining=True,
                        verbose=True,
                        run_metadata=run_metadata,
                        pretraining_d=FLAGS.pretraining_d)
                    if FLAGS.pretraining_d:
                        try:
                            print(
                                "Epoch: {} Pretraining loss: G: {:.3f}, D: {:.3f}"
                                .format(i, train_g_loss, train_d_loss))
                        except:
                            print(train_g_loss)
                            print(train_d_loss)
                    else:
                        print("Epoch: {} Pretraining loss: G: {:.3f}".format(
                            i, train_g_loss))
                else:
                    train_g_loss, train_d_loss = run_epoch(
                        session,
                        m,
                        loader,
                        'train',
                        m.opt_d,
                        m.opt_g,
                        verbose=True,
                        run_metadata=run_metadata)
                    try:
                        print("Epoch: {} Train loss: G: {:.3f}, D: {:.3f}".
                              format(i, train_g_loss, train_d_loss))
                    except:
                        print("Epoch: {} Train loss: G: {}, D: {}".format(
                            i, train_g_loss, train_d_loss))
                valid_g_loss, valid_d_loss = run_epoch(session, m,
                                                       loader, 'validation',
                                                       tf.no_op(), tf.no_op())
                try:
                    print("Epoch: {} Valid loss: G: {:.3f}, D: {:.3f}".format(
                        i, valid_g_loss, valid_d_loss))
                except:
                    print("Epoch: {} Valid loss: G: {}, D: {}".format(
                        i, valid_g_loss, valid_d_loss))

                if train_d_loss == 0.0 and train_g_loss == 0.0:
                    print('Both G and D train loss are zero. Exiting.')
                    save = True
                    do_exit = True
                if i % FLAGS.epochs_per_checkpoint == 0:
                    save = True
                if FLAGS.exit_after > 0 and time.time(
                ) - train_start_time > FLAGS.exit_after * 60:
                    print(
                        "%s: Has been running for %d seconds. Will exit (exiting after %d minutes)."
                        % (datetime.datetime.today().strftime(
                            '%Y-%m-%d %H:%M:%S'),
                           (int)(time.time() - train_start_time),
                           FLAGS.exit_after))
                    save = True
                    do_exit = True

                if save:
                    saver.save(session, checkpoint_path, global_step=i)
                    with open(os.path.join(FLAGS.traindir, 'global_step.pkl'),
                              'w') as f:
                        pkl.dump(i, f)
                    if FLAGS.profiling:
                        # Create the Timeline object, and write it to a json
                        tl = timeline.Timeline(run_metadata.step_stats)
                        ctf = tl.generate_chrome_trace_format()
                        with open(os.path.join(plots_dir, 'timeline.json'),
                                  'w') as f:
                            f.write(ctf)
                    print('{}: Saving done!'.format(i))

                step_time, loss = 0.0, 0.0
                if train_d_loss is None:  #pretraining
                    train_d_loss = 0.0
                    valid_d_loss = 0.0
                    valid_g_loss = 0.0
                if not os.path.exists(
                        os.path.join(plots_dir, 'gnuplot-input.txt')):
                    with open(os.path.join(plots_dir, 'gnuplot-input.txt'),
                              'w') as f:
                        f.write(
                            '# global-step learning-rate train-g-loss train-d-loss valid-g-loss valid-d-loss\n'
                        )
                with open(os.path.join(plots_dir, 'gnuplot-input.txt'),
                          'a') as f:
                    try:
                        f.write(
                            '{} {:.4f} {:.2f} {:.2f} {:.3} {:.3f}\n'.format(
                                i, m.lr.eval(), train_g_loss, train_d_loss,
                                valid_g_loss, valid_d_loss))
                    except:
                        f.write('{} {} {} {} {} {}\n'.format(
                            i, m.lr.eval(), train_g_loss, train_d_loss,
                            valid_g_loss, valid_d_loss))
                if not os.path.exists(
                        os.path.join(plots_dir, 'gnuplot-commands-loss.txt')):
                    with open(
                            os.path.join(plots_dir,
                                         'gnuplot-commands-loss.txt'),
                            'a') as f:
                        f.write(
                            'set terminal postscript eps color butt "Times" 14\nset yrange [0:400]\nset output "loss.eps"\nplot \'gnuplot-input.txt\' using ($1):($3) title \'train G\' with linespoints, \'gnuplot-input.txt\' using ($1):($4) title \'train D\' with linespoints, \'gnuplot-input.txt\' using ($1):($5) title \'valid G\' with linespoints, \'gnuplot-input.txt\' using ($1):($6) title \'valid D\' with linespoints, \n'
                        )
                if not os.path.exists(
                        os.path.join(plots_dir,
                                     'gnuplot-commands-midistats.txt')):
                    with open(
                            os.path.join(plots_dir,
                                         'gnuplot-commands-midistats.txt'),
                            'a') as f:
                        f.write(
                            'set terminal postscript eps color butt "Times" 14\nset yrange [0:127]\nset xrange [0:70]\nset output "midistats.eps"\nplot \'midi_stats.gnuplot\' using ($1):(100*$3) title \'Scale consistency, %\' with linespoints, \'midi_stats.gnuplot\' using ($1):($6) title \'Tone span, halftones\' with linespoints, \'midi_stats.gnuplot\' using ($1):($10) title \'Unique tones\' with linespoints, \'midi_stats.gnuplot\' using ($1):($23) title \'Intensity span, units\' with linespoints, \'midi_stats.gnuplot\' using ($1):(100*$24) title \'Polyphony, %\' with linespoints, \'midi_stats.gnuplot\' using ($1):($12) title \'3-tone repetitions\' with linespoints\n'
                        )
                try:
                    Popen(['gnuplot', 'gnuplot-commands-loss.txt'],
                          cwd=plots_dir)
                    Popen(['gnuplot', 'gnuplot-commands-midistats.txt'],
                          cwd=plots_dir)
                except:
                    print(
                        'failed to run gnuplot. Please do so yourself: gnuplot gnuplot-commands.txt cwd={}'
                        .format(plots_dir))

                if do_exit:
                    if FLAGS.call_after is not None:
                        print("%s: Will call \"%s\" before exiting." %
                              (datetime.datetime.today().strftime(
                                  '%Y-%m-%d %H:%M:%S'), FLAGS.call_after))
                        res = call(FLAGS.call_after.split(" "))
                        print('{}: call returned {}.'.format(
                            datetime.datetime.today().strftime(
                                '%Y-%m-%d %H:%M:%S'), res))
                    exit()
                sys.stdout.flush()

            test_g_loss, test_d_loss = run_epoch(session, m, loader, 'test',
                                                 tf.no_op(), tf.no_op())
            print("Test loss G: %.3f, D: %.3f" % (test_g_loss, test_d_loss))