def purge_checkpoints(log_dir_root, target_dir, verbose): vprint = print if verbose else no_op.NoOp ckpt_dir_glob = Saver.ckpt_dir_for_log_dir(path.join(log_dir_root, '*')) ckpt_dir_matches = sorted(glob.glob(ckpt_dir_glob)) for ckpt_dir in ckpt_dir_matches: log_dir = Saver.log_dir_from_ckpt_dir(ckpt_dir) all_ckpts = Saver.all_ckpts_with_iterations(ckpt_dir) if len(all_ckpts) <= 5: vprint('Skipping {}'.format(log_dir)) continue target_log_dir = path.join(target_dir, path.basename(log_dir)) target_ckpt_dir = Saver.ckpt_dir_for_log_dir(target_log_dir) os.makedirs(target_ckpt_dir, exist_ok=True) ckpts_to_keep = { all_ckpts[2], all_ckpts[len(all_ckpts) // 2], all_ckpts[-1] } ckpts_to_move = set(all_ckpts) - ckpts_to_keep vprint('Moving to {}:'.format(target_ckpt_dir)) for _, ckpt_to_move in ckpts_to_move: # ckpt_to_move is /path/to/dir/ckpt-7000, add a * to match ckpt-7000.data, .meta, .index for ckpt_file in glob.glob(ckpt_to_move + '*'): vprint('- {}'.format(ckpt_file)) shutil.move(ckpt_file, target_ckpt_dir)
def iter_ckpt_dirs(log_dir_root, job_ids_str): assert os.path.exists(log_dir_root), 'Invalid log dir: {}'.format( log_dir_root) job_ids = job_ids_str.strip().replace(';', ',').split(',') assert len(job_ids) > 0, 'No job_ids!' for job_id in job_ids: # ckpt_dir_for_log_dir appends 'ckpts', which ensures that we only get training log dirs as matches, # and not other or previous validation dir. ckpt_dir_glob = Saver.ckpt_dir_for_log_dir( path.join(log_dir_root, job_id + '*')) ckpt_dir_matches = glob.glob(ckpt_dir_glob) if len(ckpt_dir_matches) == 0: print('*** ERR: No matches for {}'.format(ckpt_dir_glob)) continue if len(ckpt_dir_matches) > 1: print('*** ERR: Multiple matches for {}: {}'.format( ckpt_dir_glob, '\n'.join(ckpt_dir_matches))) continue yield ckpt_dir_matches[0]
def train(autoencoder_config_path, probclass_config_path, restore_manager: RestoreManager, log_dir_root, datasets: Datasets, train_flags: TrainFlags, ckpt_interval_hours: float, description: str): ae_config, ae_config_rel_path = config_parser.parse( autoencoder_config_path) pc_config, pc_config_rel_path = config_parser.parse(probclass_config_path) print_configs(('ae_config', ae_config), ('pc_config', pc_config)) continue_in_ckpt_dir = restore_manager and restore_manager.continue_in_ckpt_dir if continue_in_ckpt_dir: logdir = restore_manager.log_dir else: logdir = logdir_helpers.create_unique_log_dir( [ae_config_rel_path, pc_config_rel_path], log_dir_root, restore_dir=restore_manager.ckpt_dir if restore_manager else None) print(_LOG_DIR_FORMAT.format(logdir)) if description: _write_to_sheets(logdir_helpers.log_date_from_log_dir(logdir), ae_config_rel_path, pc_config_rel_path, description, git_ref=_get_git_ref(), log_dir_root=log_dir_root, is_continue=continue_in_ckpt_dir) ae_cls = autoencoder.get_network_cls(ae_config) pc_cls = probclass.get_network_cls(pc_config) # Instantiate autoencoder and probability classifier ae = ae_cls(ae_config) pc = pc_cls(pc_config, num_centers=ae_config.num_centers) # train --- ip_train = inputpipeline.InputPipeline( inputpipeline.get_dataset(datasets.train), ae_config.crop_size, batch_size=ae_config.batch_size, shuffle=False, num_preprocess_threads=NUM_PREPROCESS_THREADS, num_crops_per_img=NUM_CROPS_PER_IMG) x_train = ip_train.get_batch() enc_out_train = ae.encode( x_train, is_training=True) # qbar is masked by the heatmap x_out_train = ae.decode(enc_out_train.qbar, is_training=True) # stop_gradient is beneficial for training. it prevents multiple gradients flowing into the heatmap. pc_in = tf.stop_gradient(enc_out_train.qbar) bc_train = pc.bitcost(pc_in, enc_out_train.symbols, is_training=True, pad_value=pc.auto_pad_value(ae)) bpp_train = bits.bitcost_to_bpp(bc_train, x_train) d_train = Distortions(ae_config, x_train, x_out_train, is_training=True) # summing over channel dimension gives 2D heatmap heatmap2D = (tf.reduce_sum(enc_out_train.heatmap, 1) if enc_out_train.heatmap is not None else None) # loss --- total_loss, H_real, pc_comps, ae_comps = get_loss(ae_config, ae, pc, d_train.d_loss_scaled, bc_train, enc_out_train.heatmap) train_op = get_train_op(ae_config, pc_config, ip_train, pc.variables(), total_loss) # test --- with tf.name_scope('test'): ip_test = inputpipeline.InputPipeline( inputpipeline.get_dataset(datasets.test), ae_config.crop_size, batch_size=ae_config.batch_size, num_preprocess_threads=NUM_PREPROCESS_THREADS, num_crops_per_img=1, big_queues=False, shuffle=False) x_test = ip_test.get_batch() enc_out_test = ae.encode(x_test, is_training=False) x_out_test = ae.decode(enc_out_test.qhard, is_training=False) bc_test = pc.bitcost(enc_out_test.qhard, enc_out_test.symbols, is_training=False, pad_value=pc.auto_pad_value(ae)) bpp_test = bits.bitcost_to_bpp(bc_test, x_test) d_test = Distortions(ae_config, x_test, x_out_test, is_training=False) try: # Try to get codec distnace for current dataset codec_distance_ms_ssim = CodecDistance(datasets.codec_distance, codec='bpg', metric='ms-ssim') get_distance = functools_ext.catcher(ValueError, handler=functools_ext.const( np.nan), f=codec_distance_ms_ssim.distance) get_distance = functools_ext.compose(np.float32, get_distance) # cast to float32 d_BPG_test = tf.py_func(get_distance, [bpp_test, d_test.ms_ssim], tf.float32, stateful=False, name='d_BPG') d_BPG_test.set_shape(()) except CodecDistanceReadException as e: print('Cannot compute CodecDistance: {}'.format(e)) d_BPG_test = tf.constant(np.nan, shape=(), name='ConstNaN') # --- train_logger = Logger() test_logger = Logger() distortion_name = ae_config.distortion_to_minimize train_logger.add_summaries(d_train.summaries_with_prefix('train')) # Visualize components of losses train_logger.add_summaries([ tf.summary.scalar('train/PC_loss/{}'.format(name), comp) for name, comp in pc_comps ]) train_logger.add_summaries([ tf.summary.scalar('train/AE_loss/{}'.format(name), comp) for name, comp in ae_comps ]) train_logger.add_summaries([tf.summary.scalar('train/bpp', bpp_train)]) train_logger.add_console_tensor('loss={:.3f}', total_loss) train_logger.add_console_tensor('ms_ssim={:.3f}', d_train.ms_ssim) train_logger.add_console_tensor('bpp={:.3f}', bpp_train) train_logger.add_console_tensor('H_real={:.3f}', H_real) test_logger.add_summaries(d_test.summaries_with_prefix('test')) test_logger.add_summaries([ tf.summary.scalar('test/bpp', bpp_test), tf.summary.scalar('test/distance_BPG_MS-SSIM', d_BPG_test), tf.summary.image('test/x_in', prep_for_image_summary(x_test, n=3, name='x_in')), tf.summary.image('test/x_out', prep_for_image_summary(x_out_test, n=3, name='x_out')) ]) if heatmap2D is not None: test_logger.add_summaries([ tf.summary.image( 'test/hm', prep_for_grayscale_image_summary(heatmap2D, n=3, autoscale=True, name='hm')) ]) test_logger.add_console_tensor('ms_ssim={:.3f}', d_test.ms_ssim) test_logger.add_console_tensor('bpp={:.3f}', bpp_test) test_logger.add_summaries([ tf.summary.histogram('centers', ae.get_centers_variable()), tf.summary.histogram( 'test/qbar', enc_out_test.qbar[:ae_config.batch_size // 2, ...]) ]) test_logger.add_console_tensor('d_BPG={:.6f}', d_BPG_test) test_logger.add_console_tensor(Logger.Numpy1DFormatter('centers={}'), ae.get_centers_variable()) print('Starting session and queues...') with tf_helpers.start_queues_in_sess( init_vars=restore_manager is None) as (sess, coord): train_logger.finalize_with_sess(sess) test_logger.finalize_with_sess(sess) if restore_manager: restore_manager.restore(sess) saver = Saver(Saver.ckpt_dir_for_log_dir(logdir), max_to_keep=1, keep_checkpoint_every_n_hours=ckpt_interval_hours) train_loop(ae_config, sess, coord, train_op, train_logger, test_logger, train_flags, logdir, saver, is_restored=restore_manager is not None)
def _get_restore_ckpt_dir(restore_flag): if Saver.is_ckpt_dir(restore_flag): return restore_flag if Saver.is_ckpt_dir(Saver.ckpt_dir_for_log_dir(restore_flag)): return Saver.ckpt_dir_for_log_dir(restore_flag) raise ValueError('Invalid ckpt dir: {}'.format(restore_flag))