def test_inference(self): graph, constants, _ = resnetv1.ResNet18(num_classes=10, input_resolution="small") model = Model(graph, constants) state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) self.assertLen(state, 2) self.assertIn("params", state) self.assertIn("batch_stats", state) out = model.apply(state, {"input": jnp.ones( (10, 32, 32, 3))})["fc/dense"] self.assertEqual(out.shape, (10, 10)) output_dict, new_state = model.apply( state, {"input": jnp.ones((10, 32, 32, 3))}, mutable=["batch_stats"]) self.assertEqual(output_dict["fc/dense"].shape, (10, 10)) self.assertIn("batch_stats", new_state)
def _synthesize(self, subg, props): synthesizer = EnumerativeSequentialSynthesizer([(subg, props)], 0, max_len=3) subg = synthesizer.synthesize()[0] m = Model(subg.graph, self.constants) state = m.init(random.PRNGKey(0), self.input) out = m.apply(state, self.input)["fc/logits"] self.assertTrue((out != self.out).any())
def test_multi_input_output(self): """Tests a subgraph substitution on a graph with multiple inputs / output ops. We use a ResNet model, which has skip connections. This test checks that the substitution produces the expected number of ops, and also that the newly produced graph is still executable. """ graph, constants, _ = resnetv1.ResNet18(num_classes=10, input_resolution="small") model = Model(graph, constants) state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) y = model.apply(state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10)) subg = [ subgraph.SubgraphNode( op=new_op(op_name="subgraph/conv0", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["resnet11/skip/relu1"])), subgraph.SubgraphNode( op=new_op(op_name="subgraph/gelu1", op_type=OpType.GELU, input_names=["subgraph/conv0"]), output_names=["resnet_stride1_filtermul1_basic12/relu2"]) ] new_graph = subgraph.replace_subgraph(graph, subg) # the subgraph is 2 ops (conv / gelu) replacing 3 ops (conv / bn / relu) self.assertLen(graph.ops, len(new_graph.ops) + 1) new_model = Model(new_graph, constants) new_state = new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) y = new_model.apply(new_state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10))
def test_inference(self): graph, constants, _ = efficientnet.EfficientNetB0(num_classes=10) for op in graph.ops: print(f"name={op.name}") print(f"input_names={op.input_names}") print() model = Model(graph, constants) state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) self.assertLen(state, 2) self.assertIn("params", state) self.assertIn("batch_stats", state) inp = {"input": jnp.ones((10, 32, 32, 3))} out = model.apply(state, inp)["head/out"] self.assertEqual(out.shape, (10, 10)) output_dict, new_state = model.apply(state, inp, mutable=["batch_stats"]) self.assertEqual(output_dict["head/out"].shape, (10, 10)) self.assertIn("batch_stats", new_state)
class ModelTest(test.TestCase): def setUp(self): super().setUp() graph, constants, _ = cnn.CifarNet() self.cnn = Model(graph, constants) self.cnn_state = self.cnn.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) def test_cnn_inference(self): y = self.cnn.apply(self.cnn_state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10)) def test_cnn_inference_dict(self): out = self.cnn.apply(self.cnn_state, {"input": jnp.ones((10, 32, 32, 3))}) logits = out["fc/logits"] self.assertEqual(logits.shape, (10, 10)) def test_cnn_params(self): params = flax.core.unfreeze(self.cnn_state)["params"] param_count = parameter_overview.count_parameters(params) self.assertEqual(param_count, 2192458)
def _synthesize(self, subg, props): synthesizer = ProgressiveSequentialSynthesizer( [(subg, props)], generation=0, mode=ProgressiveSequentialSynthesizer.Mode.WEIGHTED, max_len=3) subg = synthesizer.synthesize()[0] subg_spec = subg.subgraph for node in subg_spec: print(node.op.name) print(node.output_names) m = Model(subg.graph, self.constants) state = m.init(random.PRNGKey(0), self.input) out = m.apply(state, self.input)["fc/logits"] self.assertTrue((out != self.out).any())
class SubgraphTest(test.TestCase): def setUp(self): super().setUp() self.graph, self.constants, _ = cnn.CifarNet() self.model = Model(self.graph, self.constants) self.state = self.model.init(random.PRNGKey(0), jnp.ones( (1, 32, 32, 3))) self.subgraph = [ subgraph.SubgraphNode(op=new_op( op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), ), subgraph.SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv/1" ]), output_names=["conv_layer1/relu"]) ] self.new_graph = subgraph.replace_subgraph(self.graph, self.subgraph) self.new_model = Model(self.new_graph, self.constants) self.new_state = self.new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) def test_subgraph_inserted(self): """Tests whether subgraph nodes were inserted.""" for node in self.subgraph: found = False for op in self.new_graph.ops: if op.name == node.op.name: found = True break self.assertTrue(found, f"Did not find {node.op.name} in new graph") def test_subgraph_execution(self): """Tests whether new graph can be executed.""" y = self.new_model.apply(self.new_state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10)) def test_subgraph_pruning(self): """Tests whether new graph was pruned of old nodes.""" new_params = flax.core.unfreeze(self.new_state)["params"] new_param_count = parameter_overview.count_parameters(new_params) params = flax.core.unfreeze(self.state)["params"] param_count = parameter_overview.count_parameters(params) self.assertLess(new_param_count, param_count) def test_weight_inheritance(self): """Tests weight inheritance.""" old_params = flax.core.unfreeze(self.state)["params"] new_params = flax.core.unfreeze(self.new_state)["params"] frozen_params, trainable_params = subgraph.inherit_params( new_params, old_params) self.assertLen(new_params, len(trainable_params) + len(frozen_params)) for param in ["fc/dense", "fc/logits", "conv_layer0/conv"]: assert param in frozen_params, f"expected param {param} to be frozen" self.assertIn("conv_layer1/conv/1", trainable_params, ("expected param layer1/conv/1 to be trainable")) def test_multi_input_output(self): """Tests a subgraph substitution on a graph with multiple inputs / output ops. We use a ResNet model, which has skip connections. This test checks that the substitution produces the expected number of ops, and also that the newly produced graph is still executable. """ graph, constants, _ = resnetv1.ResNet18(num_classes=10, input_resolution="small") model = Model(graph, constants) state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) y = model.apply(state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10)) subg = [ subgraph.SubgraphNode( op=new_op(op_name="subgraph/conv0", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["resnet11/skip/relu1"])), subgraph.SubgraphNode( op=new_op(op_name="subgraph/gelu1", op_type=OpType.GELU, input_names=["subgraph/conv0"]), output_names=["resnet_stride1_filtermul1_basic12/relu2"]) ] new_graph = subgraph.replace_subgraph(graph, subg) # the subgraph is 2 ops (conv / gelu) replacing 3 ops (conv / bn / relu) self.assertLen(graph.ops, len(new_graph.ops) + 1) new_model = Model(new_graph, constants) new_state = new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) y = new_model.apply(new_state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10))
class GraphSynthesizerTest(test.TestCase): def setUp(self): super().setUp() seed = int(time.time()) logging.info("Seed: %d", seed) py_rand.seed(seed) self.graph, self.constants, _ = cnn.CifarNet() self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)[self.graph.output_names[0]] self.max_size = int(10e8) def _synthesize(self, subg, props): ctr = functools.partial(PSS, generation=0, max_len=3, filter_progress=True) synthesizer = GraphSynthesizer([(subg, props)], sequential_ctr=ctr, generation=0) subg = synthesizer.synthesize()[0] subg_spec = subg.subgraph logging.info("synthesized...") for node in subg_spec: logging.info("%s", node.op.name) logging.info("%s", node.output_names) logging.info("") fingerprint_orig = fingerprint_graph(self.graph, self.constants, self.input) fingerprint_new = fingerprint_graph(subg.graph, self.constants, self.input) self.assertNotEqual(fingerprint_orig, fingerprint_new) def test_synthesizer_easy_one(self): """Replacing [conv3x3(features = 64)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp]) def test_synthesizer_easy_two(self): """Replacing [conv3x3(features = 64)].""" py_rand.seed(10) subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp]) def test_synthesizer_one(self): """Replacing [conv3x3(features = 64), ReLU].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:6]] subg[-1].output_names = self.graph.ops[6].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp]) def test_synthesizer_two(self): """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp]) def test_synthesizer_resnet_small(self): self.graph, self.constants, _ = resnetv1.ResNet18( num_classes=10, input_resolution="small") self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)[self.graph.output_names[0]] self.max_size = int(10e8) subg_ops = [self.graph.ops[4]] + self.graph.ops[9:11] subg = [subgraph.SubgraphNode(op=o) for o in subg_ops] subg[-1].output_names = [f"{subg[-1].op.name}:0"] subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp]) def test_synthesizer_resnet_big(self): self.graph, self.constants, _ = resnetv1.ResNet18( num_classes=10, input_resolution="small") self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)[self.graph.output_names[0]] self.max_size = int(10e8) subg_ops = self.graph.ops[3:5] + self.graph.ops[8:12] subg = [subgraph.SubgraphNode(op=o) for o in subg_ops] subg[-1].output_names = [f"{subg[-1].op.name}:0"] subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp])
def train_and_eval( config, eval_perf = True, callback = None ): """Training loop + eval. Args: config: The training config. eval_perf: Whether to collect performance metrics. callback: A callback which accepts the current epoch number and performance metrics, and returns True to continue training or False to early stop. Returns: A tuple of the final metrics, number of epochs trained, and the state dict. """ if "train" in config.config_dict: train_config = config.config_dict.train else: train_config = config.config_dict rng = jax.random.PRNGKey(train_config.seed) is_host = jax.process_index() == 0 if is_host: # The pool is used to perform operations such as checkpointing in async way. pool = multiprocessing.pool.ThreadPool(2) else: pool = None # set up output directory if config.output_dir is not None: if is_host: if gfile.exists(config.output_dir): logging.warn("Output directory %s already exists.", config.output_dir) else: gfile.makedirs(config.output_dir) utils.write_to_store(config, f"{config.output_dir}/config") else: ready = False for _ in range(GFILE_TRIES): ready = gfile.exists(config.output_dir) if ready: break time.sleep(GFILE_SLEEP_SEC) if not ready: raise ValueError(f"Output directory {config.output_dir} was not " f"created within {GFILE_SLEEP_SEC * GFILE_TRIES} " "secs.") # get data num_devices = jax.device_count() batch_size = train_config.device_batch_size * num_devices if batch_size % num_devices != 0: raise ValueError("JAX num_devices {num_devices} does not divide batch_size " f"{batch_size}.") local_batch_size = batch_size // jax.process_count() local_batch_size_eval = local_batch_size * 8 if is_host: logging.info( "Global batch size %d on %d hosts results in %d local batch size. " "With %d dev per host (%d dev total), that's a %d per-device batch " "size.", batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) train_pp = preprocess.get_preprocess_fn( train_config.dataset_name, train_config.dataset.train_split, **train_config.dataset.get("preprocess_kwargs", {})) train_ds = input_pipeline.make_for_train( dataset=train_config.dataset_name, split=train_config.dataset.train_split, preprocess_fn=train_pp, batch_size=local_batch_size, shuffle_buffer_size=250_000, prefetch=2, cache_raw=False) train_iter = input_pipeline.start_input_pipeline( train_ds, n_prefetch=1) ntrain_img = input_pipeline.get_num_examples(train_config.dataset_name, train_config.dataset.train_split) steps_per_epoch = ntrain_img / batch_size total_steps = int(steps_per_epoch * train_config.epochs) eval_pp = preprocess.get_preprocess_fn(train_config.dataset_name, train_config.dataset.val_split) eval_ds, eval_steps = input_pipeline.make_for_inference( dataset=train_config.dataset_name, split=train_config.dataset.val_split, preprocess_fn=eval_pp, batch_size=local_batch_size_eval, cache_final=True, cache_raw=False, data_dir=None) eval_it = input_pipeline.start_input_pipeline(eval_ds, n_prefetch=1) # set up model graph = config.graph if isinstance(graph, tuple): graph, constants = graph[0], graph[1] else: constants = None if config.subgraph is not None: graph = replace_subgraph(graph, config.subgraph) if (config.inherit_weights and config.freeze_inherited and config.train_subg_outputs): # TODO(charlesjin) finish training with weight inheritance output_names = sum([node.output_names for node in config.subgraph], []) graph.output_names = output_names raise NotImplementedError model = Model(graph, constants) # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @partial(jax.jit, backend="cpu") def init(rng): image_size = tuple(train_ds.element_spec["image"].shape[1:]) dummy_input = jnp.zeros((1,) + image_size, jnp.float32) return flax.core.unfreeze(model.init(rng, dummy_input)) rng, rng_init = jax.random.split(rng) state_cpu = init(rng_init) params_cpu = state_cpu["params"] if "batch_stats" in state_cpu: # Non-param variable collections. Currently we only support the additional # collection batch_stats, which is Flax's convention for batchnorm. coll_cpu = {"batch_stats": state_cpu["batch_stats"]} else: coll_cpu = {} # weight inheritance if config.inherit_weights: if config.init_dir is None: raise ValueError("Cannot inherit weights without parent directory.") parent_state = bv_utils.load_checkpoint(None, f"{config.init_dir}/state") parent_params = parent_state["params"] old_params, new_params = inherit_params(params_cpu, parent_params) if config.freeze_inherited: trainable_params = new_params frozen_params = old_params else: trainable_params = {**old_params, **new_params} frozen_params = {} else: trainable_params = params_cpu frozen_params = {} if is_host: if trainable_params: logging.info("trainable params:") for key in trainable_params.keys(): logging.info(" %s", key) else: logging.warn("WARNING: no trainable params!") if frozen_params: logging.info("frozen params:") for key in frozen_params.keys(): logging.info(" %s", key) # training step @partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1, 3,)) def train_step(opt, params, other_params, coll, data, labels, rng): """Trains for a single step.""" # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch")) def loss_fn(params): all_params = {**params, **other_params} logits, new_coll = Model(graph, constants).apply( flax.core.freeze({ "params": all_params, **coll }), data, rngs={"dropout": rng_model_local}, mutable=list(coll.keys()), deterministic=False, training=True) loss = jnp.mean( bv_utils.softmax_xent( logits=logits, labels=labels)) return loss, (logits, loss, new_coll) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) aux, grads = grad_fn(params) _, loss, new_coll = aux[1] grads = jax.lax.pmean(grads, axis_name="batch") updates, opt = tx.update(grads, opt, params) params = optax.apply_updates(params, updates) return opt, params, new_coll, jax.lax.psum(loss, axis_name="batch"), rng cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, axis_name="batch"), axis_name="batch") # eval step @partial(jax.pmap, axis_name="batch") def eval_step(params, coll, data, labels, mask): mask *= labels.max(axis=1) logits = Model(graph, constants).apply( flax.core.freeze({ "params": params, **coll }), data, deterministic=True, training=False) loss = jnp.mean( bv_utils.softmax_xent( logits=logits, labels=labels)) top1_idx = jnp.argmax(logits, axis=1) top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] correct = top1_correct * mask return (jax.lax.psum(correct, axis_name="batch"), jax.lax.psum(loss, axis_name="batch"), jax.lax.psum(mask, axis_name="batch")) def eval_model(params, coll, eval_it): total_correct = 0 total_loss = 0 total = 0 eval_time = 0 eval_start = time.time() for _, batch in zip(range(eval_steps), eval_it): correct, loss, neval = eval_step(params, coll, batch["image"], batch["labels"], batch["_mask"]) total_correct += jnp.sum(correct[0]) total_loss += jnp.sum(loss[0]) total += jnp.sum(neval[0]) if total: total.block_until_ready() eval_time += time.time() - eval_start return total_correct, total_loss, total, eval_time if eval_perf and is_host: num_params = perf_tools.compute_num_params(params_cpu) image_size = tuple(train_ds.element_spec["image"].shape[1:]) dummy_input = jnp.zeros((1,) + image_size, jnp.float32) apply_fn = lambda v, inp: model.apply( # pylint: disable=g-long-lambda v, inp, deterministic=True, training=False) flops = perf_tools.compute_num_flops( apply_fn, True, # optimize flax.core.freeze({ "params": params_cpu, **coll_cpu }), dummy_input) print(f"num_params: {num_params} | flops: {flops}") else: num_params = 0 flops = 0 im_sec_core_eval_measurements = np.array([]) im_sec_core_train_measurements = np.array([]) last_step = 0 checkpoint_extra = dict( im_sec_core_eval_measurements=im_sec_core_eval_measurements, im_sec_core_train_measurements=im_sec_core_train_measurements, step=last_step) if config.output_dir is not None: checkpoint_path = f"{config.output_dir}/checkpoint.npz" else: checkpoint_path = None if trainable_params: tx, _ = bv_optax.make(train_config.optim, params_cpu, sched_kw=dict( global_batch_size=batch_size, total_steps=total_steps, steps_per_epoch=steps_per_epoch)) opt_cpu = jax.jit(tx.init, backend="cpu")(trainable_params) # EMA ema_decay = train_config.get("ema_decay", 0) if ema_decay: end_warmup_step = train_config.get("ema_warmup_steps", 1560) ema_state_cpu = {"params": params_cpu, "coll": coll_cpu} ema_manager = train_utils.ExponentialMovingAverage(ema_state_cpu, ema_decay, end_warmup_step) @partial(jax.pmap, axis_name="batch") def update_ema(step, params, collection, ema): ema_state = {"params": params, "coll": collection} return ema.update_moving_average(ema_state, step) else: update_ema = ema_manager = ema_state_cpu = None # Load checkpoint if already exists if checkpoint_path and gfile.exists(checkpoint_path): checkpoint = { "opt": opt_cpu, "coll": coll_cpu, "params": params_cpu, "ema_state": ema_state_cpu, "extra": checkpoint_extra } checkpoint_tree = jax.tree_structure(checkpoint) loaded = bv_utils.load_checkpoint(checkpoint_tree, checkpoint_path) # bfloat16 type gets lost when data is saved to disk, so we recover it. checkpoint = jax.tree_map(bv_utils.recover_dtype, loaded) opt_cpu, coll_cpu, params_cpu, ema_state_cpu, checkpoint_extra = ( checkpoint["opt"], checkpoint["coll"], checkpoint["params"], checkpoint["ema_state"], checkpoint["extra"]) im_sec_core_eval_measurements = checkpoint_extra[ "im_sec_core_eval_measurements"] im_sec_core_train_measurements = checkpoint_extra[ "im_sec_core_train_measurements"] last_step = checkpoint_extra["step"] if ema_manager and ema_state_cpu: ema_manager = ema_manager.replace(state=ema_state_cpu) logging.info("Loaded checkpoint at step %d (%d total).", last_step, total_steps) else: opt_cpu = None update_ema = ema_manager = None do_last_eval = True eval_is_compiled = False last_step = bv_optax.get_count(opt_cpu) if trainable_params and last_step < total_steps: trainable_params_repl = flax_utils.replicate(trainable_params) opt_repl = flax_utils.replicate(opt_cpu) coll_repl = flax_utils.replicate(coll_cpu) rng, rng_loop = jax.random.split(rng, 2) rngs_loop = flax_utils.replicate(rng_loop) frozen_repl = flax_utils.replicate(frozen_params) if ema_manager: ema_manager_repl = flax_utils.replicate(ema_manager) else: ema_manager_repl = None def ema_repl_to_state_cpu(ema_manager_repl): if ema_manager_repl is None: return None ema_trainable_params_repl = ema_manager_repl.state["params"] ema_trainable_params_cpu = jax.tree_map(lambda x: np.array(x[0]), ema_trainable_params_repl) ema_coll_repl = ema_manager_repl.state["coll"] ema_coll_cpu = jax.tree_map(lambda x: np.array(x[0]), ema_coll_repl) ema_state_cpu = {"params": ema_trainable_params_cpu, "coll": ema_coll_cpu} return ema_state_cpu write_checkpoints = ( is_host and checkpoint_path is not None and config.checkpoint_steps) step = last_step epoch = int(last_step / steps_per_epoch) + 1 loss = 0 train_time = 0 checkpoint_writer = None if is_host: logging.info( "Training on dataset %s for %d total epochs (starting from %d).", train_config.dataset_name, train_config.epochs, epoch) for step, train_batch in zip( range(last_step + 1, total_steps + 1), train_iter): step_start = time.time() do_last_eval = True (opt_repl, trainable_params_repl, coll_repl, loss_repl, rngs_loop) = train_step(opt_repl, trainable_params_repl, frozen_repl, coll_repl, train_batch["image"], train_batch["labels"], rngs_loop) if update_ema is not None: step_repl = flax_utils.replicate(step) ema_manager_repl = update_ema(step_repl, trainable_params_repl, coll_repl, ema_manager_repl) loss += loss_repl[0] if step > steps_per_epoch * epoch or step == total_steps: line = (f"epoch {epoch:d}" f" | train loss {loss / steps_per_epoch:.1f}") if coll_cpu: coll_repl = cross_replica_mean(coll_repl) train_time += time.time() - step_start if epoch > 1: train_im = steps_per_epoch * batch_size im_sec_core_train = train_im / num_devices / train_time im_sec_core_train_measurements = np.append( im_sec_core_train_measurements, im_sec_core_train) if epoch % train_config.log_epochs == 0: if ema_manager_repl is not None: trainable_params_repl_eval = ema_manager_repl.state["params"] coll_repl_eval = ema_manager_repl.state["coll"] else: trainable_params_repl_eval = trainable_params_repl coll_repl_eval = coll_repl params_repl = {**trainable_params_repl_eval, **frozen_repl} correct, loss, n_eval, eval_time = eval_model(params_repl, coll_repl_eval, eval_it) if eval_is_compiled: eval_im = int(n_eval) im_sec_core_eval = eval_im / num_devices / eval_time im_sec_core_eval_measurements = np.append( im_sec_core_eval_measurements, im_sec_core_eval) eval_is_compiled = True line += (f" | val loss {loss:.2f}" f" | val acc {correct / n_eval * 100:.3f}%" f" ({int(correct)} / {int(n_eval)})") do_last_eval = False if step < total_steps and callback: metrics = Metrics( loss=loss, acc=correct / n_eval, num_params=num_params, flops=flops, im_sec_core_infer=(np.median(im_sec_core_eval_measurements) if len(im_sec_core_eval_measurements) else 0), im_sec_core_train=(np.median(im_sec_core_train_measurements) if len(im_sec_core_train_measurements) else 0)) if not callback(epoch, metrics): line += " | EARLY STOPPED" if is_host: logging.info(line) break if is_host: logging.info(line) logging.info("Train measurements stddev: %.2f", np.std(im_sec_core_train_measurements)) logging.info("Eval measurements stddev: %.2f", np.std(im_sec_core_eval_measurements)) loss = 0 epoch += 1 train_time = 0 train_time += time.time() - step_start if write_checkpoints and pool and step % config.checkpoint_steps == 0: assert pool is not None bv_utils.checkpointing_timeout(checkpoint_writer, 10) checkpoint_extra[ "im_sec_core_eval_measurements"] = im_sec_core_eval_measurements checkpoint_extra[ "im_sec_core_train_measurements"] = im_sec_core_train_measurements checkpoint_extra["step"] = step # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to # debug memory errors (see b/160593526). Also, takes device 0's params # only. opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) coll_cpu = jax.tree_map(lambda x: np.array(x[0]), coll_repl) trainable_params_cpu = jax.tree_map(lambda x: np.array(x[0]), trainable_params_repl) params_cpu = {**trainable_params_cpu, **frozen_params} ema_state_cpu = ema_repl_to_state_cpu(ema_manager_repl) # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint = { "opt": opt_cpu, "coll": coll_cpu, "params": params_cpu, "ema_state": ema_state_cpu, "extra": checkpoint_extra } checkpoint_writer = pool.apply_async(bv_utils.save_checkpoint, (checkpoint, checkpoint_path)) coll_cpu = jax.tree_map(lambda x: np.array(x[0]), coll_repl) opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) trainable_params_cpu = jax.tree_map(lambda x: np.array(x[0]), trainable_params_repl) params_cpu = {**trainable_params_cpu, **frozen_params} params_repl = {**trainable_params_repl, **frozen_repl} ema_state_cpu = ema_repl_to_state_cpu(ema_manager_repl) if ema_manager: coll_repl_eval = ema_manager_repl.state["coll"] params_repl_eval = {**ema_manager_repl.state["params"], **frozen_repl} else: coll_repl_eval = coll_repl params_repl_eval = params_repl else: epoch = 0 coll_repl = flax_utils.replicate(coll_cpu) params_cpu = frozen_params params_repl = flax_utils.replicate(params_cpu) ema_state_cpu = None coll_repl_eval = coll_repl params_repl_eval = params_repl if do_last_eval: correct, loss, n_eval, eval_time = eval_model(params_repl_eval, coll_repl_eval, eval_it) if eval_is_compiled: eval_im = int(n_eval) im_sec_core_eval = eval_im / num_devices / eval_time im_sec_core_eval_measurements = np.append(im_sec_core_eval_measurements, im_sec_core_eval) eval_is_compiled = True if eval_perf and not len(im_sec_core_eval_measurements): # pylint: disable=g-explicit-length-test (can't check len on numpy arrays) assert eval_is_compiled correct, loss, n_eval, eval_time = eval_model(params_repl_eval, coll_repl_eval, eval_it) eval_im = int(n_eval) im_sec_core_eval = eval_im / num_devices / eval_time im_sec_core_eval_measurements = np.append(im_sec_core_eval_measurements, im_sec_core_eval) checkpoint_extra[ "im_sec_core_eval_measurements"] = im_sec_core_eval_measurements checkpoint_extra[ "im_sec_core_train_measurements"] = im_sec_core_train_measurements checkpoint_extra["step"] = step checkpoint = { "opt": opt_cpu, "coll": coll_cpu, "params": params_cpu, "ema_state": ema_state_cpu, "extra": checkpoint_extra } if checkpoint_path is not None and is_host and pool: checkpoint_writer = pool.apply_async(bv_utils.save_checkpoint, (checkpoint, checkpoint_path)) metrics = Metrics( loss=loss, acc=correct / n_eval, num_params=num_params, flops=flops, im_sec_core_infer=(np.median(im_sec_core_eval_measurements) if len(im_sec_core_eval_measurements) else 0), im_sec_core_train=(np.median(im_sec_core_train_measurements) if len(im_sec_core_train_measurements) else 0)) if ema_state_cpu: state = ema_state_cpu else: state = {"coll": coll_cpu, "params": params_cpu} return metrics, epoch, state
class ProgSequentialTest(test.TestCase): def setUp(self): super().setUp() seed = int(time.time()) logging.info("Seed: %d", seed) py_rand.seed(seed) self.graph, self.constants, _ = cnn.CifarNet() self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)["fc/logits"] self.max_size = int(10e8) self.hard = False def _synthesize(self, subg, props): synthesizer = ProgressiveSequentialSynthesizer( [(subg, props)], generation=0, mode=ProgressiveSequentialSynthesizer.Mode.WEIGHTED, max_len=3) subg = synthesizer.synthesize()[0] subg_spec = subg.subgraph for node in subg_spec: print(node.op.name) print(node.output_names) m = Model(subg.graph, self.constants) state = m.init(random.PRNGKey(0), self.input) out = m.apply(state, self.input)["fc/logits"] self.assertTrue((out != self.out).any()) def test_synthesizer_easy_one(self): """Replacing [conv3x3(features = 64)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) # lp = linear.LinopProperty().infer(subgraph) self._synthesize(subgraph_model, [sp, dp]) def test_synthesizer_easy_two(self): """Replacing [conv3x3(features = 64)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp]) def test_synthesizer_one(self): """Replacing [conv3x3(features = 64), ReLU].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:6]] subg[-1].output_names = self.graph.ops[6].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) # lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp]) def test_synthesizer_two(self): """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) # dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp]) def test_synthesizer_hard(self): if not self.hard: return subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp])
class EnumSequentialTest(test.TestCase): def setUp(self): super().setUp() self.graph, self.constants, _ = cnn.CifarNet() self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)["fc/logits"] self.max_size = int(10e8) self.hard = False def test_sequence_generator(self): """Test the sequence_generator function for correctness. seq_generator should generate [[0], [1], [2], [0,0], [0,1], [0,2], [1,0], [1,1], [1,2], ..., [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], ...,] """ def el_generator(): i = 0 while True: if i > 2: return yield i i += 1 seqs = list(sequence_generator(el_generator, 3)) self.assertLen(seqs, 3 + 3**2 + 3**3) for i in range(len(seqs)): if i < 3: self.assertEqual(seqs[i], [i]) elif i < 3 + 3**2: self.assertEqual(seqs[i], [(i - 3) // 3, i % 3]) else: self.assertEqual(seqs[i], [(i - 12) // 3 // 3, (i - 12) // 3 % 3, i % 3]) def test_kwargs_for_op_to_product(self): op_kwargs = {"a": [1, 2, 3], "b": [1, 2], "c": [5, 6]} input_kwargs = {"d": [1], "e": [1, 2], "f": [3, 4]} product = EnumerativeSequentialSynthesizer.kwargs_for_op_to_product( op_kwargs, input_kwargs) expected_length = 1 for _, v in op_kwargs.items(): expected_length *= len(v) for _, v in input_kwargs.items(): expected_length *= len(v) self.assertLen(product, expected_length) op_setting = {"a": 2, "b": 2, "c": 5} input_setting = {"d": 1, "e": 2, "f": 3} self.assertIn((op_setting, input_setting), product) def _synthesize(self, subg, props): synthesizer = EnumerativeSequentialSynthesizer([(subg, props)], 0, max_len=3) subg = synthesizer.synthesize()[0] m = Model(subg.graph, self.constants) state = m.init(random.PRNGKey(0), self.input) out = m.apply(state, self.input)["fc/logits"] self.assertTrue((out != self.out).any()) def test_synthesizer_easy_one(self): """Replacing [conv3x3(features = 64)]. Because we do not test linear, this is replaced by dense3x3(features = 64) due to the enumeration order. """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp]) def test_synthesizer_easy_two(self): """Replacing [conv3x3(features = 64)]. Because we test all three props, this is replaced by conv3x3(features = 64) (i.e., an identical op) due to the enumeration order. """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp]) def test_synthesizer_one(self): """Replacing [conv3x3(features = 64), ReLU]. Because we do not check for the linear property, [dense(features = 64), ReLU] works as well (which is what is synthesized due to the enumeration order). """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:6]] subg[-1].output_names = self.graph.ops[6].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) # lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp]) def test_synthesizer_two(self): """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)]. Because we do not check for the depth property, [dense(features = 64), avgpool2x2(strides=2x2)] works as well (which is what is synthesized due to the enumeration order). """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp]) def test_synthesizer_hard(self): if not self.hard: return subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp])
class SubgraphModel(): """A concrete subgraph. A concrete subgraph consists of: - The full graph, in which the subgraph is *already* embedded, i.e., you should call replace_subgraph BEFORE creating a ConcreteSubgraph! - An instantiation of the full graph, as specified by the state (i.e., parameters). If None, the subgraph is treated as abstract. - An execution of the full graph, as specified by a set of inputs. - A specification of the subgraph, as defined by the list of subgraph nodes. If None, the subgraph is just the full graph. """ def __init__(self, graph, constants, state, inputs, subgraph = None): self.graph = graph self.constants = constants self.state = state self.inputs = inputs self.subgraph: SubgraphSpec = subgraph if subgraph else [] self.input_names = None self.output_names = None self.original_outputs = graph.output_names if subgraph: self._subgraph_to_names() # graph for graph inputs -> subg inputs self.subg_inputs_graph = copy.deepcopy(graph) self.subg_inputs_graph.output_names = self.input_names self.subg_inputs_model = Model(self.subg_inputs_graph, self.constants) self.subg_inputs = None # graph for graph inputs -> subg outputs self.subg_outputs_graph = copy.deepcopy(graph) self.subg_outputs_graph.output_names = self.output_names self.subg_outputs_model = Model(self.subg_outputs_graph, self.constants) self.subg_outputs = None # graph for subg inputs -> subg outputs subg_ops = [node.op for node in subgraph] self.subg_graph = new_graph(self.input_names, self.output_names, subg_ops) self.subg_model = Model(self.subg_graph, self.constants) else: self.input_names = [ canonicalize_tensor_name(name) for name in graph.input_names ] self.output_names = [ canonicalize_tensor_name(name) for name in graph.output_names ] # subg inputs = inputs to the graph self.subg_inputs_graph = None self.subg_inputs_model = None self.subg_inputs = inputs # graph for graph inputs -> subg outputs self.subg_outputs_graph = copy.deepcopy(graph) self.subg_outputs_model = Model(self.subg_outputs_graph, self.constants) self.subg_outputs = None # subg outputs = full graph outputs self.subg_graph = self.subg_outputs_graph self.subg_model = self.subg_outputs_model def _subgraph_to_names(self): """Populates the incoming and outgoing edges of the subgraph.""" assert self.subgraph input_names = [] output_names = [] produced = [] for node in self.subgraph: # check to see which inputs are incoming edges to the subgraph for input_name in node.op.input_names: if input_name not in produced and input_name not in input_names: input_names.append(input_name) # keep track of produced tensors (internal edges in the subgraph) for idx in range(node.op.num_outputs): produced.append(f"{node.op.name}:{idx}") # only the rewired outputs become externally visible to the graph for idx, output_name in enumerate(node.output_names): if output_name is not None: output_names.append(f"{node.op.name}:{idx}") self.input_names = input_names self.output_names = output_names def get_subg_inputs( self, graph_inputs, intermediates = False, ): """Returns the inputs to the subgraph given inputs to the full graph. Args: graph_inputs: The dictionary of input values to the full graph. intermediates: Whether to return all the inputs. Returns: The inputs to the subgraph. Raises: ValueError: If execution is necessary, but state is not provided. """ # if no self.subg_inputs_model, then the subgraph is the full graph, so the # input to the subgraph is the same as the input to the full graph if not self.subg_inputs_model: return graph_inputs # execute the subg_inputs_model if not self.state: raise ValueError("Cannot execute subgraph without state.") if intermediates: old_output_names = self.subg_inputs_model.graph.output_names self.subg_inputs_model.graph.output_names = [] subg_inputs = self.subg_inputs_model.apply(self.state, graph_inputs) if intermediates: self.subg_inputs_model.graph.output_names = old_output_names return subg_inputs def get_default_subg_inputs(self): """Returns the default inputs to the subgraph.""" if self.subg_inputs is not None: return self.subg_inputs self.subg_inputs = self.get_subg_inputs(self.inputs) return self.subg_inputs def get_subg_outputs( self, graph_inputs ): """Returns the output from the subgraph given inputs to the full graph. Args: graph_inputs: The dictionary of input values to the full graph. If None, defaults to the stored input values. Returns: The outputs of the subgraph. Raises: ValueError: If execution is necessary, but state is not provided. """ # execute the subg_outputs_model if not self.state: raise ValueError("Cannot execute subgraph without state.") return self.subg_outputs_model.apply(self.state, graph_inputs) def get_default_subg_outputs(self): """Returns the default outputs of the subgraph.""" if self.subg_outputs is not None: return self.subg_outputs subg_inputs = self.get_default_subg_inputs() self.subg_outputs = self.execute_subg(subg_inputs) return self.subg_outputs def execute_subg( self, inputs ): """Returns the output from the subgraph given inputs to the subgraph. Args: inputs: The dictionary of input values to the subgraph. Returns: The outputs of the subgraph. Raises: ValueError: If state is not provided. """ if not self.state: raise ValueError("Cannot execute subgraph without state.") return self.subg_model.apply(self.state, inputs) def update_subg_outputs(self, output_names): """Updates the outputs of the subgraph. Args: output_names: The list of new output_names. Raises: ValueError: If output_names are not produced in the subgraph. """ for output_name in output_names: found = False for op in self.subg_graph.ops: for idx in range(op.num_outputs): if output_name == f"{op.name}:{idx}": found = True break if found: break if not found: raise ValueError(f"Requested output {output_name} not in subgraph.") self.output_names = output_names self.subg_graph.output_names = output_names self.subg_model.graph.output_names = output_names self.subg_outputs_graph.output_names = output_names self.subg_outputs_model.graph.output_names = output_names