Ejemplo n.º 1
0
def get_hardware_strategy(mixed_f16=False):
    try:
        # TPU detection. No parameters necessary if TPU_NAME environment variable is
        # set: this is always the case on Kaggle.
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        print('Running on TPU ', tpu.master())
    except ValueError:
        tpu = None

    if tpu:
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        if mixed_f16:
            policy = mixed_precision.Policy('mixed_bfloat16')
            mixed_precision.set_global_policy(policy)
    else:
        # Default distribution strategy in Tensorflow. Works on CPU and single GPU.
        strategy = tf.distribute.get_strategy()
        if mixed_f16:
            policy = mixed_precision.Policy('mixed_float16')
            mixed_precision.set_global_policy(policy)

    print("REPLICAS: ", strategy.num_replicas_in_sync)
    return tpu, strategy
Ejemplo n.º 2
0
    def _set_precision(calculation_dtype, calculation_epsilon):
        # enable single/half/double precision
        K.set_floatx(calculation_dtype)
        K.set_epsilon(calculation_epsilon)

        # enable mixed precission
        if "float16" in calculation_dtype:

            mixed_precision.set_global_policy("mixed_float16")
Ejemplo n.º 3
0
def evaluate(config, train_dir, weights, customize, nevents):
    """Evaluate the trained model in train_dir"""
    if config is None:
        config = Path(train_dir) / "config.yaml"
        assert config.exists(
        ), "Could not find config file in train_dir, please provide one with -c <path/to/config>"
    config, _ = parse_config(config, weights=weights)

    if customize:
        config = customization_functions[customize](config)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
    else:
        model_dtype = tf.dtypes.float32

    strategy, num_gpus = get_strategy()
    # physical_devices = tf.config.list_physical_devices('GPU')
    # for dev in physical_devices:
    #    tf.config.experimental.set_memory_growth(dev, True)

    model = make_model(config, model_dtype)
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    # need to load the weights in the same trainable configuration as the model was set up
    configure_model_weights(model, config["setup"].get("weights_config",
                                                       "all"))
    if weights:
        model.load_weights(weights, by_name=True)
    else:
        weights = get_best_checkpoint(train_dir)
        print(
            "Loading best weights that could be found from {}".format(weights))
        model.load_weights(weights, by_name=True)

    iepoch = int(weights.split("/")[-1].split("-")[1])

    for dsname in config["validation_datasets"]:
        ds_test, _ = get_heptfds_dataset(dsname,
                                         config,
                                         num_gpus,
                                         "test",
                                         supervised=False)
        if nevents:
            ds_test = ds_test.take(nevents)
        ds_test = ds_test.batch(5)
        eval_dir = str(
            Path(train_dir) / "evaluation" / "epoch_{}".format(iepoch) /
            dsname)
        Path(eval_dir).mkdir(parents=True, exist_ok=True)
        eval_model(model, ds_test, config, eval_dir)

    freeze_model(model, config, train_dir)
Ejemplo n.º 4
0
def set_mixed_precision():
    if int(str(tf.__version__).replace('.', '')) < 241:
        from tensorflow.keras.mixed_precision.experimental import Policy, set_policy
        policy = Policy('mixed_float16')
        set_policy(policy)
    else:
        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_global_policy(policy)
    log.info(
        f' Compute dtype: {policy.compute_dtype}, variable dtype: {policy.variable_dtype}'
    )
Ejemplo n.º 5
0
    def _set_precision(calculation_dtype, calculation_epsilon):
        # enable single/half/double precision
        import tensorflow.keras.backend as K
        K.set_floatx(calculation_dtype)
        K.set_epsilon(calculation_epsilon)

        # enable mixed precission
        if "float16" in calculation_dtype:
            import tensorflow.keras.mixed_precision as mixed_precision
            policy = mixed_precision.Policy("mixed_float16")
            mixed_precision.set_global_policy(policy)
Ejemplo n.º 6
0
 def test__call__with_multiple_io(self, tmpdir):
     # Create and save lower precision model
     set_global_policy('mixed_float16')
     model = mock_multiple_io_model()
     self._test_for_multiple_io(model)
     path = tmpdir.mkdir("tf-keras-vis").join("multiple_io.h5")
     model.save(path)
     # Load and test lower precision model on lower precision environment
     model = load_model(path)
     self._test_for_multiple_io(model)
     # Load and test lower precision model on full precision environment
     set_global_policy('float32')
     model = load_model(path)
     self._test_for_multiple_io(model)
Ejemplo n.º 7
0
def main(args: Namespace) -> None:
    """Run the main program.

    Arguments:
        args: The object containing the commandline arguments
    """
    config = load_config(args.config)

    if config.mixed_precision:
        set_global_policy("mixed_float16")

    generator = get_generator(config)
    GANTrainer.load_generator_weights(generator, args.load_dir)

    helper = GANEvaluator(generator, config)
    helper.generate(args.imgs_per_digit, args.output_dir)
def test_works_in_xpdnet_train(model_fun,
                               model_kwargs,
                               n_scales,
                               res,
                               n_iter=10,
                               multicoil=False,
                               use_mixed_precision=False,
                               data_consistency_learning=False):
    # trying mixed precision
    if use_mixed_precision:
        policy_type = 'mixed_float16'
    else:
        policy_type = 'float32'
    mixed_precision.set_global_policy(policy_type)
    run_params = {
        'n_primal': n_primal,
        'multicoil': multicoil,
        'n_scales': n_scales,
        'n_iter': n_iter,
        'refine_smaps': multicoil,
        'res': res,
        'primal_only': not data_consistency_learning,
    }
    model = XPDNet(model_fun, model_kwargs, **run_params)
    default_model_compile(model, lr=1e-3, loss='mae')
    n_coils = 15
    k_shape = (640, 400)
    if multicoil:
        k_shape = (n_coils, *k_shape)
    inputs = [
        tf.ones([1, *k_shape, 1], dtype=tf.complex64),
        tf.ones([1, *k_shape], dtype=tf.complex64),
    ]
    if multicoil:
        inputs += [
            tf.ones([1, *k_shape], dtype=tf.complex64),
        ]
    try:
        model.fit(
            x=inputs,
            y=tf.ones([1, 320, 320, 1]),
            epochs=1,
        )
    except (tf.errors.ResourceExhaustedError, tf.errors.InternalError):
        return False
    else:
        return True
Ejemplo n.º 9
0
def main(args: Namespace) -> None:
    """Run the main program.

    Arguments:
        args: The object containing the commandline arguments
    """
    config = load_config(args.config)

    strategy = MirroredStrategy()
    if config.mixed_precision:
        set_global_policy("mixed_float16")

    train_dataset, test_dataset = get_dataset(args.data_path,
                                              config.gan_batch_size)

    with strategy.scope():
        generator = get_generator(config)
        critic = get_critic(config)

        classifier = Classifier(config)
        ClassifierTrainer.load_weights(classifier, args.load_dir)

    # Save each run into a directory by its timestamp
    log_dir = setup_dirs(
        dirs=[args.save_dir],
        dirs_to_tstamp=[args.log_dir],
        config=config,
        file_name=CONFIG,
    )[0]

    trainer = GANTrainer(
        generator,
        critic,
        classifier,
        strategy,
        train_dataset,
        test_dataset,
        config=config,
        log_dir=log_dir,
        save_dir=args.save_dir,
    )
    trainer.train(
        record_steps=args.record_steps,
        log_graph=args.log_graph,
        save_steps=args.save_steps,
    )
Ejemplo n.º 10
0
def evaluate(config, train_dir, weights, evaluation_dir):
    """Evaluate the trained model in train_dir"""
    if config is None:
        config = Path(train_dir) / "config.yaml"
        assert config.exists(
        ), "Could not find config file in train_dir, please provide one with -c <path/to/config>"
    config, _ = parse_config(config, weights=weights)

    if evaluation_dir is None:
        eval_dir = str(Path(train_dir) / "evaluation")
    else:
        eval_dir = evaluation_dir

    Path(eval_dir).mkdir(parents=True, exist_ok=True)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
        opt = mixed_precision.LossScaleOptimizer(opt)
    else:
        model_dtype = tf.dtypes.float32

    strategy, num_gpus = get_strategy()
    ds_test, _ = get_heptfds_dataset(config["validation_dataset"], config,
                                     num_gpus, "test")
    ds_test = ds_test.batch(5)

    model = make_model(config, model_dtype)
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    # need to load the weights in the same trainable configuration as the model was set up
    configure_model_weights(model, config["setup"].get("weights_config",
                                                       "all"))
    if weights:
        model.load_weights(weights, by_name=True)
    else:
        weights = get_best_checkpoint(train_dir)
        print(
            "Loading best weights that could be found from {}".format(weights))
        model.load_weights(weights, by_name=True)

    eval_model(model, ds_test, config, eval_dir)
    freeze_model(model, config, ds_test.take(1), train_dir)
Ejemplo n.º 11
0
def setup(args):
    # Logging
    tf.get_logger().setLevel('DEBUG' if args.debug else 'WARNING')

    # Policy
    mixed_precision.set_global_policy(args.policy)
    for d in ['bfloat16', 'float16', 'float32']:
        if d in args.policy:
            args.dtype = d
            break

    # Device and strategy
    if args.tpu:
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)
    else:
        strategy = tf.distribute.get_strategy()
    return strategy
Ejemplo n.º 12
0
def setup():
    # Make base dir
    loss_dir = f'out/{FLAGS.loss}-{FLAGS.disc_model}'
    shutil.rmtree(loss_dir, ignore_errors=True)
    os.mkdir(loss_dir)

    if FLAGS.strategy == 'tpu':
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)
    elif FLAGS.strategy == 'multi_cpu':
        strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])
    else:
        strategy = tf.distribute.get_strategy()

    # Policy
    policy = mixed_precision.Policy(FLAGS.policy)
    mixed_precision.set_global_policy(policy)

    return strategy, loss_dir
Ejemplo n.º 13
0
def setup(args):
    # Logging
    logging.set_verbosity(args.log_level.upper())

    # Output directory
    args.out = os.path.join(args.base_dir, args.loss, args.data_id,
                            f'{args.backbone}-{args.feat_norm}')
    logging.info(f"out directory: '{args.out}'")
    if not args.load:
        if args.out.startswith('gs://'):
            os.system(f"gsutil -m rm {os.path.join(args.out, '**')}")
        else:
            if os.path.exists(args.out):
                shutil.rmtree(args.out)
            os.makedirs(args.out)
        logging.info(f"cleared any previous work in '{args.out}'")

    # Strategy
    if args.tpu:
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)
    elif len(tf.config.list_physical_devices('GPU')) > 1:
        strategy = tf.distribute.MirroredStrategy()
    elif args.multi_cpu:
        strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])
    else:
        strategy = tf.distribute.get_strategy()

    # Mixed precision
    policy = mixed_precision.Policy(args.policy)
    mixed_precision.set_global_policy(policy)

    # Dataset arguments
    args.views, args.with_batch_sims = ['image', 'image2'], True

    return strategy
Ejemplo n.º 14
0
def main(args: Namespace) -> None:
    """Run the main program.

    Arguments:
        args: The object containing the commandline arguments
    """
    config = load_config(args.config)

    strategy = MirroredStrategy()
    if config.mixed_precision:
        set_global_policy("mixed_float16")

    train_dataset, test_dataset = get_dataset(args.data_path,
                                              config.cls_batch_size)

    with strategy.scope():
        model = Classifier(config)

    # Save each run into a directory by its timestamp.
    log_dir = setup_dirs(
        dirs=[args.save_dir],
        dirs_to_tstamp=[args.log_dir],
        config=config,
        file_name=CONFIG,
    )[0]

    trainer = ClassifierTrainer(model, strategy, config=config)
    trainer.train(
        train_dataset,
        test_dataset,
        log_dir=log_dir,
        record_eps=args.record_eps,
        save_dir=args.save_dir,
        save_steps=args.save_steps,
        log_graph=args.log_graph,
    )
