Example #1
0
    def _find_latest_checkpoint(self):
        """
    :return: tuple (stage, checkpoint)
             if no checkpoint in any stage is found, then (model_start_stage, 0) is returned
    """
        stage = self._find_latest_stage()
        if stage is not None:
            checkpoint_file_path = gspath.join(self.args.training_dir,
                                               f"stage-{stage}",
                                               f"stage-{stage}-ckpt-*.data*")
            checkpoint_files = gspath.findall(checkpoint_file_path)
            checkpoint_no = [
                int(
                    re.search(r'stage-\d+-ckpt-(\d+)\.data',
                              checkpoint_file).group(1))
                for checkpoint_file in checkpoint_files if re.search(
                    r'stage-\d+-ckpt-(\d+)\.data', checkpoint_file) is not None
            ]
        else:
            stage = 0
            while self.model_factory.stage_total_songs[
                    stage] == 0 and stage < self.model_factory.stages:
                stage += 1
            checkpoint_no = [0]

        return stage, max(checkpoint_no)
Example #2
0
  def __init__(self, strategy, model_factory: mp3net.MP3netFactory, stage, mode, summary_dir, args):
    """Model Driver owns the actual TranceModel object and takes care of running it on a distributed environment

    :param strategy:
    :param args:             user arguments
    """
    self.strategy = strategy
    self.args = args
    self.global_batch_size = self.args.batch_size

    with self.strategy.scope():
      # use bfloat16 on tpu
      if self.args.runtime_tpu:
        # note: on GPUs we could use 'mixed_float32' but
        # then we would need to wrap the optimizers in LossScaleOptimizer! (not needed for bfloat32)
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

      print(f'Precision:')
      print(f'  Compute dtype:  {tf.keras.mixed_precision.experimental.global_policy().compute_dtype}')
      print(f'  Variable dtype: {tf.keras.mixed_precision.experimental.global_policy().variable_dtype}')
      print()
      print(f'Discriminator/Generator balance:')
      print(f'  {self.args.n_discr} discriminator updates for each generator update')
      print()

      # build keras model
      print(f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: Building model...")
      self.model = model_factory.build_model(stage, self.global_batch_size, mode)

      # set up optimizers
      #   https://towardsdatascience.com/adam-latest-trends-in-deep-learning-optimization-6be9a291375c
      # SpecGAN lr=1e-4, beta_1=0.5, beta_2=0.9  <=== works best
      # ProGAN  lr=1e-3, beta_1=0.1, beta_2=0.99
      # BigGAN  lr_G=1e-4 lr_D=4e-4 (for batch<1024), beta_1=0.0, beta_2=0.999
      # https://cs231n.github.io/neural-networks-3/
      self.discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.5, beta_2=0.9)
      self.generator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.5, beta_2=0.9)

      # set up checkpoint manager
      print(f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: Setting up checkpoint manager...")
      self.checkpoint = tf.train.Checkpoint(generator_optimizer=self.generator_optimizer,
                                            discriminator_optimizer=self.discriminator_optimizer,
                                            generator=self.model.generator,
                                            discriminator=self.model.discriminator,
                                            song_counter=self.model.song_counter)  # what to save
      self.checkpoint_dir = gspath.join(args.training_dir, "stage-{}".format(self.model.stage))
      self.checkpoint_manager = tf.train.CheckpointManager(self.checkpoint, self.checkpoint_dir, max_to_keep=5,
                                                           checkpoint_name="stage-{}-ckpt".format(self.model.stage))
      self.checkpoint_freq = args.train_checkpoint_freq

      # set up summary
      self.summary_audio_repr = AudioRepresentation(self.model.sample_rate, self.model.freq_n, compute_dtype=tf.float32)
      self.summary_freq = args.summary_freq
      self.summary_dir = summary_dir
      self.summary_writer = None
      print(f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: Summaries are written to {self.summary_dir}")
Example #3
0
    def progressive_training(self):
        stage_start = self._find_latest_stage()
        if stage_start is not None:
            print(
                "Restoring stage {} from checkpoint file".format(stage_start))
        else:
            print("Starting anew from stage 0")
            stage_start = 0

        model_weights = None
        for stage in range(stage_start, self.model_factory.stages):
            print()
            if self.model_factory.stage_total_songs[stage] == 0:
                print("======= Empty stage {0} =======".format(stage))
            else:
                blocks_n = self.model_factory.blocks_n[stage]
                freq_n = self.model_factory.freq_n[stage]
                print("======= Entering stage {0} =======".format(stage))
                print(f'Stage {stage}:')
                print("  * resolution            = {0:d} x {1:d}".format(
                    blocks_n, freq_n))
                print("  * total number of songs = {0:,}".format(
                    self.model_factory.stage_total_songs[stage]))
                print()

                # set up summary writer (different writer for different stage, since different tensors etc)
                summary_dir = gspath.join(self.args.summary_dir,
                                          f"stage-{stage}")

                # 2. start model driver
                model_driver = DistributedModelDriver(self.strategy,
                                                      self.model_factory,
                                                      stage,
                                                      mode='train',
                                                      summary_dir=summary_dir,
                                                      args=self.args)
                if model_weights is None:
                    model_driver.load_latest_checkpoint()
                else:
                    model_driver.load_from_model_weights(
                        model_weights,
                        self.model_factory.map_layer_names_for_model_growth)

                model_driver.print_model_summary()

                # 3. run model
                data_files = self.find_data_files(stage, self.args.data_dir,
                                                  self.model_factory)
                model_weights = model_driver.training_loop(data_files)

                print("======= Exiting stage {} =======".format(stage))
Example #4
0
def get_training_dir(args, confirm=True):
    print("Determining training directory...")

    # figure out the training directory
    training_dir = None
    if args.training_sub_dir is not None:
        # user specified a tag to use
        training_dir = gspath.join(args.training_base_dir,
                                   args.training_sub_dir)

        if not gspath.is_dir(training_dir):
            raise ValueError(
                "Training directory {} does not exist... exiting".format(
                    training_dir))

        if confirm:
            answer = input(
                "Do you want to reuse the training directory {}? ".format(
                    training_dir))
            if answer.lower()[0] == "y":
                print("Ok, re-using directory {}".format(training_dir))
            else:
                args.training_sub_dir = None

    if args.training_sub_dir is None:
        # creating new training directory
        args.training_sub_dir = "train_{0}".format(
            datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S"))
        training_dir = gspath.join(args.training_base_dir,
                                   args.training_sub_dir)
        print("Training in new directory {}".format(training_dir))

        # Make training dir
        if not gspath.is_dir(training_dir):
            gspath.mkdir(training_dir)

    print(f"  Training directory is {training_dir}")
    return training_dir
Example #5
0
  def inference_loop(self):
    inference_dir = gspath.join(self.args.training_dir, "infer")
    if not gspath.is_dir(inference_dir):
      gspath.mkdir(inference_dir)

    with self.strategy.scope():
      _, l1, l2, l3, l4, _ = self.infer_step()

      timestamp = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S")

      imageio.imwrite(os.path.join(inference_dir, 'l{0}_64x16_deep_strip_mean_of_squares_intensity.png'.format(timestamp)),
                      self.model.audio_representation.repr_to_spectrogram(self._collate_probe_tensors(tf.sqrt(tf.reduce_mean(l1 ** 2, axis=-1, keepdims=True))),
                                                                          intensity=True)[0, :, :])
      imageio.imwrite(os.path.join(inference_dir, 'l{0}_64x16_deep_strip_ave_intensity.png'.format(timestamp)),
                      self.model.audio_representation.repr_to_spectrogram(self._collate_probe_tensors(tf.reduce_mean(l1, axis=-1, keepdims=True)),
                                                                          intensity=True)[0, :, :])
      imageio.imwrite(os.path.join(inference_dir, 'l{0}_64x16_deep_strip_intensity.png'.format(timestamp)),
                      self.model.audio_representation.repr_to_spectrogram(self._collate_probe_tensors(l1),
                                                                          intensity=True)[0, :, :])

      imageio.imwrite(os.path.join(inference_dir, 'l{0}_256x32_deep_strip_ave_intensity.png'.format(timestamp)),
                      self.model.audio_representation.repr_to_spectrogram(self._collate_probe_tensors(tf.reduce_mean(l2, axis=-1, keepdims=True)),
                                                                          intensity=True)[0, :, :])
      imageio.imwrite(os.path.join(inference_dir, 'l{0}_256x32_deep_strip_intensity.png'.format(timestamp)),
                      self.model.audio_representation.repr_to_spectrogram(self._collate_probe_tensors(l2),
                                                                          intensity=True)[0, :, :])
      imageio.imwrite(os.path.join(inference_dir, 'l{0}_256x32_blurred_output_intensity.png'.format(timestamp)),
                      self.model.audio_representation.repr_to_spectrogram(self.model.blur_layer(l4),
                                                                          intensity=True)[0, :, :])
      imageio.imwrite(os.path.join(inference_dir, 'l{0}_256x32_blurred_output.png'.format(timestamp)),
                      self.model.audio_representation.repr_to_spectrogram(self.model.blur_layer(l4),
                                                                          intensity=False)[0, :, :])


      imageio.imwrite(os.path.join(inference_dir, 'l{0}_1024x64_output_strip_intensity.png'.format(timestamp)),
                      self.model.audio_representation.repr_to_spectrogram(self._collate_probe_tensors(l3),
                                                                          intensity=True)[0, :, :])

      imageio.imwrite(os.path.join(inference_dir, 'l{0}_1024x64_output.png'.format(timestamp)),
                      self.model.audio_representation.repr_to_spectrogram(l4,
                                                                          intensity=False)[0, :, :])
      imageio.imwrite(os.path.join(inference_dir, 'l{0}_1024x64_output_intensity.png'.format(timestamp)),
                      self.model.audio_representation.repr_to_spectrogram(l4,
                                                                          intensity=True)[0, :, :])

      wave_data = self.model.audio_representation.repr_to_audio(l4)[0, :, :]

      audio_utils.save_audio(os.path.join(inference_dir, 'l{0}_audio.wav'.format(timestamp)),
                             wave_data, self.model.audio_representation.sample_rate)
Example #6
0
    def _find_latest_stage(self):
        """
    :return:  number of latest saved model stage, or None is no checkpoints in any stages were found
    """
        checkpoint_file_path = gspath.join(self.args.training_dir, "stage-*",
                                           "stage-*-ckpt-*.data*")
        checkpoint_files = gspath.findall(checkpoint_file_path)
        stages = [
            int(
                re.search(r'stage-(\d+)-ckpt-\d+\.data',
                          checkpoint_file).group(1))
            for checkpoint_file in checkpoint_files if re.search(
                r'stage-(\d+)-ckpt-\d+\.data', checkpoint_file) is not None
        ]

        return max(stages) if stages else None
Example #7
0
    def find_data_files(training_stage, data_dir, model_factory):
        freq_n = model_factory.freq_n[training_stage]
        channels_n = model_factory.channels_n[training_stage]

        # find appropriate pre-processed data files
        file_pattern = f'*_sr{model_factory.sample_rate}_Nx{freq_n}x{channels_n}.tfrecord'
        input_file_path = gspath.join(data_dir, file_pattern)
        input_files = gspath.findall(input_file_path)[0:]
        if len(input_files) == 0:
            raise ValueError(
                f'Did not find any preprocessed file {file_pattern} in directory {data_dir}'
            )
        print("Found {0} data files in {1}: ".format(len(input_files),
                                                     data_dir))
        for input_filepath in input_files:
            print("  {}".format(gspath.split(input_filepath)[1]))

        return input_files
Example #8
0
    def evaluation_loop(self):
        current_stage, current_checkpoint_no = self._find_latest_checkpoint()
        print(
            f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: Most recent is stage-{current_stage} and checkpoint-{current_checkpoint_no}..."
        )

        # loop over different stages
        while True:
            print(
                f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: Entering stage-{current_stage}..."
            )

            print(
                f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: Initializing model and data for stage-{current_stage}..."
            )

            # set up summary writer (different writer for different stage, since different tensors etc)
            summary_dir = gspath.join(self.args.summary_dir,
                                      f"stage-{current_stage}")

            model_driver = DistributedModelDriver(self.strategy,
                                                  self.model_factory,
                                                  current_stage,
                                                  mode='eval',
                                                  summary_dir=summary_dir,
                                                  args=self.args)
            data_files = self.find_data_files(current_stage,
                                              self.args.data_dir,
                                              self.model_factory)
            dataset_iter = model_driver.get_distributed_dataset(data_files)

            current_checkpoint_no = max(0, current_checkpoint_no - 1)

            # eternal evaluation loop of writing summaries for stage = current_stage
            while True:
                new_stage, new_checkpoint_no = self._find_latest_checkpoint()

                if new_stage > current_stage:
                    print(
                        f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: New stage found..."
                    )
                    current_stage = new_stage
                    current_checkpoint_no = new_checkpoint_no - 1  # minus 1 since it's still untreated!
                    break

                if new_stage > current_stage or (
                        new_stage == current_stage
                        and new_checkpoint_no > current_checkpoint_no):
                    print(
                        f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: New checkpoint found..."
                    )
                    # wait a bit more, to make sure all checkpoint files are saved to disk

                    model_driver.load_latest_checkpoint()
                    model_driver.evaluation_loop(dataset_iter)
                    print(
                        f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: Waiting for new checkpoint..."
                    )

                current_stage = new_stage
                current_checkpoint_no = new_checkpoint_no

                time.sleep(5)
Example #9
0
  def write_summary_gan(self, reals_batch_per_replica, step):
    # make sure a summery_writer is open
    if self.summary_writer is not None and tf.math.floormod(step, 5 * self.summary_freq) == 0 and step > 0:
      self.summary_writer.close()
      self.summary_writer = None

    if self.summary_writer is None:
      self.summary_writer = SummaryWriter(self.summary_dir)
      print(f'{datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")}: Opened new summary file... ')

    print(f'{datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")}: Step {step}, writing summary... ', end='')

    # compute 1 full batch on all replicas and track the metrics
    loss_tuple, metric_tuple, score_tuple, fakes_tuple = self.eval_step_gan(reals_batch_per_replica)

    # unpack
    g_loss, d_loss = loss_tuple
    wdistance, grad_penalty, drift = metric_tuple
    fake_scores, real_scores = score_tuple
    fakes_all, fakes_all_noised = fakes_tuple

    # make the reals here
    reals_with_stddev = self._merge_batch_over_replicas(reals_batch_per_replica)

    reals = reals_with_stddev[:, :, :, :, 0]
    masking_threshold = reals_with_stddev[:, :, :, :, 1]
    reals_noised = self.model.audio_representation.add_noise(reals, masking_threshold)

    # find best, worst and mid scores (best first, ie. higher scores first)
    fakes_sorted_by_index = tf.argsort(tf.squeeze(fake_scores), direction='DESCENDING')
    reals_sorted_by_index = tf.argsort(tf.squeeze(real_scores), direction='DESCENDING')
    n = fakes_sorted_by_index.shape[0]

    fakes = tf.gather(fakes_all, [fakes_sorted_by_index[0], fakes_sorted_by_index[n // 2], fakes_sorted_by_index[-1]], axis=0)
    reals = tf.gather(reals, [reals_sorted_by_index[0], reals_sorted_by_index[n // 2], reals_sorted_by_index[-1]], axis=0)
    reals_noise = tf.gather(reals_noised, [reals_sorted_by_index[0], reals_sorted_by_index[n // 2], reals_sorted_by_index[-1]], axis=0)

    np.set_printoptions(edgeitems=10, linewidth=500)
    print()
    print(f"Fakes[0, :, :, 0]:")
    print(fakes[0, :, :, 0])
    print(f"Minimum = {tf.reduce_min(tf.where(fakes == 0, 1., tf.abs(fakes)))}")

    fakes = tf.cast(fakes, dtype=tf.float32)
    reals_noise = tf.cast(reals_noise, dtype=tf.float32)

    # write scalars...
    self.summary_writer.add_scalar('10_W_distance', wdistance.numpy(), global_step=step)

    self.summary_writer.add_scalar('20_loss/D', d_loss.numpy(), global_step=step)
    self.summary_writer.add_scalar('20_loss/D_GP', grad_penalty.numpy(), global_step=step)
    self.summary_writer.add_scalar('20_loss/D_drift', drift.numpy(), global_step=step)
    self.summary_writer.add_scalar('20_loss/G', g_loss.numpy(), global_step=step)

    self.summary_writer.add_scalar("50_progression/Fade-in", tf.reshape(self.model.fade_in(), shape=[]).numpy(), global_step=step)
    self.summary_writer.add_scalar("50_progression/Drown", self.model.drown().numpy(), global_step=step)

    fake_tonality = self.summary_audio_repr.tonality(fakes)
    real_tonality = self.summary_audio_repr.tonality(reals_noise)
    self.summary_writer.add_scalars("70_tonality/fakes", {'0_ave': tf.reduce_mean(fake_tonality).numpy(),
                                              '1_best': tf.reduce_mean(fake_tonality[0:1, :, :, :]).numpy()}, global_step=step)
    self.summary_writer.add_scalar("70_tonality/reals", tf.reduce_mean(real_tonality).numpy(), global_step=step)

    # spectrogram
    for i in range(3):
      # fake spectrograms
      self.summary_writer.add_images(f'10_fake/{i}',
                         self.summary_audio_repr.repr_to_spectrogram(fakes)[i, :, :, :].numpy(), global_step=step, dataformats='HWC')
      self.summary_writer.add_images(f'11_fake_intensity/{i}',
                         self.summary_audio_repr.repr_to_spectrogram(fakes, intensity=True, cmap=cm.CMRmap)[i, :, :, :].numpy(), global_step=step, dataformats='HWC')

      # real spectrograms
      self.summary_writer.add_images(f'20_real_noise/{i}',
                         self.summary_audio_repr.repr_to_spectrogram(reals_noise)[i, :, :, :].numpy(), global_step=step, dataformats='HWC')
      self.summary_writer.add_images(f'21_real_noise_intensity/{i}',
                         self.summary_audio_repr.repr_to_spectrogram(reals_noise, intensity=True, cmap=cm.CMRmap)[i, :, :, :].numpy(), global_step=step, dataformats='HWC')

    # audio (only if model is in final stage)
    if self.model.freq_n == self.model.audio_representation.freq_n:
      infer_dir = self.args.infer_dir
      if infer_dir is not None and not gspath.is_dir(infer_dir):
        gspath.mkdir(infer_dir)

      wav_fake = self.summary_audio_repr.repr_to_audio(fakes)
      wav_real_noise = self.summary_audio_repr.repr_to_audio(reals_noise)
      for i in range(3):
        self.summary_writer.add_audio(f'1_fake/{i}', wav_fake[i, :, :].numpy(), global_step=step, sample_rate=self.model.sample_rate)
        self.summary_writer.add_audio(f'2_real_noise/{i}', wav_real_noise[i, :, :].numpy(), global_step=step, sample_rate=self.model.sample_rate)

        if infer_dir is not None:
          audio_utils.save_audio(gspath.join(infer_dir, f'fake_sample{i}.wav'), wav_fake[i, :, :].numpy(), sample_rate=self.model.sample_rate, out_format='wav')
          audio_utils.save_audio(gspath.join(infer_dir, f'real_noise_sample{i}.wav'), wav_real_noise[i, :, :].numpy(), sample_rate=self.model.sample_rate, out_format='wav')

    # histograms
    self.summary_writer.add_histogram('1_fake', fake_scores.numpy(), global_step=step)
    self.summary_writer.add_histogram('2_real', real_scores.numpy(), global_step=step)

    self.summary_writer.add_scalars(f"11_score", {'0_fake_min': tf.reduce_min(fake_scores).numpy(),
                                      '1_fake_max': tf.reduce_max(fake_scores).numpy(),
                                      '2_real_min': tf.reduce_min(real_scores).numpy(),
                                      '3_real_max': tf.reduce_max(real_scores).numpy()}, global_step=step)

    print(f'done')

    return step
Example #10
0
def execute(args, strategy=None):
    # set up device strategy
    if args.runtime_tpu:
        # rely on the externally defined tpu_strategy
        if strategy is None:
            raise Exception("Strategy should be defined for TPU run")
        print("Running on TPU")
    elif args.mode == 'eval':
        # # make only CPU visible for evaluation loop (otherwise it tries to grab the cuda already in use)
        # physical_devices = tf.config.list_physical_devices('CPU')
        # tf.config.set_visible_devices(physical_devices)
        # device = "/cpu:0"

        device = "/gpu:0"
        strategy = tf.distribute.OneDeviceStrategy(device)
        print("Running on {}".format(device))
    else:
        device = "/gpu:0"
        strategy = tf.distribute.OneDeviceStrategy(device)
        print("Running on {}".format(device))

    model_trainer = ProgressiveTrainer(strategy, args)

    # start tensorboard
    tensorboard_process = None
    try:
        if args.runtime_launch_tensorboard:
            print("Launching tensorboard...")
            exec_str = "tensorboard --logdir={0}".format(args.training_dir)
            print('  ' + exec_str)
            tensorboard_process = subprocess.Popen(exec_str)
            print(f'  PID = {tensorboard_process.pid}')

        # decode running mode
        if args.mode == 'train':
            # Save args
            filepath = gspath.join(
                args.training_dir, 'args_{0}.txt'.format(
                    datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S")))
            data_str = '\n'.join([
                str(k) + ',' + str(v)
                for k, v in sorted(vars(args).items(), key=lambda x: x[0])
            ])
            print('Writing arguments to {}'.format(filepath))
            gspath.write(filepath, data_str)

            model_trainer.progressive_training()

        elif args.mode == 'eval':
            model_trainer.evaluation_loop()

        elif args.mode == 'infer':
            model_trainer.inference_loop()

        else:
            raise NotImplementedError()

    finally:
        if tensorboard_process is not None:
            # kill tensorboard
            subprocess.call(
                ['taskkill', '/F', '/T', '/PID',
                 str(tensorboard_process.pid)])