Ejemplo n.º 1
0
 def __init__(self, input_shape, custom_bottleneck_size = None):
     self.input_shape = input_shape
     self.custom_bottleneck_size = custom_bottleneck_size
     #self.cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE, label_smoothing = 0.1)
     self.cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
     self.mse_loss = keras.losses.MeanSquaredError(reduction = tf.keras.losses.Reduction.NONE)
     self.learning_rate = 1e-4
     self.generator_optimizer = keras.optimizers.Adam(self.learning_rate, epsilon=1e-04)
     # self.generator_optimizer = keras.optimizers.SGD(self.learning_rate)
     self.generator_optimizer = mixed_precision.LossScaleOptimizer(self.generator_optimizer)
     self.discriminator_optimizer = keras.optimizers.Adam(self.learning_rate, epsilon=1e-04)
     # self.discriminator_optimizer = keras.optimizers.SGD(self.learning_rate)
     self.discriminator_optimizer = mixed_precision.LossScaleOptimizer(self.discriminator_optimizer)
     # self.encoder_optimizer = keras.optimizers.Adam(self.learning_rate)
     # self.encoder_optimizer = mixed_precision.LossScaleOptimizer(self.encoder_optimizer)
     self.batch_size = 16
     self.epochs = 100
     self.loss_weight = 0.2
     self.createNetwork()
     self.model = CVAE(input_shape)
     self.checkpoint_dir = './training_checkpoints'
     self.checkpoint_prefix = os.path.join(self.checkpoint_dir, "ckpt" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
     self.checkpoint = tf.train.Checkpoint(generator_optimizer = self.generator_optimizer, 
                                         discriminator_optimizer = self.discriminator_optimizer,
                                         # encoder_optimizer = self.encoder_optimizer, 
                                         encoder = self.model.encoder,
                                         decoder = self.model.decoder, 
                                         discriminator = self.discriminator)
Ejemplo n.º 2
0
 def __init__(self,
              name,
              models,
              lr,
              clip_norm=None,
              weight_decay=None,
              l2_reg=None,
              wdpattern=r'.*',
              scales=None,
              return_grads=False,
              **kwargs):
     self._models = models if isinstance(models,
                                         (list, tuple)) else [models]
     self._clip_norm = clip_norm
     self._weight_decay = weight_decay
     self._l2_reg = l2_reg
     self._wdpattern = wdpattern
     if scales is not None:
         assert isinstance(scales, (list, tuple)), scales
         assert len(scales) == len(self._models), (len(scales),
                                                   len(self._models))
     self._scales = scales
     self._opt = select_optimizer(name)(lr, **kwargs)
     self._return_grads = return_grads
     # useful for mixed precision training on GPUs to
     # avoid numerical underflow caused by using float16 gradients
     prec_policy = prec.global_policy()
     self._mpt = prec_policy.compute_dtype != prec_policy.variable_dtype
     if self._mpt:
         logger.info('Mixed precision training will be performed')
         self._opt = prec.LossScaleOptimizer(self._opt)
     # we do not initialize variables here, as models may not be initialized at this point
     self._variables = None
Ejemplo n.º 3
0
def _optimizer_fn_to_optimizer(
    optimizer_fn: Union[Callable, None], model: Model, framework: str,
    mixed_precision: bool
) -> Union[None, tf.optimizers.Optimizer, torch.optim.Optimizer]:
    """A helper function to invoke an optimizer function.

    Args:
        optimizer_fn: The function to be invoked in order to instantiate an optimizer.
        model: The model with which the optimizer should be associated.
        framework: Which backend framework should be used ('tf' or 'torch').
        mixed_precision: Whether to enable mixed-precision training.

    Returns:
        An optimizer instance, or None if `optimizer_fn` was None.
    """
    optimizer = None
    if optimizer_fn:
        if framework == "tf":
            try:
                optimizer = optimizer_fn()
            except:
                raise AssertionError(
                    "optimizer_fn of Tensorflow backend should be callable without args. Please "
                    "make sure model and optimizer_fn are using the same backend"
                )
            # initialize optimizer variables
            _ = optimizer.iterations
            optimizer._create_hypers()
            optimizer._create_slots(model.trainable_variables)
            # handle mixed precision loss scaling
            if mixed_precision:
                optimizer = mixed_precision_tf.LossScaleOptimizer(optimizer)
            assert isinstance(
                optimizer, tf.optimizers.Optimizer
            ), "optimizer_fn should generate tensorflow optimizer"
        else:
            try:
                optimizer = optimizer_fn(model.parameters())
            except Exception as e:
                print(
                    "optimizer_fn of Pytorch backend should be callable with single arg. Please sure model and \
                optimizer_fn are using the same backend")
                raise ValueError(repr(e))
            assert isinstance(
                optimizer, torch.optim.Optimizer
            ), "optimizer_fn should generate pytorch optimizer"
            if mixed_precision:
                setattr(optimizer, "scaler", torch.cuda.amp.GradScaler())
            else:
                setattr(optimizer, "scaler", None)

    return optimizer
Ejemplo n.º 4
0
def evaluate(config, train_dir, weights, evaluation_dir):
    """Evaluate the trained model in train_dir"""
    if config is None:
        config = Path(train_dir) / "config.yaml"
        assert config.exists(
        ), "Could not find config file in train_dir, please provide one with -c <path/to/config>"
    config, _ = parse_config(config, weights=weights)

    if evaluation_dir is None:
        eval_dir = str(Path(train_dir) / "evaluation")
    else:
        eval_dir = evaluation_dir

    Path(eval_dir).mkdir(parents=True, exist_ok=True)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
        opt = mixed_precision.LossScaleOptimizer(opt)
    else:
        model_dtype = tf.dtypes.float32

    strategy, num_gpus = get_strategy()
    ds_test, _ = get_heptfds_dataset(config["validation_dataset"], config,
                                     num_gpus, "test")
    ds_test = ds_test.batch(5)

    model = make_model(config, model_dtype)
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    # need to load the weights in the same trainable configuration as the model was set up
    configure_model_weights(model, config["setup"].get("weights_config",
                                                       "all"))
    if weights:
        model.load_weights(weights, by_name=True)
    else:
        weights = get_best_checkpoint(train_dir)
        print(
            "Loading best weights that could be found from {}".format(weights))
        model.load_weights(weights, by_name=True)

    eval_model(model, ds_test, config, eval_dir)
    freeze_model(model, config, ds_test.take(1), train_dir)
Ejemplo n.º 5
0
def run_training(args):
    it_network = ImageTransformNet(
        input_shape=hparams['input_size'],
        residual_layers=hparams['residual_layers'],
        residual_filters=hparams['residual_filters'],
        initializer=hparams['initializer'])
    loss_network = LossNetwork(hparams['style_layers'])

    optimizer = tf.keras.optimizers.Adam(
        learning_rate=hparams['learning_rate'])
    optimizer = mixed_precision.LossScaleOptimizer(optimizer)

    ckpt_dir = os.path.join(args.name, 'pretrained')
    ckpt = tf.train.Checkpoint(network=it_network,
                               optimizer=optimizer,
                               step=tf.Variable(0))
    ckpt_manager = tf.train.CheckpointManager(
        ckpt, directory=ckpt_dir, max_to_keep=args.max_ckpt_to_keep)

    ckpt.restore(ckpt_manager.latest_checkpoint)
    log_dir = os.path.join(args.name, 'log_dir')
    writer = tf.summary.create_file_writer(logdir=log_dir)

    print('\n####################################################')
    print('Perceptual Losses for Real-Time Style Transfer Train')
    print('####################################################\n')
    if ckpt_manager.latest_checkpoint:
        print('Restored {} from: {}'.format(args.name,
                                            ckpt_manager.latest_checkpoint))
    else:
        print('Initializing {} from scratch'.format(args.name))
        save_hparams(args.name)
    print('Style image: {}'.format(args.style_img))
    print('Start TensorBoard with: $ tensorboard --logdir ./\n')

    total_loss_avg = tf.keras.metrics.Mean()
    style_loss_avg = tf.keras.metrics.Mean()
    content_loss_avg = tf.keras.metrics.Mean()

    save_hparams(args.name)

    style_img = convert(args.style_img)
    target_feature_maps = loss_network(style_img[tf.newaxis, :])
    target_gram_matrices = [gram_matrix(x) for x in target_feature_maps]
    num_style_layers = len(target_feature_maps)

    dataset = create_ds(args)
    test_content_batch = create_test_batch(args)

    @tf.function
    def test_step(batch):
        prediction = it_network(batch, training=False)
        #prediction_norm = np.array(tf.clip_by_value(prediction, 0, 1)*255, dtype=np.uint8) # Poor quality, no convergence
        #prediction_norm = np.array(tf.clip_by_value(prediction, 0, 255), dtype=np.uint8)
        return deprocess(prediction)

    @tf.function
    def train_step(batch):
        with tf.GradientTape() as tape:
            output_batch = it_network(batch, training=True)
            output_batch = 255 * (output_batch + 1.0) / 2.0  # float deprocess

            # Feed target and output batch through loss_network
            target_batch_feature_maps = loss_network(batch)
            output_batch_feature_maps = loss_network(output_batch)

            c_loss = content_loss(
                target_batch_feature_maps[hparams['content_layer_index']],
                output_batch_feature_maps[hparams['content_layer_index']])
            c_loss *= hparams['content_weight']

            # Get output gram_matrix
            output_gram_matrices = [
                gram_matrix(x) for x in output_batch_feature_maps
            ]
            s_loss = style_loss(target_gram_matrices, output_gram_matrices)
            s_loss *= hparams['style_weight'] / num_style_layers

            total_loss = c_loss + s_loss
            scaled_loss = optimizer.get_scaled_loss(total_loss)

        scaled_gradients = tape.gradient(scaled_loss,
                                         it_network.trainable_variables)
        gradients = optimizer.get_unscaled_gradients(scaled_gradients)
        #gradients = tape.gradient(total_loss, it_network.trainable_variables)
        optimizer.apply_gradients(
            zip(gradients, it_network.trainable_variables))

        total_loss_avg(total_loss)
        content_loss_avg(c_loss)
        style_loss_avg(s_loss)

    total_start = time.time()
    for batch_image in dataset:
        start = time.time()
        train_step(batch_image)

        ckpt.step.assign_add(1)
        step_int = int(ckpt.step)  # cast ckpt.step

        if (step_int) % args.ckpt_interval == 0:
            print('Time taken for step {} is {} sec'.format(
                step_int,
                time.time() - start))
            ckpt_manager.save(step_int)
            prediction_norm = test_step(test_content_batch)

            with writer.as_default():
                tf.summary.scalar('total loss',
                                  total_loss_avg.result(),
                                  step=step_int)
                tf.summary.scalar('content loss',
                                  content_loss_avg.result(),
                                  step=step_int)
                tf.summary.scalar('style loss',
                                  style_loss_avg.result(),
                                  step=step_int)
                images = np.reshape(prediction_norm,
                                    (-1, hparams['input_size'][0],
                                     hparams['input_size'][1], 3))
                tf.summary.image('generated image',
                                 images,
                                 step=step_int,
                                 max_outputs=len(test_content_batch))

            print('Total loss: {:.4f}'.format(total_loss_avg.result()))
            print('Content loss: {:.4f}'.format(content_loss_avg.result()))
            print('Style loss: {:.4f}'.format(style_loss_avg.result()))
            print('Total time: {} sec\n'.format(time.time() - total_start))
            total_loss_avg.reset_states()
            content_loss_avg.reset_states()
            style_loss_avg.reset_states()
Ejemplo n.º 6
0
    def __init__(self,
                 args,
                 actor,
                 dl,
                 encoder=None,
                 planner=None,
                 cnn=None,
                 optimizer=Adam(),
                 strategy=None,
                 global_batch_size=32):

        self.actor = actor
        self.encoder = encoder
        self.planner = planner
        self.cnn = cnn
        self.strategy = strategy
        self.args = args
        self.dl = dl
        self.global_batch_size = global_batch_size

        if args.fp16:
            optimizer = mixed_precision.LossScaleOptimizer(optimizer)

        if self.args.num_distribs is None:  # different sized clips due to different sized losses
            actor_clip = 0.06
            encoder_clip = 0.03
            planner_clip = 0.001
        else:
            actor_clip = 400
            encoder_clip = 5
            planner_clip = 0.4

        self.actor_optimizer = optimizer(learning_rate=args.learning_rate,
                                         clipnorm=actor_clip)
        self.encoder_optimizer = optimizer(learning_rate=args.learning_rate,
                                           clipnorm=encoder_clip)
        self.planner_optimizer = optimizer(learning_rate=args.learning_rate,
                                           clipnorm=planner_clip)

        self.nll_action_loss = lambda y, p_y: tf.reduce_sum(-p_y.log_prob(y),
                                                            axis=2)
        self.mae_action_loss = tf.keras.losses.MeanAbsoluteError(
            reduction=tf.keras.losses.Reduction.NONE)
        self.mse_action_loss = tf.keras.losses.MeanSquaredError(
            reduction=tf.keras.losses.Reduction.NONE)

        self.metrics = {}
        self.metrics['train_loss'] = tf.keras.metrics.Mean(name='train_loss')
        self.metrics['valid_loss'] = tf.keras.metrics.Mean(name='valid_loss')
        self.metrics['actor_grad_norm'] = tf.keras.metrics.Mean(
            name='actor_grad_norm')
        self.metrics['encoder_grad_norm'] = tf.keras.metrics.Mean(
            name='encoder_grad_norm')
        self.metrics['planner_grad_norm'] = tf.keras.metrics.Mean(
            name='planner_grad_norm')

        self.metrics['global_grad_norm'] = tf.keras.metrics.Mean(
            name='global_grad_norm')

        self.metrics['train_act_with_enc_loss'] = tf.keras.metrics.Mean(
            name='train_act_with_enc_loss')
        self.metrics['train_act_with_plan_loss'] = tf.keras.metrics.Mean(
            name='train_act_with_plan_loss')
        self.metrics['valid_act_with_enc_loss'] = tf.keras.metrics.Mean(
            name='valid_act_with_enc_loss')
        self.metrics['valid_act_with_plan_loss'] = tf.keras.metrics.Mean(
            name='valid_act_with_plan_loss')

        self.metrics['train_reg_loss'] = tf.keras.metrics.Mean(name='reg_loss')
        self.metrics['valid_reg_loss'] = tf.keras.metrics.Mean(
            name='valid_reg_loss')

        self.metrics['valid_position_loss'] = tf.keras.metrics.Mean(
            name='valid_position_loss')
        self.metrics['valid_max_position_loss'] = lfp.metric.MaxMetric(
            name='valid_max_position_loss')
        self.metrics['valid_rotation_loss'] = tf.keras.metrics.Mean(
            name='valid_rotation_loss')
        self.metrics['valid_max_rotation_loss'] = lfp.metric.MaxMetric(
            name='valid_max_rotation_loss')
        self.metrics['valid_gripper_loss'] = tf.keras.metrics.Mean(
            name='valid_rotation_loss')

        self.chkpt_manager = None
Ejemplo n.º 7
0
def main(args):

    print(args)

    if args.push_to_hub:
        login_to_hub()

    if not isinstance(args.workers, int):
        args.workers = min(16, mp.cpu_count())

    vocab = VOCABS[args.vocab]

    fonts = args.font.split(",")

    # AMP
    if args.amp:
        mixed_precision.set_global_policy("mixed_float16")

    # Load val data generator
    st = time.time()
    val_set = CharacterGenerator(
        vocab=vocab,
        num_samples=args.val_samples * len(vocab),
        cache_samples=True,
        img_transforms=T.Compose(
            [
                T.Resize((args.input_size, args.input_size)),
                # Ensure we have a 90% split of white-background images
                T.RandomApply(T.ColorInversion(), 0.9),
            ]
        ),
        font_family=fonts,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.workers,
        collate_fn=collate_fn,
    )
    print(
        f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
        f"{val_loader.num_batches} batches)"
    )

    # Load doctr model
    model = classification.__dict__[args.arch](
        pretrained=args.pretrained,
        input_shape=(args.input_size, args.input_size, 3),
        num_classes=len(vocab),
        classes=list(vocab),
        include_top=True,
    )

    # Resume weights
    if isinstance(args.resume, str):
        model.load_weights(args.resume)

    batch_transforms = T.Compose(
        [
            T.Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)),
        ]
    )

    if args.test_only:
        print("Running evaluation")
        val_loss, acc = evaluate(model, val_loader, batch_transforms)
        print(f"Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
        return

    st = time.time()

    # Load train data generator
    train_set = CharacterGenerator(
        vocab=vocab,
        num_samples=args.train_samples * len(vocab),
        cache_samples=True,
        img_transforms=T.Compose(
            [
                T.Resize((args.input_size, args.input_size)),
                # Augmentations
                T.RandomApply(T.ColorInversion(), 0.9),
                T.RandomApply(T.ToGray(3), 0.1),
                T.RandomJpegQuality(60),
                T.RandomSaturation(0.3),
                T.RandomContrast(0.3),
                T.RandomBrightness(0.3),
                # Blur
                T.RandomApply(T.GaussianBlur(kernel_shape=(3, 3), std=(0.1, 3)), 0.3),
            ]
        ),
        font_family=fonts,
    )
    train_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
        collate_fn=collate_fn,
    )
    print(
        f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
        f"{train_loader.num_batches} batches)"
    )

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, list(map(vocab.__getitem__, target)))
        return

    # Optimizer
    scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
        args.lr,
        decay_steps=args.epochs * len(train_loader),
        decay_rate=1 / (1e3),  # final lr as a fraction of initial lr
        staircase=False,
    )
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=scheduler,
        beta_1=0.95,
        beta_2=0.99,
        epsilon=1e-6,
    )
    if args.amp:
        optimizer = mixed_precision.LossScaleOptimizer(optimizer)

    # LR Finder
    if args.find_lr:
        lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp)
        plot_recorder(lrs, losses)
        return

    # Tensorboard to monitor training
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(
            name=exp_name,
            project="character-classification",
            config={
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "weight_decay": 0.0,
                "batch_size": args.batch_size,
                "architecture": args.arch,
                "input_size": args.input_size,
                "optimizer": "adam",
                "framework": "tensorflow",
                "vocab": args.vocab,
                "scheduler": "exp_decay",
                "pretrained": args.pretrained,
            },
        )

    # Create loss queue
    min_loss = np.inf

    # Training loop
    mb = master_bar(range(args.epochs))
    for epoch in mb:
        fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, args.amp)

        # Validation loop at the end of each epoch
        val_loss, acc = evaluate(model, val_loader, batch_transforms)
        if val_loss < min_loss:
            print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
            model.save_weights(f"./{exp_name}/weights")
            min_loss = val_loss
        mb.write(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
        # W&B
        if args.wb:
            wandb.log(
                {
                    "val_loss": val_loss,
                    "acc": acc,
                }
            )

    if args.wb:
        run.finish()

    if args.push_to_hub:
        push_to_hf_hub(model, exp_name, task="classification", run_config=args)

    if args.export_onnx:
        print("Exporting model to ONNX...")
        dummy_input = [tf.TensorSpec([None, args.input_size, args.input_size, 3], tf.float32, name="input")]
        model_path, _ = export_model_to_onnx(model, exp_name, dummy_input)
        print(f"Exported model saved in {model_path}")
Ejemplo n.º 8
0
def find_lr(config, outdir, figname, logscale):
    """Run the Learning Rate Finder to produce a batch loss vs. LR plot from
    which an appropriate LR-range can be determined"""
    config, _ = parse_config(config)

    # Decide tf.distribute.strategy depending on number of available GPUs
    strategy, num_gpus = get_strategy()

    ds_train, num_train_steps = get_datasets(config["train_test_datasets"],
                                             config, num_gpus, "train")

    with strategy.scope():
        opt = tf.keras.optimizers.Adam(
            learning_rate=1e-7
        )  # This learning rate will be changed by the lr_finder
        if config["setup"]["dtype"] == "float16":
            model_dtype = tf.dtypes.float16
            policy = mixed_precision.Policy("mixed_float16")
            mixed_precision.set_global_policy(policy)
            opt = mixed_precision.LossScaleOptimizer(opt)
        else:
            model_dtype = tf.dtypes.float32

        model = make_model(config, model_dtype)
        config = set_config_loss(config, config["setup"]["trainable"])

        # Run model once to build the layers
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        configure_model_weights(model, config["setup"]["trainable"])

        loss_dict, loss_weights = get_loss_dict(config)
        model.compile(
            loss=loss_dict,
            optimizer=opt,
            sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ]
            },
        )
        model.summary()

        max_steps = 200
        lr_finder = LRFinder(max_steps=max_steps)
        callbacks = [lr_finder]

        model.fit(
            ds_train.repeat(),
            epochs=max_steps,
            callbacks=callbacks,
            steps_per_epoch=1,
        )

        lr_finder.plot(save_dir=outdir, figname=figname, log_scale=logscale)
