示例#1
0
def test_glob(ctx, parallel):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.dirname(path)
        a_path = bf.join(dirpath, "ab")
        with bf.BlobFile(a_path, "wb") as w:
            w.write(contents)
        b_path = bf.join(dirpath, "bb")
        with bf.BlobFile(b_path, "wb") as w:
            w.write(contents)

        def assert_listing_equal(path, desired):
            desired = sorted([bf.join(dirpath, p) for p in desired])
            actual = sorted(list(bf.glob(path, parallel=parallel)))
            assert actual == desired, f"{actual} != {desired}"

        assert_listing_equal(bf.join(dirpath, "*b"), ["ab", "bb"])
        assert_listing_equal(bf.join(dirpath, "a*"), ["ab"])
        assert_listing_equal(bf.join(dirpath, "ab*"), ["ab"])
        assert_listing_equal(bf.join(dirpath, "*"), ["ab", "bb"])
        assert_listing_equal(bf.join(dirpath, "bb"), ["bb"])

        path = bf.join(dirpath, "test.txt")
        with bf.BlobFile(path, "wb") as w:
            w.write(contents)
        path = bf.join(dirpath, "subdir", "test.txt")
        bf.makedirs(bf.dirname(path))
        with bf.BlobFile(path, "wb") as f:
            f.write(contents)
        path = bf.join(dirpath, "subdir", "subsubdir", "test.txt")
        if "://" not in path:
            # implicit directory
            bf.makedirs(bf.dirname(path))
        with bf.BlobFile(path, "wb") as f:
            f.write(contents)

        assert_listing_equal(bf.join(dirpath, "*/test.txt"), ["subdir/test.txt"])
        assert_listing_equal(bf.join(dirpath, "*/*.txt"), ["subdir/test.txt"])
        if "://" in path:
            # local glob doesn't handle ** the same way as remote glob
            assert_listing_equal(
                bf.join(dirpath, "**.txt"),
                ["test.txt", "subdir/test.txt", "subdir/subsubdir/test.txt"],
            )
        else:
            assert_listing_equal(bf.join(dirpath, "**.txt"), ["test.txt"])
        assert_listing_equal(bf.join(dirpath, "*/test"), [])
        assert_listing_equal(bf.join(dirpath, "subdir/test.txt"), ["subdir/test.txt"])

        # directories
        assert_listing_equal(bf.join(dirpath, "*"), ["ab", "bb", "subdir", "test.txt"])
        assert_listing_equal(bf.join(dirpath, "subdir"), ["subdir"])
        assert_listing_equal(bf.join(dirpath, "subdir/"), ["subdir"])
        assert_listing_equal(bf.join(dirpath, "*/"), ["subdir"])
        assert_listing_equal(bf.join(dirpath, "*dir"), ["subdir"])
        assert_listing_equal(bf.join(dirpath, "subdir/*dir"), ["subdir/subsubdir"])
        assert_listing_equal(bf.join(dirpath, "subdir/*dir/"), ["subdir/subsubdir"])
        assert_listing_equal(bf.join(dirpath, "su*ir/*dir/"), ["subdir/subsubdir"])
def find_ema_checkpoint(main_checkpoint, step, rate):
    if main_checkpoint is None:
        return None
    filename = f"ema_{rate}_{(step):06d}.pt"
    path = bf.join(bf.dirname(main_checkpoint), filename)
    if bf.exists(path):
        return path
    return None
示例#3
0
 def _load_optimizer_state(self):
     main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
     opt_checkpoint = bf.join(bf.dirname(main_checkpoint),
                              f"opt{self.resume_step:06}.pt")
     if bf.exists(opt_checkpoint):
         logger.log(
             f"loading optimizer state from checkpoint: {opt_checkpoint}")
         state_dict = dist_util.load_state_dict(
             opt_checkpoint, map_location=dist_util.dev())
         self.opt.load_state_dict(state_dict)
示例#4
0
def test_isdir(ctx):
    contents = b"meow!"
    with ctx() as path:
        assert not bf.isdir(path)
        _write_contents(path, contents)
        assert not bf.isdir(path)

        dirpath = path + ".dir"
        bf.makedirs(dirpath)
        assert bf.isdir(dirpath)
        assert not bf.isdir(dirpath[:-1])

        filepath = bf.join(path + ".otherdir", "subdir", "file.name")
        if "://" not in path:
            # implicit directory
            bf.makedirs(bf.dirname(filepath))
        dirpath = bf.dirname(bf.dirname(filepath))
        _write_contents(filepath, contents)
        assert bf.isdir(dirpath)
        assert not bf.isdir(dirpath[:-1])
