def test_glob(ctx, parallel): contents = b"meow!" with ctx() as path: dirpath = bf.dirname(path) a_path = bf.join(dirpath, "ab") with bf.BlobFile(a_path, "wb") as w: w.write(contents) b_path = bf.join(dirpath, "bb") with bf.BlobFile(b_path, "wb") as w: w.write(contents) def assert_listing_equal(path, desired): desired = sorted([bf.join(dirpath, p) for p in desired]) actual = sorted(list(bf.glob(path, parallel=parallel))) assert actual == desired, f"{actual} != {desired}" assert_listing_equal(bf.join(dirpath, "*b"), ["ab", "bb"]) assert_listing_equal(bf.join(dirpath, "a*"), ["ab"]) assert_listing_equal(bf.join(dirpath, "ab*"), ["ab"]) assert_listing_equal(bf.join(dirpath, "*"), ["ab", "bb"]) assert_listing_equal(bf.join(dirpath, "bb"), ["bb"]) path = bf.join(dirpath, "test.txt") with bf.BlobFile(path, "wb") as w: w.write(contents) path = bf.join(dirpath, "subdir", "test.txt") bf.makedirs(bf.dirname(path)) with bf.BlobFile(path, "wb") as f: f.write(contents) path = bf.join(dirpath, "subdir", "subsubdir", "test.txt") if "://" not in path: # implicit directory bf.makedirs(bf.dirname(path)) with bf.BlobFile(path, "wb") as f: f.write(contents) assert_listing_equal(bf.join(dirpath, "*/test.txt"), ["subdir/test.txt"]) assert_listing_equal(bf.join(dirpath, "*/*.txt"), ["subdir/test.txt"]) if "://" in path: # local glob doesn't handle ** the same way as remote glob assert_listing_equal( bf.join(dirpath, "**.txt"), ["test.txt", "subdir/test.txt", "subdir/subsubdir/test.txt"], ) else: assert_listing_equal(bf.join(dirpath, "**.txt"), ["test.txt"]) assert_listing_equal(bf.join(dirpath, "*/test"), []) assert_listing_equal(bf.join(dirpath, "subdir/test.txt"), ["subdir/test.txt"]) # directories assert_listing_equal(bf.join(dirpath, "*"), ["ab", "bb", "subdir", "test.txt"]) assert_listing_equal(bf.join(dirpath, "subdir"), ["subdir"]) assert_listing_equal(bf.join(dirpath, "subdir/"), ["subdir"]) assert_listing_equal(bf.join(dirpath, "*/"), ["subdir"]) assert_listing_equal(bf.join(dirpath, "*dir"), ["subdir"]) assert_listing_equal(bf.join(dirpath, "subdir/*dir"), ["subdir/subsubdir"]) assert_listing_equal(bf.join(dirpath, "subdir/*dir/"), ["subdir/subsubdir"]) assert_listing_equal(bf.join(dirpath, "su*ir/*dir/"), ["subdir/subsubdir"])
def find_ema_checkpoint(main_checkpoint, step, rate): if main_checkpoint is None: return None filename = f"ema_{rate}_{(step):06d}.pt" path = bf.join(bf.dirname(main_checkpoint), filename) if bf.exists(path): return path return None
def _load_optimizer_state(self): main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint opt_checkpoint = bf.join(bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt") if bf.exists(opt_checkpoint): logger.log( f"loading optimizer state from checkpoint: {opt_checkpoint}") state_dict = dist_util.load_state_dict( opt_checkpoint, map_location=dist_util.dev()) self.opt.load_state_dict(state_dict)
def test_isdir(ctx): contents = b"meow!" with ctx() as path: assert not bf.isdir(path) _write_contents(path, contents) assert not bf.isdir(path) dirpath = path + ".dir" bf.makedirs(dirpath) assert bf.isdir(dirpath) assert not bf.isdir(dirpath[:-1]) filepath = bf.join(path + ".otherdir", "subdir", "file.name") if "://" not in path: # implicit directory bf.makedirs(bf.dirname(filepath)) dirpath = bf.dirname(bf.dirname(filepath)) _write_contents(filepath, contents) assert bf.isdir(dirpath) assert not bf.isdir(dirpath[:-1])
def test_listdir(ctx): contents = b"meow!" with ctx() as path: dirpath = bf.dirname(path) a_path = bf.join(dirpath, "a") with bf.BlobFile(a_path, "wb") as w: w.write(contents) b_path = bf.join(dirpath, "b") with bf.BlobFile(b_path, "wb") as w: w.write(contents) bf.makedirs(bf.join(dirpath, "c")) assert sorted(list(bf.listdir(dirpath))) == ["a", "b", "c"]
def test_read_write(ctx, streaming): contents = b"meow!\npurr\n" with ctx() as path: path = bf.join(path, "a folder", "a.file") bf.makedirs(bf.dirname(path)) with bf.BlobFile(path, "wb", streaming=streaming) as w: w.write(contents) with bf.BlobFile(path, "rb", streaming=streaming) as r: assert r.read() == contents with bf.BlobFile(path, "rb", streaming=streaming) as r: lines = list(r) assert b"".join(lines) == contents
def test_scanglob(ctx): contents = b"meow!" with ctx() as path: dirpath = bf.dirname(path) a_path = bf.join(dirpath, "ab") with bf.BlobFile(a_path, "wb") as w: w.write(contents) b_path = bf.join(dirpath, "bb") with bf.BlobFile(b_path, "wb") as w: w.write(contents) path = bf.join(dirpath, "test.txt") with bf.BlobFile(path, "wb") as w: w.write(contents) path = bf.join(dirpath, "subdir", "test.txt") bf.makedirs(bf.dirname(path)) with bf.BlobFile(path, "wb") as f: f.write(contents) entries = sorted(list(bf.scanglob(bf.join(dirpath, "*b*")))) assert entries[0].name == "ab" and entries[0].is_file assert entries[1].name == "bb" and entries[1].is_file assert entries[2].name == "subdir" and entries[2].is_dir
def test_create_local_intermediate_dirs(): contents = b"meow" with _get_temp_local_path() as path: dirpath = bf.dirname(path) with chdir(dirpath): for filepath in [ bf.join(dirpath, "dirname", "file.name"), bf.join("..", bf.basename(dirpath), "file.name"), "./file.name", "file.name", ]: with bf.BlobFile(filepath, "wb") as f: f.write(contents)
def test_rmtree(ctx): contents = b"meow!" with ctx() as path: root = bf.dirname(path) destroy_path = bf.join(root, "destroy") bf.makedirs(destroy_path) save_path = bf.join(root, "save") bf.makedirs(save_path) # implicit dir if not "://" in path: bf.makedirs(bf.join(destroy_path, "adir")) with bf.BlobFile(bf.join(destroy_path, "adir/b"), "wb") as w: w.write(contents) # explicit dir bf.makedirs(bf.join(destroy_path, "bdir")) with bf.BlobFile(bf.join(destroy_path, "bdir/b"), "wb") as w: w.write(contents) bf.makedirs(bf.join(save_path, "somedir")) with bf.BlobFile(bf.join(save_path, "somefile"), "wb") as w: w.write(contents) def assert_listing_equal(path, desired): actual = list(bf.walk(path)) # ordering of os walk is weird, only compare sorted order assert sorted(actual) == sorted(desired), f"{actual} != {desired}" assert_listing_equal( root, [ (root, ["destroy", "save"], []), (destroy_path, ["adir", "bdir"], []), (bf.join(destroy_path, "adir"), [], ["b"]), (bf.join(destroy_path, "bdir"), [], ["b"]), (save_path, ["somedir"], ["somefile"]), (bf.join(save_path, "somedir"), [], []), ], ) bf.rmtree(destroy_path) assert_listing_equal( root, [ (root, ["save"], []), (save_path, ["somedir"], ["somefile"]), (bf.join(save_path, "somedir"), [], []), ], )
def test_az_path(): contents = b"meow!\npurr\n" with _get_temp_as_path() as path: path = _convert_https_to_az(path) path = bf.join(path, "a folder", "a.file") path = _convert_https_to_az(path) bf.makedirs(_convert_https_to_az(bf.dirname(path))) with bf.BlobFile(path, "wb") as w: w.write(contents) with bf.BlobFile(path, "rb") as r: assert r.read() == contents with bf.BlobFile(path, "rb") as r: lines = list(r) assert b"".join(lines) == contents
def test_dirname(): testcases = [ ("a", ""), ("a/b", "a"), ("a/b/c", "a/b"), ("a/b/c/", "a/b/c"), ("a/b/c/////", "a/b/c"), ("", ""), ("gs://a", "gs://a"), ("gs://a/", "gs://a"), ("gs://a/////", "gs://a"), ("gs://a/b", "gs://a"), ("gs://a/b/c/test.filename", "gs://a/b/c"), ("gs://a/b/c/", "gs://a/b"), ("gs://a/b/c/////", "gs://a/b"), ( "https://a.blob.core.windows.net/container", "https://a.blob.core.windows.net/container", ), ( "https://a.blob.core.windows.net/container/", "https://a.blob.core.windows.net/container", ), ( "https://a.blob.core.windows.net/container/////", "https://a.blob.core.windows.net/container", ), ( "https://a.blob.core.windows.net/container/b", "https://a.blob.core.windows.net/container", ), ( "https://a.blob.core.windows.net/container/b/c/test.filename", "https://a.blob.core.windows.net/container/b/c", ), ( "https://a.blob.core.windows.net/container/b/c/", "https://a.blob.core.windows.net/container/b", ), ( "https://a.blob.core.windows.net/container/b/c//////", "https://a.blob.core.windows.net/container/b", ), ] for input_, desired_output in testcases: actual_output = bf.dirname(input_) assert desired_output == actual_output, f"{input_}"
def test_walk(ctx, topdown): contents = b"meow!" with ctx() as path: dirpath = bf.dirname(path) a_path = bf.join(dirpath, "a") with bf.BlobFile(a_path, "wb") as w: w.write(contents) bf.makedirs(bf.join(dirpath, "c/d")) b_path = bf.join(dirpath, "c/d/b") with bf.BlobFile(b_path, "wb") as w: w.write(contents) expected = [ (dirpath, ["c"], ["a"]), (bf.join(dirpath, "c"), ["d"], []), (bf.join(dirpath, "c", "d"), [], ["b"]), ] if not topdown: expected = list(reversed(expected)) assert list(bf.walk(dirpath, topdown=topdown)) == expected
def test_scandir(ctx): contents = b"meow!" with ctx() as path: dirpath = bf.dirname(path) a_path = bf.join(dirpath, "a") with bf.BlobFile(a_path, "wb") as w: w.write(contents) b_path = bf.join(dirpath, "b") with bf.BlobFile(b_path, "wb") as w: w.write(contents) bf.makedirs(bf.join(dirpath, "c")) entries = sorted(list(bf.scandir(dirpath))) assert [e.name for e in entries] == ["a", "b", "c"] assert [e.path for e in entries ] == [bf.join(dirpath, name) for name in ["a", "b", "c"]] assert [e.is_dir for e in entries] == [False, False, True] assert [e.is_file for e in entries] == [True, True, False] assert entries[0].stat.size == len(contents) assert entries[1].stat.size == len(contents) assert entries[2].stat is None
def test_listdir_sharded(ctx): contents = b"meow!" with ctx() as path: dirpath = bf.dirname(path) with bf.BlobFile(bf.join(dirpath, "a"), "wb") as w: w.write(contents) with bf.BlobFile(bf.join(dirpath, "aa"), "wb") as w: w.write(contents) with bf.BlobFile(bf.join(dirpath, "b"), "wb") as w: w.write(contents) with bf.BlobFile(bf.join(dirpath, "ca"), "wb") as w: w.write(contents) bf.makedirs(bf.join(dirpath, "c")) with bf.BlobFile(bf.join(dirpath, "c/a"), "wb") as w: w.write(contents) # this should also test shard_prefix_length=2 but that takes too long assert sorted(list(bf.listdir(dirpath, shard_prefix_length=1))) == [ "a", "aa", "b", "c", "ca", ]
def test_rmdir(ctx): contents = b"meow!" with ctx() as path: dirpath = bf.dirname(path) # this is an error for a local path but not for a blob path bf.rmdir(bf.join(dirpath, "fakedirname")) new_dirpath = bf.join(dirpath, "dirname") bf.makedirs(new_dirpath) assert bf.exists(new_dirpath) bf.rmdir(new_dirpath) assert not bf.exists(new_dirpath) # double delete is fine bf.rmdir(new_dirpath) # implicit dir new_filepath = bf.join(dirpath, "dirname", "name") _write_contents(new_filepath, contents) with pytest.raises(OSError): # not empty dir bf.rmdir(new_dirpath) bf.remove(new_filepath) bf.rmdir(new_dirpath)
def main(): args = create_argparser().parse_args() dist_util.setup_dist() logger.configure() logger.log("creating model and diffusion...") model, diffusion = create_classifier_and_diffusion( **args_to_dict(args, classifier_and_diffusion_defaults().keys())) model.to(dist_util.dev()) if args.noised: schedule_sampler = create_named_schedule_sampler( args.schedule_sampler, diffusion) resume_step = 0 if args.resume_checkpoint: resume_step = parse_resume_step_from_filename(args.resume_checkpoint) if dist.get_rank() == 0: logger.log( f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step" ) model.load_state_dict( dist_util.load_state_dict(args.resume_checkpoint, map_location=dist_util.dev())) # Needed for creating correct EMAs and fp16 parameters. dist_util.sync_params(model.parameters()) mp_trainer = MixedPrecisionTrainer(model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0) model = DDP( model, device_ids=[dist_util.dev()], output_device=dist_util.dev(), broadcast_buffers=False, bucket_cap_mb=128, find_unused_parameters=False, ) logger.log("creating data loader...") data = load_data( data_dir=args.data_dir, batch_size=args.batch_size, image_size=args.image_size, class_cond=True, random_crop=True, ) if args.val_data_dir: val_data = load_data( data_dir=args.val_data_dir, batch_size=args.batch_size, image_size=args.image_size, class_cond=True, ) else: val_data = None logger.log(f"creating optimizer...") opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay) if args.resume_checkpoint: opt_checkpoint = bf.join(bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt") logger.log( f"loading optimizer state from checkpoint: {opt_checkpoint}") opt.load_state_dict( dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev())) logger.log("training classifier model...") def forward_backward_log(data_loader, prefix="train"): batch, extra = next(data_loader) labels = extra["y"].to(dist_util.dev()) batch = batch.to(dist_util.dev()) # Noisy images if args.noised: t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev()) batch = diffusion.q_sample(batch, t) else: t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev()) for i, (sub_batch, sub_labels, sub_t) in enumerate( split_microbatches(args.microbatch, batch, labels, t)): logits = model(sub_batch, timesteps=sub_t) loss = F.cross_entropy(logits, sub_labels, reduction="none") losses = {} losses[f"{prefix}_loss"] = loss.detach() losses[f"{prefix}_acc@1"] = compute_top_k(logits, sub_labels, k=1, reduction="none") losses[f"{prefix}_acc@5"] = compute_top_k(logits, sub_labels, k=5, reduction="none") log_loss_dict(diffusion, sub_t, losses) del losses loss = loss.mean() if loss.requires_grad: if i == 0: mp_trainer.zero_grad() mp_trainer.backward(loss * len(sub_batch) / len(batch)) for step in range(args.iterations - resume_step): logger.logkv("step", step + resume_step) logger.logkv( "samples", (step + resume_step + 1) * args.batch_size * dist.get_world_size(), ) if args.anneal_lr: set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations) forward_backward_log(data) mp_trainer.optimize(opt) if val_data is not None and not step % args.eval_interval: with th.no_grad(): with model.no_sync(): model.eval() forward_backward_log(val_data, prefix="val") model.train() if not step % args.log_interval: logger.dumpkvs() if (step and dist.get_rank() == 0 and not (step + resume_step) % args.save_interval): logger.log("saving model...") save_model(mp_trainer, opt, step + resume_step) if dist.get_rank() == 0: logger.log("saving model...") save_model(mp_trainer, opt, step + resume_step) dist.barrier()