Ejemplo n.º 9
0
def model_scope(config, total_steps, weights, horovod_enabled=False):
    lr_schedule, optim_callbacks, lr = get_lr_schedule(config,
                                                       steps=total_steps)
    opt = get_optimizer(config, lr_schedule)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
        opt = mixed_precision.LossScaleOptimizer(opt)
    else:
        model_dtype = tf.dtypes.float32

    model = make_model(config, model_dtype)

    # Build the layers after the element and feature dimensions are specified
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    initial_epoch = 0
    loaded_opt = None

    if weights:
        if lr_schedule:
            raise Exception(
                "Restoring the optimizer state with a learning rate schedule is currently not supported"
            )

        # We need to load the weights in the same trainable configuration as the model was set up
        configure_model_weights(model,
                                config["setup"].get("weights_config", "all"))
        model.load_weights(weights, by_name=True)
        opt_weight_file = weights.replace("hdf5",
                                          "pkl").replace("/weights-", "/opt-")
        if os.path.isfile(opt_weight_file):
            loaded_opt = pickle.load(open(opt_weight_file, "rb"))

        initial_epoch = int(weights.split("/")[-1].split("-")[1])
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    config = set_config_loss(config, config["setup"]["trainable"])
    configure_model_weights(model, config["setup"]["trainable"])
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    print("model weights")
    tw_names = [m.name for m in model.trainable_weights]
    for w in model.weights:
        print("layer={} trainable={} shape={} num_weights={}".format(
            w.name, w.name in tw_names, w.shape, np.prod(w.shape)))

    loss_dict, loss_weights = get_loss_dict(config)

    model.compile(
        loss=loss_dict,
        optimizer=opt,
        sample_weight_mode="temporal",
        loss_weights=loss_weights,
        metrics={
            "cls": [
                FlattenedCategoricalAccuracy(name="acc_unweighted",
                                             dtype=tf.float64),
                FlattenedCategoricalAccuracy(
                    use_weights=True, name="acc_weighted", dtype=tf.float64),
            ] + [
                SingleClassRecall(
                    icls, name="rec_cls{}".format(icls), dtype=tf.float64)
                for icls in range(config["dataset"]["num_output_classes"])
            ]
        },
    )

    model.summary()

    # Set the optimizer weights
    if loaded_opt:

        def model_weight_setting():
            grad_vars = model.trainable_weights
            zero_grads = [tf.zeros_like(w) for w in grad_vars]
            model.optimizer.apply_gradients(zip(zero_grads, grad_vars))
            if model.optimizer.__class__.__module__ == "keras.optimizers.optimizer_v1":
                model.optimizer.optimizer.optimizer.set_weights(
                    loaded_opt["weights"])
            else:
                model.optimizer.set_weights(loaded_opt["weights"])

        # FIXME: check that this still works with multiple GPUs
        strategy = tf.distribute.get_strategy()
        strategy.run(model_weight_setting)

    return model, optim_callbacks, initial_epoch
