def _find_latest_checkpoint(self): """ :return: tuple (stage, checkpoint) if no checkpoint in any stage is found, then (model_start_stage, 0) is returned """ stage = self._find_latest_stage() if stage is not None: checkpoint_file_path = gspath.join(self.args.training_dir, f"stage-{stage}", f"stage-{stage}-ckpt-*.data*") checkpoint_files = gspath.findall(checkpoint_file_path) checkpoint_no = [ int( re.search(r'stage-\d+-ckpt-(\d+)\.data', checkpoint_file).group(1)) for checkpoint_file in checkpoint_files if re.search( r'stage-\d+-ckpt-(\d+)\.data', checkpoint_file) is not None ] else: stage = 0 while self.model_factory.stage_total_songs[ stage] == 0 and stage < self.model_factory.stages: stage += 1 checkpoint_no = [0] return stage, max(checkpoint_no)
def __init__(self, strategy, model_factory: mp3net.MP3netFactory, stage, mode, summary_dir, args): """Model Driver owns the actual TranceModel object and takes care of running it on a distributed environment :param strategy: :param args: user arguments """ self.strategy = strategy self.args = args self.global_batch_size = self.args.batch_size with self.strategy.scope(): # use bfloat16 on tpu if self.args.runtime_tpu: # note: on GPUs we could use 'mixed_float32' but # then we would need to wrap the optimizers in LossScaleOptimizer! (not needed for bfloat32) policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) print(f'Precision:') print(f' Compute dtype: {tf.keras.mixed_precision.experimental.global_policy().compute_dtype}') print(f' Variable dtype: {tf.keras.mixed_precision.experimental.global_policy().variable_dtype}') print() print(f'Discriminator/Generator balance:') print(f' {self.args.n_discr} discriminator updates for each generator update') print() # build keras model print(f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: Building model...") self.model = model_factory.build_model(stage, self.global_batch_size, mode) # set up optimizers # https://towardsdatascience.com/adam-latest-trends-in-deep-learning-optimization-6be9a291375c # SpecGAN lr=1e-4, beta_1=0.5, beta_2=0.9 <=== works best # ProGAN lr=1e-3, beta_1=0.1, beta_2=0.99 # BigGAN lr_G=1e-4 lr_D=4e-4 (for batch<1024), beta_1=0.0, beta_2=0.999 # https://cs231n.github.io/neural-networks-3/ self.discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.5, beta_2=0.9) self.generator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.5, beta_2=0.9) # set up checkpoint manager print(f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: Setting up checkpoint manager...") self.checkpoint = tf.train.Checkpoint(generator_optimizer=self.generator_optimizer, discriminator_optimizer=self.discriminator_optimizer, generator=self.model.generator, discriminator=self.model.discriminator, song_counter=self.model.song_counter) # what to save self.checkpoint_dir = gspath.join(args.training_dir, "stage-{}".format(self.model.stage)) self.checkpoint_manager = tf.train.CheckpointManager(self.checkpoint, self.checkpoint_dir, max_to_keep=5, checkpoint_name="stage-{}-ckpt".format(self.model.stage)) self.checkpoint_freq = args.train_checkpoint_freq # set up summary self.summary_audio_repr = AudioRepresentation(self.model.sample_rate, self.model.freq_n, compute_dtype=tf.float32) self.summary_freq = args.summary_freq self.summary_dir = summary_dir self.summary_writer = None print(f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: Summaries are written to {self.summary_dir}")
def progressive_training(self): stage_start = self._find_latest_stage() if stage_start is not None: print( "Restoring stage {} from checkpoint file".format(stage_start)) else: print("Starting anew from stage 0") stage_start = 0 model_weights = None for stage in range(stage_start, self.model_factory.stages): print() if self.model_factory.stage_total_songs[stage] == 0: print("======= Empty stage {0} =======".format(stage)) else: blocks_n = self.model_factory.blocks_n[stage] freq_n = self.model_factory.freq_n[stage] print("======= Entering stage {0} =======".format(stage)) print(f'Stage {stage}:') print(" * resolution = {0:d} x {1:d}".format( blocks_n, freq_n)) print(" * total number of songs = {0:,}".format( self.model_factory.stage_total_songs[stage])) print() # set up summary writer (different writer for different stage, since different tensors etc) summary_dir = gspath.join(self.args.summary_dir, f"stage-{stage}") # 2. start model driver model_driver = DistributedModelDriver(self.strategy, self.model_factory, stage, mode='train', summary_dir=summary_dir, args=self.args) if model_weights is None: model_driver.load_latest_checkpoint() else: model_driver.load_from_model_weights( model_weights, self.model_factory.map_layer_names_for_model_growth) model_driver.print_model_summary() # 3. run model data_files = self.find_data_files(stage, self.args.data_dir, self.model_factory) model_weights = model_driver.training_loop(data_files) print("======= Exiting stage {} =======".format(stage))
def get_training_dir(args, confirm=True): print("Determining training directory...") # figure out the training directory training_dir = None if args.training_sub_dir is not None: # user specified a tag to use training_dir = gspath.join(args.training_base_dir, args.training_sub_dir) if not gspath.is_dir(training_dir): raise ValueError( "Training directory {} does not exist... exiting".format( training_dir)) if confirm: answer = input( "Do you want to reuse the training directory {}? ".format( training_dir)) if answer.lower()[0] == "y": print("Ok, re-using directory {}".format(training_dir)) else: args.training_sub_dir = None if args.training_sub_dir is None: # creating new training directory args.training_sub_dir = "train_{0}".format( datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S")) training_dir = gspath.join(args.training_base_dir, args.training_sub_dir) print("Training in new directory {}".format(training_dir)) # Make training dir if not gspath.is_dir(training_dir): gspath.mkdir(training_dir) print(f" Training directory is {training_dir}") return training_dir
def inference_loop(self): inference_dir = gspath.join(self.args.training_dir, "infer") if not gspath.is_dir(inference_dir): gspath.mkdir(inference_dir) with self.strategy.scope(): _, l1, l2, l3, l4, _ = self.infer_step() timestamp = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S") imageio.imwrite(os.path.join(inference_dir, 'l{0}_64x16_deep_strip_mean_of_squares_intensity.png'.format(timestamp)), self.model.audio_representation.repr_to_spectrogram(self._collate_probe_tensors(tf.sqrt(tf.reduce_mean(l1 ** 2, axis=-1, keepdims=True))), intensity=True)[0, :, :]) imageio.imwrite(os.path.join(inference_dir, 'l{0}_64x16_deep_strip_ave_intensity.png'.format(timestamp)), self.model.audio_representation.repr_to_spectrogram(self._collate_probe_tensors(tf.reduce_mean(l1, axis=-1, keepdims=True)), intensity=True)[0, :, :]) imageio.imwrite(os.path.join(inference_dir, 'l{0}_64x16_deep_strip_intensity.png'.format(timestamp)), self.model.audio_representation.repr_to_spectrogram(self._collate_probe_tensors(l1), intensity=True)[0, :, :]) imageio.imwrite(os.path.join(inference_dir, 'l{0}_256x32_deep_strip_ave_intensity.png'.format(timestamp)), self.model.audio_representation.repr_to_spectrogram(self._collate_probe_tensors(tf.reduce_mean(l2, axis=-1, keepdims=True)), intensity=True)[0, :, :]) imageio.imwrite(os.path.join(inference_dir, 'l{0}_256x32_deep_strip_intensity.png'.format(timestamp)), self.model.audio_representation.repr_to_spectrogram(self._collate_probe_tensors(l2), intensity=True)[0, :, :]) imageio.imwrite(os.path.join(inference_dir, 'l{0}_256x32_blurred_output_intensity.png'.format(timestamp)), self.model.audio_representation.repr_to_spectrogram(self.model.blur_layer(l4), intensity=True)[0, :, :]) imageio.imwrite(os.path.join(inference_dir, 'l{0}_256x32_blurred_output.png'.format(timestamp)), self.model.audio_representation.repr_to_spectrogram(self.model.blur_layer(l4), intensity=False)[0, :, :]) imageio.imwrite(os.path.join(inference_dir, 'l{0}_1024x64_output_strip_intensity.png'.format(timestamp)), self.model.audio_representation.repr_to_spectrogram(self._collate_probe_tensors(l3), intensity=True)[0, :, :]) imageio.imwrite(os.path.join(inference_dir, 'l{0}_1024x64_output.png'.format(timestamp)), self.model.audio_representation.repr_to_spectrogram(l4, intensity=False)[0, :, :]) imageio.imwrite(os.path.join(inference_dir, 'l{0}_1024x64_output_intensity.png'.format(timestamp)), self.model.audio_representation.repr_to_spectrogram(l4, intensity=True)[0, :, :]) wave_data = self.model.audio_representation.repr_to_audio(l4)[0, :, :] audio_utils.save_audio(os.path.join(inference_dir, 'l{0}_audio.wav'.format(timestamp)), wave_data, self.model.audio_representation.sample_rate)
def _find_latest_stage(self): """ :return: number of latest saved model stage, or None is no checkpoints in any stages were found """ checkpoint_file_path = gspath.join(self.args.training_dir, "stage-*", "stage-*-ckpt-*.data*") checkpoint_files = gspath.findall(checkpoint_file_path) stages = [ int( re.search(r'stage-(\d+)-ckpt-\d+\.data', checkpoint_file).group(1)) for checkpoint_file in checkpoint_files if re.search( r'stage-(\d+)-ckpt-\d+\.data', checkpoint_file) is not None ] return max(stages) if stages else None
def find_data_files(training_stage, data_dir, model_factory): freq_n = model_factory.freq_n[training_stage] channels_n = model_factory.channels_n[training_stage] # find appropriate pre-processed data files file_pattern = f'*_sr{model_factory.sample_rate}_Nx{freq_n}x{channels_n}.tfrecord' input_file_path = gspath.join(data_dir, file_pattern) input_files = gspath.findall(input_file_path)[0:] if len(input_files) == 0: raise ValueError( f'Did not find any preprocessed file {file_pattern} in directory {data_dir}' ) print("Found {0} data files in {1}: ".format(len(input_files), data_dir)) for input_filepath in input_files: print(" {}".format(gspath.split(input_filepath)[1])) return input_files
def evaluation_loop(self): current_stage, current_checkpoint_no = self._find_latest_checkpoint() print( f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: Most recent is stage-{current_stage} and checkpoint-{current_checkpoint_no}..." ) # loop over different stages while True: print( f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: Entering stage-{current_stage}..." ) print( f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: Initializing model and data for stage-{current_stage}..." ) # set up summary writer (different writer for different stage, since different tensors etc) summary_dir = gspath.join(self.args.summary_dir, f"stage-{current_stage}") model_driver = DistributedModelDriver(self.strategy, self.model_factory, current_stage, mode='eval', summary_dir=summary_dir, args=self.args) data_files = self.find_data_files(current_stage, self.args.data_dir, self.model_factory) dataset_iter = model_driver.get_distributed_dataset(data_files) current_checkpoint_no = max(0, current_checkpoint_no - 1) # eternal evaluation loop of writing summaries for stage = current_stage while True: new_stage, new_checkpoint_no = self._find_latest_checkpoint() if new_stage > current_stage: print( f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: New stage found..." ) current_stage = new_stage current_checkpoint_no = new_checkpoint_no - 1 # minus 1 since it's still untreated! break if new_stage > current_stage or ( new_stage == current_stage and new_checkpoint_no > current_checkpoint_no): print( f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: New checkpoint found..." ) # wait a bit more, to make sure all checkpoint files are saved to disk model_driver.load_latest_checkpoint() model_driver.evaluation_loop(dataset_iter) print( f"{datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')}: Waiting for new checkpoint..." ) current_stage = new_stage current_checkpoint_no = new_checkpoint_no time.sleep(5)
def write_summary_gan(self, reals_batch_per_replica, step): # make sure a summery_writer is open if self.summary_writer is not None and tf.math.floormod(step, 5 * self.summary_freq) == 0 and step > 0: self.summary_writer.close() self.summary_writer = None if self.summary_writer is None: self.summary_writer = SummaryWriter(self.summary_dir) print(f'{datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")}: Opened new summary file... ') print(f'{datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")}: Step {step}, writing summary... ', end='') # compute 1 full batch on all replicas and track the metrics loss_tuple, metric_tuple, score_tuple, fakes_tuple = self.eval_step_gan(reals_batch_per_replica) # unpack g_loss, d_loss = loss_tuple wdistance, grad_penalty, drift = metric_tuple fake_scores, real_scores = score_tuple fakes_all, fakes_all_noised = fakes_tuple # make the reals here reals_with_stddev = self._merge_batch_over_replicas(reals_batch_per_replica) reals = reals_with_stddev[:, :, :, :, 0] masking_threshold = reals_with_stddev[:, :, :, :, 1] reals_noised = self.model.audio_representation.add_noise(reals, masking_threshold) # find best, worst and mid scores (best first, ie. higher scores first) fakes_sorted_by_index = tf.argsort(tf.squeeze(fake_scores), direction='DESCENDING') reals_sorted_by_index = tf.argsort(tf.squeeze(real_scores), direction='DESCENDING') n = fakes_sorted_by_index.shape[0] fakes = tf.gather(fakes_all, [fakes_sorted_by_index[0], fakes_sorted_by_index[n // 2], fakes_sorted_by_index[-1]], axis=0) reals = tf.gather(reals, [reals_sorted_by_index[0], reals_sorted_by_index[n // 2], reals_sorted_by_index[-1]], axis=0) reals_noise = tf.gather(reals_noised, [reals_sorted_by_index[0], reals_sorted_by_index[n // 2], reals_sorted_by_index[-1]], axis=0) np.set_printoptions(edgeitems=10, linewidth=500) print() print(f"Fakes[0, :, :, 0]:") print(fakes[0, :, :, 0]) print(f"Minimum = {tf.reduce_min(tf.where(fakes == 0, 1., tf.abs(fakes)))}") fakes = tf.cast(fakes, dtype=tf.float32) reals_noise = tf.cast(reals_noise, dtype=tf.float32) # write scalars... self.summary_writer.add_scalar('10_W_distance', wdistance.numpy(), global_step=step) self.summary_writer.add_scalar('20_loss/D', d_loss.numpy(), global_step=step) self.summary_writer.add_scalar('20_loss/D_GP', grad_penalty.numpy(), global_step=step) self.summary_writer.add_scalar('20_loss/D_drift', drift.numpy(), global_step=step) self.summary_writer.add_scalar('20_loss/G', g_loss.numpy(), global_step=step) self.summary_writer.add_scalar("50_progression/Fade-in", tf.reshape(self.model.fade_in(), shape=[]).numpy(), global_step=step) self.summary_writer.add_scalar("50_progression/Drown", self.model.drown().numpy(), global_step=step) fake_tonality = self.summary_audio_repr.tonality(fakes) real_tonality = self.summary_audio_repr.tonality(reals_noise) self.summary_writer.add_scalars("70_tonality/fakes", {'0_ave': tf.reduce_mean(fake_tonality).numpy(), '1_best': tf.reduce_mean(fake_tonality[0:1, :, :, :]).numpy()}, global_step=step) self.summary_writer.add_scalar("70_tonality/reals", tf.reduce_mean(real_tonality).numpy(), global_step=step) # spectrogram for i in range(3): # fake spectrograms self.summary_writer.add_images(f'10_fake/{i}', self.summary_audio_repr.repr_to_spectrogram(fakes)[i, :, :, :].numpy(), global_step=step, dataformats='HWC') self.summary_writer.add_images(f'11_fake_intensity/{i}', self.summary_audio_repr.repr_to_spectrogram(fakes, intensity=True, cmap=cm.CMRmap)[i, :, :, :].numpy(), global_step=step, dataformats='HWC') # real spectrograms self.summary_writer.add_images(f'20_real_noise/{i}', self.summary_audio_repr.repr_to_spectrogram(reals_noise)[i, :, :, :].numpy(), global_step=step, dataformats='HWC') self.summary_writer.add_images(f'21_real_noise_intensity/{i}', self.summary_audio_repr.repr_to_spectrogram(reals_noise, intensity=True, cmap=cm.CMRmap)[i, :, :, :].numpy(), global_step=step, dataformats='HWC') # audio (only if model is in final stage) if self.model.freq_n == self.model.audio_representation.freq_n: infer_dir = self.args.infer_dir if infer_dir is not None and not gspath.is_dir(infer_dir): gspath.mkdir(infer_dir) wav_fake = self.summary_audio_repr.repr_to_audio(fakes) wav_real_noise = self.summary_audio_repr.repr_to_audio(reals_noise) for i in range(3): self.summary_writer.add_audio(f'1_fake/{i}', wav_fake[i, :, :].numpy(), global_step=step, sample_rate=self.model.sample_rate) self.summary_writer.add_audio(f'2_real_noise/{i}', wav_real_noise[i, :, :].numpy(), global_step=step, sample_rate=self.model.sample_rate) if infer_dir is not None: audio_utils.save_audio(gspath.join(infer_dir, f'fake_sample{i}.wav'), wav_fake[i, :, :].numpy(), sample_rate=self.model.sample_rate, out_format='wav') audio_utils.save_audio(gspath.join(infer_dir, f'real_noise_sample{i}.wav'), wav_real_noise[i, :, :].numpy(), sample_rate=self.model.sample_rate, out_format='wav') # histograms self.summary_writer.add_histogram('1_fake', fake_scores.numpy(), global_step=step) self.summary_writer.add_histogram('2_real', real_scores.numpy(), global_step=step) self.summary_writer.add_scalars(f"11_score", {'0_fake_min': tf.reduce_min(fake_scores).numpy(), '1_fake_max': tf.reduce_max(fake_scores).numpy(), '2_real_min': tf.reduce_min(real_scores).numpy(), '3_real_max': tf.reduce_max(real_scores).numpy()}, global_step=step) print(f'done') return step
def execute(args, strategy=None): # set up device strategy if args.runtime_tpu: # rely on the externally defined tpu_strategy if strategy is None: raise Exception("Strategy should be defined for TPU run") print("Running on TPU") elif args.mode == 'eval': # # make only CPU visible for evaluation loop (otherwise it tries to grab the cuda already in use) # physical_devices = tf.config.list_physical_devices('CPU') # tf.config.set_visible_devices(physical_devices) # device = "/cpu:0" device = "/gpu:0" strategy = tf.distribute.OneDeviceStrategy(device) print("Running on {}".format(device)) else: device = "/gpu:0" strategy = tf.distribute.OneDeviceStrategy(device) print("Running on {}".format(device)) model_trainer = ProgressiveTrainer(strategy, args) # start tensorboard tensorboard_process = None try: if args.runtime_launch_tensorboard: print("Launching tensorboard...") exec_str = "tensorboard --logdir={0}".format(args.training_dir) print(' ' + exec_str) tensorboard_process = subprocess.Popen(exec_str) print(f' PID = {tensorboard_process.pid}') # decode running mode if args.mode == 'train': # Save args filepath = gspath.join( args.training_dir, 'args_{0}.txt'.format( datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S"))) data_str = '\n'.join([ str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0]) ]) print('Writing arguments to {}'.format(filepath)) gspath.write(filepath, data_str) model_trainer.progressive_training() elif args.mode == 'eval': model_trainer.evaluation_loop() elif args.mode == 'infer': model_trainer.inference_loop() else: raise NotImplementedError() finally: if tensorboard_process is not None: # kill tensorboard subprocess.call( ['taskkill', '/F', '/T', '/PID', str(tensorboard_process.pid)])