示例#5
0
def test_listdir(ctx):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.dirname(path)
        a_path = bf.join(dirpath, "a")
        with bf.BlobFile(a_path, "wb") as w:
            w.write(contents)
        b_path = bf.join(dirpath, "b")
        with bf.BlobFile(b_path, "wb") as w:
            w.write(contents)
        bf.makedirs(bf.join(dirpath, "c"))
        assert sorted(list(bf.listdir(dirpath))) == ["a", "b", "c"]
示例#6
0
def test_read_write(ctx, streaming):
    contents = b"meow!\npurr\n"
    with ctx() as path:
        path = bf.join(path, "a folder", "a.file")
        bf.makedirs(bf.dirname(path))
        with bf.BlobFile(path, "wb", streaming=streaming) as w:
            w.write(contents)
        with bf.BlobFile(path, "rb", streaming=streaming) as r:
            assert r.read() == contents
        with bf.BlobFile(path, "rb", streaming=streaming) as r:
            lines = list(r)
            assert b"".join(lines) == contents
示例#7
0
def test_scanglob(ctx):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.dirname(path)
        a_path = bf.join(dirpath, "ab")
        with bf.BlobFile(a_path, "wb") as w:
            w.write(contents)
        b_path = bf.join(dirpath, "bb")
        with bf.BlobFile(b_path, "wb") as w:
            w.write(contents)
        path = bf.join(dirpath, "test.txt")
        with bf.BlobFile(path, "wb") as w:
            w.write(contents)
        path = bf.join(dirpath, "subdir", "test.txt")
        bf.makedirs(bf.dirname(path))
        with bf.BlobFile(path, "wb") as f:
            f.write(contents)

        entries = sorted(list(bf.scanglob(bf.join(dirpath, "*b*"))))
        assert entries[0].name == "ab" and entries[0].is_file
        assert entries[1].name == "bb" and entries[1].is_file
        assert entries[2].name == "subdir" and entries[2].is_dir
示例#8
0
def test_create_local_intermediate_dirs():
    contents = b"meow"
    with _get_temp_local_path() as path:
        dirpath = bf.dirname(path)
        with chdir(dirpath):
            for filepath in [
                    bf.join(dirpath, "dirname", "file.name"),
                    bf.join("..", bf.basename(dirpath), "file.name"),
                    "./file.name",
                    "file.name",
            ]:
                with bf.BlobFile(filepath, "wb") as f:
                    f.write(contents)
示例#9
0
def test_rmtree(ctx):
    contents = b"meow!"
    with ctx() as path:
        root = bf.dirname(path)
        destroy_path = bf.join(root, "destroy")
        bf.makedirs(destroy_path)
        save_path = bf.join(root, "save")
        bf.makedirs(save_path)

        # implicit dir
        if not "://" in path:
            bf.makedirs(bf.join(destroy_path, "adir"))
        with bf.BlobFile(bf.join(destroy_path, "adir/b"), "wb") as w:
            w.write(contents)

        # explicit dir
        bf.makedirs(bf.join(destroy_path, "bdir"))
        with bf.BlobFile(bf.join(destroy_path, "bdir/b"), "wb") as w:
            w.write(contents)

        bf.makedirs(bf.join(save_path, "somedir"))
        with bf.BlobFile(bf.join(save_path, "somefile"), "wb") as w:
            w.write(contents)

        def assert_listing_equal(path, desired):
            actual = list(bf.walk(path))
            # ordering of os walk is weird, only compare sorted order
            assert sorted(actual) == sorted(desired), f"{actual} != {desired}"

        assert_listing_equal(
            root,
            [
                (root, ["destroy", "save"], []),
                (destroy_path, ["adir", "bdir"], []),
                (bf.join(destroy_path, "adir"), [], ["b"]),
                (bf.join(destroy_path, "bdir"), [], ["b"]),
                (save_path, ["somedir"], ["somefile"]),
                (bf.join(save_path, "somedir"), [], []),
            ],
        )

        bf.rmtree(destroy_path)

        assert_listing_equal(
            root,
            [
                (root, ["save"], []),
                (save_path, ["somedir"], ["somefile"]),
                (bf.join(save_path, "somedir"), [], []),
            ],
        )
示例#10
0
def test_az_path():
    contents = b"meow!\npurr\n"
    with _get_temp_as_path() as path:
        path = _convert_https_to_az(path)
        path = bf.join(path, "a folder", "a.file")
        path = _convert_https_to_az(path)
        bf.makedirs(_convert_https_to_az(bf.dirname(path)))
        with bf.BlobFile(path, "wb") as w:
            w.write(contents)
        with bf.BlobFile(path, "rb") as r:
            assert r.read() == contents
        with bf.BlobFile(path, "rb") as r:
            lines = list(r)
            assert b"".join(lines) == contents