Ejemplo n.º 10
0
def train():

    parser = config_parser()
    args = parser.parse_args()

    if args.random_seed is not None:
        print('Fixing random seed', args.random_seed)
        np.random.seed(args.random_seed)
        tf.compat.v1.set_random_seed(args.random_seed)

    # Load data

    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
                                                                  recenter=True, bd_factor=.75,
                                                                  spherify=args.spherify)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        print('Loaded llff', images.shape,
              render_poses.shape, hwf, args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]

        i_val = i_test
        i_train = np.array([i for i in np.arange(int(images.shape[0])) if
                            (i not in i_test and i not in i_val)])

        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = tf.reduce_min(bds) * .9
            far = tf.reduce_max(bds) * 1.
        else:
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)

    elif args.dataset_type == 'blender':
        images, poses, render_poses, hwf, i_split = load_blender_data(
            args.datadir, args.half_res, args.testskip)
        print('Loaded blender', images.shape,
              render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split

        near = 2.
        far = 6.

        if args.white_bkgd:
            images = images[..., :3]*images[..., -1:] + (1.-images[..., -1:])
        else:
            images = images[..., :3]

    elif args.dataset_type == 'deepvoxels':

        images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
                                                                 basedir=args.datadir,
                                                                 testskip=args.testskip)

        print('Loaded deepvoxels', images.shape,
              render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split

        hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
        near = hemi_R-1.
        far = hemi_R+1.

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    if args.render_test:
        render_poses = np.array(poses[i_test])

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, models = create_nerf(
        args)

    bds_dict = {
        'near': tf.cast(near, tf.float32),
        'far': tf.cast(far, tf.float32),
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Short circuit if only rendering out from trained model
    if args.render_only:
        print('RENDER ONLY')
        if args.render_test:
            # render_test switches to test poses
            images = images[i_test]
        else:
            # Default is smoother render_poses path
            images = None

        testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format(
            'test' if args.render_test else 'path', start))
        os.makedirs(testsavedir, exist_ok=True)
        print('test poses shape', render_poses.shape)

        rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test,
                              gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
        print('Done rendering', testsavedir)
        imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'),
                         to8b(rgbs), fps=30, quality=8)

        return

    # Create optimizer
    lrate = args.lrate
    if args.lrate_decay > 0:
        lrate = tf.keras.optimizers.schedules.ExponentialDecay(lrate,
                                                               decay_steps=args.lrate_decay * 1000, decay_rate=0.1)
    optimizer = tf.keras.optimizers.Adam(lrate)
    optimizer = mixed_precision.LossScaleOptimizer(optimizer)
    models['optimizer'] = optimizer

    global_step = tf.compat.v1.train.get_or_create_global_step()
    global_step.assign(start)

    # Prepare raybatch tensor if batching random rays
    N_rand = args.N_rand
    use_batching = not args.no_batching
    if use_batching:
        # For random ray batching.
        #
        # Constructs an array 'rays_rgb' of shape [N*H*W, 3, 3] where axis=1 is
        # interpreted as,
        #   axis=0: ray origin in world space
        #   axis=1: ray direction in world space
        #   axis=2: observed RGB color of pixel
        print('get rays')
        # get_rays_np() returns rays_origin=[H, W, 3], rays_direction=[H, W, 3]
        # for each pixel in the image. This stack() adds a new dimension.
        rays = [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]]
        rays = np.stack(rays, axis=0)  # [N, ro+rd, H, W, 3]
        print('done, concats')
        # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.concatenate([rays, images[:, None, ...]], 1)
        # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4])
        rays_rgb = np.stack([rays_rgb[i]
                             for i in i_train], axis=0)  # train images only
        # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = np.reshape(rays_rgb, [-1, 3, 3])
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)
        print('done')
        i_batch = 0

    N_iters = 1000000
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # Summary writers
    writer = tf.summary.create_file_writer(
        os.path.join(basedir, 'summaries', expname))

    for i in range(start, N_iters):
        time0 = time.time()

        # Sample random ray batch

        if use_batching:
            # Random over all images
            batch = rays_rgb[i_batch:i_batch+N_rand]  # [B, 2+1, 3*?]
            batch = tf.transpose(batch, [1, 0, 2])

            # batch_rays[i, n, xyz] = ray origin or direction, example_id, 3D position
            # target_s[n, rgb] = example_id, observed color.
            batch_rays, target_s = batch[:2], batch[2]

            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                np.random.shuffle(rays_rgb)
                i_batch = 0

        else:
            # Random from one image
            img_i = np.random.choice(i_train)
            target = images[img_i]
            pose = poses[img_i, :3, :4]

            if N_rand is not None:
                rays_o, rays_d = get_rays(H, W, focal, pose)
                if i < args.precrop_iters:
                    dH = int(H//2 * args.precrop_frac)
                    dW = int(W//2 * args.precrop_frac)
                    coords = tf.stack(tf.meshgrid(
                        tf.range(H//2 - dH, H//2 + dH), 
                        tf.range(W//2 - dW, W//2 + dW), 
                        indexing='ij'), -1)
                    if i < 10:
                        print('precrop', dH, dW, coords[0,0], coords[-1,-1])
                else:
                    coords = tf.stack(tf.meshgrid(
                        tf.range(H), tf.range(W), indexing='ij'), -1)
                coords = tf.reshape(coords, [-1, 2])
                select_inds = np.random.choice(
                    coords.shape[0], size=[N_rand], replace=False)
                select_inds = tf.gather_nd(coords, select_inds[:, tf.newaxis])
                rays_o = tf.gather_nd(rays_o, select_inds)
                rays_d = tf.gather_nd(rays_d, select_inds)
                batch_rays = tf.stack([rays_o, rays_d], 0)
                target_s = tf.gather_nd(target, select_inds)

        #####  Core optimization loop  #####

        with tf.GradientTape() as tape:

            # Make predictions for color, disparity, accumulated opacity.
            rgb, disp, acc, extras = render(
                H, W, focal, chunk=args.chunk, rays=batch_rays,
                verbose=i < 10, retraw=True, **render_kwargs_train)

            # Compute MSE loss between predicted and true RGB.
            img_loss = img2mse(rgb, target_s)
            trans = extras['raw'][..., -1]
            loss = img_loss
            psnr = mse2psnr(img_loss)

            # Add MSE loss for coarse-grained model
            if 'rgb0' in extras:
                img_loss0 = img2mse(extras['rgb0'], target_s)
                loss += img_loss0
                psnr0 = mse2psnr(img_loss0)

        scaled_gradients = tape.gradient(loss, grad_vars)
        gradients = optimizer.get_unscaled_gradients(scaled_gradients)
        optimizer.apply_gradients(zip(gradients, grad_vars))

        dt = time.time()-time0

        #####           end            #####

        # Rest is logging

        def save_weights(net, prefix, i):
            path = os.path.join(
                basedir, expname, '{}_{:06d}.npy'.format(prefix, i))
            np.save(path, net.get_weights())
            print('saved weights at', path)

        if i % args.i_weights == 0:
            for k in models:
                save_weights(models[k], k, i)

        if i % args.i_video == 0 and i > 0:

            rgbs, disps = render_path(
                render_poses, hwf, args.chunk, render_kwargs_test)
            print('Done, saving', rgbs.shape, disps.shape)
            moviebase = os.path.join(
                basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4',
                             to8b(rgbs), fps=30, quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4',
                             to8b(disps / np.max(disps)), fps=30, quality=8)

            if args.use_viewdirs:
                render_kwargs_test['c2w_staticcam'] = render_poses[0][:3, :4]
                rgbs_still, _ = render_path(
                    render_poses, hwf, args.chunk, render_kwargs_test)
                render_kwargs_test['c2w_staticcam'] = None
                imageio.mimwrite(moviebase + 'rgb_still.mp4',
                                 to8b(rgbs_still), fps=30, quality=8)

        if i % args.i_testset == 0 and i > 0:
            testsavedir = os.path.join(
                basedir, expname, 'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', poses[i_test].shape)
            render_path(poses[i_test], hwf, args.chunk, render_kwargs_test,
                        gt_imgs=images[i_test], savedir=testsavedir)
            print('Saved test set')

        if i % args.i_print == 0 or i < 10:

            print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
            print('iter time {:.05f}'.format(dt))
            tf.summary.experimental.set_step(global_step)
            with writer.as_default():
                tf.summary.scalar('loss', loss)
                tf.summary.scalar('psnr', psnr)
                tf.summary.histogram('tran', trans)
                if args.N_importance > 0:
                    tf.summary.scalar('psnr0', psnr0)
                writer.flush()

            if i % args.i_img == 0:

                # Log a rendered validation view to Tensorboard
                img_i = np.random.choice(i_val)
                target = images[img_i]
                pose = poses[img_i, :3, :4]

                rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,
                                                **render_kwargs_test)

                psnr = mse2psnr(img2mse(rgb, target))
                
                # Save out the validation image for Tensorboard-free monitoring
                testimgdir = os.path.join(basedir, expname, 'tboard_val_imgs')
                if i==0:
                    os.makedirs(testimgdir, exist_ok=True)
                imageio.imwrite(os.path.join(testimgdir, '{:06d}.png'.format(i)), to8b(rgb))

                with writer.as_default():
                    tf.summary.image('rgb', to8b(rgb)[tf.newaxis])
                    tf.summary.image(
                        'disp', disp[tf.newaxis, ..., tf.newaxis])
                    tf.summary.image(
                        'acc', acc[tf.newaxis, ..., tf.newaxis])

                    tf.summary.scalar('psnr_holdout', psnr)
                    tf.summary.image('rgb_holdout', target[tf.newaxis])
                    writer.flush()

                if args.N_importance > 0:

                    with writer.as_default(): 
                        tf.summary.image(
                            'rgb0', to8b(extras['rgb0'])[tf.newaxis])
                        tf.summary.image(
                            'disp0', extras['disp0'][tf.newaxis, ..., tf.newaxis])
                        tf.summary.image(
                            'z_std', extras['z_std'][tf.newaxis, ..., tf.newaxis])
                        writer.flush()

        global_step.assign_add(1)
Ejemplo n.º 11
0
    def __init__(self, args, actor, dl, encoder=None, planner=None, cnn=None, gripper_cnn=None, img_embed_to_goal_space=None, lang_embed_to_goal_space = None,\
                optimizer=Adam, strategy=None, global_batch_size=32):

        self.actor = actor
        self.encoder = encoder
        self.planner = planner
        self.cnn = cnn
        self.gripper_cnn = gripper_cnn
        self.img_embed_to_goal_space = img_embed_to_goal_space
        self.lang_embed_to_goal_space = lang_embed_to_goal_space
        self.strategy = strategy
        self.args = args
        self.dl = dl
        self.global_batch_size = global_batch_size

        if args.fp16:
            optimizer = mixed_precision.LossScaleOptimizer(optimizer)

        if self.args.num_distribs is None:  # different sized clips due to different sized losses
            actor_clip = 0.06
            encoder_clip = 0.03
            planner_clip = 0.001
            cnn_clip = 10  # TODO find value if doing non de
            gripper_cnn_clip = 10.0
            mapper_clip = 5.0
        else:
            actor_clip = 400.0
            encoder_clip = 30.0
            planner_clip = 5.0
            cnn_clip = 20.0
            gripper_cnn_clip = 10.0
            mapper_clip = 5.0

        self.temperature = args.temperature
        self.temp_schedule = cosineDecay(min_frac=1 / 16,
                                         max=args.temperature,
                                         decay_steps=20000)

        # bit boiler platy having them all separate, but I tried a really clean dicts+comprehensions method and the TPU complained about having non XLA functions - so it stays this way for now.
        self.actor_optimizer = optimizer(learning_rate=args.learning_rate,
                                         clipnorm=actor_clip)
        self.encoder_optimizer = optimizer(learning_rate=args.learning_rate,
                                           clipnorm=encoder_clip)
        self.planner_optimizer = optimizer(learning_rate=args.learning_rate,
                                           clipnorm=planner_clip)
        self.cnn_optimizer = optimizer(learning_rate=args.learning_rate,
                                       clipnorm=cnn_clip)
        self.gripper_cnn_optimizer = optimizer(
            learning_rate=args.learning_rate, clipnorm=gripper_cnn_clip)
        self.img_embed_to_goal_space_optimizer = optimizer(
            learning_rate=args.learning_rate, clipnorm=mapper_clip)
        self.lang_embed_to_goal_space_optimizer = optimizer(
            learning_rate=args.learning_rate, clipnorm=mapper_clip)

        self.nll_action_loss = lambda y, p_y: tf.reduce_sum(-p_y.log_prob(y),
                                                            axis=2)
        self.mae_action_loss = tf.keras.losses.MeanAbsoluteError(
            reduction=tf.keras.losses.Reduction.NONE)
        self.mse_action_loss = tf.keras.losses.MeanSquaredError(
            reduction=tf.keras.losses.Reduction.NONE)

        self.metrics = {}
        self.metrics['train_loss'] = tf.keras.metrics.Mean(name='train_loss')
        self.metrics['valid_loss'] = tf.keras.metrics.Mean(name='valid_loss')
        self.metrics['actor_grad_norm'] = tf.keras.metrics.Mean(
            name='actor_grad_norm')
        self.metrics['encoder_grad_norm'] = tf.keras.metrics.Mean(
            name='encoder_grad_norm')
        self.metrics['planner_grad_norm'] = tf.keras.metrics.Mean(
            name='planner_grad_norm')
        self.metrics['cnn_grad_norm'] = tf.keras.metrics.Mean(
            name='cnn_grad_norm')
        self.metrics['gripper_cnn_grad_norm'] = tf.keras.metrics.Mean(
            name='gripper_cnn_grad_norm')
        self.metrics['img_embed_to_goal_space_norm'] = tf.keras.metrics.Mean(
            name='img_embed_to_goal_space_norm')
        self.metrics['lang_embed_to_goal_space_norm'] = tf.keras.metrics.Mean(
            name='lang_embed_to_goal_space_norm')

        self.metrics['global_grad_norm'] = tf.keras.metrics.Mean(
            name='global_grad_norm')

        self.metrics['train_act_with_enc_loss'] = tf.keras.metrics.Mean(
            name='train_act_with_enc_loss')
        self.metrics['train_act_with_plan_loss'] = tf.keras.metrics.Mean(
            name='train_act_with_plan_loss')
        self.metrics['valid_act_with_enc_loss'] = tf.keras.metrics.Mean(
            name='valid_act_with_enc_loss')
        self.metrics['valid_act_with_plan_loss'] = tf.keras.metrics.Mean(
            name='valid_act_with_plan_loss')

        self.metrics['train_reg_loss'] = tf.keras.metrics.Mean(name='reg_loss')
        self.metrics['valid_reg_loss'] = tf.keras.metrics.Mean(
            name='valid_reg_loss')

        self.metrics['train_discrete_planner_loss'] = tf.keras.metrics.Mean(
            name='train_discrete_planner_loss')
        self.metrics['valid_discrete_planner_loss'] = tf.keras.metrics.Mean(
            name='valid_discrete_planner_loss')

        self.metrics['valid_position_loss'] = tf.keras.metrics.Mean(
            name='valid_position_loss')
        self.metrics['valid_max_position_loss'] = lfp.metric.MaxMetric(
            name='valid_max_position_loss')
        self.metrics['valid_rotation_loss'] = tf.keras.metrics.Mean(
            name='valid_rotation_loss')
        self.metrics['valid_max_rotation_loss'] = lfp.metric.MaxMetric(
            name='valid_max_rotation_loss')
        self.metrics['valid_gripper_loss'] = tf.keras.metrics.Mean(
            name='valid_rotation_loss')

        self.metrics['valid_enc_position_loss'] = tf.keras.metrics.Mean(
            name='valid_enc_position_loss')
        self.metrics['valid_enc_max_position_loss'] = lfp.metric.MaxMetric(
            name='valid_enc_max_position_loss')
        self.metrics['valid_enc_rotation_loss'] = tf.keras.metrics.Mean(
            name='valid_enc_rotation_loss')
        self.metrics['valid_enc_max_rotation_loss'] = lfp.metric.MaxMetric(
            name='valid_enc_max_rotation_loss')
        self.metrics['valid_enc_gripper_loss'] = tf.keras.metrics.Mean(
            name='valid_enc_rotation_loss')

        self.metrics['valid_lang_position_loss'] = tf.keras.metrics.Mean(
            name='valid_position_loss')
        self.metrics['valid_lang_max_position_loss'] = lfp.metric.MaxMetric(
            name='valid_max_position_loss')
        self.metrics['valid_lang_rotation_loss'] = tf.keras.metrics.Mean(
            name='valid_rotation_loss')
        self.metrics['valid_lang_max_rotation_loss'] = lfp.metric.MaxMetric(
            name='valid_max_rotation_loss')
        self.metrics['valid_lang_gripper_loss'] = tf.keras.metrics.Mean(
            name='valid_rotation_loss')

        self.chkpt_manager = None
Ejemplo n.º 12
0
def main(args):

    print(args)

    if args.push_to_hub:
        login_to_hub()

    if not isinstance(args.workers, int):
        args.workers = min(16, mp.cpu_count())

    # AMP
    if args.amp:
        mixed_precision.set_global_policy("mixed_float16")

    st = time.time()
    val_set = DetectionDataset(
        img_folder=os.path.join(args.val_path, "images"),
        label_path=os.path.join(args.val_path, "labels.json"),
        sample_transforms=T.SampleCompose(([
            T.Resize((args.input_size, args.input_size),
                     preserve_aspect_ratio=True,
                     symmetric_pad=True)
        ] if not args.rotation or args.eval_straight else []) + ([
            T.Resize(args.input_size, preserve_aspect_ratio=True
                     ),  # This does not pad
            T.RandomRotate(90, expand=True),
            T.Resize((args.input_size, args.input_size),
                     preserve_aspect_ratio=True,
                     symmetric_pad=True),
        ] if args.rotation and not args.eval_straight else [])),
        use_polygons=args.rotation and not args.eval_straight,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.workers,
    )
    print(
        f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
        f"{val_loader.num_batches} batches)")
    with open(os.path.join(args.val_path, "labels.json"), "rb") as f:
        val_hash = hashlib.sha256(f.read()).hexdigest()

    batch_transforms = T.Compose([
        T.Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)),
    ])

    # Load doctr model
    model = detection.__dict__[args.arch](
        pretrained=args.pretrained,
        input_shape=(args.input_size, args.input_size, 3),
        assume_straight_pages=not args.rotation,
    )

    # Resume weights
    if isinstance(args.resume, str):
        model.load_weights(args.resume)

    # Metrics
    val_metric = LocalizationConfusion(use_polygons=args.rotation
                                       and not args.eval_straight,
                                       mask_shape=(args.input_size,
                                                   args.input_size))
    if args.test_only:
        print("Running evaluation")
        val_loss, recall, precision, mean_iou = evaluate(
            model, val_loader, batch_transforms, val_metric)
        print(
            f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | "
            f"Mean IoU: {mean_iou:.2%})")
        return

    st = time.time()
    # Load both train and val data generators
    train_set = DetectionDataset(
        img_folder=os.path.join(args.train_path, "images"),
        label_path=os.path.join(args.train_path, "labels.json"),
        img_transforms=T.Compose([
            # Augmentations
            T.RandomApply(T.ColorInversion(), 0.1),
            T.RandomJpegQuality(60),
            T.RandomSaturation(0.3),
            T.RandomContrast(0.3),
            T.RandomBrightness(0.3),
        ]),
        sample_transforms=T.SampleCompose(([
            T.Resize((args.input_size, args.input_size),
                     preserve_aspect_ratio=True,
                     symmetric_pad=True)
        ] if not args.rotation else []) + ([
            T.Resize(args.input_size, preserve_aspect_ratio=True
                     ),  # This does not pad
            T.RandomRotate(90, expand=True),
            T.Resize((args.input_size, args.input_size),
                     preserve_aspect_ratio=True,
                     symmetric_pad=True),
        ] if args.rotation else [])),
        use_polygons=args.rotation,
    )
    train_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
    )
    print(
        f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
        f"{train_loader.num_batches} batches)")
    with open(os.path.join(args.train_path, "labels.json"), "rb") as f:
        train_hash = hashlib.sha256(f.read()).hexdigest()

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target)
        return

    # Optimizer
    scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
        args.lr,
        decay_steps=args.epochs * len(train_loader),
        decay_rate=1 / (25e4),  # final lr as a fraction of initial lr
        staircase=False,
    )
    optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler,
                                         beta_1=0.95,
                                         beta_2=0.99,
                                         epsilon=1e-6,
                                         clipnorm=5)
    if args.amp:
        optimizer = mixed_precision.LossScaleOptimizer(optimizer)
    # LR Finder
    if args.find_lr:
        lrs, losses = record_lr(model,
                                train_loader,
                                batch_transforms,
                                optimizer,
                                amp=args.amp)
        plot_recorder(lrs, losses)
        return

    # Tensorboard to monitor training
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(
            name=exp_name,
            project="text-detection",
            config={
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "weight_decay": 0.0,
                "batch_size": args.batch_size,
                "architecture": args.arch,
                "input_size": args.input_size,
                "optimizer": "adam",
                "framework": "tensorflow",
                "scheduler": "exp_decay",
                "train_hash": train_hash,
                "val_hash": val_hash,
                "pretrained": args.pretrained,
                "rotation": args.rotation,
            },
        )

    if args.freeze_backbone:
        for layer in model.feat_extractor.layers:
            layer.trainable = False

    min_loss = np.inf

    # Training loop
    mb = master_bar(range(args.epochs))
    for epoch in mb:
        fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb,
                      args.amp)
        # Validation loop at the end of each epoch
        val_loss, recall, precision, mean_iou = evaluate(
            model, val_loader, batch_transforms, val_metric)
        if val_loss < min_loss:
            print(
                f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state..."
            )
            model.save_weights(f"./{exp_name}/weights")
            min_loss = val_loss
        log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
        if any(val is None for val in (recall, precision, mean_iou)):
            log_msg += "(Undefined metric value, caused by empty GTs or predictions)"
        else:
            log_msg += f"(Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%})"
        mb.write(log_msg)
        # W&B
        if args.wb:
            wandb.log({
                "val_loss": val_loss,
                "recall": recall,
                "precision": precision,
                "mean_iou": mean_iou,
            })

    if args.wb:
        run.finish()

    if args.push_to_hub:
        push_to_hf_hub(model, exp_name, task="detection", run_config=args)