Ejemplo n.º 15
0
# verify GPU devices are available and ready
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA
devices = tf.config.list_physical_devices('GPU')
assert len(devices) != 0, "No GPU devices found."

# ------------------------------------------------------------------
# System Configurations
# ------------------------------------------------------------------
if config.MIRROR_STRATEGY:
    strategy = tf.distribute.MirroredStrategy()
    print('Multi-GPU enabled')

if config.MIXED_PRECISION:
    policy = mixed_precision.Policy('mixed_float16')
    mixed_precision.set_global_policy(policy)
    print('Mixed precision enabled')

if config.XLA_ACCELERATE:
    tf.config.optimizer.set_jit(True)
    print('Accelerated Linear Algebra enabled')

# Disable AutoShard, data lives in memory, use in memory options
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = \
    tf.data.experimental.AutoShardPolicy.OFF


# ---------------------------------------------------------------------------
# script train.py
# ---------------------------------------------------------------------------
Ejemplo n.º 16
0
def main(args):

    print(args)

    if not isinstance(args.workers, int):
        args.workers = min(16, mp.cpu_count())

    # AMP
    if args.amp:
        mixed_precision.set_global_policy("mixed_float16")

    input_shape = (args.size, args.size,
                   3) if isinstance(args.size, int) else None

    # Load docTR model
    model = detection.__dict__[args.arch](
        pretrained=isinstance(args.resume, str),
        assume_straight_pages=not args.rotation,
        input_shape=input_shape,
    )

    # Resume weights
    if isinstance(args.resume, str):
        print(f"Resuming {args.resume}")
        model.load_weights(args.resume).expect_partial()

    input_shape = model.cfg[
        "input_shape"] if input_shape is None else input_shape
    mean, std = model.cfg["mean"], model.cfg["std"]

    st = time.time()
    ds = datasets.__dict__[args.dataset](
        train=True,
        download=True,
        use_polygons=args.rotation,
        sample_transforms=T.Resize(input_shape[:2]),
    )
    # Monkeypatch
    subfolder = ds.root.split("/")[-2:]
    ds.root = str(Path(ds.root).parent.parent)
    ds.data = [(os.path.join(*subfolder, name), target)
               for name, target in ds.data]
    _ds = datasets.__dict__[args.dataset](
        train=False,
        download=True,
        use_polygons=args.rotation,
        sample_transforms=T.Resize(input_shape[:2]),
    )
    subfolder = _ds.root.split("/")[-2:]
    ds.data.extend([(os.path.join(*subfolder, name), target)
                    for name, target in _ds.data])

    test_loader = DataLoader(
        ds,
        batch_size=args.batch_size,
        drop_last=False,
        num_workers=args.workers,
        shuffle=False,
    )
    print(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in "
          f"{len(test_loader)} batches)")

    batch_transforms = T.Normalize(mean=mean, std=std)

    # Metrics
    metric = LocalizationConfusion(use_polygons=args.rotation,
                                   mask_shape=input_shape[:2])

    print("Running evaluation")
    val_loss, recall, precision, mean_iou = evaluate(model, test_loader,
                                                     batch_transforms, metric)
    print(
        f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | "
        f"Mean IoU: {mean_iou:.2%})")
Ejemplo n.º 17
0
import os
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Sequential, Model, Input
from tensorflow.keras.layers import Conv2D, Dense, BatchNormalization, Conv2DTranspose, LeakyReLU, InputLayer, Flatten, Reshape, Activation, Dropout
from tensorflow.keras.applications.vgg19 import VGG19
from tensorflow.keras import metrics, backend as K
import time
import random
import cv2 as cv
import datetime
from tensorflow.keras import mixed_precision
from tensorflow.python.ops import math_ops

mixed_precision.set_global_policy('mixed_float16')

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    class CVAE(tf.keras.Model):
        def __init__(self, input_shape, latent_dim = None):
            super(CVAE, self).__init__()
            self.latent_dim = latent_dim
            
            self.encoder = self.encoder_func(input_shape, custom_bottleneck_size = self.latent_dim)
            self.decoder = self.decoder_func(custom_bottleneck_size = latent_dim)
            
        
        def encoder_func(self, input_shape, custom_bottleneck_size = None):
            inputx = Input(shape = input_shape)
            layer = self.convLayer(inputx, 64) #128
Ejemplo n.º 18
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)

    time_str = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
    get_root_logger(
        log_file=os.path.join(args.train_url, f"train_{time_str}.log"))

    if args.fp16:
        from tensorflow.keras import mixed_precision

        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)

    train_optimizer = build_tf_optimizers(cfg.dict["optimizer"])

    mirrored_strategy = tf.distribute.MirroredStrategy()

    with mirrored_strategy.scope():
        model = build_tf_models(cfg.dict["model"])
        model.compile(optimizer=train_optimizer, )

    train_dataset_obj = build_datasets(cfg.dict["data"]["train"])
    val_dataset_obj = build_datasets(cfg.dict["data"]["val"])

    if cfg.dict["dataset_type"] == "TFRecordDataset":
        train_dataset = train_dataset_obj()
        val_dataset = val_dataset_obj()

        num_train_samples = cfg.dict["data"]["num_train_samples"]
        num_val_samples = cfg.dict["data"]["num_val_samples"]
    else:
        train_dataset = train_dataset_obj.get_data_dict()
        val_dataset = val_dataset_obj.get_data_dict()

        num_train_samples = len(train_dataset_obj)
        num_val_samples = len(val_dataset_obj)

    _adjust_batchsize(
        cfg,
        mirrored_strategy.num_replicas_in_sync,
        num_train_samples=num_train_samples,
        num_val_samples=num_val_samples,
    )

    _adjust_lr(cfg, num_replicas=mirrored_strategy.num_replicas_in_sync)
    _adjust_callback(cfg, args.train_url)

    for pipeline in cfg.dict["train_pipeline"]:
        train_dataset = build_tf_pipelines(pipeline)(train_dataset)

    for pipeline in cfg.dict["val_pipeline"]:
        val_dataset = build_tf_pipelines(pipeline)(val_dataset)

    callback_list = []
    for callback in cfg.dict["callbacks"]:
        callback_list.append(build_tf_callbacks(callback))

    model.fit(
        x=train_dataset,
        epochs=cfg.dict["train_cfg"]["epochs"],
        steps_per_epoch=cfg.dict["train_cfg"]["steps_per_epoch"],
        validation_data=val_dataset,
        validation_steps=cfg.dict["train_cfg"]["val_steps"],
        callbacks=callback_list,
        verbose=cfg.dict["train_cfg"]["use_keras_progbar"],
    )
Ejemplo n.º 19
0
    def __init__(self,
                 observation_space,
                 action_space,
                 model_f,
                 m_dir=None,
                 log_name=None,
                 start_step=0,
                 mixed_float=False):
        """
        Parameters
        ----------
        observation_space : gym.Space
            Observation space of the environment.
        action_space : gym.Space
            Action space of the environment. Current agent expects only
            a discrete action space.
        model_f
            A function that returns actor, critic models. 
            It should take obeservation space and action space as inputs.
            It should not compile the model.
        m_dir : str
            A model directory to load the model if there's a model to load
        log_name : str
            A name for log. If not specified, will be set to current time.
            - If m_dir is specified yet no log_name is given, it will continue
            counting.
            - If m_dir and log_name are both specified, it will load model from
            m_dir, but will record as it is the first training.
        start_step : int
            Total step starts from start_step
        mixed_float : bool
            Whether or not to use mixed precision
        """
        # model : The actual training model
        # t_model : Fixed target model
        print('Model directory : {}'.format(m_dir))
        print('Log name : {}'.format(log_name))
        print('Starting from step {}'.format(start_step))
        print(f'Use mixed float? {mixed_float}')
        self.action_space = action_space
        self.action_range = action_space.high - action_space.low
        self.action_shape = action_space.shape
        self.observation_space = observation_space
        self.mixed_float = mixed_float
        if mixed_float:
            policy = mixed_precision.Policy('mixed_float16')
            mixed_precision.set_global_policy(policy)

        assert hp.Algorithm in hp.available_algorithms, "Wrong Algorithm!"

        # Special variables
        if hp.Algorithm == 'V-MPO':

            self.eta = tf.Variable(1.0,
                                   trainable=True,
                                   name='eta',
                                   dtype='float32')
            self.alpha_mu = tf.Variable(1.0,
                                        trainable=True,
                                        name='alpha_mu',
                                        dtype='float32')
            self.alpha_sig = tf.Variable(1.0,
                                         trainable=True,
                                         name='alpha_sig',
                                         dtype='float32')

        elif hp.Algorithm == 'A2C':
            action_num = tf.reduce_prod(self.action_shape)
            self.log_sigma = tf.Variable(tf.fill((action_num), 0.1),
                                         trainable=True,
                                         name='sigma',
                                         dtype='float32')

        #Inputs
        if hp.ICM_ENABLE:
            actor, critic, icm_models = model_f(observation_space,
                                                action_space)
            encoder, inverse, forward = icm_models
            self.models = {
                'actor': actor,
                'critic': critic,
                'encoder': encoder,
                'inverse': inverse,
                'forward': forward,
            }
        else:
            actor, critic = model_f(observation_space, action_space)
            self.models = {
                'actor': actor,
                'critic': critic,
            }
        targets = ['actor', 'critic']

        # Common ADAM optimizer; in V-MPO loss is merged together
        common_lr = tf.function(partial(self._lr, 'common'))
        self.common_optimizer = keras.optimizers.Adam(
            learning_rate=common_lr,
            epsilon=hp.lr['common'].epsilon,
            global_clipnorm=hp.lr['common'].grad_clip,
        )
        if self.mixed_float:
            self.common_optimizer = mixed_precision.LossScaleOptimizer(
                self.common_optimizer)

        for name, model in self.models.items():
            lr = tf.function(partial(self._lr, name))
            optimizer = keras.optimizers.Adam(
                learning_rate=lr,
                epsilon=hp.lr[name].epsilon,
                global_clipnorm=hp.lr[name].grad_clip,
            )
            if self.mixed_float:
                optimizer = mixed_precision.LossScaleOptimizer(optimizer)
            model.compile(optimizer=optimizer)
            model.summary()

        # Load model if specified
        if m_dir is not None:
            for name, model in self.models.items():
                model.load_weights(path.join(m_dir, name))
            print(f'model loaded : {m_dir}')

        # Initialize target model
        self.t_models = {}
        for name in targets:
            model = self.models[name]
            self.t_models[name] = keras.models.clone_model(model)
            self.t_models[name].set_weights(model.get_weights())

        # File writer for tensorboard
        if log_name is None:
            self.log_name = datetime.now().strftime('%m_%d_%H_%M_%S')
        else:
            self.log_name = log_name
        self.file_writer = tf.summary.create_file_writer(
            path.join('logs', self.log_name))
        self.file_writer.set_as_default()
        print('Writing logs at logs/' + self.log_name)

        # Scalars
        self.start_training = False
        self.total_steps = tf.Variable(start_step, dtype=tf.int64)

        # Savefile folder directory
        if m_dir is None:
            self.save_dir = path.join('savefiles', self.log_name)
            self.save_count = 0
        else:
            if log_name is None:
                self.save_dir, self.save_count = path.split(m_dir)
                self.save_count = int(self.save_count)
            else:
                self.save_dir = path.join('savefiles', self.log_name)
                self.save_count = 0
        self.model_dir = None