示例#11
0
def test_dirname():
    testcases = [
        ("a", ""),
        ("a/b", "a"),
        ("a/b/c", "a/b"),
        ("a/b/c/", "a/b/c"),
        ("a/b/c/////", "a/b/c"),
        ("", ""),
        ("gs://a", "gs://a"),
        ("gs://a/", "gs://a"),
        ("gs://a/////", "gs://a"),
        ("gs://a/b", "gs://a"),
        ("gs://a/b/c/test.filename", "gs://a/b/c"),
        ("gs://a/b/c/", "gs://a/b"),
        ("gs://a/b/c/////", "gs://a/b"),
        (
            "https://a.blob.core.windows.net/container",
            "https://a.blob.core.windows.net/container",
        ),
        (
            "https://a.blob.core.windows.net/container/",
            "https://a.blob.core.windows.net/container",
        ),
        (
            "https://a.blob.core.windows.net/container/////",
            "https://a.blob.core.windows.net/container",
        ),
        (
            "https://a.blob.core.windows.net/container/b",
            "https://a.blob.core.windows.net/container",
        ),
        (
            "https://a.blob.core.windows.net/container/b/c/test.filename",
            "https://a.blob.core.windows.net/container/b/c",
        ),
        (
            "https://a.blob.core.windows.net/container/b/c/",
            "https://a.blob.core.windows.net/container/b",
        ),
        (
            "https://a.blob.core.windows.net/container/b/c//////",
            "https://a.blob.core.windows.net/container/b",
        ),
    ]
    for input_, desired_output in testcases:
        actual_output = bf.dirname(input_)
        assert desired_output == actual_output, f"{input_}"
示例#12
0
def test_walk(ctx, topdown):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.dirname(path)
        a_path = bf.join(dirpath, "a")
        with bf.BlobFile(a_path, "wb") as w:
            w.write(contents)
        bf.makedirs(bf.join(dirpath, "c/d"))
        b_path = bf.join(dirpath, "c/d/b")
        with bf.BlobFile(b_path, "wb") as w:
            w.write(contents)
        expected = [
            (dirpath, ["c"], ["a"]),
            (bf.join(dirpath, "c"), ["d"], []),
            (bf.join(dirpath, "c", "d"), [], ["b"]),
        ]
        if not topdown:
            expected = list(reversed(expected))
        assert list(bf.walk(dirpath, topdown=topdown)) == expected
示例#13
0
def test_scandir(ctx):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.dirname(path)
        a_path = bf.join(dirpath, "a")
        with bf.BlobFile(a_path, "wb") as w:
            w.write(contents)
        b_path = bf.join(dirpath, "b")
        with bf.BlobFile(b_path, "wb") as w:
            w.write(contents)
        bf.makedirs(bf.join(dirpath, "c"))
        entries = sorted(list(bf.scandir(dirpath)))
        assert [e.name for e in entries] == ["a", "b", "c"]
        assert [e.path for e in entries
                ] == [bf.join(dirpath, name) for name in ["a", "b", "c"]]
        assert [e.is_dir for e in entries] == [False, False, True]
        assert [e.is_file for e in entries] == [True, True, False]
        assert entries[0].stat.size == len(contents)
        assert entries[1].stat.size == len(contents)
        assert entries[2].stat is None
示例#14
0
def test_listdir_sharded(ctx):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.dirname(path)
        with bf.BlobFile(bf.join(dirpath, "a"), "wb") as w:
            w.write(contents)
        with bf.BlobFile(bf.join(dirpath, "aa"), "wb") as w:
            w.write(contents)
        with bf.BlobFile(bf.join(dirpath, "b"), "wb") as w:
            w.write(contents)
        with bf.BlobFile(bf.join(dirpath, "ca"), "wb") as w:
            w.write(contents)
        bf.makedirs(bf.join(dirpath, "c"))
        with bf.BlobFile(bf.join(dirpath, "c/a"), "wb") as w:
            w.write(contents)
        # this should also test shard_prefix_length=2 but that takes too long
        assert sorted(list(bf.listdir(dirpath, shard_prefix_length=1))) == [
            "a",
            "aa",
            "b",
            "c",
            "ca",
        ]
示例#15
0
def test_rmdir(ctx):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.dirname(path)
        # this is an error for a local path but not for a blob path
        bf.rmdir(bf.join(dirpath, "fakedirname"))
        new_dirpath = bf.join(dirpath, "dirname")
        bf.makedirs(new_dirpath)
        assert bf.exists(new_dirpath)
        bf.rmdir(new_dirpath)
        assert not bf.exists(new_dirpath)

        # double delete is fine
        bf.rmdir(new_dirpath)

        # implicit dir
        new_filepath = bf.join(dirpath, "dirname", "name")
        _write_contents(new_filepath, contents)
        with pytest.raises(OSError):
            # not empty dir
            bf.rmdir(new_dirpath)
        bf.remove(new_filepath)
        bf.rmdir(new_dirpath)