Ejemplo n.º 13
0
    def __init__(self,
                 observation_space,
                 action_space,
                 model_f,
                 m_dir=None,
                 log_name=None,
                 start_step=0,
                 mixed_float=False):
        """
        Parameters
        ----------
        observation_space : gym.Space
            Observation space of the environment.
        action_space : gym.Space
            Action space of the environment. Current agent expects only
            a discrete action space.
        model_f
            A function that returns actor, critic models. 
            It should take obeservation space and action space as inputs.
            It should not compile the model.
        m_dir : str
            A model directory to load the model if there's a model to load
        log_name : str
            A name for log. If not specified, will be set to current time.
            - If m_dir is specified yet no log_name is given, it will continue
            counting.
            - If m_dir and log_name are both specified, it will load model from
            m_dir, but will record as it is the first training.
        start_step : int
            Total step starts from start_step
        mixed_float : bool
            Whether or not to use mixed precision
        """
        # model : The actual training model
        # t_model : Fixed target model
        print('Model directory : {}'.format(m_dir))
        print('Log name : {}'.format(log_name))
        print('Starting from step {}'.format(start_step))
        print(f'Use mixed float? {mixed_float}')
        self.action_space = action_space
        self.action_range = action_space.high - action_space.low
        self.action_shape = action_space.shape
        self.observation_space = observation_space
        self.mixed_float = mixed_float
        if mixed_float:
            policy = mixed_precision.Policy('mixed_float16')
            mixed_precision.set_global_policy(policy)

        assert hp.Algorithm in hp.available_algorithms, "Wrong Algorithm!"

        # Special variables
        if hp.Algorithm == 'V-MPO':

            self.eta = tf.Variable(1.0,
                                   trainable=True,
                                   name='eta',
                                   dtype='float32')
            self.alpha_mu = tf.Variable(1.0,
                                        trainable=True,
                                        name='alpha_mu',
                                        dtype='float32')
            self.alpha_sig = tf.Variable(1.0,
                                         trainable=True,
                                         name='alpha_sig',
                                         dtype='float32')

        elif hp.Algorithm == 'A2C':
            action_num = tf.reduce_prod(self.action_shape)
            self.log_sigma = tf.Variable(tf.fill((action_num), 0.1),
                                         trainable=True,
                                         name='sigma',
                                         dtype='float32')

        #Inputs
        if hp.ICM_ENABLE:
            actor, critic, icm_models = model_f(observation_space,
                                                action_space)
            encoder, inverse, forward = icm_models
            self.models = {
                'actor': actor,
                'critic': critic,
                'encoder': encoder,
                'inverse': inverse,
                'forward': forward,
            }
        else:
            actor, critic = model_f(observation_space, action_space)
            self.models = {
                'actor': actor,
                'critic': critic,
            }
        targets = ['actor', 'critic']

        # Common ADAM optimizer; in V-MPO loss is merged together
        common_lr = tf.function(partial(self._lr, 'common'))
        self.common_optimizer = keras.optimizers.Adam(
            learning_rate=common_lr,
            epsilon=hp.lr['common'].epsilon,
            global_clipnorm=hp.lr['common'].grad_clip,
        )
        if self.mixed_float:
            self.common_optimizer = mixed_precision.LossScaleOptimizer(
                self.common_optimizer)

        for name, model in self.models.items():
            lr = tf.function(partial(self._lr, name))
            optimizer = keras.optimizers.Adam(
                learning_rate=lr,
                epsilon=hp.lr[name].epsilon,
                global_clipnorm=hp.lr[name].grad_clip,
            )
            if self.mixed_float:
                optimizer = mixed_precision.LossScaleOptimizer(optimizer)
            model.compile(optimizer=optimizer)
            model.summary()

        # Load model if specified
        if m_dir is not None:
            for name, model in self.models.items():
                model.load_weights(path.join(m_dir, name))
            print(f'model loaded : {m_dir}')

        # Initialize target model
        self.t_models = {}
        for name in targets:
            model = self.models[name]
            self.t_models[name] = keras.models.clone_model(model)
            self.t_models[name].set_weights(model.get_weights())

        # File writer for tensorboard
        if log_name is None:
            self.log_name = datetime.now().strftime('%m_%d_%H_%M_%S')
        else:
            self.log_name = log_name
        self.file_writer = tf.summary.create_file_writer(
            path.join('logs', self.log_name))
        self.file_writer.set_as_default()
        print('Writing logs at logs/' + self.log_name)

        # Scalars
        self.start_training = False
        self.total_steps = tf.Variable(start_step, dtype=tf.int64)

        # Savefile folder directory
        if m_dir is None:
            self.save_dir = path.join('savefiles', self.log_name)
            self.save_count = 0
        else:
            if log_name is None:
                self.save_dir, self.save_count = path.split(m_dir)
                self.save_count = int(self.save_count)
            else:
                self.save_dir = path.join('savefiles', self.log_name)
                self.save_count = 0
        self.model_dir = None