Ejemplo n.º 20
0
def main(args):

    print(args)

    if args.push_to_hub:
        login_to_hub()

    if not isinstance(args.workers, int):
        args.workers = min(16, mp.cpu_count())

    # AMP
    if args.amp:
        mixed_precision.set_global_policy("mixed_float16")

    st = time.time()
    val_set = DetectionDataset(
        img_folder=os.path.join(args.val_path, "images"),
        label_path=os.path.join(args.val_path, "labels.json"),
        sample_transforms=T.SampleCompose(([
            T.Resize((args.input_size, args.input_size),
                     preserve_aspect_ratio=True,
                     symmetric_pad=True)
        ] if not args.rotation or args.eval_straight else []) + ([
            T.Resize(args.input_size, preserve_aspect_ratio=True
                     ),  # This does not pad
            T.RandomRotate(90, expand=True),
            T.Resize((args.input_size, args.input_size),
                     preserve_aspect_ratio=True,
                     symmetric_pad=True),
        ] if args.rotation and not args.eval_straight else [])),
        use_polygons=args.rotation and not args.eval_straight,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.workers,
    )
    print(
        f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
        f"{val_loader.num_batches} batches)")
    with open(os.path.join(args.val_path, "labels.json"), "rb") as f:
        val_hash = hashlib.sha256(f.read()).hexdigest()

    batch_transforms = T.Compose([
        T.Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)),
    ])

    # Load doctr model
    model = detection.__dict__[args.arch](
        pretrained=args.pretrained,
        input_shape=(args.input_size, args.input_size, 3),
        assume_straight_pages=not args.rotation,
    )

    # Resume weights
    if isinstance(args.resume, str):
        model.load_weights(args.resume)

    # Metrics
    val_metric = LocalizationConfusion(use_polygons=args.rotation
                                       and not args.eval_straight,
                                       mask_shape=(args.input_size,
                                                   args.input_size))
    if args.test_only:
        print("Running evaluation")
        val_loss, recall, precision, mean_iou = evaluate(
            model, val_loader, batch_transforms, val_metric)
        print(
            f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | "
            f"Mean IoU: {mean_iou:.2%})")
        return

    st = time.time()
    # Load both train and val data generators
    train_set = DetectionDataset(
        img_folder=os.path.join(args.train_path, "images"),
        label_path=os.path.join(args.train_path, "labels.json"),
        img_transforms=T.Compose([
            # Augmentations
            T.RandomApply(T.ColorInversion(), 0.1),
            T.RandomJpegQuality(60),
            T.RandomSaturation(0.3),
            T.RandomContrast(0.3),
            T.RandomBrightness(0.3),
        ]),
        sample_transforms=T.SampleCompose(([
            T.Resize((args.input_size, args.input_size),
                     preserve_aspect_ratio=True,
                     symmetric_pad=True)
        ] if not args.rotation else []) + ([
            T.Resize(args.input_size, preserve_aspect_ratio=True
                     ),  # This does not pad
            T.RandomRotate(90, expand=True),
            T.Resize((args.input_size, args.input_size),
                     preserve_aspect_ratio=True,
                     symmetric_pad=True),
        ] if args.rotation else [])),
        use_polygons=args.rotation,
    )
    train_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
    )
    print(
        f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
        f"{train_loader.num_batches} batches)")
    with open(os.path.join(args.train_path, "labels.json"), "rb") as f:
        train_hash = hashlib.sha256(f.read()).hexdigest()

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target)
        return

    # Optimizer
    scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
        args.lr,
        decay_steps=args.epochs * len(train_loader),
        decay_rate=1 / (25e4),  # final lr as a fraction of initial lr
        staircase=False,
    )
    optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler,
                                         beta_1=0.95,
                                         beta_2=0.99,
                                         epsilon=1e-6,
                                         clipnorm=5)
    if args.amp:
        optimizer = mixed_precision.LossScaleOptimizer(optimizer)
    # LR Finder
    if args.find_lr:
        lrs, losses = record_lr(model,
                                train_loader,
                                batch_transforms,
                                optimizer,
                                amp=args.amp)
        plot_recorder(lrs, losses)
        return

    # Tensorboard to monitor training
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(
            name=exp_name,
            project="text-detection",
            config={
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "weight_decay": 0.0,
                "batch_size": args.batch_size,
                "architecture": args.arch,
                "input_size": args.input_size,
                "optimizer": "adam",
                "framework": "tensorflow",
                "scheduler": "exp_decay",
                "train_hash": train_hash,
                "val_hash": val_hash,
                "pretrained": args.pretrained,
                "rotation": args.rotation,
            },
        )

    if args.freeze_backbone:
        for layer in model.feat_extractor.layers:
            layer.trainable = False

    min_loss = np.inf

    # Training loop
    mb = master_bar(range(args.epochs))
    for epoch in mb:
        fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb,
                      args.amp)
        # Validation loop at the end of each epoch
        val_loss, recall, precision, mean_iou = evaluate(
            model, val_loader, batch_transforms, val_metric)
        if val_loss < min_loss:
            print(
                f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state..."
            )
            model.save_weights(f"./{exp_name}/weights")
            min_loss = val_loss
        log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
        if any(val is None for val in (recall, precision, mean_iou)):
            log_msg += "(Undefined metric value, caused by empty GTs or predictions)"
        else:
            log_msg += f"(Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%})"
        mb.write(log_msg)
        # W&B
        if args.wb:
            wandb.log({
                "val_loss": val_loss,
                "recall": recall,
                "precision": precision,
                "mean_iou": mean_iou,
            })

    if args.wb:
        run.finish()

    if args.push_to_hub:
        push_to_hf_hub(model, exp_name, task="detection", run_config=args)
Ejemplo n.º 21
0
def main(args):

    print(args)

    if args.push_to_hub:
        login_to_hub()

    if not isinstance(args.workers, int):
        args.workers = min(16, mp.cpu_count())

    vocab = VOCABS[args.vocab]
    fonts = args.font.split(",")

    # AMP
    if args.amp:
        mixed_precision.set_global_policy("mixed_float16")

    st = time.time()

    if isinstance(args.val_path, str):
        with open(os.path.join(args.val_path, "labels.json"), "rb") as f:
            val_hash = hashlib.sha256(f.read()).hexdigest()

        # Load val data generator
        val_set = RecognitionDataset(
            img_folder=os.path.join(args.val_path, "images"),
            labels_path=os.path.join(args.val_path, "labels.json"),
            img_transforms=T.Resize((args.input_size, 4 * args.input_size),
                                    preserve_aspect_ratio=True),
        )
    else:
        val_hash = None
        # Load synthetic data generator
        val_set = WordGenerator(
            vocab=vocab,
            min_chars=args.min_chars,
            max_chars=args.max_chars,
            num_samples=args.val_samples * len(vocab),
            font_family=fonts,
            img_transforms=T.Compose([
                T.Resize((args.input_size, 4 * args.input_size),
                         preserve_aspect_ratio=True),
                # Ensure we have a 90% split of white-background images
                T.RandomApply(T.ColorInversion(), 0.9),
            ]),
        )

    val_loader = DataLoader(
        val_set,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.workers,
    )
    print(
        f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
        f"{val_loader.num_batches} batches)")

    # Load doctr model
    model = recognition.__dict__[args.arch](
        pretrained=args.pretrained,
        input_shape=(args.input_size, 4 * args.input_size, 3),
        vocab=vocab,
    )
    # Resume weights
    if isinstance(args.resume, str):
        model.load_weights(args.resume)

    # Metrics
    val_metric = TextMatch()

    batch_transforms = T.Compose([
        T.Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)),
    ])

    if args.test_only:
        print("Running evaluation")
        val_loss, exact_match, partial_match = evaluate(
            model, val_loader, batch_transforms, val_metric)
        print(
            f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})"
        )
        return

    st = time.time()

    if isinstance(args.train_path, str):
        # Load train data generator
        base_path = Path(args.train_path)
        parts = ([base_path]
                 if base_path.joinpath("labels.json").is_file() else
                 [base_path.joinpath(sub) for sub in os.listdir(base_path)])
        with open(parts[0].joinpath("labels.json"), "rb") as f:
            train_hash = hashlib.sha256(f.read()).hexdigest()

        train_set = RecognitionDataset(
            parts[0].joinpath("images"),
            parts[0].joinpath("labels.json"),
            img_transforms=T.Compose([
                T.RandomApply(T.ColorInversion(), 0.1),
                T.Resize((args.input_size, 4 * args.input_size),
                         preserve_aspect_ratio=True),
                # Augmentations
                T.RandomJpegQuality(60),
                T.RandomSaturation(0.3),
                T.RandomContrast(0.3),
                T.RandomBrightness(0.3),
            ]),
        )
        if len(parts) > 1:
            for subfolder in parts[1:]:
                train_set.merge_dataset(
                    RecognitionDataset(subfolder.joinpath("images"),
                                       subfolder.joinpath("labels.json")))
    else:
        train_hash = None
        # Load synthetic data generator
        train_set = WordGenerator(
            vocab=vocab,
            min_chars=args.min_chars,
            max_chars=args.max_chars,
            num_samples=args.train_samples * len(vocab),
            font_family=fonts,
            img_transforms=T.Compose([
                T.Resize((args.input_size, 4 * args.input_size),
                         preserve_aspect_ratio=True),
                # Ensure we have a 90% split of white-background images
                T.RandomApply(T.ColorInversion(), 0.9),
                T.RandomJpegQuality(60),
                T.RandomSaturation(0.3),
                T.RandomContrast(0.3),
                T.RandomBrightness(0.3),
            ]),
        )

    train_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
    )
    print(
        f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
        f"{train_loader.num_batches} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target)
        return

    # Optimizer
    scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
        args.lr,
        decay_steps=args.epochs * len(train_loader),
        decay_rate=1 / (25e4),  # final lr as a fraction of initial lr
        staircase=False,
    )
    optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler,
                                         beta_1=0.95,
                                         beta_2=0.99,
                                         epsilon=1e-6,
                                         clipnorm=5)
    if args.amp:
        optimizer = mixed_precision.LossScaleOptimizer(optimizer)
    # LR Finder
    if args.find_lr:
        lrs, losses = record_lr(model,
                                train_loader,
                                batch_transforms,
                                optimizer,
                                amp=args.amp)
        plot_recorder(lrs, losses)
        return

    # Tensorboard to monitor training
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(
            name=exp_name,
            project="text-recognition",
            config={
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "weight_decay": 0.0,
                "batch_size": args.batch_size,
                "architecture": args.arch,
                "input_size": args.input_size,
                "optimizer": "adam",
                "framework": "tensorflow",
                "scheduler": "exp_decay",
                "vocab": args.vocab,
                "train_hash": train_hash,
                "val_hash": val_hash,
                "pretrained": args.pretrained,
            },
        )

    min_loss = np.inf

    # Training loop
    mb = master_bar(range(args.epochs))
    for epoch in mb:
        fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb,
                      args.amp)

        # Validation loop at the end of each epoch
        val_loss, exact_match, partial_match = evaluate(
            model, val_loader, batch_transforms, val_metric)
        if val_loss < min_loss:
            print(
                f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state..."
            )
            model.save_weights(f"./{exp_name}/weights")
            min_loss = val_loss
        mb.write(
            f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
            f"(Exact: {exact_match:.2%} | Partial: {partial_match:.2%})")
        # W&B
        if args.wb:
            wandb.log({
                "val_loss": val_loss,
                "exact_match": exact_match,
                "partial_match": partial_match,
            })

    if args.wb:
        run.finish()

    if args.push_to_hub:
        push_to_hf_hub(model, exp_name, task="recognition", run_config=args)
