def save_samples_to_json(features: List[Dict[str, Any]], config: ml_collections.ConfigDict, step: int): """Save samples to a json file.""" save_samples_for_this_step = ( config.get('save_samples_every_steps') and (step % config.get('save_samples_every_steps') == 0)) process_index = jax.process_index() accepted_processes = config.get('save_samples_process_ids', 0) if isinstance(accepted_processes, list): save_samples_for_this_process = (process_index in accepted_processes) elif accepted_processes == -1: save_samples_for_this_process = True else: save_samples_for_this_process = (process_index == accepted_processes) if save_samples_for_this_step and save_samples_for_this_process: logging.info('Saving samples at step %d, process %d', step, process_index) path = os.path.join(config.model_dir, 'samples', 'step_%d.process_%d.json' % (step, process_index)) tf.io.gfile.makedirs(os.path.dirname(path)) with tf.io.gfile.GFile(path, 'ab') as fp: for batch in features: json.dump(batch, fp) fp.write('\n')
def config_to_opt_args(config: ml_collections.ConfigDict): opt_kwargs = dict(eps=config.get('opt_eps'), decay=config.get('opt_decay'), momentum=config.get('opt_momentum'), beta1=config.get('opt_beta1'), beta2=config.get('opt_beta2'), weight_decay=config.get('opt_weight_decay', 0)) opt_kwargs = {k: v for k, v in opt_kwargs.items() if v is not None} return opt_kwargs
def get_optimizer( config: ml_collections.ConfigDict) -> tf.keras.optimizers.Optimizer: """Returns an optimizer based on the given configuration. Supports the Adam optimizer. Default values for optional optimizer parameters come from TensorFlow Core v2.4.1: https://www.tensorflow.org/api_docs/python/tf/keras/optimizers. Args: config: A ConfigDict containing a `config.opt` sub-config. Returns: A tf.keras optimizer. Raises: ValueError: `config` missing the `opt` sub-config. ValueError: `config.opt` contains an invalid optimizer class. ValueError: `config.schedule` contains an invalid schedule class. """ opt_config = config.get('opt', None) if opt_config is None: raise ValueError(f'Provided `config` missing `opt` sub-config: {config}') initial_learning_rate = opt_config.get('learning_rate', 0.001) steps_per_epoch = config.dataset.num_train_examples / config.dataset.batch_size schedule_config = config.get('schedule', {}) schedule_type = schedule_config.get('schedule', None) if schedule_type == 'exponential': decay_steps = int(schedule_config.epochs_per_decay * steps_per_epoch) learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate, decay_steps=decay_steps, decay_rate=schedule_config.decay_rate, staircase=schedule_config.staircase) elif schedule_type is None: learning_rate = initial_learning_rate print(f'No LR schedule provided. Using a fixed LR of "{learning_rate}".') else: raise ValueError(f'Unknown scheduler name: {schedule_type}') opt_type = opt_config.get('optimizer', None) if opt_type == 'adam': opt = tf.keras.optimizers.Adam( learning_rate=learning_rate, beta_1=opt_config.get('beta_1', 0.9), beta_2=opt_config.get('beta_2', 0.999), epsilon=opt_config.get('epsilon', 1e-07), amsgrad=opt_config.get('amsgrad', False)) start_step = int(steps_per_epoch * config.train.initial_epoch) if opt_config.get('use_model_averaging', True): opt = tfa.optimizers.MovingAverage( opt, average_decay=0.9999, start_step=start_step) return opt raise ValueError(f'Unknown optimizer name: {opt_type}')
def get_defaults(): config = ConfigDict() # Covariance matrix stabilising jitter config.jitter = 1e-6 # Parameter transformations config.transformations = transformations = ConfigDict() transformations.positive_transform = tfb.Softplus transformations.identity_transform = tfb.Identity transformations.lengthscale = "positive_transform" transformations.variance = "positive_transform" transformations.obs_noise = "positive_transform" transformations.latent = "identity_transform" transformations.basis_fns = "identity_transform" return config
def build_model_graph( config: ml_collections.ConfigDict, outcomes: List[ml_collections.ConfigDict]) -> tf.keras.Model: """Returns a tf.keras.Model configured with the given ConfigDict. Args: config: A ConfigDict containing a `config.model` sub-config. outcomes: A list of outcome ConfigDict instances. Returns: A tf.keras.Model. Raises: ValueError: `config` missing the `model` sub-config. ValueError: `config.model` contains an invalid model backbone. """ model_config = config.get('model', None) if model_config is None: raise ValueError(f'Provided `config` missing `model` sub-config: {config}') model_backbone = model_config.get('backbone', None) # Config specifies model choice. if model_backbone == 'inceptionv3': return inceptionv3(model_config, outcomes) raise ValueError(f'Unknown model backbone: {model_backbone}')
def main(cfg, make_env, experiment_name): exp_dir = os.path.join(cfg.save_dir, experiment_name) if not os.path.exists(exp_dir): os.makedirs(exp_dir) with open(os.path.join(exp_dir, "config.yaml"), "w") as fp: yaml.dump(ConfigDict.to_dict(cfg), fp) else: raise ValueError("Experiment already exists.") logger = Logger( exp_dir, save_tb=True, log_frequency=cfg.log_frequency, agent="sac", ) utils.set_seed_everywhere(cfg.seed) device = torch.device(cfg.device) env = make_env() cfg.sac.obs_dim = env.observation_space.shape[0] cfg.sac.action_dim = env.action_space.shape[0] cfg.sac.action_range = [ float(env.action_space.low.min()), float(env.action_space.high.max()), ] agent = SACAgent(cfg.sac, device) replay_buffer = ReplayBuffer( env.observation_space.shape, env.action_space.shape, int(cfg.replay_buffer_capacity), device, ) video_recorder = VideoRecorder(exp_dir if cfg.save_video else None) train(env, logger, video_recorder, cfg, agent, replay_buffer)
def dummy_input(config: ml_collections.ConfigDict) -> Dict[Text, Any]: """Produces model-specific dummy input batch. See BaseTask.""" if config.get('max_length_with_entity_tokens') is not None: max_length = config.max_length_with_entity_tokens else: max_length = config.model_config.encoder_config.max_length bsz = config.per_device_batch_size text_shape = (bsz, max_length) mention_shape = (config.max_mentions) int_type = jnp.int32 position_ids = np.arange(max_length) position_ids = np.tile(position_ids, (bsz, 1)) dummy_input = { 'text_ids': jnp.ones(text_shape, int_type), 'text_mask': jnp.ones(text_shape, int_type), 'position_ids': jnp.asarray(position_ids, int_type), 'segment_ids': jnp.zeros(text_shape, int_type), 'classifier_target': jnp.ones(bsz, int_type), 'mention_start_positions': jnp.zeros(mention_shape, int_type), 'mention_end_positions': jnp.zeros(mention_shape, int_type), 'mention_mask': jnp.ones(mention_shape, int_type), 'mention_batch_positions': jnp.ones(mention_shape, int_type), } return dummy_input
def get_default_config(): config = ConfigDict() config.online = False config.prefix = 'SimpleSAC' config.project = 'sac' config.output_dir = '/tmp/SimpleSAC' config.random_delay = 0.0 config.experiment_id = '' return config
def load_config_from_dir(exp_dir): """Load experiment config.""" try: with open(os.path.join(exp_dir, "config.yaml"), "r") as fp: cfg = yaml.load(fp, Loader=yaml.FullLoader) return ConfigDict(cfg) except FileNotFoundError as e: raise e
def get_config(): config = ConfigDict() config.constructor = "xgboost.XGBRegressor" config.hparams = ConfigDict({ 'tree_method': "gpu_hist", 'booster': "gbtree", 'n_estimators': 250, 'learning_rate': 1e-2, 'max_depth': 5, 'reg_alpha': 1.0, 'reg_lambda': 1.0, 'min_child_weight': 0.0, 'subsample': 0.8, 'colsample_bytree': 0.8, 'num_parallel_tree': 1, }) return config
def get_config(self): config = ConfigDict() config.hidden_size = 128 config.ff_size = 256 config.num_heads = 2 config.num_encoder_layers = 2 config.num_symbols = 8 return config
def load_data(idx, mode='test'): run = api.run(os.path.join(USER, PROJECT, idx)) data_module, _ = load_dataset(ConfigDict(run.config)) data_module.prepare_data() if mode == 'test': data_module.setup('test') return data_module.test_dataloader() elif mode == 'train': data_module.setup('fit') return data_module.train_dataloader()
def build_outcome_head(config: ml_collections.ConfigDict, inputs: tf.Tensor, l2: float = 0.0) -> tf.Tensor: """Returns an output head tensor configured for the given outcome. Supports regression, binary classification, and multinomial classification outcomes. Note: binary classification labels are assumed to be of shape (2,). Binary heads consist of a `tf.keras.layers.Dense(2)` with a softmax activation. Args: config: An outcome ConfigDict. inputs: The backbone output tensor; used as the input to the head. l2: The l2 regularization factor used in `tf.keras.layers.Dense` layers. Returns: A tensor representing the output of the given head. Raises: ValueError: `config` missing a valid `type`. ValueError: A binary classification config uses num_classes=1 rather than num_classes=2. """ outcome_type = config.get('type', None) if outcome_type is None: raise ValueError(f'Provided `config` missing `type`: {config}') l2_regularizer = tf.keras.regularizers.L2(l2) if l2 else None if outcome_type == 'regression': head = tf.keras.layers.Dense( 1, dtype=tf.float32, name=config.name, kernel_regularizer=l2_regularizer) return head(inputs) if outcome_type == 'classification': if config.num_classes < 2: raise ValueError('Binary heads should specify `config.num_classes=2`.' 'Binary labels are assumed to be one-hot vectors.') head = tf.keras.layers.Dense( config.num_classes, activation='softmax', dtype=tf.float32, name=config.name, kernel_regularizer=l2_regularizer) return head(inputs) raise ValueError(f'Unknown outcome type: {outcome_type}')
def inceptionv3(model_config: ml_collections.ConfigDict, outcomes: List[ml_collections.ConfigDict]) -> tf.keras.Model: """Returns an InceptionV3 architecture as defined by the configuration. See https://tensorflow.org/api_docs/python/tf/keras/applications/InceptionV3. Args: model_config: A ConfigDict containing model hyperparamters. outcomes: A list of outcome ConfigDict instances. Returns: An InceptionV3-based model. """ input_shape = model_config.get('input_shape', DEFAULT_IMAGE_SHAPE) backbone = tf.keras.applications.InceptionV3( include_top=False, weights=model_config.get('weights', 'imagenet'), input_shape=input_shape, pooling=model_config.get('pooling', 'avg')) weight_decay = model_config.get('weight_decay', 0.0) if weight_decay: backbone = add_l2_regularizers( backbone, tf.keras.layers.Conv2D, l2=weight_decay) backbone_drop_rate = model_config.get('backbone_drop_rate', 0.2) inputs_image = tf.keras.Input(shape=input_shape, name='image') hid = backbone(inputs_image) hid = tf.keras.layers.Dropout(backbone_drop_rate)(hid) outputs = [] for outcome in outcomes: outputs.append(build_outcome_head(outcome, hid, l2=weight_decay)) model = tf.keras.Model( inputs=[inputs_image], outputs=outputs, name=model_config.backbone) model.summary() print(f'Number of l2 regularizers: {len(model.losses)}.') return model
def setup_experiment(exp_dir): """Initializes a training experiment.""" if os.path.exists(exp_dir): if not FLAGS.resume: raise ValueError( "Experiment already exists. Run with --resume to continue.") with open(os.path.join(exp_dir, "config.yaml"), "r") as fp: cfg = yaml.load(fp, Loader=yaml.FullLoader) FLAGS.config.update(cfg) else: os.makedirs(exp_dir) with open(os.path.join(exp_dir, "config.yaml"), "w") as fp: yaml.dump(ConfigDict.to_dict(FLAGS.config), fp)
def create_train_state(config: ml_collections.ConfigDict, params, model_state): """Create initial training state.""" dynamic_scale = None platform = jax.local_devices()[0].platform if config.half_precision and platform == 'gpu': dynamic_scale = flax.optim.DynamicScale() opt_kwargs = dict(eps=config.get('opt_eps'), beta1=config.get('opt_beta1'), beta2=config.get('opt_beta2'), weight_decay=config.get('opt_weight_decay', 0)) opt_kwargs = {k: v for k, v in opt_kwargs.items() if v is not None} # remove unset optimizer = create_optim(config.opt, params, **opt_kwargs) ema = EmaState.create(config.ema_decay, optimizer.target, model_state) state = TrainState(step=0, optimizer=optimizer, model_state=model_state, dynamic_scale=dynamic_scale, ema=ema) return state
def test_factorized_attention(self): config = ConfigDict() config.hidden_size = 256 config.ff_size = 256 config.num_encoder_layers = 2 config.num_heads = 2 fact = layers.FactorizedAttention(config) inputs = tf.random.uniform(shape=(8, 8, 8, 256)) output = fact(inputs) self.assertEqual(output.shape, (8, 8, 8, 256))
def make_loss_fn( cls, config: ml_collections.ConfigDict ) -> Callable[..., Tuple[float, MetricGroups, Dict[str, Any]]]: """Creates task loss function. See BaseTask. Model is trained using entity linking loss. Args: config: contains experiment hyperparameters. Returns: Loss function. """ el_score_mode = config.get('el_score_mode', 'dot') def loss_fn( model_config: ml_collections.FrozenConfigDict, model_params: Dict[str, Any], model_vars: Dict[str, Any], # pylint: disable=unused-argument batch: Dict[str, Any], deterministic: bool, dropout_rng: Optional[Dict[str, Array]] = None, ) -> Tuple[float, MetricGroups, Dict[str, Any]]: """Task-specific loss function. See BaseTask.""" loss_helpers, logging_helpers = cls.build_model(model_config).apply( # pylint: disable=unused-variable {'params': model_params}, batch, deterministic=deterministic, rngs=dropout_rng) mention_target_ids = batch['mention_target_ids'] mention_target_ids = mention_target_ids * batch['mention_target_weights'] (loss, el_final_metrics, _) = mention_losses.entity_linking_loss( loss_helpers['target_mention_encodings'], loss_helpers['entity_embeddings'], mention_target_ids, batch['mention_target_weights'], el_score_mode) metrics = {'agg': el_final_metrics} return loss, metrics, {} return loss_fn
def __init__(self, config: ml_collections.ConfigDict): self.memory_reduction = config.memory_reduction self.memory_entity_id_pattern = config.memory_entity_id_pattern self.memory_text_pattern = config.memory_text_pattern self.memory_positions_pattern = config.memory_positions_pattern self.save_k_retrieval = config.get('save_k_retrieval', 10) if self.save_k_retrieval is not None and config.model_config.encoder_config.get( 'k_top_post_selection') is None: raise Exception( 'save_k_retrieval only allowed with k_top_post_selection') # Lazy load for memory self.memory_entity_id = None self.memory_text = None self.memory_positions = None # TODO(urikz): Move `memory_prop` to `data_utils.load_sharded_array`, # e.g. array = array[:int(memory_prop * array.shape[0])] assert config.memory_prop is None
def get_loss(config: ml_collections.ConfigDict) -> tf.losses.Loss: """Returns a loss for use in training and evaluation. Args: config: A ConfigDict containing a `config.loss` name. Returns: A loss function. Raises: ValueError: `config.loss` is missing or an unknown loss name. """ loss_name = config.get('loss', None) if loss_name == 'ce': return tf.keras.losses.CategoricalCrossentropy(from_logits=False) if loss_name == 'bce': return tf.keras.losses.BinaryCrossentropy(from_logits=False) if loss_name == 'mse': return tf.keras.losses.MeanSquaredError() raise ValueError(f'Unknown loss name: {loss_name}')
def get_config(): """Experiment configuration.""" config = ConfigDict() # Data. config.dataset = 'imagenet' config.downsample = True config.downsample_res = 64 config.resolution = [256, 256] config.random_channel = True # Training. config.batch_size = 1 config.max_train_steps = 15000 config.save_checkpoint_secs = 900 config.num_epochs = -1 config.polyak_decay = 0.999 config.eval_num_examples = 20000 config.eval_batch_size = 16 config.eval_checkpoint_wait_secs = -1 config.optimizer = ConfigDict() config.optimizer.type = 'rmsprop' config.optimizer.learning_rate = 3e-4 # Model. config.model = ConfigDict() config.model.hidden_size = 32 config.model.ff_size = 32 config.model.num_heads = 1 config.model.num_encoder_layers = 1 config.model.resolution = [64, 64] config.model.name = 'color_upsampler' config.sample = ConfigDict() config.sample.gen_data_dir = '' config.sample.log_dir = 'samples' config.sample.batch_size = 1 config.sample.mode = 'argmax' config.sample.num_samples = 1 config.sample.num_outputs = 1 config.sample.skip_batches = 0 config.sample.gen_file = 'gen0' return config
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> TrainState: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: Final TrainState. """ writer = metric_writers.create_default_writer( logdir=workdir, just_logging=jax.host_id() != 0) rng = random.PRNGKey(0) image_size = 224 if config.batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = config.batch_size // jax.process_count() platform = jax.local_devices()[0].platform if config.half_precision: if platform == 'tpu': input_dtype = tf.bfloat16 else: input_dtype = tf.float16 else: input_dtype = tf.float32 dataset_builder = tfds.builder(config.dataset) train_iter = create_input_iter(dataset_builder, local_batch_size, image_size, input_dtype, train=True, cache=config.cache) eval_iter = create_input_iter(dataset_builder, local_batch_size, image_size, input_dtype, train=False, cache=config.cache) steps_per_epoch = (dataset_builder.info.splits['train'].num_examples // config.batch_size) if config.num_train_steps == -1: num_steps = int(steps_per_epoch * config.num_epochs) else: num_steps = config.num_train_steps if config.steps_per_eval == -1: num_validation_examples = dataset_builder.info.splits[ 'validation'].num_examples steps_per_eval = num_validation_examples // config.batch_size else: steps_per_eval = config.steps_per_eval steps_per_checkpoint = steps_per_epoch * 10 base_learning_rate = config.learning_rate * config.batch_size / 256. model_cls = getattr(models, config.model) model = create_model(model_cls=model_cls, half_precision=config.half_precision) learning_rate_fn = create_learning_rate_fn(config, base_learning_rate, steps_per_epoch) state = create_train_state(rng, config, model, image_size, learning_rate_fn) state = restore_checkpoint(state, workdir) # step_offset > 0 if restarting from checkpoint step_offset = int(state.step) state = jax_utils.replicate(state) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') train_metrics = [] hooks = [] if jax.process_index() == 0: hooks += [ periodic_actions.Profile(num_profile_steps=5, logdir=workdir) ] train_metrics_last_t = time.time() logging.info('Initial compilation, this might take some minutes...') for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) for h in hooks: h(step) if step == step_offset: logging.info('Initial compilation completed.') if config.get('log_every_steps'): train_metrics.append(metrics) if (step + 1) % config.log_every_steps == 0: train_metrics = common_utils.get_metrics(train_metrics) summary = { f'train_{k}': v for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items() } summary['steps_per_second'] = config.log_every_steps / ( time.time() - train_metrics_last_t) writer.write_scalars(step + 1, summary) train_metrics = [] train_metrics_last_t = time.time() if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch eval_metrics = [] # sync batch statistics across replicas state = sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) writer.write_scalars( step + 1, {f'eval_{key}': val for key, val in summary.items()}) writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = sync_batch_stats(state) save_checkpoint(state, workdir) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() return state
def get_config(): """Experiment configuration.""" config = ConfigDict() # Data. config.dataset = 'imagenet' config.downsample = True config.downsample_res = 64 config.resolution = [256, 256] # Training. config.batch_size = 7 config.max_train_steps = 450000 config.save_checkpoint_secs = 900 config.num_epochs = -1 config.polyak_decay = 0.999 config.eval_num_examples = 20000 config.eval_batch_size = 16 config.eval_checkpoint_wait_secs = -1 # loss hparams. config.loss_factor = 0.99 config.encoder_loss_factor = 0.01 config.optimizer = ConfigDict() config.optimizer.type = 'rmsprop' config.optimizer.learning_rate = 3e-4 # Model. config.model = ConfigDict() config.model.hidden_size = 512 config.model.stage = 'encoder_decoder' config.model.resolution = [64, 64] config.model.name = 'coltran_core' # encoder config.model.encoder = ConfigDict() config.model.encoder.ff_size = 512 config.model.encoder.hidden_size = 512 config.model.encoder.num_heads = 4 config.model.encoder.num_encoder_layers = 4 config.model.encoder.dropout = 0.0 # decoder config.model.decoder = ConfigDict() config.model.decoder.ff_size = 512 config.model.decoder.hidden_size = 512 config.model.decoder.resolution = [64, 64] config.model.decoder.num_heads = 4 config.model.decoder.num_inner_layers = 2 config.model.decoder.num_outer_layers = 2 config.model.decoder.dropout = 0.0 config.model.decoder.skip = True config.model.decoder.cond_mlp = 'affine' config.model.decoder.cond_mlp_act = 'identity' config.model.decoder.cond_ln_act = 'identity' config.model.decoder.cond_ln = True config.model.decoder.cond_ln_seq = 'sc' config.model.decoder.cond_ln_sp_ave = 'learnable' config.model.decoder.cond_ln_init = 'glorot_uniform' config.model.decoder.cond_att_init = 'glorot_uniform' config.model.decoder.cond_att_v = True config.model.decoder.cond_att_q = True config.model.decoder.cond_att_k = True config.model.decoder.cond_att_scale = True config.model.decoder.cond_att_act = 'identity' config.sample = ConfigDict() config.sample.log_dir = '' config.sample.batch_size = 1 config.sample.mode = 'sample' config.sample.num_samples = 1 config.sample.num_outputs = 1 config.sample.skip_batches = 0 config.sample.gen_file = 'gen0' return config
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") config.vocab_path = vocab_path tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=jax.local_device_count(), config=config, reverse_translation=config.reverse_translation, vocab_path=vocab_path) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: predict_ds = predict_ds.take(config.num_predict_steps) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) input_shape = (config.per_device_batch_size, config.max_target_length) target_shape = (config.per_device_batch_size, config.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam( config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if start_step == 0: writer.write_hparams(dict(config)) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap( functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=config.label_smoothing), axis_name="batch", donate_argnums=(0,)) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial( eval_step, config=eval_config), axis_name="batch") p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=config.max_predict_length, config=predict_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial( predict_step, config=predict_config, beam_size=config.beam_size), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed("training_metrics"): logging.info("Gathering training metrics.") train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed("eval"): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps) writer.write_scalars( step, {"eval_" + k: v for k, v in eval_results.items()}) with report_progress.timed("translate_and_bleu"): exemplars, bleu_score = translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=config.max_predict_length) writer.write_scalars(step, {"bleu": bleu_score}) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = (step % config.checkpoint_every_steps == 0 or is_last_step) if config.save_checkpoints and save_checkpoint and jax.host_id() == 0: with report_progress.timed("checkpoint"): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step)
def inference_time(config: ml_collections.ConfigDict, workdir: str): """Runs a number of steps and measures inference time.""" assert config.batch, f'Expected --config.batch={config.batch} > 0' assert config.num_classes, ( f'Expected --config.num_classes={config.num_classes} > 0') assert config.image_size, ( f'Expected --config.image_size={config.image_size} > 0') # Build VisionTransformer architecture model_config = config_lib.MODEL_CONFIGS[config.model_name] model = models.VisionTransformer( num_classes=config.num_classes, **model_config) # Make sure initial model parameters (before replication) are on CPU only. @functools.partial(jax.jit, backend='cpu') def init(rng): return model.init( rng, # Discard the "num_local_devices" dimension for initialization. inputs=jnp.ones([1, config.image_size, config.image_size, 3], jnp.float32), train=False) variables = init(jax.random.PRNGKey(0)) params_repl = flax_utils.replicate(variables['params']) # pmap replicates the models over all TPUs/GPUs vit_fn_repl = jax.pmap(functools.partial(model.apply, train=False)) images = jnp.ones([ jax.local_device_count(), config.batch // jax.local_device_count(), config.image_size, config.image_size, 3 ], jnp.float32) writer = metric_writers.create_default_writer(workdir, asynchronous=False) writer.write_hparams(config.to_dict()) logging.info('Starting training loop; initial compile can take a while...') logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images) logits.block_until_ready() logging.info('Done.') logging.info('Going to run %d inferences WITHOUT measuring...', config.initial_steps) for _ in range(config.initial_steps): logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images) logits.block_until_ready() logging.info('Going to run %d inferences measuring...', config.steps) times = [] for _ in range(config.initial_steps): t0 = time.time() logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images) logits.block_until_ready() times.append(time.time() - t0) logging.info('times=%s', times) imgs_sec_core = config.batch / jax.local_device_count() / np.array(times) logging.info('imgs_sec_core_min=%f', imgs_sec_core.min()) logging.info('imgs_sec_core_max=%f', imgs_sec_core.max()) logging.info('imgs_sec_core_mean=%f', imgs_sec_core.mean()) logging.info('imgs_sec_core_std=%f', imgs_sec_core.std()) writer.write_scalars( 0, dict( imgs_sec_core_min=imgs_sec_core.min(), imgs_sec_core_max=imgs_sec_core.max(), imgs_sec_core_mean=imgs_sec_core.mean(), imgs_sec_core_std=imgs_sec_core.std(), ))
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> TrainState: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: The final train state that includes the trained parameters. """ # Prepare datasets. train_dataset = input_pipeline.TextDataset(tfds_name='glue/sst2', split='train') eval_dataset = input_pipeline.TextDataset(tfds_name='glue/sst2', split='validation') train_batches = train_dataset.get_bucketed_batches( config.batch_size, config.bucket_size, max_input_length=config.max_input_length, drop_remainder=True, shuffle=True, shuffle_seed=config.seed) eval_batches = eval_dataset.get_batches(batch_size=config.batch_size) # Keep track of vocab size in the config so that the embedder knows it. config.vocab_size = len(train_dataset.vocab) # Compile step functions. train_step_fn = jax.jit(train_step) eval_step_fn = jax.jit(eval_step) # Create model and a state that contains the parameters. rng = jax.random.PRNGKey(config.seed) model = model_from_config(config) state = create_train_state(rng, config, model) summary_writer = tensorboard.SummaryWriter(workdir) summary_writer.hparams(dict(config)) # Main training loop. logging.info('Starting training...') for epoch in range(1, config.num_epochs + 1): # Train for one epoch. rng, epoch_rng = jax.random.split(rng) rngs = {'dropout': epoch_rng} state, train_metrics = train_epoch(train_step_fn, state, train_batches, epoch, rngs) # Evaluate current model on the validation data. eval_metrics = evaluate_model(eval_step_fn, state, eval_batches, epoch) # Write metrics to TensorBoard. summary_writer.scalar('train_loss', train_metrics.loss, epoch) summary_writer.scalar('train_accuracy', train_metrics.accuracy * 100, epoch) summary_writer.scalar('eval_loss', eval_metrics.loss, epoch) summary_writer.scalar('eval_accuracy', eval_metrics.accuracy * 100, epoch) summary_writer.flush() return state
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs training interleaved with evaluation.""" # Setup input pipeline dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train') ds_train, ds_test = input_pipeline.get_datasets(config) batch = next(iter(ds_train)) logging.info(ds_train) logging.info(ds_test) # Build VisionTransformer architecture model_cls = {'ViT': models.VisionTransformer, 'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')] model = model_cls(num_classes=dataset_info['num_classes'], **config.model) def init_model(): return model.init( jax.random.PRNGKey(0), # Discard the "num_local_devices" dimension for initialization. jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name), train=False) # Use JIT to make sure params reside in CPU memory. variables = jax.jit(init_model, backend='cpu')() model_or_filename = config.get('model_or_filename') if model_or_filename: # Loading model from repo published with "How to train your ViT? Data, # Augmentation, and Regularization in Vision Transformers" paper. # https://arxiv.org/abs/2106.10270 if '-' in model_or_filename: filename = model_or_filename else: # Select best checkpoint from i21k pretraining by final upstream # validation accuracy. df = checkpoint.get_augreg_df(directory=config.pretrained_dir) sel = df.filename.apply( lambda filename: filename.split('-')[0] == model_or_filename) best = df.loc[sel].query('ds=="i21k"').sort_values('final_val').iloc[-1] filename = best.filename logging.info('Selected fillename="%s" for "%s" with final_val=%.3f', filename, model_or_filename, best.final_val) pretrained_path = os.path.join(config.pretrained_dir, f'{config.model.name}.npz') else: # ViT / Mixer papers filename = config.model.name pretrained_path = os.path.join(config.pretrained_dir, f'{filename}.npz') if not tf.io.gfile.exists(pretrained_path): raise ValueError( f'Could not find "{pretrained_path}" - you can download models from ' '"gs://vit_models/imagenet21k" or directly set ' '--config.pretrained_dir="gs://vit_models/imagenet21k".') params = checkpoint.load_pretrained( pretrained_path=pretrained_path, init_params=variables['params'], model_config=config.model) total_steps = config.total_steps lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr, config.decay_type, config.warmup_steps) update_fn_repl = make_update_fn( apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn) infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False)) # Create optimizer and replicate it over all TPUs/GPUs opt = momentum_clip.Optimizer( dtype=config.optim_dtype, grad_norm_clip=config.grad_norm_clip).create(params) initial_step = 1 opt, initial_step = flax_checkpoints.restore_checkpoint( workdir, (opt, initial_step)) logging.info('Will start/continue training at initial_step=%d', initial_step) opt_repl = flax.jax_utils.replicate(opt) # Delete references to the objects that are not needed anymore del opt del params # Prepare the learning-rate and pre-fetch it to device to avoid delays. update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0)) # Setup metric writer & hooks. writer = metric_writers.create_default_writer(workdir, asynchronous=False) writer.write_hparams(config.to_dict()) hooks = [ periodic_actions.Profile(logdir=workdir), periodic_actions.ReportProgress( num_train_steps=total_steps, writer=writer), ] # Run training loop logging.info('Starting training loop; initial compile can take a while...') t0 = lt0 = time.time() lstep = initial_step for step, batch in zip( range(initial_step, total_steps + 1), input_pipeline.prefetch(ds_train, config.prefetch)): with jax.profiler.StepTraceContext('train', step_num=step): opt_repl, loss_repl, update_rng_repl = update_fn_repl( opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl) for hook in hooks: hook(step) if step == initial_step: logging.info('First step took %.1f seconds.', time.time() - t0) t0 = time.time() lt0, lstep = time.time(), step # Report training metrics if config.progress_every and step % config.progress_every == 0: img_sec_core_train = (config.batch * (step - lstep) / (time.time() - lt0)) / jax.device_count() lt0, lstep = time.time(), step writer.write_scalars( step, dict( train_loss=float(flax.jax_utils.unreplicate(loss_repl)), img_sec_core_train=img_sec_core_train)) done = step / total_steps logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' # pylint: disable=logging-format-interpolation f'img/sec/core: {img_sec_core_train:.1f}, ' f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h') # Run evaluation if ((config.eval_every and step % config.eval_every == 0) or (step == total_steps)): accuracies = [] lt0 = time.time() for test_batch in input_pipeline.prefetch(ds_test, config.prefetch): logits = infer_fn_repl( dict(params=opt_repl.target), test_batch['image']) accuracies.append( (np.argmax(logits, axis=-1) == np.argmax(test_batch['label'], axis=-1)).mean()) accuracy_test = np.mean(accuracies) img_sec_core_test = ( config.batch_eval * ds_test.cardinality().numpy() / (time.time() - lt0) / jax.device_count()) lt0 = time.time() lr = float(lr_fn(step)) logging.info(f'Step: {step} ' # pylint: disable=logging-format-interpolation f'Learning rate: {lr:.7f}, ' f'Test accuracy: {accuracy_test:0.5f}, ' f'img/sec/core: {img_sec_core_test:.1f}') writer.write_scalars( step, dict( accuracy_test=accuracy_test, lr=lr, img_sec_core_test=img_sec_core_test)) # Store checkpoint. if ((config.checkpoint_every and step % config.eval_every == 0) or step == total_steps): checkpoint_path = flax_checkpoints.save_checkpoint( workdir, (flax.jax_utils.unreplicate(opt_repl), step), step) logging.info('Stored checkpoint at step %d to "%s"', step, checkpoint_path) return flax.jax_utils.unreplicate(opt_repl)
def generate(config: ml_collections.ConfigDict): """Generates memories.""" # Establish host information local_device_count = jax.local_device_count() device_count = jax.device_count() process_count = jax.process_count() process_index = jax.process_index() task = memory_generation_task.MemoryGenerationTask model_config = ml_collections.FrozenConfigDict(config.model_config) model = task.build_model(model_config) p_predict_step = jax.pmap(functools.partial( task.make_prediction_fn(config), model_config, ), axis_name='batch') rng = jax.random.PRNGKey(config.seed) # Initialization needs to be pmapped because models use collective ops. # Create dummy input dummy_input = { key: jnp.tile(value, (local_device_count, ) + (1, ) * value.ndim) for key, value in task.dummy_input(config).items() } rng, init_rng = jax.random.split(rng) init_rng = jax.random.split(init_rng, local_device_count) logging.info('Initializing model.') initial_variables = jax.pmap(model.init, 'batch', static_broadcasted_argnums=2)(init_rng, dummy_input, True) logging.info('Finished initializing model.') initial_variables = initial_variables.unfreeze() if config.load_weights is not None: logging.info('Loading model weights from file') loaded_variables = task.load_weights(config) unexpected, missing = checkpoint_utils.merge_nested_dicts( initial_variables, loaded_variables) logging.info('*** Unexpected features: ***') for feature_name in unexpected: logging.info('\t%s', feature_name) # In the prediction mode we don't allow any features to be missing # pylint: disable=g-explicit-length-test if len(missing) > 0: raise ValueError('Missing features: %s' % ','.join(missing)) # model_params = jax_utils.unreplicate(initial_variables['params']) model_params = initial_variables['params'] model_vars = { key: value for key, value in initial_variables.items() if key != 'params' } # We access model params only from train state. del initial_variables writer = metric_writers.create_default_writer( config.output_dir, just_logging=process_index > 0) max_length = config.get('max_length_with_entity_tokens', model_config.encoder_config.max_length) num_total_memories = math.ceil(config.num_total_memories / process_count) memory_saver = memory_generation_task.MemorySaver( num_total_memories=num_total_memories, memory_dim=config.memory_dim, max_length=max_length, max_mentions_per_sample=config.max_mentions_per_sample, memory_key_dim=config.get('memory_key_dim')) n_samples = 0 data_iter = get_data_iterator(config) logging.info('Start memory generation.') with metric_writers.ensure_flushes(writer): for step, batch in enumerate(data_iter): batch = jax.tree_map(jnp.asarray, batch) predictions = p_predict_step( model_params, model_vars, batch, ) predictions = jax.device_get(predictions) memory_saver.add_memories(batch, predictions) n_devices, batch_size, _ = batch['text_ids'].shape logging.log_first_n( logging.INFO, 'Process %d / %d: ' 'Finished generating step %d, local devices %d, batch size %d', 5, process_index, process_count, step, n_devices, batch_size) n_samples += device_count * config.per_device_batch_size if (step % config.log_every_steps == 0 or memory_saver.get_num_memories() >= num_total_memories): writer.write_scalars( step, dict(n_memories=memory_saver.get_num_memories(), n_samples=n_samples)) if memory_saver.get_num_memories() >= num_total_memories: break logging.info('Process %d / %d: Finished generating memories: %d out of %d', process_index, process_count, memory_saver.get_num_memories(), num_total_memories) start_time = time.time() logging.info('Process %d / %d: Start saving generated memories to files.', process_index, process_count) memory_saver.save(config.output_dir, num_shards=config.num_shards, stride=process_count, offset=process_index, shard_size_divisible=config.shard_size_divisible) logging.info( 'Process %d / %d: Finished saving generated memories to files in %.2f seconds', process_index, process_count, time.time() - start_time)
def _get_augment_element_fn( dataset_config: ml_collections.ConfigDict ) -> Callable[[TensorDict, TensorDict, TensorDict], TensorDictTriple]: """"Returns a function that augments the `IMAGE_KEY` input tensor. The following transformations are applied if specified in `dataset_config`: - tf.image.random_flip_left_right - tf.image.random_flip_up_down - tf.image.random_brightness - tf.image.random_hue - tf.image.random_saturation - tf.image.random_contrast Important: The returned function requires that image tensors have dtype tf.float32 and have values in range [0, 1]. Augmented images are then clipped back to the [0, 1] range. Args: dataset_config: A ConfigDict used to build the set of applied augmentations. Returns: A function that applies the set of augmentations to an input TensorDict's `IMAGE_KEY` image. """ horizontal_flip = dataset_config.get('random_horizontal_flip', False) vertical_flip = dataset_config.get('random_vertical_flip', False) brightness_max_delta = dataset_config.get('random_brightness_max_delta', None) hue_max_delta = dataset_config.get('random_hue_max_delta', None) saturation_lower = dataset_config.get('random_saturation_lower', None) saturation_upper = dataset_config.get('random_saturation_upper', None) apply_saturation = saturation_lower and saturation_upper if apply_saturation and (saturation_upper <= saturation_lower): raise ValueError( f'Invalid saturation range: ({saturation_lower}, {saturation_upper})' ) contrast_lower = dataset_config.get('random_contrast_lower', None) contrast_upper = dataset_config.get('random_contrast_upper', None) apply_contrast = contrast_lower and contrast_upper if apply_contrast and (contrast_upper <= contrast_lower): raise ValueError( f'Invalid contrast range: ({contrast_lower}, {contrast_upper})') def _augment_element_fn(inputs: TensorDict, labels: TensorDict, weights: TensorDict) -> TensorDictTriple: image = inputs[IMAGE_KEY] # Ensure images are in the expected format. Image augmentations assume that # the image tensor is a tf.float32 and contains pixels in range [0, 1]. tf.debugging.assert_type(image, tf.float32) tf.debugging.assert_less_equal(tf.math.reduce_max(image), 1.0) tf.debugging.assert_greater_equal(tf.math.reduce_min(image), 0.0) if horizontal_flip: image = tf.image.random_flip_left_right(image) if vertical_flip: image = tf.image.random_flip_up_down(image) if brightness_max_delta: image = tf.image.random_brightness(image, max_delta=brightness_max_delta) if hue_max_delta: image = tf.image.random_hue(image, max_delta=hue_max_delta) if apply_saturation: image = tf.image.random_saturation(image, lower=saturation_lower, upper=saturation_upper) if apply_contrast: image = tf.image.random_contrast(image, lower=contrast_lower, upper=contrast_upper) # Clip image back to [0.0, 1.0] prior to architecture-specific centering. image = tf.clip_by_value(image, 0.0, 1.0) inputs[IMAGE_KEY] = image return inputs, labels, weights return _augment_element_fn
) parser.add_argument("--print_batch_metrics", action='store_true', default=False, help="Set to print metrics for every batch.") args = parser.parse_args() assert ( args.checkpoint is not None ), "A checkpoint needs to be specified via commandline argument (--checkpoint)" assert ( args.config is not None ), "A config needs to be specified via commandline argument (--config)" with open(args.config) as f: cfg = ConfigDict(yaml.load(f, Loader=yaml.Loader)) cfg.checkpoint = args.checkpoint if args.annotations is not None: cfg.test_annotations = args.annotations if args.imagedir is not None: cfg.test_imagedir = args.imagedir if args.uncertainty_threshold is not None: cfg.uncertainty_threshold = args.uncertainty_threshold if args.uncertainty_gate_type is not None: cfg.uncertainty_gate_type = args.uncertainty_gate_type if args.weighted_prediction is not None: cfg.weighted_prediction = args.weighted_prediction assert (