Ejemplo n.º 14
0
def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,
          customize):

    try:
        from comet_ml import Experiment
        experiment = Experiment(
            project_name="particleflow-tf",
            auto_metric_logging=True,
            auto_param_logging=True,
            auto_histogram_weight_logging=True,
            auto_histogram_gradient_logging=False,
            auto_histogram_activation_logging=False,
        )
    except Exception as e:
        print("Failed to initialize comet-ml dashboard")
        experiment = None
    """Train a model defined by config"""
    config_file_path = config
    config, config_file_stem = parse_config(config,
                                            nepochs=nepochs,
                                            weights=weights)

    if plot_freq:
        config["callbacks"]["plot_freq"] = plot_freq

    if customize:
        config = customization_functions[customize](config)

    if recreate or (weights is None):
        outdir = create_experiment_dir(prefix=prefix + config_file_stem + "_",
                                       suffix=platform.node())
    else:
        outdir = str(Path(weights).parent)

    # Decide tf.distribute.strategy depending on number of available GPUs
    strategy, num_gpus = get_strategy()
    #if "CPU" not in strategy.extended.worker_devices[0]:
    #    nvidia_smi_call = "nvidia-smi --query-gpu=timestamp,name,pci.bus_id,pstate,power.draw,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used --format=csv -l 1 -f {}/nvidia_smi_log.csv".format(outdir)
    #    p = subprocess.Popen(shlex.split(nvidia_smi_call))

    ds_train, num_train_steps = get_datasets(config["train_test_datasets"],
                                             config, num_gpus, "train")
    ds_test, num_test_steps = get_datasets(config["train_test_datasets"],
                                           config, num_gpus, "test")
    ds_val, ds_info = get_heptfds_dataset(
        config["validation_dataset"], config, num_gpus, "test",
        config["setup"]["num_events_validation"])
    ds_val = ds_val.batch(5)

    if ntrain:
        ds_train = ds_train.take(ntrain)
        num_train_steps = ntrain
    if ntest:
        ds_test = ds_test.take(ntest)
        num_test_steps = ntest

    print("num_train_steps", num_train_steps)
    print("num_test_steps", num_test_steps)
    total_steps = num_train_steps * config["setup"]["num_epochs"]
    print("total_steps", total_steps)

    if experiment:
        experiment.set_name(outdir)
        experiment.log_code("mlpf/tfmodel/model.py")
        experiment.log_code("mlpf/tfmodel/utils.py")
        experiment.log_code(config_file_path)

    shutil.copy(config_file_path, outdir + "/config.yaml"
                )  # Copy the config file to the train dir for later reference

    with strategy.scope():
        lr_schedule, optim_callbacks = get_lr_schedule(config,
                                                       steps=total_steps)
        opt = get_optimizer(config, lr_schedule)

        if config["setup"]["dtype"] == "float16":
            model_dtype = tf.dtypes.float16
            policy = mixed_precision.Policy("mixed_float16")
            mixed_precision.set_global_policy(policy)
            opt = mixed_precision.LossScaleOptimizer(opt)
        else:
            model_dtype = tf.dtypes.float32

        model = make_model(config, model_dtype)

        # Build the layers after the element and feature dimensions are specified
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        initial_epoch = 0
        if weights:
            # We need to load the weights in the same trainable configuration as the model was set up
            configure_model_weights(
                model, config["setup"].get("weights_config", "all"))
            model.load_weights(weights, by_name=True)
            initial_epoch = int(weights.split("/")[-1].split("-")[1])
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        config = set_config_loss(config, config["setup"]["trainable"])
        configure_model_weights(model, config["setup"]["trainable"])
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        print("model weights")
        tw_names = [m.name for m in model.trainable_weights]
        for w in model.weights:
            print("layer={} trainable={} shape={} num_weights={}".format(
                w.name, w.name in tw_names, w.shape, np.prod(w.shape)))

        loss_dict, loss_weights = get_loss_dict(config)
        model.compile(
            loss=loss_dict,
            optimizer=opt,
            sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ] + [
                    SingleClassRecall(
                        icls, name="rec_cls{}".format(icls), dtype=tf.float64)
                    for icls in range(config["dataset"]["num_output_classes"])
                ]
            },
        )
        model.summary()

    callbacks = prepare_callbacks(config["callbacks"],
                                  outdir,
                                  ds_val,
                                  ds_info,
                                  comet_experiment=experiment)
    callbacks.append(optim_callbacks)

    fit_result = model.fit(
        ds_train.repeat(),
        validation_data=ds_test.repeat(),
        epochs=initial_epoch + config["setup"]["num_epochs"],
        callbacks=callbacks,
        steps_per_epoch=num_train_steps,
        validation_steps=num_test_steps,
        initial_epoch=initial_epoch,
    )

    history_path = Path(outdir) / "history"
    history_path = str(history_path)
    with open("{}/history.json".format(history_path), "w") as fi:
        json.dump(fit_result.history, fi)

    weights = get_best_checkpoint(outdir)
    print("Loading best weights that could be found from {}".format(weights))
    model.load_weights(weights, by_name=True)

    model.save(outdir + "/model_full", save_format="tf")

    print("Training done.")