Ejemplo n.º 22
0
def main(args):

    print(args)

    if args.push_to_hub:
        login_to_hub()

    if not isinstance(args.workers, int):
        args.workers = min(16, mp.cpu_count())

    vocab = VOCABS[args.vocab]

    fonts = args.font.split(",")

    # AMP
    if args.amp:
        mixed_precision.set_global_policy("mixed_float16")

    # Load val data generator
    st = time.time()
    val_set = CharacterGenerator(
        vocab=vocab,
        num_samples=args.val_samples * len(vocab),
        cache_samples=True,
        img_transforms=T.Compose(
            [
                T.Resize((args.input_size, args.input_size)),
                # Ensure we have a 90% split of white-background images
                T.RandomApply(T.ColorInversion(), 0.9),
            ]
        ),
        font_family=fonts,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.workers,
        collate_fn=collate_fn,
    )
    print(
        f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
        f"{val_loader.num_batches} batches)"
    )

    # Load doctr model
    model = classification.__dict__[args.arch](
        pretrained=args.pretrained,
        input_shape=(args.input_size, args.input_size, 3),
        num_classes=len(vocab),
        classes=list(vocab),
        include_top=True,
    )

    # Resume weights
    if isinstance(args.resume, str):
        model.load_weights(args.resume)

    batch_transforms = T.Compose(
        [
            T.Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)),
        ]
    )

    if args.test_only:
        print("Running evaluation")
        val_loss, acc = evaluate(model, val_loader, batch_transforms)
        print(f"Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
        return

    st = time.time()

    # Load train data generator
    train_set = CharacterGenerator(
        vocab=vocab,
        num_samples=args.train_samples * len(vocab),
        cache_samples=True,
        img_transforms=T.Compose(
            [
                T.Resize((args.input_size, args.input_size)),
                # Augmentations
                T.RandomApply(T.ColorInversion(), 0.9),
                T.RandomApply(T.ToGray(3), 0.1),
                T.RandomJpegQuality(60),
                T.RandomSaturation(0.3),
                T.RandomContrast(0.3),
                T.RandomBrightness(0.3),
                # Blur
                T.RandomApply(T.GaussianBlur(kernel_shape=(3, 3), std=(0.1, 3)), 0.3),
            ]
        ),
        font_family=fonts,
    )
    train_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
        collate_fn=collate_fn,
    )
    print(
        f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
        f"{train_loader.num_batches} batches)"
    )

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, list(map(vocab.__getitem__, target)))
        return

    # Optimizer
    scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
        args.lr,
        decay_steps=args.epochs * len(train_loader),
        decay_rate=1 / (1e3),  # final lr as a fraction of initial lr
        staircase=False,
    )
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=scheduler,
        beta_1=0.95,
        beta_2=0.99,
        epsilon=1e-6,
    )
    if args.amp:
        optimizer = mixed_precision.LossScaleOptimizer(optimizer)

    # LR Finder
    if args.find_lr:
        lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp)
        plot_recorder(lrs, losses)
        return

    # Tensorboard to monitor training
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(
            name=exp_name,
            project="character-classification",
            config={
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "weight_decay": 0.0,
                "batch_size": args.batch_size,
                "architecture": args.arch,
                "input_size": args.input_size,
                "optimizer": "adam",
                "framework": "tensorflow",
                "vocab": args.vocab,
                "scheduler": "exp_decay",
                "pretrained": args.pretrained,
            },
        )

    # Create loss queue
    min_loss = np.inf

    # Training loop
    mb = master_bar(range(args.epochs))
    for epoch in mb:
        fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, args.amp)

        # Validation loop at the end of each epoch
        val_loss, acc = evaluate(model, val_loader, batch_transforms)
        if val_loss < min_loss:
            print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
            model.save_weights(f"./{exp_name}/weights")
            min_loss = val_loss
        mb.write(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
        # W&B
        if args.wb:
            wandb.log(
                {
                    "val_loss": val_loss,
                    "acc": acc,
                }
            )

    if args.wb:
        run.finish()

    if args.push_to_hub:
        push_to_hf_hub(model, exp_name, task="classification", run_config=args)

    if args.export_onnx:
        print("Exporting model to ONNX...")
        dummy_input = [tf.TensorSpec([None, args.input_size, args.input_size, 3], tf.float32, name="input")]
        model_path, _ = export_model_to_onnx(model, exp_name, dummy_input)
        print(f"Exported model saved in {model_path}")
Ejemplo n.º 23
0
def compute_validation_loss(config, train_dir, weights):
    """Evaluate the trained model in train_dir"""
    if config is None:
        config = Path(train_dir) / "config.yaml"
        assert config.exists(
        ), "Could not find config file in train_dir, please provide one with -c <path/to/config>"
    config, _ = parse_config(config, weights=weights)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
    else:
        model_dtype = tf.dtypes.float32

    strategy, num_gpus = get_strategy()
    ds_test, num_test_steps = get_datasets(config["train_test_datasets"],
                                           config, num_gpus, "test")

    with strategy.scope():
        model = make_model(config, model_dtype)
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        # need to load the weights in the same trainable configuration as the model was set up
        configure_model_weights(model,
                                config["setup"].get("weights_config", "all"))
        if weights:
            model.load_weights(weights, by_name=True)
        else:
            weights = get_best_checkpoint(train_dir)
            print("Loading best weights that could be found from {}".format(
                weights))
            model.load_weights(weights, by_name=True)

        loss_dict, loss_weights = get_loss_dict(config)
        model.compile(
            loss=loss_dict,
            # sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ] + [
                    SingleClassRecall(
                        icls, name="rec_cls{}".format(icls), dtype=tf.float64)
                    for icls in range(config["dataset"]["num_output_classes"])
                ]
            },
        )

        losses = model.evaluate(
            x=ds_test,
            steps=num_test_steps,
            return_dict=True,
        )
    with open("{}/losses.txt".format(train_dir), "w") as loss_file:
        loss_file.write(json.dumps(losses) + "\n")