示例#16
0
def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist()
    logger.configure()

    logger.log("creating model and diffusion...")
    model, diffusion = create_classifier_and_diffusion(
        **args_to_dict(args,
                       classifier_and_diffusion_defaults().keys()))
    model.to(dist_util.dev())
    if args.noised:
        schedule_sampler = create_named_schedule_sampler(
            args.schedule_sampler, diffusion)

    resume_step = 0
    if args.resume_checkpoint:
        resume_step = parse_resume_step_from_filename(args.resume_checkpoint)
        if dist.get_rank() == 0:
            logger.log(
                f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step"
            )
            model.load_state_dict(
                dist_util.load_state_dict(args.resume_checkpoint,
                                          map_location=dist_util.dev()))

    # Needed for creating correct EMAs and fp16 parameters.
    dist_util.sync_params(model.parameters())

    mp_trainer = MixedPrecisionTrainer(model=model,
                                       use_fp16=args.classifier_use_fp16,
                                       initial_lg_loss_scale=16.0)

    model = DDP(
        model,
        device_ids=[dist_util.dev()],
        output_device=dist_util.dev(),
        broadcast_buffers=False,
        bucket_cap_mb=128,
        find_unused_parameters=False,
    )

    logger.log("creating data loader...")
    data = load_data(
        data_dir=args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        class_cond=True,
        random_crop=True,
    )
    if args.val_data_dir:
        val_data = load_data(
            data_dir=args.val_data_dir,
            batch_size=args.batch_size,
            image_size=args.image_size,
            class_cond=True,
        )
    else:
        val_data = None

    logger.log(f"creating optimizer...")
    opt = AdamW(mp_trainer.master_params,
                lr=args.lr,
                weight_decay=args.weight_decay)
    if args.resume_checkpoint:
        opt_checkpoint = bf.join(bf.dirname(args.resume_checkpoint),
                                 f"opt{resume_step:06}.pt")
        logger.log(
            f"loading optimizer state from checkpoint: {opt_checkpoint}")
        opt.load_state_dict(
            dist_util.load_state_dict(opt_checkpoint,
                                      map_location=dist_util.dev()))

    logger.log("training classifier model...")

    def forward_backward_log(data_loader, prefix="train"):
        batch, extra = next(data_loader)
        labels = extra["y"].to(dist_util.dev())

        batch = batch.to(dist_util.dev())
        # Noisy images
        if args.noised:
            t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev())
            batch = diffusion.q_sample(batch, t)
        else:
            t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev())

        for i, (sub_batch, sub_labels, sub_t) in enumerate(
                split_microbatches(args.microbatch, batch, labels, t)):
            logits = model(sub_batch, timesteps=sub_t)
            loss = F.cross_entropy(logits, sub_labels, reduction="none")

            losses = {}
            losses[f"{prefix}_loss"] = loss.detach()
            losses[f"{prefix}_acc@1"] = compute_top_k(logits,
                                                      sub_labels,
                                                      k=1,
                                                      reduction="none")
            losses[f"{prefix}_acc@5"] = compute_top_k(logits,
                                                      sub_labels,
                                                      k=5,
                                                      reduction="none")
            log_loss_dict(diffusion, sub_t, losses)
            del losses
            loss = loss.mean()
            if loss.requires_grad:
                if i == 0:
                    mp_trainer.zero_grad()
                mp_trainer.backward(loss * len(sub_batch) / len(batch))

    for step in range(args.iterations - resume_step):
        logger.logkv("step", step + resume_step)
        logger.logkv(
            "samples",
            (step + resume_step + 1) * args.batch_size * dist.get_world_size(),
        )
        if args.anneal_lr:
            set_annealed_lr(opt, args.lr,
                            (step + resume_step) / args.iterations)
        forward_backward_log(data)
        mp_trainer.optimize(opt)
        if val_data is not None and not step % args.eval_interval:
            with th.no_grad():
                with model.no_sync():
                    model.eval()
                    forward_backward_log(val_data, prefix="val")
                    model.train()
        if not step % args.log_interval:
            logger.dumpkvs()
        if (step and dist.get_rank() == 0
                and not (step + resume_step) % args.save_interval):
            logger.log("saving model...")
            save_model(mp_trainer, opt, step + resume_step)

    if dist.get_rank() == 0:
        logger.log("saving model...")
        save_model(mp_trainer, opt, step + resume_step)
    dist.barrier()