def test_isdir(ctx): contents = b"meow!" with ctx() as path: assert not bf.isdir(path) _write_contents(path, contents) assert not bf.isdir(path) dirpath = path + ".dir" bf.makedirs(dirpath) assert bf.isdir(dirpath) assert not bf.isdir(dirpath[:-1]) filepath = bf.join(path + ".otherdir", "subdir", "file.name") if "://" not in path: # implicit directory bf.makedirs(bf.dirname(filepath)) dirpath = bf.dirname(bf.dirname(filepath)) _write_contents(filepath, contents) assert bf.isdir(dirpath) assert not bf.isdir(dirpath[:-1])
def test_cache_dir(ctx): cache_dir = tempfile.mkdtemp() contents = b"meow!" alternative_contents = b"purr!" with ctx() as path: with bf.BlobFile(path, mode="wb") as f: f.write(contents) with bf.BlobFile(path, mode="rb", streaming=False, cache_dir=cache_dir) as f: assert f.read() == contents content_hash = hashlib.md5(contents).hexdigest() cache_path = bf.join(cache_dir, content_hash, bf.basename(path)) with open(cache_path, "rb") as f: assert f.read() == contents # alter the cached file to make sure we are not re-reading the remote file with open(cache_path, "wb") as f: f.write(alternative_contents) with bf.BlobFile(path, mode="rb", streaming=False, cache_dir=cache_dir) as f: assert f.read() == alternative_contents
def save(self): def save_checkpoint(rate, params): state_dict = self._master_params_to_state_dict(params) if dist.get_rank() == 0: logger.log(f"saving model {rate}...") if not rate: filename = f"model{(self.step+self.resume_step):06d}.pt" else: filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: th.save(state_dict, f) save_checkpoint(0, self.master_params) for rate, params in zip(self.ema_rate, self.ema_params): save_checkpoint(rate, params) if dist.get_rank() == 0: with bf.BlobFile( bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), "wb", ) as f: th.save(self.opt.state_dict(), f) dist.barrier()
def test_listdir_sharded(ctx): contents = b"meow!" with ctx() as path: dirpath = bf.dirname(path) with bf.BlobFile(bf.join(dirpath, "a"), "wb") as w: w.write(contents) with bf.BlobFile(bf.join(dirpath, "aa"), "wb") as w: w.write(contents) with bf.BlobFile(bf.join(dirpath, "b"), "wb") as w: w.write(contents) with bf.BlobFile(bf.join(dirpath, "ca"), "wb") as w: w.write(contents) bf.makedirs(bf.join(dirpath, "c")) with bf.BlobFile(bf.join(dirpath, "c/a"), "wb") as w: w.write(contents) # this should also test shard_prefix_length=2 but that takes too long assert sorted(list(bf.listdir(dirpath, shard_prefix_length=1))) == [ "a", "aa", "b", "c", "ca", ]
def test_rmtree(ctx): contents = b"meow!" with ctx() as path: root = bf.dirname(path) destroy_path = bf.join(root, "destroy") bf.makedirs(destroy_path) save_path = bf.join(root, "save") bf.makedirs(save_path) # implicit dir if not "://" in path: bf.makedirs(bf.join(destroy_path, "adir")) with bf.BlobFile(bf.join(destroy_path, "adir/b"), "wb") as w: w.write(contents) # explicit dir bf.makedirs(bf.join(destroy_path, "bdir")) with bf.BlobFile(bf.join(destroy_path, "bdir/b"), "wb") as w: w.write(contents) bf.makedirs(bf.join(save_path, "somedir")) with bf.BlobFile(bf.join(save_path, "somefile"), "wb") as w: w.write(contents) def assert_listing_equal(path, desired): actual = list(bf.walk(path)) # ordering of os walk is weird, only compare sorted order assert sorted(actual) == sorted(desired), f"{actual} != {desired}" assert_listing_equal( root, [ (root, ["destroy", "save"], []), (destroy_path, ["adir", "bdir"], []), (bf.join(destroy_path, "adir"), [], ["b"]), (bf.join(destroy_path, "bdir"), [], ["b"]), (save_path, ["somedir"], ["somefile"]), (bf.join(save_path, "somedir"), [], []), ], ) bf.rmtree(destroy_path) assert_listing_equal( root, [ (root, ["save"], []), (save_path, ["somedir"], ["somefile"]), (bf.join(save_path, "somedir"), [], []), ], )
def assert_listing_equal(path, desired): desired = sorted([bf.join(dirpath, p) for p in desired]) actual = sorted(list(bf.glob(path, parallel=parallel))) assert actual == desired, f"{actual} != {desired}"
def test_glob(ctx, parallel): contents = b"meow!" with ctx() as path: dirpath = bf.dirname(path) a_path = bf.join(dirpath, "ab") with bf.BlobFile(a_path, "wb") as w: w.write(contents) b_path = bf.join(dirpath, "bb") with bf.BlobFile(b_path, "wb") as w: w.write(contents) def assert_listing_equal(path, desired): desired = sorted([bf.join(dirpath, p) for p in desired]) actual = sorted(list(bf.glob(path, parallel=parallel))) assert actual == desired, f"{actual} != {desired}" assert_listing_equal(bf.join(dirpath, "*b"), ["ab", "bb"]) assert_listing_equal(bf.join(dirpath, "a*"), ["ab"]) assert_listing_equal(bf.join(dirpath, "ab*"), ["ab"]) assert_listing_equal(bf.join(dirpath, "*"), ["ab", "bb"]) assert_listing_equal(bf.join(dirpath, "bb"), ["bb"]) path = bf.join(dirpath, "test.txt") with bf.BlobFile(path, "wb") as w: w.write(contents) path = bf.join(dirpath, "subdir", "test.txt") bf.makedirs(bf.dirname(path)) with bf.BlobFile(path, "wb") as f: f.write(contents) path = bf.join(dirpath, "subdir", "subsubdir", "test.txt") if "://" not in path: # implicit directory bf.makedirs(bf.dirname(path)) with bf.BlobFile(path, "wb") as f: f.write(contents) assert_listing_equal(bf.join(dirpath, "*/test.txt"), ["subdir/test.txt"]) assert_listing_equal(bf.join(dirpath, "*/*.txt"), ["subdir/test.txt"]) if "://" in path: # local glob doesn't handle ** the same way as remote glob assert_listing_equal( bf.join(dirpath, "**.txt"), ["test.txt", "subdir/test.txt", "subdir/subsubdir/test.txt"], ) else: assert_listing_equal(bf.join(dirpath, "**.txt"), ["test.txt"]) assert_listing_equal(bf.join(dirpath, "*/test"), []) assert_listing_equal(bf.join(dirpath, "subdir/test.txt"), ["subdir/test.txt"]) # directories assert_listing_equal(bf.join(dirpath, "*"), ["ab", "bb", "subdir", "test.txt"]) assert_listing_equal(bf.join(dirpath, "subdir"), ["subdir"]) assert_listing_equal(bf.join(dirpath, "subdir/"), ["subdir"]) assert_listing_equal(bf.join(dirpath, "*/"), ["subdir"]) assert_listing_equal(bf.join(dirpath, "*dir"), ["subdir"]) assert_listing_equal(bf.join(dirpath, "subdir/*dir"), ["subdir/subsubdir"]) assert_listing_equal(bf.join(dirpath, "subdir/*dir/"), ["subdir/subsubdir"]) assert_listing_equal(bf.join(dirpath, "su*ir/*dir/"), ["subdir/subsubdir"])
def test_join(): testcases = [ ("a", "b", "a/b"), ("a/b", "c", "a/b/c"), ("a/b/", "c", "a/b/c"), ("a/b/", "c/", "a/b/c/"), ("a/b/", "/c/", "/c/"), ("", "", ""), # this doesn't work with : in the second path ( "gs://a/b/c", "d0123456789-._~!$&'()*+,;=@", "gs://a/b/c/d0123456789-._~!$&'()*+,;=@", ), ("gs://a", "b", "gs://a/b"), ("gs://a/b", "c", "gs://a/b/c"), ("gs://a/b/", "c", "gs://a/b/c"), ("gs://a/b/", "c/", "gs://a/b/c/"), ("gs://a/b/", "/c/", "gs://a/c/"), ("gs://a/b/", "../c", "gs://a/c"), ("gs://a/b/", "../c/", "gs://a/c/"), ("gs://a/b/", "../../c/", "gs://a/c/"), ( "https://a.blob.core.windows.net/container", "b", "https://a.blob.core.windows.net/container/b", ), ( "https://a.blob.core.windows.net/container/b", "c", "https://a.blob.core.windows.net/container/b/c", ), ( "https://a.blob.core.windows.net/container/b/", "c", "https://a.blob.core.windows.net/container/b/c", ), ( "https://a.blob.core.windows.net/container/b/", "c/", "https://a.blob.core.windows.net/container/b/c/", ), ( "https://a.blob.core.windows.net/container/b/", "/c/", "https://a.blob.core.windows.net/container/c/", ), ( "https://a.blob.core.windows.net/container/b/", "../c", "https://a.blob.core.windows.net/container/c", ), ( "https://a.blob.core.windows.net/container/b/", "../c/", "https://a.blob.core.windows.net/container/c/", ), ( "https://a.blob.core.windows.net/container/b/", "../../c/", "https://a.blob.core.windows.net/container/c/", ), ("gs://test/a/b", "c:d", "gs://test/a/b/c:d"), ] for input_a, input_b, desired_output in testcases: actual_output = bf.join(input_a, input_b) assert desired_output == actual_output, f"{input_a} {input_b}"
def main(): args = create_argparser().parse_args() dist_util.setup_dist() logger.configure() logger.log("creating model and diffusion...") model, diffusion = create_classifier_and_diffusion( **args_to_dict(args, classifier_and_diffusion_defaults().keys())) model.to(dist_util.dev()) if args.noised: schedule_sampler = create_named_schedule_sampler( args.schedule_sampler, diffusion) resume_step = 0 if args.resume_checkpoint: resume_step = parse_resume_step_from_filename(args.resume_checkpoint) if dist.get_rank() == 0: logger.log( f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step" ) model.load_state_dict( dist_util.load_state_dict(args.resume_checkpoint, map_location=dist_util.dev())) # Needed for creating correct EMAs and fp16 parameters. dist_util.sync_params(model.parameters()) mp_trainer = MixedPrecisionTrainer(model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0) model = DDP( model, device_ids=[dist_util.dev()], output_device=dist_util.dev(), broadcast_buffers=False, bucket_cap_mb=128, find_unused_parameters=False, ) logger.log("creating data loader...") data = load_data( data_dir=args.data_dir, batch_size=args.batch_size, image_size=args.image_size, class_cond=True, random_crop=True, ) if args.val_data_dir: val_data = load_data( data_dir=args.val_data_dir, batch_size=args.batch_size, image_size=args.image_size, class_cond=True, ) else: val_data = None logger.log(f"creating optimizer...") opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay) if args.resume_checkpoint: opt_checkpoint = bf.join(bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt") logger.log( f"loading optimizer state from checkpoint: {opt_checkpoint}") opt.load_state_dict( dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev())) logger.log("training classifier model...") def forward_backward_log(data_loader, prefix="train"): batch, extra = next(data_loader) labels = extra["y"].to(dist_util.dev()) batch = batch.to(dist_util.dev()) # Noisy images if args.noised: t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev()) batch = diffusion.q_sample(batch, t) else: t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev()) for i, (sub_batch, sub_labels, sub_t) in enumerate( split_microbatches(args.microbatch, batch, labels, t)): logits = model(sub_batch, timesteps=sub_t) loss = F.cross_entropy(logits, sub_labels, reduction="none") losses = {} losses[f"{prefix}_loss"] = loss.detach() losses[f"{prefix}_acc@1"] = compute_top_k(logits, sub_labels, k=1, reduction="none") losses[f"{prefix}_acc@5"] = compute_top_k(logits, sub_labels, k=5, reduction="none") log_loss_dict(diffusion, sub_t, losses) del losses loss = loss.mean() if loss.requires_grad: if i == 0: mp_trainer.zero_grad() mp_trainer.backward(loss * len(sub_batch) / len(batch)) for step in range(args.iterations - resume_step): logger.logkv("step", step + resume_step) logger.logkv( "samples", (step + resume_step + 1) * args.batch_size * dist.get_world_size(), ) if args.anneal_lr: set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations) forward_backward_log(data) mp_trainer.optimize(opt) if val_data is not None and not step % args.eval_interval: with th.no_grad(): with model.no_sync(): model.eval() forward_backward_log(val_data, prefix="val") model.train() if not step % args.log_interval: logger.dumpkvs() if (step and dist.get_rank() == 0 and not (step + resume_step) % args.save_interval): logger.log("saving model...") save_model(mp_trainer, opt, step + resume_step) if dist.get_rank() == 0: logger.log("saving model...") save_model(mp_trainer, opt, step + resume_step) dist.barrier()
def main(H: HParams): layout = H.reward_model_spec.run_params.all_gpu_layout() reward_model = RewardModel(task_hparams=H.task, spec=H.reward_model_spec, layout=layout) setup_logging_with_pacific_tz() act_dtype = torch.float16 if H.fp16_activations else torch.float32 results_dir = bf.join( os.environ.get("OUTPUT_DIR", os.path.join("/tmp/jobs", os.getenv("JOB_NAME"))), "results") bf.makedirs(results_dir) if layout.is_logging_rank: with open(bf.join(results_dir, "task_hparams.json"), "w") as f: json.dump(H.task.to_json(), f) with open(bf.join(results_dir, "hparams.json"), "w") as f: json.dump(H.to_json(), f) # Creates files for printing. Only the replica root prints the files output_file_name = os.devnull if layout.is_replica_root: fname = f"samples.{layout.replica_idx}.jsonl" output_file_name = bf.join(results_dir, fname) print(f"Outputs will be written to {output_file_name}") input_iter = make_jsonl_samples_iter(H.input_path, layout=layout) replica_rewards = [] with open(output_file_name, "a") as out_f: input_idx = 0 for input in input_iter: with Timer() as timer: query_tokens = torch.tensor(input["context_tokens"]) assert_shape_eq(query_tokens, (H.task.query.length, ), "Context tokens shape mismatch") response_tokens = torch.tensor(input["sample_tokens"]) assert_eq(response_tokens.dim(), 2) n_responses = response_tokens.size(0) results = reward_model.reward( query_tokens=query_tokens.unsqueeze(0), response_tokens=response_tokens.unsqueeze(0), act_dtype=act_dtype, ) rewards = to_numpy(results["reward"].reshape((n_responses, ))) if layout.is_replica_root: replica_rewards.append(rewards) output = {**input, H.output_key: rewards} out_f.write( (json.dumps(jsonl_encoding.encode_example(output)) + "\n")) input_idx += 1 if layout.is_replica_root: print(f"Batch {input_idx}. Took {timer.interval} seconds") if layout.is_replica_root: print(f"Wrote {input_idx} batches to {output_file_name}") replica_rewards = np.stack(replica_rewards, axis=0) all_rewards = reward_model.dp_comm.mpi_all_gather( replica_rewards, "rewards") if layout.replica_idx == 0: all_rewards = np.concatenate(all_rewards, axis=0) print(f"Mean reward: {all_rewards.mean():.3f}") if all_rewards.shape[1] > 1: print( f"Stddev within a query: {all_rewards.std(axis=1, ddof=1).mean():.3}" ) print( f"Stddev across queries: {all_rewards.std(axis=0, ddof=1).mean():.3}" ) return dict(output_path=results_dir)