Ejemplo n.º 24
0
def run(args):
    split_on = "none" if args.split_on is (None or "none") else args.split_on


    # Create project dir (if it doesn't exist)
    import ipdb; ipdb.set_trace()
    prjdir = cfg.MAIN_PRJDIR/args.prjname
    os.makedirs(prjdir, exist_ok=True)


    # Create outdir (using the loaded hyperparamters) or
    # use content (model) from an existing run
    fea_strs = ["use_tile"]
    args_dict = vars(args)
    fea_names = "_".join([k.split("use_")[-1] for k in fea_strs if args_dict[k] is True])
    prm_file_path = prjdir/f"params_{fea_names}.json"
    if prm_file_path.exists() is False:
        shutil.copy(fdir/f"../default_params/default_params_{fea_names}.json", prm_file_path)
    params = Params(prm_file_path)

    if args.rundir is not None:
        outdir = Path(args.rundir).resolve()
        assert outdir.exists(), f"The {outdir} doen't exist."
        print_fn = print
    else:
        outdir = create_outdir_2(prjdir, args)

        # Save hyper-parameters
        params.save(outdir/"params.json")

        # Logger
        lg = Logger(outdir/"logger.log")
        print_fn = get_print_func(lg.logger)
        print_fn(f"File path: {fdir}")
        print_fn(f"\n{pformat(vars(args))}")


    # Load dataframe (annotations)
    annotations_file = cfg.DATA_PROCESSED_DIR/args.dataname/cfg.SF_ANNOTATIONS_FILENAME
    dtype = {"image_id": str, "slide": str}
    data = pd.read_csv(annotations_file, dtype=dtype, engine="c", na_values=["na", "NaN"], low_memory=True)
    # data = data.astype({"image_id": str, "slide": str})
    print_fn(data.shape)


    print_fn("\nFull dataset:")
    if args.target[0] == "Response":
        print_groupby_stat_rsp(data, split_on="Group", print_fn=print_fn)
    else:
        print_groupby_stat_ctype(data, split_on="Group", print_fn=print_fn)


    # Determine tfr_dir (the path to TFRecords)
    tfr_dir = (cfg.DATADIR/args.tfr_dir_name).resolve()
    pred_tfr_dir = (cfg.DATADIR/args.pred_tfr_dir_name).resolve()
    label = f"{params.tile_px}px_{params.tile_um}um"
    tfr_dir = tfr_dir/label
    pred_tfr_dir = pred_tfr_dir/label

    # Create outcomes (for drug response)
    # outcomes = {}
    # unique_outcomes = list(set(data[args.target[0]].values))
    # unique_outcomes.sort()
    # for smp, o in zip(data[args.id_name], data[args.target[0]]):
    #     outcomes[smp] = {"outcome": unique_outcomes.index(o)}


    # Scalers for each feature set
    # import ipdb; ipdb.set_trace()
    ge_scaler, dd1_scaler, dd2_scaler = None, None, None

    ge_cols  = [c for c in data.columns if c.startswith("ge_")]
    dd1_cols = [c for c in data.columns if c.startswith("dd1_")]
    dd2_cols = [c for c in data.columns if c.startswith("dd2_")]

    if args.scale_fea:
        if args.use_ge and len(ge_cols) > 0:
            ge_scaler = get_scaler(data[ge_cols])
        if args.use_dd1 and len(dd1_cols) > 0:
            dd1_scaler = get_scaler(data[dd1_cols])
        if args.use_dd2 and len(dd2_cols) > 0:
            dd2_scaler = get_scaler(data[dd2_cols])


    # Create manifest
    # print_fn("\nCreate/load manifest ...")
    # timer = Timer()
    # manifest = create_manifest(directory=tfr_dir, n_files=None)
    # timer.display_timer(print_fn)


    # -----------------------------------------------
    # Data splits
    # -----------------------------------------------

    # --------------
    # Yitan's splits
    # --------------
    if args.target[0] == "Response":
        if args.use_dd1 is False and args.use_dd2 is False:
            splitdir = cfg.DATADIR/"PDX_Transfer_Learning_Classification/Processed_Data/Data_For_MultiModal_Learning/Data_Partition_Drug_Specific"
            splitdir = splitdir/params.drug_specific
        else:
            splitdir = cfg.DATADIR/"PDX_Transfer_Learning_Classification/Processed_Data/Data_For_MultiModal_Learning/Data_Partition"
    else:
        splitdir = cfg.DATADIR/"PDX_Transfer_Learning_Classification/Processed_Data/Data_For_MultiModal_Learning/Data_Partition"

    tr_id = cast_list(read_lines(str(splitdir/f"cv_{args.split_id}"/"TrainList.txt")), int)
    vl_id = cast_list(read_lines(str(splitdir/f"cv_{args.split_id}"/"ValList.txt")), int)
    te_id = cast_list(read_lines(str(splitdir/f"cv_{args.split_id}"/"TestList.txt")), int)

    # Update ids
    index_col_name = "index"
    tr_id = sorted(set(data[index_col_name]).intersection(set(tr_id)))
    vl_id = sorted(set(data[index_col_name]).intersection(set(vl_id)))
    te_id = sorted(set(data[index_col_name]).intersection(set(te_id)))

    # Subsample train samples
    if args.n_samples > 0:
        if args.n_samples < len(tr_id):
            tr_id = tr_id[:args.n_samples]
        if args.n_samples < len(vl_id):
            vl_id = vl_id[:args.n_samples]
        if args.n_samples < len(te_id):
            te_id = te_id[:args.n_samples]

    
    ### ap --------------
    # Drop slide duplicates
    ###
    fea_columns = ["slide"]
    data = data.drop_duplicates(subset=fea_columns)
    ### ap --------------

    # --------------
    # TidyData
    # --------------
    # TODO: finish and test this class
    # td = TidyData(data,
    #               ge_prfx="ge_",
    #               dd1_prfx="dd1_",
    #               dd2_prfx="dd2_",
    #               index_col_name="index",
    #               split_ids={"tr_id": tr_id, "vl_id": vl_id, "te_id": te_id}
    # )
    # ge_scaler = td.ge_scaler
    # dd1_scaler = td.dd1_scaler
    # dd2_scaler = td.dd2_scaler

    # tr_meta = td.tr_meta
    # vl_meta = td.vl_meta
    # te_meta = td.te_meta
    # tr_meta.to_csv(outdir/"tr_meta.csv", index=False)
    # vl_meta.to_csv(outdir/"vl_meta.csv", index=False)
    # te_meta.to_csv(outdir/"te_meta.csv", index=False)

    # # Variables (dict/dataframes/arrays) that are passed as features to the NN
    # xtr = {"ge_data": td.tr_ge.values, "dd1_data": td.tr_dd1.values, "dd2_data": td.tr_dd2.values}
    # xvl = {"ge_data": td.vl_ge.values, "dd1_data": td.vl_dd1.values, "dd2_data": td.vl_dd2.values}
    # xte = {"ge_data": td.te_ge.values, "dd1_data": td.te_dd1.values, "dd2_data": td.te_dd2.values}

    # --------------
    # w/o TidyData
    # --------------
    kwargs = {"ge_cols": ge_cols,
              "dd1_cols": dd1_cols,
              "dd2_cols": dd2_cols,
              "ge_scaler": ge_scaler,
              "dd1_scaler": dd1_scaler,
              "dd2_scaler": dd2_scaler,
              "ge_dtype": cfg.GE_DTYPE,
              "dd_dtype": cfg.DD_DTYPE,
              "index_col_name": index_col_name,
              "split_on": split_on
              }
    tr_ge, tr_dd1, tr_dd2, tr_meta = split_data_and_extract_fea(data, ids=tr_id, **kwargs)
    vl_ge, vl_dd1, vl_dd2, vl_meta = split_data_and_extract_fea(data, ids=vl_id, **kwargs)
    te_ge, te_dd1, te_dd2, te_meta = split_data_and_extract_fea(data, ids=te_id, **kwargs)

    ### ap --------------
    # Create annotations for slideflow
    ###
    # import ipdb; ipdb.set_trace()
    tr_meta["submitter_id"] = tr_meta["Group"]  # submitter_id (specific patient); Group (specific treatment group)
    vl_meta["submitter_id"] = vl_meta["Group"]
    te_meta["submitter_id"] = te_meta["Group"]
    tr_meta["training_phase"] = "train"
    vl_meta["training_phase"] = "validation"
    te_meta["training_phase"] = "test"
    keep_cols = ["submitter_id", "slide", "model", "patient_id", "specimen_id", "sample_id",
                 "training_phase", "Group", "ctype", "csite", "ctype_label", "csite_label"]
    tr_meta_tmp = tr_meta[keep_cols]
    vl_meta_tmp = vl_meta[keep_cols]
    te_meta_tmp = te_meta[keep_cols]
    tr_meta.to_csv(outdir/"train_annotations.csv", index=False)
    vl_meta.to_csv(outdir/"validation_annotations.csv", index=False)
    te_meta.to_csv(outdir/"test_annotations.csv", index=False)
    sf_df = pd.concat([tr_meta_tmp, vl_meta_tmp, te_meta_tmp], axis=0)
    sf_df.to_csv(outdir/"annotations_for_sf.csv", index=False)
    del tr_meta_tmp, vl_meta_tmp, te_meta_tmp, sf_df
    ### ap --------------

    if args.train is True:
        tr_meta.to_csv(outdir/"tr_meta.csv", index=False)
        vl_meta.to_csv(outdir/"vl_meta.csv", index=False)
        te_meta.to_csv(outdir/"te_meta.csv", index=False)

    ge_shape = (tr_ge.shape[1],)
    dd_shape = (tr_dd1.shape[1],)

    if args.target[0] == "Response":
        print_fn("\nTrain:")
        print_groupby_stat_rsp(tr_meta, split_on="Group", print_fn=print_fn)
        print_fn("\nValidation:")
        print_groupby_stat_rsp(vl_meta, split_on="Group", print_fn=print_fn)
        print_fn("\nTest:")
        print_groupby_stat_rsp(te_meta, split_on="Group", print_fn=print_fn)
    else:
        print_fn("\nTrain:")
        print_groupby_stat_ctype(tr_meta, split_on="Group", print_fn=print_fn)
        print_fn("\nValidation:")
        print_groupby_stat_ctype(vl_meta, split_on="Group", print_fn=print_fn)
        print_fn("\nTest:")
        print_groupby_stat_ctype(te_meta, split_on="Group", print_fn=print_fn)

    # Make sure indices do not overlap
    assert len( set(tr_id).intersection(set(vl_id)) ) == 0, "Overlapping indices btw tr and vl"
    assert len( set(tr_id).intersection(set(te_id)) ) == 0, "Overlapping indices btw tr and te"
    assert len( set(vl_id).intersection(set(te_id)) ) == 0, "Overlapping indices btw vl and te"

    # Print split ratios
    print_fn("")
    print_fn("Train samples {} ({:.2f}%)".format( tr_meta.shape[0], 100*tr_meta.shape[0]/data.shape[0] ))
    print_fn("Val   samples {} ({:.2f}%)".format( vl_meta.shape[0], 100*vl_meta.shape[0]/data.shape[0] ))
    print_fn("Test  samples {} ({:.2f}%)".format( te_meta.shape[0], 100*te_meta.shape[0]/data.shape[0] ))

    tr_grp_unq = set(tr_meta[split_on].values)
    vl_grp_unq = set(vl_meta[split_on].values)
    te_grp_unq = set(te_meta[split_on].values)
    print_fn("")
    print_fn(f"Total intersects on {split_on} btw tr and vl: {len(tr_grp_unq.intersection(vl_grp_unq))}")
    print_fn(f"Total intersects on {split_on} btw tr and te: {len(tr_grp_unq.intersection(te_grp_unq))}")
    print_fn(f"Total intersects on {split_on} btw vl and te: {len(vl_grp_unq.intersection(te_grp_unq))}")
    print_fn(f"Unique {split_on} in tr: {len(tr_grp_unq)}")
    print_fn(f"Unique {split_on} in vl: {len(vl_grp_unq)}")
    print_fn(f"Unique {split_on} in te: {len(te_grp_unq)}")


    # --------------------------
    # Obtain T/V/E tfr filenames
    # --------------------------
    # List of sample names for T/V/E
    tr_smp_names = list(tr_meta[args.id_name].values)
    vl_smp_names = list(vl_meta[args.id_name].values)
    te_smp_names = list(te_meta[args.id_name].values)

    # TFRecords filenames
    train_tfr_files = get_tfr_files(tfr_dir, tr_smp_names)
    val_tfr_files = get_tfr_files(tfr_dir, vl_smp_names)
    if args.eval is True:
        assert pred_tfr_dir.exists(), f"Dir {pred_tfr_dir} is not found."
        # test_tfr_files = get_tfr_files(tfr_dir, te_smp_names)  # use same tfr_dir for eval
        test_tfr_files = get_tfr_files(pred_tfr_dir, te_smp_names)
        # print_fn("Total samples {}".format(len(train_tfr_files) + len(val_tfr_files) + len(test_tfr_files)))

    # Missing tfrecords
    print("\nThese samples miss a tfrecord:")
    df_miss = data.loc[~data[args.id_name].isin(tr_smp_names + vl_smp_names + te_smp_names), ["smp", "image_id"]]
    print(df_miss)

    assert sorted(tr_smp_names) == sorted(tr_meta[args.id_name].values.tolist()), "Sample names in the tr_smp_names and tr_meta don't match."
    assert sorted(vl_smp_names) == sorted(vl_meta[args.id_name].values.tolist()), "Sample names in the vl_smp_names and vl_meta don't match."
    assert sorted(te_smp_names) == sorted(te_meta[args.id_name].values.tolist()), "Sample names in the te_smp_names and te_meta don't match."


    # -------------------------------
    # Class weight
    # -------------------------------
    tile_cnts = pd.read_csv(tfr_dir/"tile_counts_per_slide.csv")
    tile_cnts.insert(loc=0, column="tfr_abs_fname", value=tile_cnts["tfr_fname"].map(lambda s: str(tfr_dir/s)))
    cat = tile_cnts[tile_cnts["tfr_abs_fname"].isin(train_tfr_files)]

    ### ap --------------
    # if args.target[0] not in cat.columns:
    #     tile_cnts = tile_cnts[tile_cnts["smp"].isin(tr_meta["smp"])]
    df = tr_meta[["smp", args.target[0]]]
    cat = cat.merge(df, on="smp", how="inner")
    ### ap --------------

    cat = cat.groupby(args.target[0]).agg({"smp": "nunique", "max_tiles": "sum", "n_tiles": "sum", "slide": "nunique"}).reset_index()
    categories = {}
    for i, row_data in cat.iterrows():
        dct = {"num_samples": row_data["smp"], "num_tiles": row_data["n_tiles"]}
        categories[row_data[args.target[0]]] = dct

    class_weight = calc_class_weights(train_tfr_files,
                                      class_weights_method=params.class_weights_method,
                                      categories=categories)
    # class_weight = {"Response": class_weight}


    # --------------------------
    # Build tf.data objects
    # --------------------------
    tf.keras.backend.clear_session()

    # import ipdb; ipdb.set_trace()
    if args.use_tile:

        # -------------------------------
        # Parsing funcs
        # -------------------------------
        # import ipdb; ipdb.set_trace()
        if args.target[0] == "Response":
            # Response
            parse_fn = parse_tfrec_fn_rsp
            parse_fn_train_kwargs = {
                "use_tile": args.use_tile,
                "use_ge": args.use_ge,
                "use_dd1": args.use_dd1,
                "use_dd2": args.use_dd2,
                "ge_scaler": ge_scaler,
                "dd1_scaler": dd1_scaler,
                "dd2_scaler": dd2_scaler,
                "id_name": args.id_name,
                "augment": params.augment,
                "application": params.base_image_model,
                # "application": None,
            }
        else:
            # Ctype
            parse_fn = parse_tfrec_fn_ctype
            parse_fn_train_kwargs = {
                "use_tile": args.use_tile,
                "use_ge": args.use_ge,
                "ge_scaler": ge_scaler,
                "id_name": args.id_name,
                "augment": params.augment,
                "target": args.target[0]
            }

        parse_fn_non_train_kwargs = parse_fn_train_kwargs.copy()
        parse_fn_non_train_kwargs["augment"] = False

        # ----------------------------------------
        # Number of tiles/examples in each dataset
        # ----------------------------------------
        # import ipdb; ipdb.set_trace()
        tr_tiles = tile_cnts[tile_cnts[args.id_name].isin(tr_smp_names)]["n_tiles"].sum()
        vl_tiles = tile_cnts[tile_cnts[args.id_name].isin(vl_smp_names)]["n_tiles"].sum()
        te_tiles = tile_cnts[tile_cnts[args.id_name].isin(te_smp_names)]["n_tiles"].sum()

        eval_batch_size = 4 * params.batch_size
        tr_steps = tr_tiles // params.batch_size
        vl_steps = vl_tiles // eval_batch_size
        # te_steps = te_tiles // eval_batch_size

        # -------------------------------
        # Create TF datasets
        # -------------------------------
        print("\nCreating TF datasets.")

        # Training
        # import ipdb; ipdb.set_trace()
        train_data = create_tf_data(
            batch_size=params.batch_size,
            deterministic=False,
            include_meta=False,
            interleave=True,
            n_concurrent_shards=params.n_concurrent_shards,  # 32, 64
            parse_fn=parse_fn,
            prefetch=1,  # 2
            repeat=True,
            seed=None,  # cfg.seed,
            shuffle_files=True,
            shuffle_size=params.shuffle_size,  # 8192
            tfrecords=train_tfr_files,
            **parse_fn_train_kwargs)

        # Determine feature shapes from data
        bb = next(train_data.__iter__())

        # Infer dims of features from the data
        # import ipdb; ipdb.set_trace()
        if args.use_ge:
            ge_shape = bb[0]["ge_data"].numpy().shape[1:]
        else:
            ge_shape = None

        if args.use_dd1:
            dd_shape = bb[0]["dd1_data"].numpy().shape[1:]
        else:
            dd_shape = None

        # Print keys and dims
        for i, item in enumerate(bb):
            print(f"\nItem {i}")
            if isinstance(item, dict):
                for k in item.keys():
                    print(f"\t{k}: {item[k].numpy().shape}")
            elif isinstance(item.numpy(), np.ndarray):
                print(item)

        # for i, rec in enumerate(train_data.take(2)):
        #     tf.print(rec[1])

        # Evaluation (val, test, train)
        create_tf_data_eval_kwargs = {
            "batch_size": eval_batch_size,
            "include_meta": False,
            "interleave": False,
            "parse_fn": parse_fn,
            "prefetch": None,  # 2
            "repeat": False,
            "seed": None,
            "shuffle_files": False,
            "shuffle_size": None,
        }

        # import ipdb; ipdb.set_trace()
        create_tf_data_eval_kwargs.update({"tfrecords": val_tfr_files, "include_meta": False})
        val_data = create_tf_data(
            **create_tf_data_eval_kwargs,
            **parse_fn_non_train_kwargs
        )

    # ----------------------
    # Prep for training
    # ----------------------
    # import ipdb; ipdb.set_trace()

    # # Loss and target
    # if args.use_tile:
    #     loss = losses.BinaryCrossentropy(label_smoothing=params.label_smoothing)
    # else:
    #     if params.y_encoding == "onehot":
    #         if index_col_name in data.columns:
    #             # Using Yitan's T/V/E splits
    #             # print(te_meta[["index", "Group", "grp_name", "Response"]])
    #             ytr = pd.get_dummies(tr_meta[args.target[0]].values)
    #             yvl = pd.get_dummies(vl_meta[args.target[0]].values)
    #             yte = pd.get_dummies(te_meta[args.target[0]].values)
    #         else:
    #             ytr = y_onehot.iloc[tr_id, :].reset_index(drop=True)
    #             yvl = y_onehot.iloc[vl_id, :].reset_index(drop=True)
    #             yte = y_onehot.iloc[te_id, :].reset_index(drop=True)
    #         loss = losses.CategoricalCrossentropy()
    #     elif params.y_encoding == "label":
    #         if index_col_name in data.columns:
    #             # Using Yitan's T/V/E splits
    #             ytr = tr_meta[args.target[0]].values
    #             yvl = vl_meta[args.target[0]].values
    #             yte = te_meta[args.target[0]].values
    #             loss = losses.BinaryCrossentropy(label_smoothing=params.label_smoothing)
    #         else:
    #             ytr = ydata_label[tr_id]
    #             yvl = ydata_label[vl_id]
    #             yte = ydata_label[te_id]
    #             loss = losses.SparseCategoricalCrossentropy()
    #     else:
    #         raise ValueError(f"Unknown value for y_encoding ({params.y_encoding}).")


    # -------------
    # Train model
    # -------------
    model = None

    # import ipdb; ipdb.set_trace()
    if args.train is True:

        # Callbacks list
        monitor = "val_loss"
        # monitor = "val_pr-auc"
        callbacks = keras_callbacks(outdir, monitor=monitor,
                                    save_best_only=params.save_best_only,
                                    patience=params.patience)
        # callbacks = keras_callbacks(outdir, monitor="auc", patience=params.patience)

        # Mixed precision
        if params.use_fp16:
            print_fn("\nTrain with mixed precision")
            if int(tf.keras.__version__.split(".")[1]) == 4:  # TF 2.4
                from tensorflow.keras import mixed_precision
                policy = mixed_precision.Policy("mixed_float16")
                mixed_precision.set_global_policy(policy)
            elif int(tf.keras.__version__.split(".")[1]) == 3:  # TF 2.3
                from tensorflow.keras.mixed_precision import experimental as mixed_precision
                policy = mixed_precision.Policy("mixed_float16")
                mixed_precision.set_policy(policy)
            print_fn("Compute dtype: %s" % policy.compute_dtype)
            print_fn("Variable dtype: %s" % policy.variable_dtype)

        # ----------------------
        # Define model
        # ----------------------
        # import ipdb; ipdb.set_trace()

        from tensorflow.keras.layers import Input, Dense, Dropout, Activation, BatchNormalization
        from tensorflow.keras import layers
        from tensorflow.keras import losses
        from tensorflow.keras import optimizers
        from tensorflow.keras.models import Sequential, Model, load_model

        # trainable = True
        trainable = False
        # from_logits = True
        from_logits = False
        fit_verbose = 1
        pretrain = params.pretrain
        pooling = params.pooling
        n_classes = len(sorted(tr_meta[args.target[0]].unique()))

        model_inputs = []
        merge_inputs = []

        if args.use_tile:
            image_shape = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE, 3)
            tile_input_tensor = tf.keras.Input(shape=image_shape, name="tile_image")

            base_img_model = tf.keras.applications.Xception(
                include_top=False,
                weights=pretrain,
                input_shape=None,
                input_tensor=None,
                pooling=pooling)

            print_fn(f"\nNumber of layers in the base image model ({params.base_image_model}): {len(base_img_model.layers)}")
            print_fn("Trainable variables: {}".format(len(base_img_model.trainable_variables)))
            print_fn("Shape of trainable variables at {}: {}".format(0, base_img_model.trainable_variables[0].shape))
            print_fn("Shape of trainable variables at {}: {}".format(-1, base_img_model.trainable_variables[-1].shape))

            print_fn("\nFreeze base model.")
            base_img_model.trainable = trainable  # Freeze the base_img_model
            print_fn("Trainable variables: {}".format(len(base_img_model.trainable_variables)))

            print_fn("\nPrint some layers")
            print_fn("Name of layer {}: {}".format(0, base_img_model.layers[0].name))
            print_fn("Name of layer {}: {}".format(-1, base_img_model.layers[-1].name))

            # training=False makes the base model to run in inference mode so
            # that batchnorm layers are not updated during the fine-tuning stage.
            # x_tile = base_img_model(tile_input_tensor)
            x_tile = base_img_model(tile_input_tensor, training=False)
            # x_tile = base_img_model(tile_input_tensor, training=trainable)
            model_inputs.append(tile_input_tensor)

            # x_tile = Dense(params.dense1_img, activation=tf.nn.relu, name="dense1_img")(x_tile)
            # x_tile = Dense(params.dense2_img, activation=tf.nn.relu, name="dense2_img")(x_tile)
            # x_tile = BatchNormalization(name="batchnorm_im")(x_tile)
            merge_inputs.append(x_tile)
            del tile_input_tensor, x_tile

        # Merge towers
        if len(merge_inputs) > 1:
            mm = layers.Concatenate(axis=1, name="merger")(merge_inputs)
        else:
            mm = merge_inputs[0]

        # Dense layers of the top classfier
        mm = Dense(params.dense1_top, activation=tf.nn.relu, name="dense1_top")(mm)
        # mm = BatchNormalization(name="batchnorm_top")(mm)
        # mm = Dropout(params.dropout1_top)(mm)

        # Output
        output = Dense(n_classes, activation=tf.nn.relu, name="logits")(mm)
        if from_logits is False:
            output = Activation(tf.nn.softmax, dtype="float32", name="softmax")(output)

        # Assemble final model
        model = Model(inputs=model_inputs, outputs=output)

        metrics = [
            tf.keras.metrics.SparseCategoricalAccuracy(name="CatAcc"),
            tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=from_logits, name="CatCrossEnt")
        ]

        if params.optimizer == "SGD":
            optimizer = optimizers.SGD(learning_rate=params.learning_rate, momentum=0.9, nesterov=True)
        elif params.optimizer == "Adam":
            optimizer = optimizers.Adam(learning_rate=params.learning_rate)

        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=from_logits)

        model.compile(loss=loss, optimizer=optimizer, metrics=metrics)


        # import ipdb; ipdb.set_trace()
        print_fn("\nBase model")
        base_img_model.summary(print_fn=print_fn)
        print_fn("\nFull model")
        model.summary(print_fn=print_fn)
        print_fn("Trainable variables: {}".format(len(model.trainable_variables)))

        print_fn(f"Train steps:      {tr_steps}")
        print_fn(f"Validation steps: {vl_steps}")

        # ------------
        # Train
        # ------------
        # import ipdb; ipdb.set_trace()
        # tr_steps = 10  # tr_tiles // params.batch_size // 15  # for debugging
        print_fn("\n{}".format(yellow("Train")))
        timer = Timer()
        history = model.fit(x=train_data,
                            validation_data=val_data,
                            steps_per_epoch=tr_steps,
                            validation_steps=vl_steps,
                            class_weight=class_weight,
                            epochs=params.epochs,
                            verbose=fit_verbose,
                            callbacks=callbacks)
        # del train_data, val_data
        timer.display_timer(print_fn)
        plot_prfrm_metrics(history, title="Train stage", name="tn", outdir=outdir)
        model = load_best_model(outdir)  # load best model

        # Save trained model
        print_fn("\nSave trained model.")
        model.save(outdir/"best_model_trained")

        create_tf_data_eval_kwargs.update({"tfrecords": test_tfr_files, "include_meta": True})
        test_data = create_tf_data(
            **create_tf_data_eval_kwargs,
            **parse_fn_non_train_kwargs
        )

        # Calc hits
        te_tile_preds = calc_tile_preds(test_data, model=model, outdir=outdir)
        te_tile_preds = te_tile_preds.sort_values(["image_id", "tile_id"], ascending=True)
        hits_tn = calc_hits(te_tile_preds, te_meta)
        hits_tn.to_csv(outdir/"hits_tn.csv", index=False)

        # ------------
        # Finetune
        # ------------
        # import ipdb; ipdb.set_trace()
        print_fn("\n{}".format(green("Finetune")))
        unfreeze_top_layers = 50
        # Unfreeze layers of the base model
        for layer in base_img_model.layers[-unfreeze_top_layers:]:
            layer.trainable = True
            print_fn("{}: (trainable={})".format(layer.name, layer.trainable))
        print_fn("Trainable variables: {}".format(len(model.trainable_variables)))

        model.compile(loss=loss,
                      optimizer=optimizers.Adam(learning_rate=params.learning_rate/10),
                      metrics=metrics)

        callbacks = keras_callbacks(outdir, monitor=monitor,
                                    save_best_only=params.save_best_only,
                                    patience=params.patience,
                                    name="finetune")

        total_epochs = history.epoch[-1] + params.finetune_epochs
        timer = Timer()
        history_fn = model.fit(x=train_data,
                               validation_data=val_data,
                               steps_per_epoch=tr_steps,
                               validation_steps=vl_steps,
                               class_weight=class_weight,
                               epochs=total_epochs,
                               initial_epoch=history.epoch[-1]+1,
                               verbose=fit_verbose,
                               callbacks=callbacks)
        del train_data, val_data
        plot_prfrm_metrics(history_fn, title="Finetune stage", name="fn", outdir=outdir)
        timer.display_timer(print_fn)

        # Save trained model
        print_fn("\nSave finetuned model.")
        model.save(outdir/"best_model_finetuned")
        base_img_model.save(outdir/"best_model_img_base_finetuned")


    if args.eval is True:

        print_fn("\n{}".format(bold("Test set predictions.")))
        timer = Timer()
        # calc_tf_preds(test_data, te_meta, model, outdir, args, name="test", print_fn=print_fn)
        # import ipdb; ipdb.set_trace()
        te_tile_preds = calc_tile_preds(test_data, model=model, outdir=outdir)
        te_tile_preds = te_tile_preds.sort_values(["image_id", "tile_id"], ascending=True)
        te_tile_preds.to_csv(outdir/"te_tile_preds.csv", index=False)
        # print(te_tile_preds[["image_id", "tile_id", "y_true", "y_pred_label", "prob"]][:20])
        # print(te_tile_preds.iloc[:20, 1:])
        del test_data

        # Calc hits
        hits_fn = calc_hits(te_tile_preds, te_meta)
        hits_fn.to_csv(outdir/"hits_fn.csv", index=False)

        # from sklearn.metrics import roc_curve, roc_auc_score, auc, average_precision_score
        # roc_auc = roc_auc_score(te_tile_preds["y_true"], te_tile_preds["prob"], average="macro")

        import ipdb; ipdb.set_trace()
        roc_auc = {}
        import matplotlib.pyplot as plt
        from sklearn.metrics import roc_curve, auc
        fig, ax = plt.subplots(figsize=(8, 6))
        for true in range(0, n_classes):
            if true in te_tile_preds["y_true"].values:
                fpr, tpr, thresh = roc_curve(te_tile_preds["y_true"], te_tile_preds["prob"], pos_label=true)
                roc_auc[i] = auc(fpr, tpr)
                plt.plot(fpr, tpr, linestyle='--', label=f"Class {true} vs Rest")
            else:
                roc_auc[i] = None

        # plt.plot([0,0], [1,1], '--', label="Random")
        plt.title("Multiclass ROC Curve")
        plt.xlabel("FPR")
        plt.ylabel("TPR")
        plt.legend(loc="best")
        plt.savefig(outdir/"Multiclass ROC", dpi=70);

        # Avergae precision score
        from sklearn.metrics import average_precision_score
        y_true_vec = te_tile_preds.y_true.values
        y_true_onehot = np.zeros((y_true_vec.size, n_classes))
        y_true_onehot[np.arange(y_true_vec.size), y_true_vec] = 1
        y_probs = te_tile_preds[[c for c in te_tile_preds.columns if "prob_" in c]]
        print_fn("\nAvearge precision")
        print_fn("Micro    {}".format(average_precision_score(y_true_onehot, y_probs, average="micro")))
        print_fn("Macro    {}".format(average_precision_score(y_true_onehot, y_probs, average="macro")))
        print_fn("Wieghted {}".format(average_precision_score(y_true_onehot, y_probs, average="weighted")))
        print_fn("Samples  {}".format(average_precision_score(y_true_onehot, y_probs, average="samples")))


        import ipdb; ipdb.set_trace()
        agg_method = "mean"
        # agg_by = "smp"
        agg_by = "image_id"
        smp_preds = agg_tile_preds(te_tile_preds, agg_by=agg_by, meta=te_meta, agg_method=agg_method)

        timer.display_timer(print_fn)

    lg.close_logger()