Ejemplo n.º 15
0
                  padding='same',
                  activation='relu'))
model.add(layers.GlobalAveragePooling2D())
model.add(layers.Flatten())
model.add(layers.Dense(128, activation='relu'))
model.add(layers.LayerNormalization(axis=1, center=True, scale=True))
model.add(layers.Dense(256, activation='relu'))
model.add(layers.LayerNormalization(axis=1, center=True, scale=True))
model.add(layers.Dense(512, activation='relu'))
model.add(layers.LayerNormalization(axis=1, center=True, scale=True))
model.add(layers.Dense(1024, activation='relu'))
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

optimizer = tf.keras.optimizers.Adam(lr=2e-5)
opt = mixed_precision.LossScaleOptimizer(optimizer)

model.compile(loss='sparse_categorical_crossentropy',
              optimizer=opt,
              metrics=['accuracy'])


def scheduler(epoch, lr):
    if epoch < 3:
        return lr
    else:
        return lr * tf.math.exp(-0.1)


SCH = tf.keras.callbacks.LearningRateScheduler(scheduler)
Ejemplo n.º 16
0
def main():

    # Main function to collect configuration file and run the script
    print(f'GPU REPLICAS: {strategy.num_replicas_in_sync}')
    t0 = time.time()

    print(f'Train dir: {config.TRAIN_DATADIR}')
    print(f'Validation dir: {config.VAL_DATADIR}')

    # Initialize Callbacks
    callbacks = gen_callbacks(config, config.CALLBACKS_METADATA)

    # open files and get dataset tensor slices
    train_images, train_labels = get_tensorslices(
        data_dir=config.TRAIN_DATADIR, img_id='x', label_id='y'
    )

    # open files and get dataset tensor slices
    val_images, val_labels = get_tensorslices(
        data_dir=config.VAL_DATADIR, img_id='x', label_id='y'
    )

    # extract values for training
    NUM_TRAINING_IMAGES = train_images.shape[0]
    NUM_VALIDATION_IMAGES = val_images.shape[0]
    STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // config.BATCH_SIZE

    print(f'{NUM_TRAINING_IMAGES} training images')
    print(f'{NUM_VALIDATION_IMAGES} validation images')

    # generate training dataset
    train_dataset = \
        tf.data.Dataset.from_tensor_slices((train_images, train_labels))

    # generate validation dataset
    val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
    val_dataset = val_dataset.batch(config.VAL_BATCH_SIZE)

    # Create model output directory
    os.system(f'mkdir -p {config.MODEL_SAVEDIR}')

    # Initialize and compile model
    with strategy.scope():

        # initialize UNet model
        model = unet_batchnorm(
            nclass=config.N_CLASSES, input_size=config.INPUT_SIZE,
            maps=config.MODEL_METADATA['network_maps']
        )

        # initialize optimizer, exit of not valid optimizer
        if config.MODEL_METADATA['optimizer_name'] == 'Adadelta':
            optimizer = Adadelta(lr=config.MODEL_METADATA['lr'])
        elif config.MODEL_METADATA['optimizer_name'] == 'Adam':
            optimizer = Adam(lr=config.MODEL_METADATA['lr'])
        else:
            sys.exit('Optimizer provided is not supported.')

        # enabling mixed precision to avoid underflow
        optimizer = mixed_precision.LossScaleOptimizer(optimizer)

        # compile model to start training
        model.compile(
            optimizer,
            loss=config.MODEL_METADATA['loss'],
            metrics=config.MODEL_METADATA['metrics']
        )
        model.summary()

    # Disable AutoShard, data lives in memory, use in memory options
    train_dataset = train_dataset.with_options(options)
    val_dataset = val_dataset.with_options(options)

    # Train the model and save to disk
    model.fit(
        get_training_dataset(
            train_dataset,
            config,
            do_aug=config.MODEL_METADATA['do_aug']
        ),
        initial_epoch=config.START_EPOCH,
        epochs=config.N_EPOCHS,
        steps_per_epoch=STEPS_PER_EPOCH,
        validation_data=val_dataset,
        callbacks=callbacks,
        verbose=2
    )

    print(f'Execution time: {time.time() - t0}')
