def maybe_fetch_kaggle_dataset(data_root: str, kaggle_id: str, dataset_id: str, kaggle_credential: KaggleCredential) -> None: kaggle_id = ANIME_SKETCH_COLORIZATION_DATASET_KAGGLE_ID dataset_id = ANIME_SKETCH_COLORIZATION_DATASET_DATASET_ID target_dataset = get_kaggle_dataset_id(kaggle_id, dataset_id) fetch_kaggle_dataset_args = [ 'kaggle', 'datasets', 'download', target_dataset ] fork_env = os.environ fork_env[KAGGLE_USERNAME_ENV_ID] = kaggle_credential.username fork_env[KAGGLE_KEY_ENV_ID] = kaggle_credential.key proc = Popen(fetch_kaggle_dataset_args, cwd=data_root, env=fork_env, stdin=PIPE, stdout=PIPE, stderr=PIPE) _, stderr = proc.communicate() if proc.returncode != 0: global_logger.error( 'Fetch dataset from Kaggle failed with return code {ret}'.format( ret=proc.returncode)) global_logger.error('Error message: {msg}'.format(msg=stderr)) raise RuntimeError('Fetch Kaggle dataset failed.') global_logger.info('Fetch Kaggle dataset done.')
def maybe_extract_kaggle_dataset(extract_location: str, zipfile_location: str) -> None: if os.path.exists(extract_location): global_logger.warn('The {dest} directory already exist. Skip.'.format( dest=extract_location)) return with zipfile.ZipFile(zipfile_location, 'r') as zip_ref: zip_ref.extractall(extract_location) global_logger.info('Extract dataset done.')
def main(): args = parser.parse_args() global_logger.info(args) if args.action == 'train': if args.app == 'image_coloring': train_image_coloring(epoch=args.epoch, batch_size=args.batch_size) if args.action == 'system_check': run_system_check() if args.action == 'fetch_kaggle_credential': show_kaggle_credential() if args.action == 'fetch_kaggle_dataset': if args.app == 'image_coloring': fetch_image_coloring_dataset() if args.action == 'generate_mini_dataset': if args.app == 'image_coloring': generate_mini_dataset() global_logger.info('Done.')
def train_image_coloring(epoch: int, batch_size: int) -> None: dataset_gen = AnimeSketchColorizationDatasetGenerator() tf_dataset = dataset_gen.get_tf_dataset().batch(batch_size, drop_remainder=True) generator = ImageColoringGeneratorModel() discriminator = ImageColoringDiscriminatorModel() gan = ImageColoringGanModel(generator, discriminator) for _ in range(epoch): real_color_batch, real_bw_batch, real_y_batch = get_real_samples( tf_dataset, batch_size) fake_color_samples, fake_y_batch = get_fake_samples( generator, real_bw_batch, batch_size) d_loss_real = discriminator.train_on_batch( [real_color_batch, real_bw_batch], real_y_batch) d_loss_fake = discriminator.train_on_batch( [fake_color_samples, real_bw_batch], fake_y_batch) global_logger.info('Training done.')
def generate_mini_dataset() -> None: _ = AnimeSketchColorizationDatasetGenerator(type='PROD') dev_dataset_root = get_data_root('DEV') prod_dataset_root = get_data_root('PROD') dev_dataset_location = get_extract_location( dev_dataset_root, ANIME_SKETCH_COLORIZATION_DATASET_DATASET_ID) prod_dataset_location = get_extract_location( prod_dataset_root, ANIME_SKETCH_COLORIZATION_DATASET_DATASET_ID) recreate_dir(dev_dataset_location) dev_colorgram_location = get_colorgram_location(dev_dataset_location) dev_train_location = get_train_location(dev_dataset_location) dev_val_location = get_val_location(dev_dataset_location) prod_colorgram_location = get_colorgram_location(prod_dataset_location) prod_train_location = get_train_location(prod_dataset_location) prod_val_location = get_val_location(prod_dataset_location) create_dir_if_not_exist(dev_colorgram_location) create_dir_if_not_exist(dev_train_location) create_dir_if_not_exist(dev_val_location) train_ids = get_data_ids(prod_train_location)[:10] val_ids = get_data_ids(prod_val_location)[:10] colorgram_ids = get_data_ids(prod_colorgram_location)[:10] for train_id in train_ids: shutil.copyfile( os.path.join(prod_train_location, '{id}.png'.format(id=train_id)), os.path.join(dev_train_location, '{id}.png'.format(id=train_id))) for val_id in val_ids: shutil.copyfile( os.path.join(prod_val_location, '{id}.png'.format(id=val_id)), os.path.join(dev_val_location, '{id}.png'.format(id=val_id))) for colorgram_id in colorgram_ids: shutil.copyfile( os.path.join(prod_colorgram_location, '{id}.json'.format(id=colorgram_id)), os.path.join(dev_colorgram_location, '{id}.json'.format(id=colorgram_id))) global_logger.info('Generate mini dataset done.')
def show_kaggle_credential() -> None: cred = get_kaggle_credential() global_logger.info( 'Kaggle username is {username}.'.format(username=cred.username)) global_logger.info('Kaggle key is {key}.'.format(key=cred.key))
def run_system_check(): gpus = tf.config.list_physical_devices("GPU") global_logger.info('Found {cnt} GPU devices.'.format(cnt=len(gpus))) global_logger.info(gpus)
def fetch_image_coloring_dataset(): _ = AnimeSketchColorizationDatasetGenerator(type='PROD') global_logger.info('Fetch image coloring dataset. Done.')