Ejemplo n.º 25
0
def run_training(
    model_f,
    lr_f,
    name,
    epochs,
    batch_size,
    steps_per_epoch,
    vid_dir,
    edge_dir,
    train_vid_names,
    val_vid_names,
    frame_size,
    flow_map_size,
    interpolate_ratios,
    patch_size,
    overlap,
    edge_model_f,
    mixed_float=True,
    notebook=True,
    profile=False,
    edge_model_path=None,
    amodel_path=None,
    load_model_path=None,
):
    """
    patch_size, frame_size and flow_map_size are all
        (WIDTH, HEIGHT) format
    """
    if ((edge_model_path is None) or (amodel_path is None))\
        and (load_model_path is None):
        raise ValueError('Need a path to load model')
    if mixed_float:
        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_global_policy(policy)

    st = time.time()

    a_model = anime_model(model_f, interpolate_ratios, flow_map_size)
    e_model = EdgeModel([patch_size[1], patch_size[0], 3], edge_model_f)

    if amodel_path is not None:
        a_model.load_weights(amodel_path).expect_partial()
        print('*' * 50)
        print(f'Anime model loaded from : {amodel_path}')
        print('*' * 50)

    if edge_model_path is not None:
        e_model.load_weights(edge_model_path).expect_partial()
        print('*' * 50)
        print(f'Edge model loaded from : {edge_model_path}')
        print('*' * 50)

    c_model = AnimeModelCyclic(
        a_model,
        e_model,
        (patch_size[1], patch_size[0]),
        overlap,
    )
    if load_model_path is not None:
        c_model.load_weights(load_model_path)
        print('*' * 50)
        print(f'Cyclic model loaded from : {load_model_path}')
        print('*' * 50)
    c_model.compile(optimizer='adam')

    logdir = 'logs/fit/' + name
    if profile:
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=logdir,
            histogram_freq=1,
            profile_batch='3,5',
            update_freq='epoch')
    else:
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=logdir,
            histogram_freq=1,
            profile_batch=0,
            update_freq='epoch')

    lr_callback = keras.callbacks.LearningRateScheduler(lr_f, verbose=1)

    savedir = 'savedmodels/' + name + '/{epoch}'
    save_callback = keras.callbacks.ModelCheckpoint(savedir,
                                                    save_weights_only=True,
                                                    verbose=1)

    if notebook:
        tqdm_callback = TqdmNotebookCallback(metrics=['loss'],
                                             leave_inner=False)
    else:
        tqdm_callback = TqdmCallback()

    train_ds = create_train_dataset(vid_dir,
                                    edge_dir,
                                    train_vid_names,
                                    frame_size,
                                    batch_size,
                                    parallel=6)
    val_ds = create_train_dataset(vid_dir,
                                  edge_dir,
                                  val_vid_names,
                                  frame_size,
                                  batch_size,
                                  val_data=True,
                                  parallel=4)

    image_callback = ValFigCallback(val_ds, logdir)

    c_model.fit(
        x=train_ds,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        callbacks=[
            tensorboard_callback,
            lr_callback,
            save_callback,
            tqdm_callback,
            image_callback,
        ],
        verbose=0,
        validation_data=val_ds,
        validation_steps=50,
    )

    delta = time.time() - st
    hours, remain = divmod(delta, 3600)
    minutes, seconds = divmod(remain, 60)
    print(
        f'Took {hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds')