Ejemplo n.º 17
0
def main(args):

    print(args)

    if args.push_to_hub:
        login_to_hub()

    if not isinstance(args.workers, int):
        args.workers = min(16, mp.cpu_count())

    vocab = VOCABS[args.vocab]
    fonts = args.font.split(",")

    # AMP
    if args.amp:
        mixed_precision.set_global_policy("mixed_float16")

    st = time.time()

    if isinstance(args.val_path, str):
        with open(os.path.join(args.val_path, "labels.json"), "rb") as f:
            val_hash = hashlib.sha256(f.read()).hexdigest()

        # Load val data generator
        val_set = RecognitionDataset(
            img_folder=os.path.join(args.val_path, "images"),
            labels_path=os.path.join(args.val_path, "labels.json"),
            img_transforms=T.Resize((args.input_size, 4 * args.input_size),
                                    preserve_aspect_ratio=True),
        )
    else:
        val_hash = None
        # Load synthetic data generator
        val_set = WordGenerator(
            vocab=vocab,
            min_chars=args.min_chars,
            max_chars=args.max_chars,
            num_samples=args.val_samples * len(vocab),
            font_family=fonts,
            img_transforms=T.Compose([
                T.Resize((args.input_size, 4 * args.input_size),
                         preserve_aspect_ratio=True),
                # Ensure we have a 90% split of white-background images
                T.RandomApply(T.ColorInversion(), 0.9),
            ]),
        )

    val_loader = DataLoader(
        val_set,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.workers,
    )
    print(
        f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
        f"{val_loader.num_batches} batches)")

    # Load doctr model
    model = recognition.__dict__[args.arch](
        pretrained=args.pretrained,
        input_shape=(args.input_size, 4 * args.input_size, 3),
        vocab=vocab,
    )
    # Resume weights
    if isinstance(args.resume, str):
        model.load_weights(args.resume)

    # Metrics
    val_metric = TextMatch()

    batch_transforms = T.Compose([
        T.Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)),
    ])

    if args.test_only:
        print("Running evaluation")
        val_loss, exact_match, partial_match = evaluate(
            model, val_loader, batch_transforms, val_metric)
        print(
            f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})"
        )
        return

    st = time.time()

    if isinstance(args.train_path, str):
        # Load train data generator
        base_path = Path(args.train_path)
        parts = ([base_path]
                 if base_path.joinpath("labels.json").is_file() else
                 [base_path.joinpath(sub) for sub in os.listdir(base_path)])
        with open(parts[0].joinpath("labels.json"), "rb") as f:
            train_hash = hashlib.sha256(f.read()).hexdigest()

        train_set = RecognitionDataset(
            parts[0].joinpath("images"),
            parts[0].joinpath("labels.json"),
            img_transforms=T.Compose([
                T.RandomApply(T.ColorInversion(), 0.1),
                T.Resize((args.input_size, 4 * args.input_size),
                         preserve_aspect_ratio=True),
                # Augmentations
                T.RandomJpegQuality(60),
                T.RandomSaturation(0.3),
                T.RandomContrast(0.3),
                T.RandomBrightness(0.3),
            ]),
        )
        if len(parts) > 1:
            for subfolder in parts[1:]:
                train_set.merge_dataset(
                    RecognitionDataset(subfolder.joinpath("images"),
                                       subfolder.joinpath("labels.json")))
    else:
        train_hash = None
        # Load synthetic data generator
        train_set = WordGenerator(
            vocab=vocab,
            min_chars=args.min_chars,
            max_chars=args.max_chars,
            num_samples=args.train_samples * len(vocab),
            font_family=fonts,
            img_transforms=T.Compose([
                T.Resize((args.input_size, 4 * args.input_size),
                         preserve_aspect_ratio=True),
                # Ensure we have a 90% split of white-background images
                T.RandomApply(T.ColorInversion(), 0.9),
                T.RandomJpegQuality(60),
                T.RandomSaturation(0.3),
                T.RandomContrast(0.3),
                T.RandomBrightness(0.3),
            ]),
        )

    train_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
    )
    print(
        f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
        f"{train_loader.num_batches} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target)
        return

    # Optimizer
    scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
        args.lr,
        decay_steps=args.epochs * len(train_loader),
        decay_rate=1 / (25e4),  # final lr as a fraction of initial lr
        staircase=False,
    )
    optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler,
                                         beta_1=0.95,
                                         beta_2=0.99,
                                         epsilon=1e-6,
                                         clipnorm=5)
    if args.amp:
        optimizer = mixed_precision.LossScaleOptimizer(optimizer)
    # LR Finder
    if args.find_lr:
        lrs, losses = record_lr(model,
                                train_loader,
                                batch_transforms,
                                optimizer,
                                amp=args.amp)
        plot_recorder(lrs, losses)
        return

    # Tensorboard to monitor training
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(
            name=exp_name,
            project="text-recognition",
            config={
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "weight_decay": 0.0,
                "batch_size": args.batch_size,
                "architecture": args.arch,
                "input_size": args.input_size,
                "optimizer": "adam",
                "framework": "tensorflow",
                "scheduler": "exp_decay",
                "vocab": args.vocab,
                "train_hash": train_hash,
                "val_hash": val_hash,
                "pretrained": args.pretrained,
            },
        )

    min_loss = np.inf

    # Training loop
    mb = master_bar(range(args.epochs))
    for epoch in mb:
        fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb,
                      args.amp)

        # Validation loop at the end of each epoch
        val_loss, exact_match, partial_match = evaluate(
            model, val_loader, batch_transforms, val_metric)
        if val_loss < min_loss:
            print(
                f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state..."
            )
            model.save_weights(f"./{exp_name}/weights")
            min_loss = val_loss
        mb.write(
            f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
            f"(Exact: {exact_match:.2%} | Partial: {partial_match:.2%})")
        # W&B
        if args.wb:
            wandb.log({
                "val_loss": val_loss,
                "exact_match": exact_match,
                "partial_match": partial_match,
            })

    if args.wb:
        run.finish()

    if args.push_to_hub:
        push_to_hf_hub(model, exp_name, task="recognition", run_config=args)