Ejemplo n.º 26
0
def main(args):

    print(args)

    if not isinstance(args.workers, int):
        args.workers = min(16, mp.cpu_count())

    # AMP
    if args.amp:
        mixed_precision.set_global_policy("mixed_float16")

    # Load doctr model
    model = recognition.__dict__[args.arch](
        pretrained=True if args.resume is None else False,
        input_shape=(args.input_size, 4 * args.input_size, 3),
        vocab=VOCABS[args.vocab],
    )

    # Resume weights
    if isinstance(args.resume, str):
        model.load_weights(args.resume)

    st = time.time()
    ds = datasets.__dict__[args.dataset](
        train=True,
        download=True,
        recognition_task=True,
        use_polygons=args.regular,
        img_transforms=T.Resize((args.input_size, 4 * args.input_size),
                                preserve_aspect_ratio=True),
    )

    _ds = datasets.__dict__[args.dataset](
        train=False,
        download=True,
        recognition_task=True,
        use_polygons=args.regular,
        img_transforms=T.Resize((args.input_size, 4 * args.input_size),
                                preserve_aspect_ratio=True),
    )
    ds.data.extend([(np_img, target) for np_img, target in _ds.data])

    test_loader = DataLoader(
        ds,
        batch_size=args.batch_size,
        drop_last=False,
        num_workers=args.workers,
        shuffle=False,
    )
    print(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in "
          f"{len(test_loader)} batches)")

    mean, std = model.cfg["mean"], model.cfg["std"]
    batch_transforms = T.Normalize(mean=mean, std=std)

    # Metrics
    val_metric = TextMatch()

    print("Running evaluation")
    val_loss, exact_match, partial_match = evaluate(model, test_loader,
                                                    batch_transforms,
                                                    val_metric)
    print(
        f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})"
    )
Ejemplo n.º 27
0
def find_lr(config, outdir, figname, logscale):
    """Run the Learning Rate Finder to produce a batch loss vs. LR plot from
    which an appropriate LR-range can be determined"""
    config, _ = parse_config(config)

    # Decide tf.distribute.strategy depending on number of available GPUs
    strategy, num_gpus = get_strategy()

    ds_train, num_train_steps = get_datasets(config["train_test_datasets"],
                                             config, num_gpus, "train")

    with strategy.scope():
        opt = tf.keras.optimizers.Adam(
            learning_rate=1e-7
        )  # This learning rate will be changed by the lr_finder
        if config["setup"]["dtype"] == "float16":
            model_dtype = tf.dtypes.float16
            policy = mixed_precision.Policy("mixed_float16")
            mixed_precision.set_global_policy(policy)
            opt = mixed_precision.LossScaleOptimizer(opt)
        else:
            model_dtype = tf.dtypes.float32

        model = make_model(config, model_dtype)
        config = set_config_loss(config, config["setup"]["trainable"])

        # Run model once to build the layers
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        configure_model_weights(model, config["setup"]["trainable"])

        loss_dict, loss_weights = get_loss_dict(config)
        model.compile(
            loss=loss_dict,
            optimizer=opt,
            sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ]
            },
        )
        model.summary()

        max_steps = 200
        lr_finder = LRFinder(max_steps=max_steps)
        callbacks = [lr_finder]

        model.fit(
            ds_train.repeat(),
            epochs=max_steps,
            callbacks=callbacks,
            steps_per_epoch=1,
        )

        lr_finder.plot(save_dir=outdir, figname=figname, log_scale=logscale)
Ejemplo n.º 28
0
def initialization():
    global logging_level

    parser = argparse.ArgumentParser(
        prog=os.path.basename(sys.argv[0]),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=__doc__)

    groupM = parser.add_argument_group("Mandatory")
    groupM.add_argument(
        '-m',
        '--model_dir',
        type=check_dir,
        required=True,
        help=
        "Model directory, metadata, classifier and SentencePiece models will be saved in the same directory"
    )
    groupM.add_argument('-s',
                        '--source_lang',
                        required=True,
                        help="Source language")
    groupM.add_argument('-t',
                        '--target_lang',
                        required=True,
                        help="Target language")
    groupM.add_argument(
        '--mono_train',
        type=argparse.FileType('r'),
        default=None,
        required=False,
        help=
        "File containing monolingual sentences of both languages shuffled together, used to train SentencePiece embeddings. Not required for XLMR."
    )
    groupM.add_argument(
        '--parallel_train',
        type=argparse.FileType('r'),
        default=None,
        required=True,
        help="TSV file containing parallel sentences to train the classifier")
    groupM.add_argument(
        '--parallel_valid',
        type=argparse.FileType('r'),
        default=None,
        required=True,
        help="TSV file containing parallel sentences for validation")

    groupO = parser.add_argument_group('Options')
    groupO.add_argument('-S',
                        '--source_tokenizer_command',
                        help="Source language tokenizer full command")
    groupO.add_argument('-T',
                        '--target_tokenizer_command',
                        help="Target language tokenizer full command")
    #groupO.add_argument('-f', '--source_word_freqs', type=argparse.FileType('r'), default=None, required=False, help="L language gzipped list of word frequencies")
    groupO.add_argument(
        '-F',
        '--target_word_freqs',
        type=argparse.FileType('r'),
        default=None,
        required=False,
        help=
        "R language gzipped list of word frequencies (needed for frequence based noise)"
    )
    groupO.add_argument(
        '--block_size',
        type=check_positive,
        default=10000,
        help=
        "Sentence pairs per block when apliying multiprocessing in the noise function"
    )
    groupO.add_argument('-p',
                        '--processes',
                        type=check_positive,
                        default=max(1,
                                    cpu_count() - 1),
                        help="Number of process to use")
    groupO.add_argument(
        '-g',
        '--gpu',
        type=check_positive_or_zero,
        help=
        "Which GPU use, starting from 0. Will set the CUDA_VISIBLE_DEVICES.")
    groupO.add_argument('--mixed_precision',
                        action='store_true',
                        default=False,
                        help="Use mixed precision float16 for training")
    groupO.add_argument(
        '--save_train',
        type=str,
        default=None,
        help=
        "Save the generated training dataset into a file. If the file already exists the training dataset will be loaded from there."
    )
    groupO.add_argument(
        '--save_valid',
        type=str,
        default=None,
        help=
        "Save the generated validation dataset into a file. If the file already exists the validation dataset will be loaded from there."
    )
    groupO.add_argument(
        '--distilled',
        action='store_true',
        help=
        'Enable Knowledge Distillation training. It needs pre-built training set with raw scores from a teacher model.'
    )
    groupO.add_argument(
        '--seed',
        default=None,
        type=int,
        help="Seed for random number generation. By default, no seeed is used."
    )

    # Classifier training options
    groupO.add_argument('--classifier_type',
                        choices=model_classes.keys(),
                        default="dec_attention",
                        help="Neural network architecture of the classifier")
    groupO.add_argument(
        '--batch_size',
        type=check_positive,
        default=None,
        help=
        "Batch size during classifier training. If None, default architecture value will be used."
    )
    groupO.add_argument(
        '--steps_per_epoch',
        type=check_positive,
        default=None,
        help=
        "Number of batch updates per epoch during training. If None, default architecture value will be used or the full dataset size."
    )
    groupO.add_argument(
        '--epochs',
        type=check_positive,
        default=None,
        help=
        "Number of epochs for training. If None, default architecture value will be used."
    )
    groupO.add_argument(
        '--patience',
        type=check_positive,
        default=None,
        help=
        "Stop training when validation has stopped improving after PATIENCE number of epochs"
    )

    # Negative sampling options
    groupO.add_argument(
        '--pos_ratio',
        default=1,
        type=int,
        help=
        "Ratio of positive samples used to oversample on validation and test sets"
    )
    groupO.add_argument('--rand_ratio',
                        default=3,
                        type=int,
                        help="Ratio of negative samples misaligned randomly")
    groupO.add_argument(
        '--womit_ratio',
        default=3,
        type=int,
        help="Ratio of negative samples misaligned by randomly omitting words")
    groupO.add_argument(
        '--freq_ratio',
        default=3,
        type=int,
        help=
        "Ratio of negative samples misaligned by replacing words by frequence (needs --target_word_freq)"
    )
    groupO.add_argument(
        '--fuzzy_ratio',
        default=0,
        type=int,
        help="Ratio of negative samples misaligned by fuzzy matching")
    groupO.add_argument(
        '--neighbour_mix',
        default=False,
        type=bool,
        help="If use negative samples misaligned by neighbourhood")

    # P**n removal training options
    groupO.add_argument(
        '--porn_removal_train',
        type=argparse.FileType('r'),
        help=
        "File with training dataset for FastText classifier. Each sentence must contain at the beginning the '__label__negative' or '__label__positive' according to FastText convention. It should be lowercased and tokenized."
    )
    groupO.add_argument(
        '--porn_removal_test',
        type=argparse.FileType('r'),
        help=
        "Test set to compute precision and accuracy of the p**n removal classifier"
    )
    groupO.add_argument('--porn_removal_file',
                        type=str,
                        default="porn_removal.bin",
                        help="P**n removal classifier output file")
    groupO.add_argument(
        '--porn_removal_side',
        choices=['sl', 'tl'],
        default="sl",
        help=
        "Whether the p**n removal should be applied at the source or at the target language."
    )

    # LM fluency filter training options
    groupO.add_argument(
        '--noisy_examples_file_sl',
        type=str,
        help=
        "File with noisy text in the SL. These are used to estimate the perplexity of noisy text."
    )
    groupO.add_argument(
        '--noisy_examples_file_tl',
        type=str,
        help=
        "File with noisy text in the TL. These are used to estimate the perplexity of noisy text."
    )
    groupO.add_argument(
        '--lm_dev_size',
        type=check_positive_or_zero,
        default=2000,
        help=
        "Number of sentences to be removed from clean text before training LMs. These are used to estimate the perplexity of clean text."
    )
    groupO.add_argument('--lm_file_sl',
                        type=str,
                        help="SL language model output file.")
    groupO.add_argument('--lm_file_tl',
                        type=str,
                        help="TL language model output file.")
    groupO.add_argument(
        '--lm_training_file_sl',
        type=str,
        help=
        "SL text from which the SL LM is trained. If this parameter is not specified, SL LM is trained from the SL side of the input file, after removing --lm_dev_size sentences."
    )
    groupO.add_argument(
        '--lm_training_file_tl',
        type=str,
        help=
        "TL text from which the TL LM is trained. If this parameter is not specified, TL LM is trained from the TL side of the input file, after removing --lm_dev_size sentences."
    )
    groupO.add_argument(
        '--lm_clean_examples_file_sl',
        type=str,
        help=
        "File with clean text in the SL. Used to estimate the perplexity of clean text. This option must be used together with --lm_training_file_sl and both files must not have common sentences. This option replaces --lm_dev_size."
    )
    groupO.add_argument(
        '--lm_clean_examples_file_tl',
        type=str,
        help=
        "File with clean text in the TL. Used to estimate the perplexity of clean text. This option must be used together with --lm_training_file_tl and both files must not have common sentences. This option replaces --lm_dev_size."
    )

    groupL = parser.add_argument_group('Logging')
    groupL.add_argument('-q',
                        '--quiet',
                        action='store_true',
                        help='Silent logging mode')
    groupL.add_argument('--debug',
                        action='store_true',
                        help='Debug logging mode')
    groupL.add_argument('--logfile',
                        type=argparse.FileType('a'),
                        default=sys.stderr,
                        help="Store log to a file")

    args = parser.parse_args()

    if args.freq_ratio > 0 and args.target_word_freqs is None:
        raise Exception(
            "Frequence based noise needs target language word frequencies")
    if args.mono_train is None and args.classifier_type != 'xlmr':
        raise Exception(
            "Argument --mono_train not found, required when not training XLMR classifier"
        )

    if args.seed is not None:
        np.random.seed(args.seed)
        random.seed(args.seed)
        os.environ["PYTHONHASHSEED"] = str(args.seed)
        tf.random.seed = args.seed

    if args.gpu is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    elif "CUDA_VISIBLE_DEVICES" not in os.environ or os.environ[
            "CUDA_VISIBLE_DEVICES"] == "":
        import psutil
        cpus = psutil.cpu_count(logical=False)
        # Set number of threads for CPU training
        tf.config.threading.set_intra_op_parallelism_threads(
            min(cpus, args.processes))
        tf.config.threading.set_inter_op_parallelism_threads(
            min(2, args.processes))

    if args.mixed_precision:
        from tensorflow.keras import mixed_precision
        mixed_precision.set_global_policy('mixed_float16')

    # Remove trailing / in model dir
    args.model_dir.rstrip('/')

    # If the model files are basenames, prepend model path
    if args.lm_file_sl and args.lm_file_sl.count('/') == 0:
        args.lm_file_sl = args.model_dir + '/' + args.lm_file_sl
    if args.lm_file_tl and args.lm_file_tl.count('/') == 0:
        args.lm_file_tl = args.model_dir + '/' + args.lm_file_tl
    if args.porn_removal_file and args.porn_removal_file.count('/') == 0:
        args.porn_removal_file = args.model_dir + '/' + args.porn_removal_file

    # Logging
    logging_setup(args)
    logging_level = logging.getLogger().level
    if logging_level < logging.INFO:
        tf.get_logger().setLevel('INFO')
    else:
        tf.get_logger().setLevel('CRITICAL')

    return args
Ejemplo n.º 29
0
def model_scope(config, total_steps, weights, horovod_enabled=False):
    lr_schedule, optim_callbacks, lr = get_lr_schedule(config,
                                                       steps=total_steps)
    opt = get_optimizer(config, lr_schedule)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
        opt = mixed_precision.LossScaleOptimizer(opt)
    else:
        model_dtype = tf.dtypes.float32

    model = make_model(config, model_dtype)

    # Build the layers after the element and feature dimensions are specified
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    initial_epoch = 0
    loaded_opt = None

    if weights:
        if lr_schedule:
            raise Exception(
                "Restoring the optimizer state with a learning rate schedule is currently not supported"
            )

        # We need to load the weights in the same trainable configuration as the model was set up
        configure_model_weights(model,
                                config["setup"].get("weights_config", "all"))
        model.load_weights(weights, by_name=True)
        opt_weight_file = weights.replace("hdf5",
                                          "pkl").replace("/weights-", "/opt-")
        if os.path.isfile(opt_weight_file):
            loaded_opt = pickle.load(open(opt_weight_file, "rb"))

        initial_epoch = int(weights.split("/")[-1].split("-")[1])
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    config = set_config_loss(config, config["setup"]["trainable"])
    configure_model_weights(model, config["setup"]["trainable"])
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    print("model weights")
    tw_names = [m.name for m in model.trainable_weights]
    for w in model.weights:
        print("layer={} trainable={} shape={} num_weights={}".format(
            w.name, w.name in tw_names, w.shape, np.prod(w.shape)))

    loss_dict, loss_weights = get_loss_dict(config)

    model.compile(
        loss=loss_dict,
        optimizer=opt,
        sample_weight_mode="temporal",
        loss_weights=loss_weights,
        metrics={
            "cls": [
                FlattenedCategoricalAccuracy(name="acc_unweighted",
                                             dtype=tf.float64),
                FlattenedCategoricalAccuracy(
                    use_weights=True, name="acc_weighted", dtype=tf.float64),
            ] + [
                SingleClassRecall(
                    icls, name="rec_cls{}".format(icls), dtype=tf.float64)
                for icls in range(config["dataset"]["num_output_classes"])
            ]
        },
    )

    model.summary()

    # Set the optimizer weights
    if loaded_opt:

        def model_weight_setting():
            grad_vars = model.trainable_weights
            zero_grads = [tf.zeros_like(w) for w in grad_vars]
            model.optimizer.apply_gradients(zip(zero_grads, grad_vars))
            if model.optimizer.__class__.__module__ == "keras.optimizers.optimizer_v1":
                model.optimizer.optimizer.optimizer.set_weights(
                    loaded_opt["weights"])
            else:
                model.optimizer.set_weights(loaded_opt["weights"])

        # FIXME: check that this still works with multiple GPUs
        strategy = tf.distribute.get_strategy()
        strategy.run(model_weight_setting)

    return model, optim_callbacks, initial_epoch
Ejemplo n.º 30
0
def enable_amp():
    mixed_precision.set_global_policy("mixed_float16")