def get_hardware_strategy(mixed_f16=False):
    try:
        # TPU detection. No parameters necessary if TPU_NAME environment variable is
        # set: this is always the case on Kaggle.
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        print('Running on TPU ', tpu.master())
    except ValueError:
        tpu = None

    if tpu:
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        if mixed_f16:
            policy = mixed_precision.Policy('mixed_bfloat16')
            mixed_precision.set_global_policy(policy)
    else:
        # Default distribution strategy in Tensorflow. Works on CPU and single GPU.
        strategy = tf.distribute.get_strategy()
        if mixed_f16:
            policy = mixed_precision.Policy('mixed_float16')
            mixed_precision.set_global_policy(policy)

    print("REPLICAS: ", strategy.num_replicas_in_sync)
    return tpu, strategy
Exemple #2
0
def evaluate(config, train_dir, weights, customize, nevents):
    """Evaluate the trained model in train_dir"""
    if config is None:
        config = Path(train_dir) / "config.yaml"
        assert config.exists(
        ), "Could not find config file in train_dir, please provide one with -c <path/to/config>"
    config, _ = parse_config(config, weights=weights)

    if customize:
        config = customization_functions[customize](config)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
    else:
        model_dtype = tf.dtypes.float32

    strategy, num_gpus = get_strategy()
    # physical_devices = tf.config.list_physical_devices('GPU')
    # for dev in physical_devices:
    #    tf.config.experimental.set_memory_growth(dev, True)

    model = make_model(config, model_dtype)
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    # need to load the weights in the same trainable configuration as the model was set up
    configure_model_weights(model, config["setup"].get("weights_config",
                                                       "all"))
    if weights:
        model.load_weights(weights, by_name=True)
    else:
        weights = get_best_checkpoint(train_dir)
        print(
            "Loading best weights that could be found from {}".format(weights))
        model.load_weights(weights, by_name=True)

    iepoch = int(weights.split("/")[-1].split("-")[1])

    for dsname in config["validation_datasets"]:
        ds_test, _ = get_heptfds_dataset(dsname,
                                         config,
                                         num_gpus,
                                         "test",
                                         supervised=False)
        if nevents:
            ds_test = ds_test.take(nevents)
        ds_test = ds_test.batch(5)
        eval_dir = str(
            Path(train_dir) / "evaluation" / "epoch_{}".format(iepoch) /
            dsname)
        Path(eval_dir).mkdir(parents=True, exist_ok=True)
        eval_model(model, ds_test, config, eval_dir)

    freeze_model(model, config, train_dir)
Exemple #3
0
def set_mixed_precision():
    if int(str(tf.__version__).replace('.', '')) < 241:
        from tensorflow.keras.mixed_precision.experimental import Policy, set_policy
        policy = Policy('mixed_float16')
        set_policy(policy)
    else:
        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_global_policy(policy)
    log.info(
        f' Compute dtype: {policy.compute_dtype}, variable dtype: {policy.variable_dtype}'
    )
Exemple #4
0
    def _set_precision(calculation_dtype, calculation_epsilon):
        # enable single/half/double precision
        import tensorflow.keras.backend as K
        K.set_floatx(calculation_dtype)
        K.set_epsilon(calculation_epsilon)

        # enable mixed precission
        if "float16" in calculation_dtype:
            import tensorflow.keras.mixed_precision as mixed_precision
            policy = mixed_precision.Policy("mixed_float16")
            mixed_precision.set_global_policy(policy)
Exemple #5
0
def init_nerf_model(D=8,
                    W=256,
                    input_ch=3,
                    input_ch_views=3,
                    output_ch=4,
                    skips=[4],
                    use_viewdirs=False):

    print('MODEL', input_ch, input_ch_views, type(input_ch),
          type(input_ch_views), use_viewdirs)
    input_ch = int(input_ch)
    input_ch_views = int(input_ch_views)

    inputs = tf.keras.Input(shape=(input_ch + input_ch_views))
    inputs_pts, inputs_views = tf.split(inputs, [input_ch, input_ch_views], -1)
    inputs_pts.set_shape([None, input_ch])
    inputs_views.set_shape([None, input_ch_views])

    outputs = inputs_pts
    for i in range(D):
        outputs = tf.keras.layers.Dense(W)(outputs)
        outputs = tf.keras.layers.Activation(
            'relu', dtype=mixed_precision.Policy('float32'))(outputs)
        if i in skips:
            outputs = tf.concat([inputs_pts, outputs], -1)

    if use_viewdirs:
        alpha_out = tf.keras.layers.Dense(1)(outputs)
        alpha_out = tf.keras.layers.Activation(None,
                                               dtype='float32')(alpha_out)
        bottleneck = tf.keras.layers.Dense(256)(outputs)
        bottleneck = tf.keras.layers.Activation(None,
                                                dtype='float32')(bottleneck)
        inputs_viewdirs = tf.concat([bottleneck, inputs_views],
                                    -1)  # concat viewdirs
        outputs = inputs_viewdirs
        # The supplement to the paper states there are 4 hidden layers here, but this is an error since
        # the experiments were actually run with 1 hidden layer, so we will leave it as 1.
        for i in range(1):
            outputs = tf.keras.layers.Dense(W // 2)(outputs)
            outputs = tf.keras.layers.Activation('relu',
                                                 dtype='float32')(outputs)
        outputs = tf.keras.layers.Dense(3)(outputs)
        outputs = tf.keras.layers.Activation(None, dtype='float32')(outputs)
        outputs = tf.concat([outputs, alpha_out], -1)
    else:
        outputs = tf.keras.layers.Dense(output_ch)(outputs)
        outputs = tf.keras.layers.Activation(None, dtype='float32')(outputs)

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model
Exemple #6
0
def evaluate(config, train_dir, weights, evaluation_dir):
    """Evaluate the trained model in train_dir"""
    if config is None:
        config = Path(train_dir) / "config.yaml"
        assert config.exists(
        ), "Could not find config file in train_dir, please provide one with -c <path/to/config>"
    config, _ = parse_config(config, weights=weights)

    if evaluation_dir is None:
        eval_dir = str(Path(train_dir) / "evaluation")
    else:
        eval_dir = evaluation_dir

    Path(eval_dir).mkdir(parents=True, exist_ok=True)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
        opt = mixed_precision.LossScaleOptimizer(opt)
    else:
        model_dtype = tf.dtypes.float32

    strategy, num_gpus = get_strategy()
    ds_test, _ = get_heptfds_dataset(config["validation_dataset"], config,
                                     num_gpus, "test")
    ds_test = ds_test.batch(5)

    model = make_model(config, model_dtype)
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    # need to load the weights in the same trainable configuration as the model was set up
    configure_model_weights(model, config["setup"].get("weights_config",
                                                       "all"))
    if weights:
        model.load_weights(weights, by_name=True)
    else:
        weights = get_best_checkpoint(train_dir)
        print(
            "Loading best weights that could be found from {}".format(weights))
        model.load_weights(weights, by_name=True)

    eval_model(model, ds_test, config, eval_dir)
    freeze_model(model, config, ds_test.take(1), train_dir)
Exemple #7
0
def setup():
    # Make base dir
    loss_dir = f'out/{FLAGS.loss}-{FLAGS.disc_model}'
    shutil.rmtree(loss_dir, ignore_errors=True)
    os.mkdir(loss_dir)

    if FLAGS.strategy == 'tpu':
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)
    elif FLAGS.strategy == 'multi_cpu':
        strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])
    else:
        strategy = tf.distribute.get_strategy()

    # Policy
    policy = mixed_precision.Policy(FLAGS.policy)
    mixed_precision.set_global_policy(policy)

    return strategy, loss_dir
Exemple #8
0
def setup(args):
    # Logging
    logging.set_verbosity(args.log_level.upper())

    # Output directory
    args.out = os.path.join(args.base_dir, args.loss, args.data_id,
                            f'{args.backbone}-{args.feat_norm}')
    logging.info(f"out directory: '{args.out}'")
    if not args.load:
        if args.out.startswith('gs://'):
            os.system(f"gsutil -m rm {os.path.join(args.out, '**')}")
        else:
            if os.path.exists(args.out):
                shutil.rmtree(args.out)
            os.makedirs(args.out)
        logging.info(f"cleared any previous work in '{args.out}'")

    # Strategy
    if args.tpu:
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)
    elif len(tf.config.list_physical_devices('GPU')) > 1:
        strategy = tf.distribute.MirroredStrategy()
    elif args.multi_cpu:
        strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])
    else:
        strategy = tf.distribute.get_strategy()

    # Mixed precision
    policy = mixed_precision.Policy(args.policy)
    mixed_precision.set_global_policy(policy)

    # Dataset arguments
    args.views, args.with_batch_sims = ['image', 'image2'], True

    return strategy
Exemple #9
0
def configure_precision(precision=16):
    if precision == 16:
        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_global_policy(policy)
Exemple #10
0
def run_training(
    model_f,
    lr_f,
    name,
    epochs,
    batch_size,
    steps_per_epoch,
    vid_dir,
    edge_dir,
    train_vid_names,
    val_vid_names,
    frame_size,
    flow_map_size,
    interpolate_ratios,
    patch_size,
    overlap,
    edge_model_f,
    mixed_float=True,
    notebook=True,
    profile=False,
    edge_model_path=None,
    amodel_path=None,
    load_model_path=None,
):
    """
    patch_size, frame_size and flow_map_size are all
        (WIDTH, HEIGHT) format
    """
    if ((edge_model_path is None) or (amodel_path is None))\
        and (load_model_path is None):
        raise ValueError('Need a path to load model')
    if mixed_float:
        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_global_policy(policy)

    st = time.time()

    a_model = anime_model(model_f, interpolate_ratios, flow_map_size)
    e_model = EdgeModel([patch_size[1], patch_size[0], 3], edge_model_f)

    if amodel_path is not None:
        a_model.load_weights(amodel_path).expect_partial()
        print('*' * 50)
        print(f'Anime model loaded from : {amodel_path}')
        print('*' * 50)

    if edge_model_path is not None:
        e_model.load_weights(edge_model_path).expect_partial()
        print('*' * 50)
        print(f'Edge model loaded from : {edge_model_path}')
        print('*' * 50)

    c_model = AnimeModelCyclic(
        a_model,
        e_model,
        (patch_size[1], patch_size[0]),
        overlap,
    )
    if load_model_path is not None:
        c_model.load_weights(load_model_path)
        print('*' * 50)
        print(f'Cyclic model loaded from : {load_model_path}')
        print('*' * 50)
    c_model.compile(optimizer='adam')

    logdir = 'logs/fit/' + name
    if profile:
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=logdir,
            histogram_freq=1,
            profile_batch='3,5',
            update_freq='epoch')
    else:
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=logdir,
            histogram_freq=1,
            profile_batch=0,
            update_freq='epoch')

    lr_callback = keras.callbacks.LearningRateScheduler(lr_f, verbose=1)

    savedir = 'savedmodels/' + name + '/{epoch}'
    save_callback = keras.callbacks.ModelCheckpoint(savedir,
                                                    save_weights_only=True,
                                                    verbose=1)

    if notebook:
        tqdm_callback = TqdmNotebookCallback(metrics=['loss'],
                                             leave_inner=False)
    else:
        tqdm_callback = TqdmCallback()

    train_ds = create_train_dataset(vid_dir,
                                    edge_dir,
                                    train_vid_names,
                                    frame_size,
                                    batch_size,
                                    parallel=6)
    val_ds = create_train_dataset(vid_dir,
                                  edge_dir,
                                  val_vid_names,
                                  frame_size,
                                  batch_size,
                                  val_data=True,
                                  parallel=4)

    image_callback = ValFigCallback(val_ds, logdir)

    c_model.fit(
        x=train_ds,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        callbacks=[
            tensorboard_callback,
            lr_callback,
            save_callback,
            tqdm_callback,
            image_callback,
        ],
        verbose=0,
        validation_data=val_ds,
        validation_steps=50,
    )

    delta = time.time() - st
    hours, remain = divmod(delta, 3600)
    minutes, seconds = divmod(remain, 60)
    print(
        f'Took {hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds')
Exemple #11
0
te_meta.to_csv(outdir/"te_meta.csv", index=False)

# Onehot encoding
# ydata = data[args.target[0]].values
# y_onehot = pd.get_dummies(ydata)
# ydata_label = np.argmax(y_onehot.values, axis=1)
# num_classes = len(np.unique(ydata_label))

callbacks = keras_callbacks(outdir, monitor="val_loss")

# Mixed precision
if params.use_fp16:
    print("Train with mixed precision")
    if int(tf.keras.__version__.split(".")[1]) == 4:  # TF 2.4
        from tensorflow.keras import mixed_precision
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
    elif int(tf.keras.__version__.split(".")[1]) == 3:  # TF 2.3
        from tensorflow.keras.mixed_precision import experimental as mixed_precision
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_policy(policy)
    print("Compute dtype: %s" % policy.compute_dtype)
    print("Variable dtype: %s" % policy.variable_dtype)


# y_encoding = "onehot"
# y_encoding = "label"  # to be used binary cross-entropy

if params.y_encoding == "onehot":
    if index_col_name in data.columns:
        # Using Yitan's T/V/E splits
Exemple #12
0
def find_lr(config, outdir, figname, logscale):
    """Run the Learning Rate Finder to produce a batch loss vs. LR plot from
    which an appropriate LR-range can be determined"""
    config, _ = parse_config(config)

    # Decide tf.distribute.strategy depending on number of available GPUs
    strategy, num_gpus = get_strategy()

    ds_train, num_train_steps = get_datasets(config["train_test_datasets"],
                                             config, num_gpus, "train")

    with strategy.scope():
        opt = tf.keras.optimizers.Adam(
            learning_rate=1e-7
        )  # This learning rate will be changed by the lr_finder
        if config["setup"]["dtype"] == "float16":
            model_dtype = tf.dtypes.float16
            policy = mixed_precision.Policy("mixed_float16")
            mixed_precision.set_global_policy(policy)
            opt = mixed_precision.LossScaleOptimizer(opt)
        else:
            model_dtype = tf.dtypes.float32

        model = make_model(config, model_dtype)
        config = set_config_loss(config, config["setup"]["trainable"])

        # Run model once to build the layers
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        configure_model_weights(model, config["setup"]["trainable"])

        loss_dict, loss_weights = get_loss_dict(config)
        model.compile(
            loss=loss_dict,
            optimizer=opt,
            sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ]
            },
        )
        model.summary()

        max_steps = 200
        lr_finder = LRFinder(max_steps=max_steps)
        callbacks = [lr_finder]

        model.fit(
            ds_train.repeat(),
            epochs=max_steps,
            callbacks=callbacks,
            steps_per_epoch=1,
        )

        lr_finder.plot(save_dir=outdir, figname=figname, log_scale=logscale)
Exemple #13
0
def compute_validation_loss(config, train_dir, weights):
    """Evaluate the trained model in train_dir"""
    if config is None:
        config = Path(train_dir) / "config.yaml"
        assert config.exists(
        ), "Could not find config file in train_dir, please provide one with -c <path/to/config>"
    config, _ = parse_config(config, weights=weights)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
    else:
        model_dtype = tf.dtypes.float32

    strategy, num_gpus = get_strategy()
    ds_test, num_test_steps = get_datasets(config["train_test_datasets"],
                                           config, num_gpus, "test")

    with strategy.scope():
        model = make_model(config, model_dtype)
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        # need to load the weights in the same trainable configuration as the model was set up
        configure_model_weights(model,
                                config["setup"].get("weights_config", "all"))
        if weights:
            model.load_weights(weights, by_name=True)
        else:
            weights = get_best_checkpoint(train_dir)
            print("Loading best weights that could be found from {}".format(
                weights))
            model.load_weights(weights, by_name=True)

        loss_dict, loss_weights = get_loss_dict(config)
        model.compile(
            loss=loss_dict,
            # sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ] + [
                    SingleClassRecall(
                        icls, name="rec_cls{}".format(icls), dtype=tf.float64)
                    for icls in range(config["dataset"]["num_output_classes"])
                ]
            },
        )

        losses = model.evaluate(
            x=ds_test,
            steps=num_test_steps,
            return_dict=True,
        )
    with open("{}/losses.txt".format(train_dir), "w") as loss_file:
        loss_file.write(json.dumps(losses) + "\n")
Exemple #14
0
def run(args):
    split_on = "none" if args.split_on is (None or "none") else args.split_on

    # Create project dir (if it doesn't exist)
    # import ipdb; ipdb.set_trace()
    prjdir = cfg.MAIN_PRJDIR / args.prjname
    os.makedirs(prjdir, exist_ok=True)

    # Create outdir (using the loaded hyperparamters) or
    # use content (model) from an existing run
    fea_strs = ["use_tile"]
    args_dict = vars(args)
    fea_names = "_".join(
        [k.split("use_")[-1] for k in fea_strs if args_dict[k] is True])
    prm_file_path = prjdir / f"params_{fea_names}.json"
    if prm_file_path.exists() is False:
        shutil.copy(
            fdir / f"../default_params/default_params_{fea_names}.json",
            prm_file_path)
    params = Params(prm_file_path)

    if args.rundir is not None:
        outdir = Path(args.rundir).resolve()
        assert outdir.exists(), f"The {outdir} doen't exist."
        print_fn = print
    else:
        # outdir = create_outdir(prjdir, args)
        outdir = prjdir / f"{params.base_image_model}_finetuned"

        # Save hyper-parameters
        params.save(outdir / "params.json")

        # Logger
        lg = Logger(outdir / "logger.log")
        print_fn = get_print_func(lg.logger)
        print_fn(f"File path: {fdir}")
        print_fn(f"\n{pformat(vars(args))}")

    # Load dataframe (annotations)
    annotations_file = cfg.DATA_PROCESSED_DIR / args.dataname / cfg.SF_ANNOTATIONS_FILENAME
    dtype = {"image_id": str, "slide": str}
    data = pd.read_csv(annotations_file,
                       dtype=dtype,
                       engine="c",
                       na_values=["na", "NaN"],
                       low_memory=True)
    # data = data.astype({"image_id": str, "slide": str})
    print_fn(data.shape)

    # print_fn("\nFull dataset:")
    # if args.target[0] == "Response":
    #     print_groupby_stat_rsp(data, split_on="Group", print_fn=print_fn)
    # else:
    #     print_groupby_stat_ctype(data, split_on="Group", print_fn=print_fn)
    print_groupby_stat_ctype(data, split_on="Group", print_fn=print_fn)

    # Drop slide dups
    fea_columns = ["slide"]
    data = data.drop_duplicates(subset=fea_columns)

    # Aggregate non-responders to balance the responders
    # import ipdb; ipdb.set_trace()
    # n_samples = data["ctype"].value_counts().min()
    n_samples = 30
    dfs = []
    for ctype, count in data['ctype'].value_counts().items():
        aa = data[data.ctype == ctype]
        if aa.shape[0] > n_samples:
            aa = aa.sample(n=n_samples)
        dfs.append(aa)
    data = pd.concat(dfs, axis=0).reset_index(drop=True)
    print_groupby_stat_ctype(data, split_on="Group", print_fn=print_fn)

    te_size = 0.15
    itr, ite = train_test_split(np.arange(data.shape[0]),
                                test_size=te_size,
                                shuffle=True,
                                stratify=data["ctype_label"].values)
    tr_meta_ = data.iloc[itr, :].reset_index(drop=True)
    te_meta = data.iloc[ite, :].reset_index(drop=True)

    vl_size = 0.10
    itr, ivl = train_test_split(np.arange(tr_meta_.shape[0]),
                                test_size=vl_size,
                                shuffle=True,
                                stratify=tr_meta_["ctype_label"].values)
    tr_meta = tr_meta_.iloc[itr, :].reset_index(drop=True)
    vl_meta = tr_meta_.iloc[ivl, :].reset_index(drop=True)

    print_groupby_stat_ctype(tr_meta, split_on="Group", print_fn=print_fn)
    print_groupby_stat_ctype(vl_meta, split_on="Group", print_fn=print_fn)
    print_groupby_stat_ctype(te_meta, split_on="Group", print_fn=print_fn)

    print_fn(tr_meta.shape)
    print_fn(vl_meta.shape)
    print_fn(te_meta.shape)

    # Determine tfr_dir (the path to TFRecords)
    tfr_dir = (cfg.DATADIR / args.tfr_dir_name).resolve()
    pred_tfr_dir = (cfg.DATADIR / args.pred_tfr_dir_name).resolve()
    label = f"{params.tile_px}px_{params.tile_um}um"
    tfr_dir = tfr_dir / label
    pred_tfr_dir = pred_tfr_dir / label

    # Scalers for each feature set
    ge_scaler, dd1_scaler, dd2_scaler = None, None, None

    ge_cols = [c for c in data.columns if c.startswith("ge_")]
    dd1_cols = [c for c in data.columns if c.startswith("dd1_")]
    dd2_cols = [c for c in data.columns if c.startswith("dd2_")]

    if args.scale_fea:
        if args.use_ge and len(ge_cols) > 0:
            ge_scaler = get_scaler(data[ge_cols])
        if args.use_dd1 and len(dd1_cols) > 0:
            dd1_scaler = get_scaler(data[dd1_cols])
        if args.use_dd2 and len(dd2_cols) > 0:
            dd2_scaler = get_scaler(data[dd2_cols])

    # --------------------------
    # Obtain T/V/E tfr filenames
    # --------------------------
    # List of sample names for T/V/E
    tr_smp_names = list(tr_meta[args.id_name].values)
    vl_smp_names = list(vl_meta[args.id_name].values)
    te_smp_names = list(te_meta[args.id_name].values)

    # TFRecords filenames
    train_tfr_files = get_tfr_files(tfr_dir, tr_smp_names)
    val_tfr_files = get_tfr_files(tfr_dir, vl_smp_names)
    if args.eval is True:
        assert pred_tfr_dir.exists(), f"Dir {pred_tfr_dir} is not found."
        # test_tfr_files = get_tfr_files(tfr_dir, te_smp_names)  # use same tfr_dir for eval
        test_tfr_files = get_tfr_files(pred_tfr_dir, te_smp_names)
        # print_fn("Total samples {}".format(len(train_tfr_files) + len(val_tfr_files) + len(test_tfr_files)))

    assert sorted(tr_smp_names) == sorted(tr_meta[args.id_name].values.tolist(
    )), "Sample names in the tr_smp_names and tr_meta don't match."
    assert sorted(vl_smp_names) == sorted(vl_meta[args.id_name].values.tolist(
    )), "Sample names in the vl_smp_names and vl_meta don't match."
    assert sorted(te_smp_names) == sorted(te_meta[args.id_name].values.tolist(
    )), "Sample names in the te_smp_names and te_meta don't match."

    # -------------------------------
    # Class weight
    # -------------------------------
    tile_cnts = pd.read_csv(tfr_dir / "tile_counts_per_slide.csv")
    tile_cnts.insert(
        loc=0,
        column="tfr_abs_fname",
        value=tile_cnts["tfr_fname"].map(lambda s: str(tfr_dir / s)))
    cat = tile_cnts[tile_cnts["tfr_abs_fname"].isin(train_tfr_files)]

    # import ipdb; ipdb.set_trace()
    ### ap --------------
    # if args.target[0] not in cat.columns:
    #     tile_cnts = tile_cnts[tile_cnts["smp"].isin(tr_meta["smp"])]
    df = tr_meta[["smp", args.target[0]]]
    cat = cat.merge(df, on="smp", how="inner")
    ### ap --------------

    cat = cat.groupby(args.target[0]).agg({
        "smp": "nunique",
        "max_tiles": "sum",
        "n_tiles": "sum",
        "slide": "nunique"
    }).reset_index()
    categories = {}
    for i, row_data in cat.iterrows():
        dct = {
            "num_samples": row_data["smp"],
            "num_tiles": row_data["n_tiles"]
        }
        categories[row_data[args.target[0]]] = dct

    class_weight = calc_class_weights(
        train_tfr_files,
        class_weights_method=params.class_weights_method,
        categories=categories)

    # --------------------------
    # Build tf.data objects
    # --------------------------
    tf.keras.backend.clear_session()

    # import ipdb; ipdb.set_trace()
    if args.use_tile:

        # -------------------------------
        # Parsing funcs
        # -------------------------------
        # import ipdb; ipdb.set_trace()
        if args.target[0] == "Response":
            # Response
            parse_fn = parse_tfrec_fn_rsp
            parse_fn_train_kwargs = {
                "use_tile": args.use_tile,
                "use_ge": args.use_ge,
                "use_dd1": args.use_dd1,
                "use_dd2": args.use_dd2,
                "ge_scaler": ge_scaler,
                "dd1_scaler": dd1_scaler,
                "dd2_scaler": dd2_scaler,
                "id_name": args.id_name,
                "augment": params.augment,
                "application": params.base_image_model,
                # "application": None,
            }
        else:
            # Ctype
            parse_fn = parse_tfrec_fn_ctype
            parse_fn_train_kwargs = {
                "use_tile": args.use_tile,
                "use_ge": args.use_ge,
                "ge_scaler": ge_scaler,
                "id_name": args.id_name,
                "augment": params.augment,
                "target": args.target[0]
            }

        parse_fn_non_train_kwargs = parse_fn_train_kwargs.copy()
        parse_fn_non_train_kwargs["augment"] = False

        # ----------------------------------------
        # Number of tiles/examples in each dataset
        # ----------------------------------------
        # import ipdb; ipdb.set_trace()
        tr_tiles = tile_cnts[tile_cnts[args.id_name].isin(
            tr_smp_names)]["n_tiles"].sum()
        vl_tiles = tile_cnts[tile_cnts[args.id_name].isin(
            vl_smp_names)]["n_tiles"].sum()
        te_tiles = tile_cnts[tile_cnts[args.id_name].isin(
            te_smp_names)]["n_tiles"].sum()

        eval_batch_size = 4 * params.batch_size
        tr_steps = tr_tiles // params.batch_size
        vl_steps = vl_tiles // eval_batch_size
        te_steps = te_tiles // eval_batch_size

        # -------------------------------
        # Create TF datasets
        # -------------------------------
        print("\nCreating TF datasets.")

        # Training
        # import ipdb; ipdb.set_trace()
        train_data = create_tf_data(
            batch_size=params.batch_size,
            deterministic=False,
            include_meta=False,
            interleave=True,
            n_concurrent_shards=params.n_concurrent_shards,  # 32, 64
            parse_fn=parse_fn,
            prefetch=1,  # 2
            repeat=True,
            seed=None,  # cfg.seed,
            shuffle_files=True,
            shuffle_size=params.shuffle_size,  # 8192
            tfrecords=train_tfr_files,
            **parse_fn_train_kwargs)

        # Determine feature shapes from data
        bb = next(train_data.__iter__())

        # Infer dims of features from the data
        # import ipdb; ipdb.set_trace()
        if args.use_ge:
            ge_shape = bb[0]["ge_data"].numpy().shape[1:]
        else:
            ge_shape = None

        if args.use_dd1:
            dd_shape = bb[0]["dd1_data"].numpy().shape[1:]
        else:
            dd_shape = None

        # Print keys and dims
        for i, item in enumerate(bb):
            print(f"\nItem {i}")
            if isinstance(item, dict):
                for k in item.keys():
                    print(f"\t{k}: {item[k].numpy().shape}")
            elif isinstance(item.numpy(), np.ndarray):
                print(item)

        # Evaluation (val, test, train)
        create_tf_data_eval_kwargs = {
            "batch_size": eval_batch_size,
            "include_meta": False,
            "interleave": False,
            "parse_fn": parse_fn,
            "prefetch": None,  # 2
            "repeat": False,
            "seed": None,
            "shuffle_files": False,
            "shuffle_size": None,
        }

        # import ipdb; ipdb.set_trace()
        create_tf_data_eval_kwargs.update({
            "tfrecords": val_tfr_files,
            "include_meta": False
        })
        val_data = create_tf_data(**create_tf_data_eval_kwargs,
                                  **parse_fn_non_train_kwargs)

    # ----------------------
    # Prep for training
    # ----------------------
    # import ipdb; ipdb.set_trace()

    # -------------
    # Train model
    # -------------
    model = None

    # import ipdb; ipdb.set_trace()
    if args.train is True:

        # Callbacks list
        monitor = "val_loss"
        callbacks = keras_callbacks(outdir,
                                    monitor=monitor,
                                    save_best_only=params.save_best_only,
                                    patience=params.patience)

        # Mixed precision
        if params.use_fp16:
            print_fn("\nTrain with mixed precision")
            if int(tf.keras.__version__.split(".")[1]) == 4:  # TF 2.4
                from tensorflow.keras import mixed_precision
                policy = mixed_precision.Policy("mixed_float16")
                mixed_precision.set_global_policy(policy)
            elif int(tf.keras.__version__.split(".")[1]) == 3:  # TF 2.3
                from tensorflow.keras.mixed_precision import experimental as mixed_precision
                policy = mixed_precision.Policy("mixed_float16")
                mixed_precision.set_policy(policy)
            print_fn("Compute dtype: %s" % policy.compute_dtype)
            print_fn("Variable dtype: %s" % policy.variable_dtype)

        # ----------------------
        # Define model
        # ----------------------
        # import ipdb; ipdb.set_trace()

        from tensorflow.keras.layers import Input, Dense, Dropout, Activation, BatchNormalization
        from tensorflow.keras import layers
        from tensorflow.keras import losses
        from tensorflow.keras import optimizers
        from tensorflow.keras.models import Sequential, Model, load_model

        # trainable = True
        trainable = False
        # from_logits = True
        from_logits = False
        fit_verbose = 1
        pretrain = params.pretrain
        pooling = params.pooling
        n_classes = len(sorted(tr_meta[args.target[0]].unique()))

        model_inputs = []
        merge_inputs = []

        if args.use_tile:
            image_shape = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE, 3)
            tile_input_tensor = tf.keras.Input(shape=image_shape,
                                               name="tile_image")

            base_img_model = tf.keras.applications.Xception(include_top=False,
                                                            weights=pretrain,
                                                            input_shape=None,
                                                            input_tensor=None,
                                                            pooling=pooling)

            print_fn(
                f"\nNumber of layers in the base image model ({params.base_image_model}): {len(base_img_model.layers)}"
            )
            print_fn("Trainable variables: {}".format(
                len(base_img_model.trainable_variables)))
            print_fn("Shape of trainable variables at {}: {}".format(
                0, base_img_model.trainable_variables[0].shape))
            print_fn("Shape of trainable variables at {}: {}".format(
                -1, base_img_model.trainable_variables[-1].shape))

            print_fn("\nFreeze base model.")
            base_img_model.trainable = trainable  # Freeze the base_img_model
            print_fn("Trainable variables: {}".format(
                len(base_img_model.trainable_variables)))

            print_fn("\nPrint some layers")
            print_fn("Name of layer {}: {}".format(
                0, base_img_model.layers[0].name))
            print_fn("Name of layer {}: {}".format(
                -1, base_img_model.layers[-1].name))

            # training=False makes the base model to run in inference mode so
            # that batchnorm layers are not updated during the fine-tuning stage.
            # x_tile = base_img_model(tile_input_tensor)
            x_tile = base_img_model(tile_input_tensor, training=False)
            # x_tile = base_img_model(tile_input_tensor, training=trainable)
            model_inputs.append(tile_input_tensor)

            # x_tile = Dense(params.dense1_img, activation=tf.nn.relu, name="dense1_img")(x_tile)
            # x_tile = Dense(params.dense2_img, activation=tf.nn.relu, name="dense2_img")(x_tile)
            # x_tile = BatchNormalization(name="batchnorm_im")(x_tile)
            merge_inputs.append(x_tile)
            del tile_input_tensor, x_tile

        # Merge towers
        if len(merge_inputs) > 1:
            mm = layers.Concatenate(axis=1, name="merger")(merge_inputs)
        else:
            mm = merge_inputs[0]

        # Dense layers of the top classfier
        mm = Dense(params.dense1_top, activation=tf.nn.relu,
                   name="dense1_top")(mm)
        # mm = BatchNormalization(name="batchnorm_top")(mm)
        # mm = Dropout(params.dropout1_top)(mm)

        # Output
        output = Dense(n_classes, activation=tf.nn.relu, name="logits")(mm)
        if from_logits is False:
            output = Activation(tf.nn.softmax, dtype="float32",
                                name="softmax")(output)

        # Assemble final model
        model = Model(inputs=model_inputs, outputs=output)

        metrics = [
            tf.keras.metrics.SparseCategoricalAccuracy(name="CatAcc"),
            tf.keras.metrics.SparseCategoricalCrossentropy(
                from_logits=from_logits, name="CatCrossEnt")
        ]

        if params.optimizer == "SGD":
            optimizer = optimizers.SGD(learning_rate=params.learning_rate,
                                       momentum=0.9,
                                       nesterov=True)
        elif params.optimizer == "Adam":
            optimizer = optimizers.Adam(learning_rate=params.learning_rate)

        loss = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=from_logits)

        model.compile(loss=loss, optimizer=optimizer, metrics=metrics)

        # import ipdb; ipdb.set_trace()
        print_fn("\nBase model")
        base_img_model.summary(print_fn=print_fn)
        print_fn("\nFull model")
        model.summary(print_fn=print_fn)
        print_fn("Trainable variables: {}".format(
            len(model.trainable_variables)))

        print_fn(f"Train steps:      {tr_steps}")
        # print_fn(f"Validation steps: {vl_steps}")

        # ------------
        # Train
        # ------------
        import ipdb
        ipdb.set_trace()
        # tr_steps = 10  # tr_tiles // params.batch_size // 15  # for debugging
        print_fn("\n{}".format(yellow("Train")))
        timer = Timer()
        history = model.fit(x=train_data,
                            validation_data=val_data,
                            steps_per_epoch=tr_steps,
                            validation_steps=vl_steps,
                            class_weight=class_weight,
                            epochs=params.epochs,
                            verbose=fit_verbose,
                            callbacks=callbacks)
        # del train_data, val_data
        timer.display_timer(print_fn)
        plot_prfrm_metrics(history,
                           title="Train stage",
                           name="tn",
                           outdir=outdir)
        model = load_best_model(outdir)  # load best model

        # Save trained model
        print_fn("\nSave trained model.")
        model.save(outdir / "best_model_trained")

        create_tf_data_eval_kwargs.update({
            "tfrecords": test_tfr_files,
            "include_meta": True
        })
        test_data = create_tf_data(**create_tf_data_eval_kwargs,
                                   **parse_fn_non_train_kwargs)

        # Calc hits
        te_tile_preds = calc_tile_preds(test_data, model=model, outdir=outdir)
        te_tile_preds = te_tile_preds.sort_values(["image_id", "tile_id"],
                                                  ascending=True)
        hits_tn = calc_hits(te_tile_preds, te_meta)
        hits_tn.to_csv(outdir / "hits_tn.csv", index=False)

        # ------------
        # Finetune
        # ------------
        # import ipdb; ipdb.set_trace()
        print_fn("\n{}".format(green("Finetune")))
        unfreeze_top_layers = 50
        # Unfreeze layers of the base model
        for layer in base_img_model.layers[-unfreeze_top_layers:]:
            layer.trainable = True
            print_fn("{}: (trainable={})".format(layer.name, layer.trainable))
        print_fn("Trainable variables: {}".format(
            len(model.trainable_variables)))

        model.compile(
            loss=loss,
            optimizer=optimizers.Adam(learning_rate=params.learning_rate / 10),
            metrics=metrics)

        callbacks = keras_callbacks(outdir,
                                    monitor=monitor,
                                    save_best_only=params.save_best_only,
                                    patience=params.patience,
                                    name="finetune")

        total_epochs = history.epoch[-1] + params.finetune_epochs
        timer = Timer()
        history_fn = model.fit(x=train_data,
                               validation_data=val_data,
                               steps_per_epoch=tr_steps,
                               validation_steps=vl_steps,
                               class_weight=class_weight,
                               epochs=total_epochs,
                               initial_epoch=history.epoch[-1] + 1,
                               verbose=fit_verbose,
                               callbacks=callbacks)
        del train_data, val_data
        plot_prfrm_metrics(history_fn,
                           title="Finetune stage",
                           name="fn",
                           outdir=outdir)
        timer.display_timer(print_fn)

        # Save trained model
        print_fn("\nSave finetuned model.")
        model.save(outdir / "best_model_finetuned")
        base_img_model.save(outdir / "best_model_img_base_finetuned")

    if args.eval is True:

        print_fn("\n{}".format(bold("Test set predictions.")))
        timer = Timer()
        # import ipdb; ipdb.set_trace()
        te_tile_preds = calc_tile_preds(test_data, model=model, outdir=outdir)
        te_tile_preds = te_tile_preds.sort_values(["image_id", "tile_id"],
                                                  ascending=True)
        te_tile_preds.to_csv(outdir / "te_tile_preds.csv", index=False)
        # print(te_tile_preds[["image_id", "tile_id", "y_true", "y_pred_label", "prob"]][:20])
        # print(te_tile_preds.iloc[:20, 1:])
        del test_data

        # Calc hits
        hits_fn = calc_hits(te_tile_preds, te_meta)
        hits_fn.to_csv(outdir / "hits_fn.csv", index=False)

        # import ipdb; ipdb.set_trace()
        roc_auc = {}
        import matplotlib.pyplot as plt
        from sklearn.metrics import roc_curve, auc
        fig, ax = plt.subplots(figsize=(8, 6))
        for true in range(0, n_classes):
            if true in te_tile_preds["y_true"].values:
                fpr, tpr, thresh = roc_curve(te_tile_preds["y_true"],
                                             te_tile_preds["prob"],
                                             pos_label=true)
                roc_auc[i] = auc(fpr, tpr)
                plt.plot(fpr,
                         tpr,
                         linestyle='--',
                         label=f"Class {true} vs Rest")
            else:
                roc_auc[i] = None

        # plt.plot([0,0], [1,1], '--', label="Random")
        plt.title("Multiclass ROC Curve")
        plt.xlabel("FPR")
        plt.ylabel("TPR")
        plt.legend(loc="best")
        plt.savefig(outdir / "Multiclass ROC", dpi=70)

        # Avergae precision score
        from sklearn.metrics import average_precision_score
        y_true_vec = te_tile_preds.y_true.values
        y_true_onehot = np.zeros((y_true_vec.size, n_classes))
        y_true_onehot[np.arange(y_true_vec.size), y_true_vec] = 1
        y_probs = te_tile_preds[[
            c for c in te_tile_preds.columns if "prob_" in c
        ]]
        print_fn("\nAvearge precision")
        print_fn("Micro    {}".format(
            average_precision_score(y_true_onehot, y_probs, average="micro")))
        print_fn("Macro    {}".format(
            average_precision_score(y_true_onehot, y_probs, average="macro")))
        print_fn("Wieghted {}".format(
            average_precision_score(y_true_onehot, y_probs,
                                    average="weighted")))
        print_fn("Samples  {}".format(
            average_precision_score(y_true_onehot, y_probs,
                                    average="samples")))

        import ipdb
        ipdb.set_trace()
        agg_method = "mean"
        # agg_by = "smp"
        agg_by = "image_id"
        smp_preds = agg_tile_preds(te_tile_preds,
                                   agg_by=agg_by,
                                   meta=te_meta,
                                   agg_method=agg_method)

        timer.display_timer(print_fn)

    lg.close_logger()
Exemple #15
0
def run(args):
    split_on = "none" if args.split_on is (None or "none") else args.split_on


    # Create project dir (if it doesn't exist)
    import ipdb; ipdb.set_trace()
    prjdir = cfg.MAIN_PRJDIR/args.prjname
    os.makedirs(prjdir, exist_ok=True)


    # Create outdir (using the loaded hyperparamters) or
    # use content (model) from an existing run
    fea_strs = ["use_tile"]
    args_dict = vars(args)
    fea_names = "_".join([k.split("use_")[-1] for k in fea_strs if args_dict[k] is True])
    prm_file_path = prjdir/f"params_{fea_names}.json"
    if prm_file_path.exists() is False:
        shutil.copy(fdir/f"../default_params/default_params_{fea_names}.json", prm_file_path)
    params = Params(prm_file_path)

    if args.rundir is not None:
        outdir = Path(args.rundir).resolve()
        assert outdir.exists(), f"The {outdir} doen't exist."
        print_fn = print
    else:
        outdir = create_outdir_2(prjdir, args)

        # Save hyper-parameters
        params.save(outdir/"params.json")

        # Logger
        lg = Logger(outdir/"logger.log")
        print_fn = get_print_func(lg.logger)
        print_fn(f"File path: {fdir}")
        print_fn(f"\n{pformat(vars(args))}")


    # Load dataframe (annotations)
    annotations_file = cfg.DATA_PROCESSED_DIR/args.dataname/cfg.SF_ANNOTATIONS_FILENAME
    dtype = {"image_id": str, "slide": str}
    data = pd.read_csv(annotations_file, dtype=dtype, engine="c", na_values=["na", "NaN"], low_memory=True)
    # data = data.astype({"image_id": str, "slide": str})
    print_fn(data.shape)


    print_fn("\nFull dataset:")
    if args.target[0] == "Response":
        print_groupby_stat_rsp(data, split_on="Group", print_fn=print_fn)
    else:
        print_groupby_stat_ctype(data, split_on="Group", print_fn=print_fn)


    # Determine tfr_dir (the path to TFRecords)
    tfr_dir = (cfg.DATADIR/args.tfr_dir_name).resolve()
    pred_tfr_dir = (cfg.DATADIR/args.pred_tfr_dir_name).resolve()
    label = f"{params.tile_px}px_{params.tile_um}um"
    tfr_dir = tfr_dir/label
    pred_tfr_dir = pred_tfr_dir/label

    # Create outcomes (for drug response)
    # outcomes = {}
    # unique_outcomes = list(set(data[args.target[0]].values))
    # unique_outcomes.sort()
    # for smp, o in zip(data[args.id_name], data[args.target[0]]):
    #     outcomes[smp] = {"outcome": unique_outcomes.index(o)}


    # Scalers for each feature set
    # import ipdb; ipdb.set_trace()
    ge_scaler, dd1_scaler, dd2_scaler = None, None, None

    ge_cols  = [c for c in data.columns if c.startswith("ge_")]
    dd1_cols = [c for c in data.columns if c.startswith("dd1_")]
    dd2_cols = [c for c in data.columns if c.startswith("dd2_")]

    if args.scale_fea:
        if args.use_ge and len(ge_cols) > 0:
            ge_scaler = get_scaler(data[ge_cols])
        if args.use_dd1 and len(dd1_cols) > 0:
            dd1_scaler = get_scaler(data[dd1_cols])
        if args.use_dd2 and len(dd2_cols) > 0:
            dd2_scaler = get_scaler(data[dd2_cols])


    # Create manifest
    # print_fn("\nCreate/load manifest ...")
    # timer = Timer()
    # manifest = create_manifest(directory=tfr_dir, n_files=None)
    # timer.display_timer(print_fn)


    # -----------------------------------------------
    # Data splits
    # -----------------------------------------------

    # --------------
    # Yitan's splits
    # --------------
    if args.target[0] == "Response":
        if args.use_dd1 is False and args.use_dd2 is False:
            splitdir = cfg.DATADIR/"PDX_Transfer_Learning_Classification/Processed_Data/Data_For_MultiModal_Learning/Data_Partition_Drug_Specific"
            splitdir = splitdir/params.drug_specific
        else:
            splitdir = cfg.DATADIR/"PDX_Transfer_Learning_Classification/Processed_Data/Data_For_MultiModal_Learning/Data_Partition"
    else:
        splitdir = cfg.DATADIR/"PDX_Transfer_Learning_Classification/Processed_Data/Data_For_MultiModal_Learning/Data_Partition"

    tr_id = cast_list(read_lines(str(splitdir/f"cv_{args.split_id}"/"TrainList.txt")), int)
    vl_id = cast_list(read_lines(str(splitdir/f"cv_{args.split_id}"/"ValList.txt")), int)
    te_id = cast_list(read_lines(str(splitdir/f"cv_{args.split_id}"/"TestList.txt")), int)

    # Update ids
    index_col_name = "index"
    tr_id = sorted(set(data[index_col_name]).intersection(set(tr_id)))
    vl_id = sorted(set(data[index_col_name]).intersection(set(vl_id)))
    te_id = sorted(set(data[index_col_name]).intersection(set(te_id)))

    # Subsample train samples
    if args.n_samples > 0:
        if args.n_samples < len(tr_id):
            tr_id = tr_id[:args.n_samples]
        if args.n_samples < len(vl_id):
            vl_id = vl_id[:args.n_samples]
        if args.n_samples < len(te_id):
            te_id = te_id[:args.n_samples]

    
    ### ap --------------
    # Drop slide duplicates
    ###
    fea_columns = ["slide"]
    data = data.drop_duplicates(subset=fea_columns)
    ### ap --------------

    # --------------
    # TidyData
    # --------------
    # TODO: finish and test this class
    # td = TidyData(data,
    #               ge_prfx="ge_",
    #               dd1_prfx="dd1_",
    #               dd2_prfx="dd2_",
    #               index_col_name="index",
    #               split_ids={"tr_id": tr_id, "vl_id": vl_id, "te_id": te_id}
    # )
    # ge_scaler = td.ge_scaler
    # dd1_scaler = td.dd1_scaler
    # dd2_scaler = td.dd2_scaler

    # tr_meta = td.tr_meta
    # vl_meta = td.vl_meta
    # te_meta = td.te_meta
    # tr_meta.to_csv(outdir/"tr_meta.csv", index=False)
    # vl_meta.to_csv(outdir/"vl_meta.csv", index=False)
    # te_meta.to_csv(outdir/"te_meta.csv", index=False)

    # # Variables (dict/dataframes/arrays) that are passed as features to the NN
    # xtr = {"ge_data": td.tr_ge.values, "dd1_data": td.tr_dd1.values, "dd2_data": td.tr_dd2.values}
    # xvl = {"ge_data": td.vl_ge.values, "dd1_data": td.vl_dd1.values, "dd2_data": td.vl_dd2.values}
    # xte = {"ge_data": td.te_ge.values, "dd1_data": td.te_dd1.values, "dd2_data": td.te_dd2.values}

    # --------------
    # w/o TidyData
    # --------------
    kwargs = {"ge_cols": ge_cols,
              "dd1_cols": dd1_cols,
              "dd2_cols": dd2_cols,
              "ge_scaler": ge_scaler,
              "dd1_scaler": dd1_scaler,
              "dd2_scaler": dd2_scaler,
              "ge_dtype": cfg.GE_DTYPE,
              "dd_dtype": cfg.DD_DTYPE,
              "index_col_name": index_col_name,
              "split_on": split_on
              }
    tr_ge, tr_dd1, tr_dd2, tr_meta = split_data_and_extract_fea(data, ids=tr_id, **kwargs)
    vl_ge, vl_dd1, vl_dd2, vl_meta = split_data_and_extract_fea(data, ids=vl_id, **kwargs)
    te_ge, te_dd1, te_dd2, te_meta = split_data_and_extract_fea(data, ids=te_id, **kwargs)

    ### ap --------------
    # Create annotations for slideflow
    ###
    # import ipdb; ipdb.set_trace()
    tr_meta["submitter_id"] = tr_meta["Group"]  # submitter_id (specific patient); Group (specific treatment group)
    vl_meta["submitter_id"] = vl_meta["Group"]
    te_meta["submitter_id"] = te_meta["Group"]
    tr_meta["training_phase"] = "train"
    vl_meta["training_phase"] = "validation"
    te_meta["training_phase"] = "test"
    keep_cols = ["submitter_id", "slide", "model", "patient_id", "specimen_id", "sample_id",
                 "training_phase", "Group", "ctype", "csite", "ctype_label", "csite_label"]
    tr_meta_tmp = tr_meta[keep_cols]
    vl_meta_tmp = vl_meta[keep_cols]
    te_meta_tmp = te_meta[keep_cols]
    tr_meta.to_csv(outdir/"train_annotations.csv", index=False)
    vl_meta.to_csv(outdir/"validation_annotations.csv", index=False)
    te_meta.to_csv(outdir/"test_annotations.csv", index=False)
    sf_df = pd.concat([tr_meta_tmp, vl_meta_tmp, te_meta_tmp], axis=0)
    sf_df.to_csv(outdir/"annotations_for_sf.csv", index=False)
    del tr_meta_tmp, vl_meta_tmp, te_meta_tmp, sf_df
    ### ap --------------

    if args.train is True:
        tr_meta.to_csv(outdir/"tr_meta.csv", index=False)
        vl_meta.to_csv(outdir/"vl_meta.csv", index=False)
        te_meta.to_csv(outdir/"te_meta.csv", index=False)

    ge_shape = (tr_ge.shape[1],)
    dd_shape = (tr_dd1.shape[1],)

    if args.target[0] == "Response":
        print_fn("\nTrain:")
        print_groupby_stat_rsp(tr_meta, split_on="Group", print_fn=print_fn)
        print_fn("\nValidation:")
        print_groupby_stat_rsp(vl_meta, split_on="Group", print_fn=print_fn)
        print_fn("\nTest:")
        print_groupby_stat_rsp(te_meta, split_on="Group", print_fn=print_fn)
    else:
        print_fn("\nTrain:")
        print_groupby_stat_ctype(tr_meta, split_on="Group", print_fn=print_fn)
        print_fn("\nValidation:")
        print_groupby_stat_ctype(vl_meta, split_on="Group", print_fn=print_fn)
        print_fn("\nTest:")
        print_groupby_stat_ctype(te_meta, split_on="Group", print_fn=print_fn)

    # Make sure indices do not overlap
    assert len( set(tr_id).intersection(set(vl_id)) ) == 0, "Overlapping indices btw tr and vl"
    assert len( set(tr_id).intersection(set(te_id)) ) == 0, "Overlapping indices btw tr and te"
    assert len( set(vl_id).intersection(set(te_id)) ) == 0, "Overlapping indices btw vl and te"

    # Print split ratios
    print_fn("")
    print_fn("Train samples {} ({:.2f}%)".format( tr_meta.shape[0], 100*tr_meta.shape[0]/data.shape[0] ))
    print_fn("Val   samples {} ({:.2f}%)".format( vl_meta.shape[0], 100*vl_meta.shape[0]/data.shape[0] ))
    print_fn("Test  samples {} ({:.2f}%)".format( te_meta.shape[0], 100*te_meta.shape[0]/data.shape[0] ))

    tr_grp_unq = set(tr_meta[split_on].values)
    vl_grp_unq = set(vl_meta[split_on].values)
    te_grp_unq = set(te_meta[split_on].values)
    print_fn("")
    print_fn(f"Total intersects on {split_on} btw tr and vl: {len(tr_grp_unq.intersection(vl_grp_unq))}")
    print_fn(f"Total intersects on {split_on} btw tr and te: {len(tr_grp_unq.intersection(te_grp_unq))}")
    print_fn(f"Total intersects on {split_on} btw vl and te: {len(vl_grp_unq.intersection(te_grp_unq))}")
    print_fn(f"Unique {split_on} in tr: {len(tr_grp_unq)}")
    print_fn(f"Unique {split_on} in vl: {len(vl_grp_unq)}")
    print_fn(f"Unique {split_on} in te: {len(te_grp_unq)}")


    # --------------------------
    # Obtain T/V/E tfr filenames
    # --------------------------
    # List of sample names for T/V/E
    tr_smp_names = list(tr_meta[args.id_name].values)
    vl_smp_names = list(vl_meta[args.id_name].values)
    te_smp_names = list(te_meta[args.id_name].values)

    # TFRecords filenames
    train_tfr_files = get_tfr_files(tfr_dir, tr_smp_names)
    val_tfr_files = get_tfr_files(tfr_dir, vl_smp_names)
    if args.eval is True:
        assert pred_tfr_dir.exists(), f"Dir {pred_tfr_dir} is not found."
        # test_tfr_files = get_tfr_files(tfr_dir, te_smp_names)  # use same tfr_dir for eval
        test_tfr_files = get_tfr_files(pred_tfr_dir, te_smp_names)
        # print_fn("Total samples {}".format(len(train_tfr_files) + len(val_tfr_files) + len(test_tfr_files)))

    # Missing tfrecords
    print("\nThese samples miss a tfrecord:")
    df_miss = data.loc[~data[args.id_name].isin(tr_smp_names + vl_smp_names + te_smp_names), ["smp", "image_id"]]
    print(df_miss)

    assert sorted(tr_smp_names) == sorted(tr_meta[args.id_name].values.tolist()), "Sample names in the tr_smp_names and tr_meta don't match."
    assert sorted(vl_smp_names) == sorted(vl_meta[args.id_name].values.tolist()), "Sample names in the vl_smp_names and vl_meta don't match."
    assert sorted(te_smp_names) == sorted(te_meta[args.id_name].values.tolist()), "Sample names in the te_smp_names and te_meta don't match."


    # -------------------------------
    # Class weight
    # -------------------------------
    tile_cnts = pd.read_csv(tfr_dir/"tile_counts_per_slide.csv")
    tile_cnts.insert(loc=0, column="tfr_abs_fname", value=tile_cnts["tfr_fname"].map(lambda s: str(tfr_dir/s)))
    cat = tile_cnts[tile_cnts["tfr_abs_fname"].isin(train_tfr_files)]

    ### ap --------------
    # if args.target[0] not in cat.columns:
    #     tile_cnts = tile_cnts[tile_cnts["smp"].isin(tr_meta["smp"])]
    df = tr_meta[["smp", args.target[0]]]
    cat = cat.merge(df, on="smp", how="inner")
    ### ap --------------

    cat = cat.groupby(args.target[0]).agg({"smp": "nunique", "max_tiles": "sum", "n_tiles": "sum", "slide": "nunique"}).reset_index()
    categories = {}
    for i, row_data in cat.iterrows():
        dct = {"num_samples": row_data["smp"], "num_tiles": row_data["n_tiles"]}
        categories[row_data[args.target[0]]] = dct

    class_weight = calc_class_weights(train_tfr_files,
                                      class_weights_method=params.class_weights_method,
                                      categories=categories)
    # class_weight = {"Response": class_weight}


    # --------------------------
    # Build tf.data objects
    # --------------------------
    tf.keras.backend.clear_session()

    # import ipdb; ipdb.set_trace()
    if args.use_tile:

        # -------------------------------
        # Parsing funcs
        # -------------------------------
        # import ipdb; ipdb.set_trace()
        if args.target[0] == "Response":
            # Response
            parse_fn = parse_tfrec_fn_rsp
            parse_fn_train_kwargs = {
                "use_tile": args.use_tile,
                "use_ge": args.use_ge,
                "use_dd1": args.use_dd1,
                "use_dd2": args.use_dd2,
                "ge_scaler": ge_scaler,
                "dd1_scaler": dd1_scaler,
                "dd2_scaler": dd2_scaler,
                "id_name": args.id_name,
                "augment": params.augment,
                "application": params.base_image_model,
                # "application": None,
            }
        else:
            # Ctype
            parse_fn = parse_tfrec_fn_ctype
            parse_fn_train_kwargs = {
                "use_tile": args.use_tile,
                "use_ge": args.use_ge,
                "ge_scaler": ge_scaler,
                "id_name": args.id_name,
                "augment": params.augment,
                "target": args.target[0]
            }

        parse_fn_non_train_kwargs = parse_fn_train_kwargs.copy()
        parse_fn_non_train_kwargs["augment"] = False

        # ----------------------------------------
        # Number of tiles/examples in each dataset
        # ----------------------------------------
        # import ipdb; ipdb.set_trace()
        tr_tiles = tile_cnts[tile_cnts[args.id_name].isin(tr_smp_names)]["n_tiles"].sum()
        vl_tiles = tile_cnts[tile_cnts[args.id_name].isin(vl_smp_names)]["n_tiles"].sum()
        te_tiles = tile_cnts[tile_cnts[args.id_name].isin(te_smp_names)]["n_tiles"].sum()

        eval_batch_size = 4 * params.batch_size
        tr_steps = tr_tiles // params.batch_size
        vl_steps = vl_tiles // eval_batch_size
        # te_steps = te_tiles // eval_batch_size

        # -------------------------------
        # Create TF datasets
        # -------------------------------
        print("\nCreating TF datasets.")

        # Training
        # import ipdb; ipdb.set_trace()
        train_data = create_tf_data(
            batch_size=params.batch_size,
            deterministic=False,
            include_meta=False,
            interleave=True,
            n_concurrent_shards=params.n_concurrent_shards,  # 32, 64
            parse_fn=parse_fn,
            prefetch=1,  # 2
            repeat=True,
            seed=None,  # cfg.seed,
            shuffle_files=True,
            shuffle_size=params.shuffle_size,  # 8192
            tfrecords=train_tfr_files,
            **parse_fn_train_kwargs)

        # Determine feature shapes from data
        bb = next(train_data.__iter__())

        # Infer dims of features from the data
        # import ipdb; ipdb.set_trace()
        if args.use_ge:
            ge_shape = bb[0]["ge_data"].numpy().shape[1:]
        else:
            ge_shape = None

        if args.use_dd1:
            dd_shape = bb[0]["dd1_data"].numpy().shape[1:]
        else:
            dd_shape = None

        # Print keys and dims
        for i, item in enumerate(bb):
            print(f"\nItem {i}")
            if isinstance(item, dict):
                for k in item.keys():
                    print(f"\t{k}: {item[k].numpy().shape}")
            elif isinstance(item.numpy(), np.ndarray):
                print(item)

        # for i, rec in enumerate(train_data.take(2)):
        #     tf.print(rec[1])

        # Evaluation (val, test, train)
        create_tf_data_eval_kwargs = {
            "batch_size": eval_batch_size,
            "include_meta": False,
            "interleave": False,
            "parse_fn": parse_fn,
            "prefetch": None,  # 2
            "repeat": False,
            "seed": None,
            "shuffle_files": False,
            "shuffle_size": None,
        }

        # import ipdb; ipdb.set_trace()
        create_tf_data_eval_kwargs.update({"tfrecords": val_tfr_files, "include_meta": False})
        val_data = create_tf_data(
            **create_tf_data_eval_kwargs,
            **parse_fn_non_train_kwargs
        )

    # ----------------------
    # Prep for training
    # ----------------------
    # import ipdb; ipdb.set_trace()

    # # Loss and target
    # if args.use_tile:
    #     loss = losses.BinaryCrossentropy(label_smoothing=params.label_smoothing)
    # else:
    #     if params.y_encoding == "onehot":
    #         if index_col_name in data.columns:
    #             # Using Yitan's T/V/E splits
    #             # print(te_meta[["index", "Group", "grp_name", "Response"]])
    #             ytr = pd.get_dummies(tr_meta[args.target[0]].values)
    #             yvl = pd.get_dummies(vl_meta[args.target[0]].values)
    #             yte = pd.get_dummies(te_meta[args.target[0]].values)
    #         else:
    #             ytr = y_onehot.iloc[tr_id, :].reset_index(drop=True)
    #             yvl = y_onehot.iloc[vl_id, :].reset_index(drop=True)
    #             yte = y_onehot.iloc[te_id, :].reset_index(drop=True)
    #         loss = losses.CategoricalCrossentropy()
    #     elif params.y_encoding == "label":
    #         if index_col_name in data.columns:
    #             # Using Yitan's T/V/E splits
    #             ytr = tr_meta[args.target[0]].values
    #             yvl = vl_meta[args.target[0]].values
    #             yte = te_meta[args.target[0]].values
    #             loss = losses.BinaryCrossentropy(label_smoothing=params.label_smoothing)
    #         else:
    #             ytr = ydata_label[tr_id]
    #             yvl = ydata_label[vl_id]
    #             yte = ydata_label[te_id]
    #             loss = losses.SparseCategoricalCrossentropy()
    #     else:
    #         raise ValueError(f"Unknown value for y_encoding ({params.y_encoding}).")


    # -------------
    # Train model
    # -------------
    model = None

    # import ipdb; ipdb.set_trace()
    if args.train is True:

        # Callbacks list
        monitor = "val_loss"
        # monitor = "val_pr-auc"
        callbacks = keras_callbacks(outdir, monitor=monitor,
                                    save_best_only=params.save_best_only,
                                    patience=params.patience)
        # callbacks = keras_callbacks(outdir, monitor="auc", patience=params.patience)

        # Mixed precision
        if params.use_fp16:
            print_fn("\nTrain with mixed precision")
            if int(tf.keras.__version__.split(".")[1]) == 4:  # TF 2.4
                from tensorflow.keras import mixed_precision
                policy = mixed_precision.Policy("mixed_float16")
                mixed_precision.set_global_policy(policy)
            elif int(tf.keras.__version__.split(".")[1]) == 3:  # TF 2.3
                from tensorflow.keras.mixed_precision import experimental as mixed_precision
                policy = mixed_precision.Policy("mixed_float16")
                mixed_precision.set_policy(policy)
            print_fn("Compute dtype: %s" % policy.compute_dtype)
            print_fn("Variable dtype: %s" % policy.variable_dtype)

        # ----------------------
        # Define model
        # ----------------------
        # import ipdb; ipdb.set_trace()

        from tensorflow.keras.layers import Input, Dense, Dropout, Activation, BatchNormalization
        from tensorflow.keras import layers
        from tensorflow.keras import losses
        from tensorflow.keras import optimizers
        from tensorflow.keras.models import Sequential, Model, load_model

        # trainable = True
        trainable = False
        # from_logits = True
        from_logits = False
        fit_verbose = 1
        pretrain = params.pretrain
        pooling = params.pooling
        n_classes = len(sorted(tr_meta[args.target[0]].unique()))

        model_inputs = []
        merge_inputs = []

        if args.use_tile:
            image_shape = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE, 3)
            tile_input_tensor = tf.keras.Input(shape=image_shape, name="tile_image")

            base_img_model = tf.keras.applications.Xception(
                include_top=False,
                weights=pretrain,
                input_shape=None,
                input_tensor=None,
                pooling=pooling)

            print_fn(f"\nNumber of layers in the base image model ({params.base_image_model}): {len(base_img_model.layers)}")
            print_fn("Trainable variables: {}".format(len(base_img_model.trainable_variables)))
            print_fn("Shape of trainable variables at {}: {}".format(0, base_img_model.trainable_variables[0].shape))
            print_fn("Shape of trainable variables at {}: {}".format(-1, base_img_model.trainable_variables[-1].shape))

            print_fn("\nFreeze base model.")
            base_img_model.trainable = trainable  # Freeze the base_img_model
            print_fn("Trainable variables: {}".format(len(base_img_model.trainable_variables)))

            print_fn("\nPrint some layers")
            print_fn("Name of layer {}: {}".format(0, base_img_model.layers[0].name))
            print_fn("Name of layer {}: {}".format(-1, base_img_model.layers[-1].name))

            # training=False makes the base model to run in inference mode so
            # that batchnorm layers are not updated during the fine-tuning stage.
            # x_tile = base_img_model(tile_input_tensor)
            x_tile = base_img_model(tile_input_tensor, training=False)
            # x_tile = base_img_model(tile_input_tensor, training=trainable)
            model_inputs.append(tile_input_tensor)

            # x_tile = Dense(params.dense1_img, activation=tf.nn.relu, name="dense1_img")(x_tile)
            # x_tile = Dense(params.dense2_img, activation=tf.nn.relu, name="dense2_img")(x_tile)
            # x_tile = BatchNormalization(name="batchnorm_im")(x_tile)
            merge_inputs.append(x_tile)
            del tile_input_tensor, x_tile

        # Merge towers
        if len(merge_inputs) > 1:
            mm = layers.Concatenate(axis=1, name="merger")(merge_inputs)
        else:
            mm = merge_inputs[0]

        # Dense layers of the top classfier
        mm = Dense(params.dense1_top, activation=tf.nn.relu, name="dense1_top")(mm)
        # mm = BatchNormalization(name="batchnorm_top")(mm)
        # mm = Dropout(params.dropout1_top)(mm)

        # Output
        output = Dense(n_classes, activation=tf.nn.relu, name="logits")(mm)
        if from_logits is False:
            output = Activation(tf.nn.softmax, dtype="float32", name="softmax")(output)

        # Assemble final model
        model = Model(inputs=model_inputs, outputs=output)

        metrics = [
            tf.keras.metrics.SparseCategoricalAccuracy(name="CatAcc"),
            tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=from_logits, name="CatCrossEnt")
        ]

        if params.optimizer == "SGD":
            optimizer = optimizers.SGD(learning_rate=params.learning_rate, momentum=0.9, nesterov=True)
        elif params.optimizer == "Adam":
            optimizer = optimizers.Adam(learning_rate=params.learning_rate)

        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=from_logits)

        model.compile(loss=loss, optimizer=optimizer, metrics=metrics)


        # import ipdb; ipdb.set_trace()
        print_fn("\nBase model")
        base_img_model.summary(print_fn=print_fn)
        print_fn("\nFull model")
        model.summary(print_fn=print_fn)
        print_fn("Trainable variables: {}".format(len(model.trainable_variables)))

        print_fn(f"Train steps:      {tr_steps}")
        print_fn(f"Validation steps: {vl_steps}")

        # ------------
        # Train
        # ------------
        # import ipdb; ipdb.set_trace()
        # tr_steps = 10  # tr_tiles // params.batch_size // 15  # for debugging
        print_fn("\n{}".format(yellow("Train")))
        timer = Timer()
        history = model.fit(x=train_data,
                            validation_data=val_data,
                            steps_per_epoch=tr_steps,
                            validation_steps=vl_steps,
                            class_weight=class_weight,
                            epochs=params.epochs,
                            verbose=fit_verbose,
                            callbacks=callbacks)
        # del train_data, val_data
        timer.display_timer(print_fn)
        plot_prfrm_metrics(history, title="Train stage", name="tn", outdir=outdir)
        model = load_best_model(outdir)  # load best model

        # Save trained model
        print_fn("\nSave trained model.")
        model.save(outdir/"best_model_trained")

        create_tf_data_eval_kwargs.update({"tfrecords": test_tfr_files, "include_meta": True})
        test_data = create_tf_data(
            **create_tf_data_eval_kwargs,
            **parse_fn_non_train_kwargs
        )

        # Calc hits
        te_tile_preds = calc_tile_preds(test_data, model=model, outdir=outdir)
        te_tile_preds = te_tile_preds.sort_values(["image_id", "tile_id"], ascending=True)
        hits_tn = calc_hits(te_tile_preds, te_meta)
        hits_tn.to_csv(outdir/"hits_tn.csv", index=False)

        # ------------
        # Finetune
        # ------------
        # import ipdb; ipdb.set_trace()
        print_fn("\n{}".format(green("Finetune")))
        unfreeze_top_layers = 50
        # Unfreeze layers of the base model
        for layer in base_img_model.layers[-unfreeze_top_layers:]:
            layer.trainable = True
            print_fn("{}: (trainable={})".format(layer.name, layer.trainable))
        print_fn("Trainable variables: {}".format(len(model.trainable_variables)))

        model.compile(loss=loss,
                      optimizer=optimizers.Adam(learning_rate=params.learning_rate/10),
                      metrics=metrics)

        callbacks = keras_callbacks(outdir, monitor=monitor,
                                    save_best_only=params.save_best_only,
                                    patience=params.patience,
                                    name="finetune")

        total_epochs = history.epoch[-1] + params.finetune_epochs
        timer = Timer()
        history_fn = model.fit(x=train_data,
                               validation_data=val_data,
                               steps_per_epoch=tr_steps,
                               validation_steps=vl_steps,
                               class_weight=class_weight,
                               epochs=total_epochs,
                               initial_epoch=history.epoch[-1]+1,
                               verbose=fit_verbose,
                               callbacks=callbacks)
        del train_data, val_data
        plot_prfrm_metrics(history_fn, title="Finetune stage", name="fn", outdir=outdir)
        timer.display_timer(print_fn)

        # Save trained model
        print_fn("\nSave finetuned model.")
        model.save(outdir/"best_model_finetuned")
        base_img_model.save(outdir/"best_model_img_base_finetuned")


    if args.eval is True:

        print_fn("\n{}".format(bold("Test set predictions.")))
        timer = Timer()
        # calc_tf_preds(test_data, te_meta, model, outdir, args, name="test", print_fn=print_fn)
        # import ipdb; ipdb.set_trace()
        te_tile_preds = calc_tile_preds(test_data, model=model, outdir=outdir)
        te_tile_preds = te_tile_preds.sort_values(["image_id", "tile_id"], ascending=True)
        te_tile_preds.to_csv(outdir/"te_tile_preds.csv", index=False)
        # print(te_tile_preds[["image_id", "tile_id", "y_true", "y_pred_label", "prob"]][:20])
        # print(te_tile_preds.iloc[:20, 1:])
        del test_data

        # Calc hits
        hits_fn = calc_hits(te_tile_preds, te_meta)
        hits_fn.to_csv(outdir/"hits_fn.csv", index=False)

        # from sklearn.metrics import roc_curve, roc_auc_score, auc, average_precision_score
        # roc_auc = roc_auc_score(te_tile_preds["y_true"], te_tile_preds["prob"], average="macro")

        import ipdb; ipdb.set_trace()
        roc_auc = {}
        import matplotlib.pyplot as plt
        from sklearn.metrics import roc_curve, auc
        fig, ax = plt.subplots(figsize=(8, 6))
        for true in range(0, n_classes):
            if true in te_tile_preds["y_true"].values:
                fpr, tpr, thresh = roc_curve(te_tile_preds["y_true"], te_tile_preds["prob"], pos_label=true)
                roc_auc[i] = auc(fpr, tpr)
                plt.plot(fpr, tpr, linestyle='--', label=f"Class {true} vs Rest")
            else:
                roc_auc[i] = None

        # plt.plot([0,0], [1,1], '--', label="Random")
        plt.title("Multiclass ROC Curve")
        plt.xlabel("FPR")
        plt.ylabel("TPR")
        plt.legend(loc="best")
        plt.savefig(outdir/"Multiclass ROC", dpi=70);

        # Avergae precision score
        from sklearn.metrics import average_precision_score
        y_true_vec = te_tile_preds.y_true.values
        y_true_onehot = np.zeros((y_true_vec.size, n_classes))
        y_true_onehot[np.arange(y_true_vec.size), y_true_vec] = 1
        y_probs = te_tile_preds[[c for c in te_tile_preds.columns if "prob_" in c]]
        print_fn("\nAvearge precision")
        print_fn("Micro    {}".format(average_precision_score(y_true_onehot, y_probs, average="micro")))
        print_fn("Macro    {}".format(average_precision_score(y_true_onehot, y_probs, average="macro")))
        print_fn("Wieghted {}".format(average_precision_score(y_true_onehot, y_probs, average="weighted")))
        print_fn("Samples  {}".format(average_precision_score(y_true_onehot, y_probs, average="samples")))


        import ipdb; ipdb.set_trace()
        agg_method = "mean"
        # agg_by = "smp"
        agg_by = "image_id"
        smp_preds = agg_tile_preds(te_tile_preds, agg_by=agg_by, meta=te_meta, agg_method=agg_method)

        timer.display_timer(print_fn)

    lg.close_logger()
Exemple #16
0
def main():
    policy = mixed_precision.Policy("mixed_float16")
    mixed_precision.set_global_policy(policy)

    print("Compute dtype: %s" % policy.compute_dtype)
    print("Variable dtype: %s" % policy.variable_dtype)

    args = parser.parse_args()
    tta_config = TTAConfig.from_json(args.model_config)
    model = TTAForPretraining(tta_config)
    model(
        {
            "input_word_ids": tf.keras.Input(shape=[None], dtype=tf.int64),
            "input_mask": tf.keras.Input(shape=[None], dtype=tf.int64),
        }
    )
    model.summary()

    with open(args.spm_model, "rb") as f:
        tokenizer = text.SentencepieceTokenizer(f.read(), add_bos=True, add_eos=True)

    def preprocess_and_make_label(strings: tf.Tensor):
        tokenized = tokenizer.tokenize(strings)
        input_mask = tf.ragged.map_flat_values(tf.ones_like, tokenized)

        input_word_ids = tokenized.to_tensor(shape=[tokenized.shape[0], tta_config.max_position_ids])
        labels = tokenized.to_tensor(shape=[tokenized.shape[0], tta_config.max_position_ids], default_value=-1)
        input_mask = input_mask.to_tensor(shape=[tokenized.shape[0], tta_config.max_position_ids])

        return {
            "input_word_ids": input_word_ids,
            "input_mask": input_mask,
        }, labels

    trainset = (
        tf.data.TextLineDataset(args.train_data, num_parallel_reads=tf.data.AUTOTUNE)
        .shuffle(100000)
        .repeat()
        .batch(args.batch_size)
        .map(preprocess_and_make_label, num_parallel_calls=tf.data.AUTOTUNE)
    )
    devset = (
        tf.data.TextLineDataset(args.dev_data.split(","), num_parallel_reads=tf.data.AUTOTUNE)
        .shuffle(50000)
        .take(10000)
        .batch(args.batch_size)
        .map(preprocess_and_make_label, num_parallel_calls=tf.data.AUTOTUNE)
    )
    print(f"Total step: {args.steps_per_epoch * args.target_epoch}")
    print(f"learning rate: {args.learning_rate}, warmup ratio: {args.warmup_ratio}")
    model.compile(
        optimizer=tf.keras.optimizers.Adam(
            learning_rate=LinearWarmupAndDecayScheduler(
                args.learning_rate, warmup_ratio=args.warmup_ratio, total_steps=args.steps_per_epoch * args.target_epoch
            )
        ),
        loss=sparse_categorical_crossentropy_with_ignore,
        metrics=[sparse_categorical_accuracy_with_ignore],
    )
    model.fit(
        trainset,
        validation_data=devset,
        steps_per_epoch=args.steps_per_epoch,
        epochs=args.target_epoch,
        callbacks=[
            tf.keras.callbacks.ModelCheckpoint("./models/model-{epoch}", save_best_only=True, verbose=True, save_weights_only=True),
        ],
    )
Exemple #17
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)

    time_str = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
    get_root_logger(
        log_file=os.path.join(args.train_url, f"train_{time_str}.log"))

    if args.fp16:
        from tensorflow.keras import mixed_precision

        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)

    train_optimizer = build_tf_optimizers(cfg.dict["optimizer"])

    mirrored_strategy = tf.distribute.MirroredStrategy()

    with mirrored_strategy.scope():
        model = build_tf_models(cfg.dict["model"])
        model.compile(optimizer=train_optimizer, )

    train_dataset_obj = build_datasets(cfg.dict["data"]["train"])
    val_dataset_obj = build_datasets(cfg.dict["data"]["val"])

    if cfg.dict["dataset_type"] == "TFRecordDataset":
        train_dataset = train_dataset_obj()
        val_dataset = val_dataset_obj()

        num_train_samples = cfg.dict["data"]["num_train_samples"]
        num_val_samples = cfg.dict["data"]["num_val_samples"]
    else:
        train_dataset = train_dataset_obj.get_data_dict()
        val_dataset = val_dataset_obj.get_data_dict()

        num_train_samples = len(train_dataset_obj)
        num_val_samples = len(val_dataset_obj)

    _adjust_batchsize(
        cfg,
        mirrored_strategy.num_replicas_in_sync,
        num_train_samples=num_train_samples,
        num_val_samples=num_val_samples,
    )

    _adjust_lr(cfg, num_replicas=mirrored_strategy.num_replicas_in_sync)
    _adjust_callback(cfg, args.train_url)

    for pipeline in cfg.dict["train_pipeline"]:
        train_dataset = build_tf_pipelines(pipeline)(train_dataset)

    for pipeline in cfg.dict["val_pipeline"]:
        val_dataset = build_tf_pipelines(pipeline)(val_dataset)

    callback_list = []
    for callback in cfg.dict["callbacks"]:
        callback_list.append(build_tf_callbacks(callback))

    model.fit(
        x=train_dataset,
        epochs=cfg.dict["train_cfg"]["epochs"],
        steps_per_epoch=cfg.dict["train_cfg"]["steps_per_epoch"],
        validation_data=val_dataset,
        validation_steps=cfg.dict["train_cfg"]["val_steps"],
        callbacks=callback_list,
        verbose=cfg.dict["train_cfg"]["use_keras_progbar"],
    )
Exemple #18
0
    def __init__(self,
                 observation_space,
                 action_space,
                 model_f,
                 m_dir=None,
                 log_name=None,
                 start_step=0,
                 mixed_float=False):
        """
        Parameters
        ----------
        observation_space : gym.Space
            Observation space of the environment.
        action_space : gym.Space
            Action space of the environment. Current agent expects only
            a discrete action space.
        model_f
            A function that returns actor, critic models. 
            It should take obeservation space and action space as inputs.
            It should not compile the model.
        m_dir : str
            A model directory to load the model if there's a model to load
        log_name : str
            A name for log. If not specified, will be set to current time.
            - If m_dir is specified yet no log_name is given, it will continue
            counting.
            - If m_dir and log_name are both specified, it will load model from
            m_dir, but will record as it is the first training.
        start_step : int
            Total step starts from start_step
        mixed_float : bool
            Whether or not to use mixed precision
        """
        # model : The actual training model
        # t_model : Fixed target model
        print('Model directory : {}'.format(m_dir))
        print('Log name : {}'.format(log_name))
        print('Starting from step {}'.format(start_step))
        print(f'Use mixed float? {mixed_float}')
        self.action_space = action_space
        self.action_range = action_space.high - action_space.low
        self.action_shape = action_space.shape
        self.observation_space = observation_space
        self.mixed_float = mixed_float
        if mixed_float:
            policy = mixed_precision.Policy('mixed_float16')
            mixed_precision.set_global_policy(policy)

        assert hp.Algorithm in hp.available_algorithms, "Wrong Algorithm!"

        # Special variables
        if hp.Algorithm == 'V-MPO':

            self.eta = tf.Variable(1.0,
                                   trainable=True,
                                   name='eta',
                                   dtype='float32')
            self.alpha_mu = tf.Variable(1.0,
                                        trainable=True,
                                        name='alpha_mu',
                                        dtype='float32')
            self.alpha_sig = tf.Variable(1.0,
                                         trainable=True,
                                         name='alpha_sig',
                                         dtype='float32')

        elif hp.Algorithm == 'A2C':
            action_num = tf.reduce_prod(self.action_shape)
            self.log_sigma = tf.Variable(tf.fill((action_num), 0.1),
                                         trainable=True,
                                         name='sigma',
                                         dtype='float32')

        #Inputs
        if hp.ICM_ENABLE:
            actor, critic, icm_models = model_f(observation_space,
                                                action_space)
            encoder, inverse, forward = icm_models
            self.models = {
                'actor': actor,
                'critic': critic,
                'encoder': encoder,
                'inverse': inverse,
                'forward': forward,
            }
        else:
            actor, critic = model_f(observation_space, action_space)
            self.models = {
                'actor': actor,
                'critic': critic,
            }
        targets = ['actor', 'critic']

        # Common ADAM optimizer; in V-MPO loss is merged together
        common_lr = tf.function(partial(self._lr, 'common'))
        self.common_optimizer = keras.optimizers.Adam(
            learning_rate=common_lr,
            epsilon=hp.lr['common'].epsilon,
            global_clipnorm=hp.lr['common'].grad_clip,
        )
        if self.mixed_float:
            self.common_optimizer = mixed_precision.LossScaleOptimizer(
                self.common_optimizer)

        for name, model in self.models.items():
            lr = tf.function(partial(self._lr, name))
            optimizer = keras.optimizers.Adam(
                learning_rate=lr,
                epsilon=hp.lr[name].epsilon,
                global_clipnorm=hp.lr[name].grad_clip,
            )
            if self.mixed_float:
                optimizer = mixed_precision.LossScaleOptimizer(optimizer)
            model.compile(optimizer=optimizer)
            model.summary()

        # Load model if specified
        if m_dir is not None:
            for name, model in self.models.items():
                model.load_weights(path.join(m_dir, name))
            print(f'model loaded : {m_dir}')

        # Initialize target model
        self.t_models = {}
        for name in targets:
            model = self.models[name]
            self.t_models[name] = keras.models.clone_model(model)
            self.t_models[name].set_weights(model.get_weights())

        # File writer for tensorboard
        if log_name is None:
            self.log_name = datetime.now().strftime('%m_%d_%H_%M_%S')
        else:
            self.log_name = log_name
        self.file_writer = tf.summary.create_file_writer(
            path.join('logs', self.log_name))
        self.file_writer.set_as_default()
        print('Writing logs at logs/' + self.log_name)

        # Scalars
        self.start_training = False
        self.total_steps = tf.Variable(start_step, dtype=tf.int64)

        # Savefile folder directory
        if m_dir is None:
            self.save_dir = path.join('savefiles', self.log_name)
            self.save_count = 0
        else:
            if log_name is None:
                self.save_dir, self.save_count = path.split(m_dir)
                self.save_count = int(self.save_count)
            else:
                self.save_dir = path.join('savefiles', self.log_name)
                self.save_count = 0
        self.model_dir = None
Exemple #19
0
def build(model_fn: Callable[[], Union[Model, List[Model]]],
          optimizer_fn: Union[str, Scheduler, Callable, List[str], List[Callable], List[Scheduler], None],
          weights_path: Union[str, None, List[Union[str, None]]] = None,
          model_name: Union[str, List[str], None] = None,
          mixed_precision: bool = False) -> Union[Model, List[Model]]:
    """Build model instances and associate them with optimizers.

    This method can be used with TensorFlow models / optimizers:
    ```python
    model_def = fe.architecture.tensorflow.LeNet
    model = fe.build(model_fn = model_def, optimizer_fn="adam")
    model = fe.build(model_fn = model_def, optimizer_fn=lambda: tf.optimizers.Adam(lr=0.1))
    model = fe.build(model_fn = model_def, optimizer_fn="adam", weights_path="~/weights.h5")
    ```

    This method can be used with PyTorch models / optimizers:
    ```python
    model_def = fe.architecture.pytorch.LeNet
    model = fe.build(model_fn = model_def, optimizer_fn="adam")
    model = fe.build(model_fn = model_def, optimizer_fn=lambda x: torch.optim.Adam(params=x, lr=0.1))
    model = fe.build(model_fn = model_def, optimizer_fn="adam", weights_path="~/weights.pt)
    ```

    Args:
        model_fn: A function that define model(s).
        optimizer_fn: Optimizer string/definition or a list of optimizer instances/strings. The number of optimizers
            provided here should match the number of models generated by the `model_fn`.
        model_name: Name(s) of the model(s) that will be used for logging purpose. If None, a name will be
            automatically generated and assigned.
        weights_path: Path(s) from which to load model weights. If not None, then the number of weight paths provided
            should match the number of models generated by the `model_fn`.
        mixed_precision: Whether to enable mixed-precision network operations.

    Returns:
        models: The model(s) built by FastEstimator.
    """
    def _generate_model_names(num_names):
        names = ["model" if i + build.count == 0 else "model{}".format(i + build.count) for i in range(num_names)]
        build.count += num_names
        return names
    if not hasattr(build, "count"):
        build.count = 0
    # tensorflow models requires setting global policies prior to model creation. Since there is no way to know the
    # framework of model, setting the policy for both tf and pytorch here.
    if mixed_precision:
        mixed_precision_tf.set_global_policy(mixed_precision_tf.Policy('mixed_float16'))
    else:
        mixed_precision_tf.set_global_policy(mixed_precision_tf.Policy('float32'))
    if  torch.cuda.device_count() > 1:
        if not isinstance(tf.distribute.get_strategy(), tf.distribute.MirroredStrategy):
            tf.distribute.experimental_set_strategy(tf.distribute.MirroredStrategy())
    models, optimizer_fn = to_list(model_fn()), to_list(optimizer_fn)
    # fill optimizers if optimizer_fn is None
    if not optimizer_fn:
        optimizer_fn = [None] * len(models)
    # generate names
    if not model_name:
        model_name = _generate_model_names(len(models))
    model_name = to_list(model_name)
    # load weights
    if weights_path:
        weights_path = to_list(weights_path)
    else:
        weights_path = [None] * len(models)
    assert len(models) == len(optimizer_fn) == len(weights_path) == len(model_name), \
        "Found inconsistency in number of models, optimizers, model_name or weights"
    # create optimizer
    for idx, (model, optimizer_def, weight, name) in enumerate(zip(models, optimizer_fn, weights_path, model_name)):
        models[idx] = trace_model(_fe_compile(model, optimizer_def, weight, name, mixed_precision),
                                  model_idx=idx if len(models) > 1 else -1,
                                  model_fn=model_fn,
                                  optimizer_fn=optimizer_def,
                                  weights_path=weight)
    if len(models) == 1:
        models = models[0]
    return models
Exemple #20
0
def run_training(
    model_f,
    lr_f,
    name,
    epochs,
    batch_size,
    steps_per_epoch,
    train_vid_paths,
    val_vid_paths,
    test_vid_paths,
    frame_size,
    interpolate_ratios,
    mixed_float=True,
    notebook=True,
    profile=False,
    load_model_path=None,
):
    if mixed_float:
        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_global_policy(policy)

    st = time.time()

    inputs = keras.Input((frame_size[1], frame_size[0], 6))
    mymodel = AnimeModel(inputs, model_f, interpolate_ratios)
    if load_model_path:
        mymodel.load_weights(load_model_path)
        print(f'Loaded from : {load_model_path}')
    loss = keras.losses.MeanAbsoluteError()
    mymodel.compile(
        optimizer='adam',
        loss=loss,
    )

    logdir = 'logs/fit/' + name
    if profile:
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=logdir,
            histogram_freq=1,
            profile_batch='3,5',
            update_freq=steps_per_epoch)
    else:
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=logdir,
            histogram_freq=1,
            profile_batch=0,
            update_freq=steps_per_epoch)

    lr_callback = keras.callbacks.LearningRateScheduler(lr_f, verbose=1)

    savedir = 'savedmodels/' + name + '/{epoch}'
    save_callback = keras.callbacks.ModelCheckpoint(savedir,
                                                    save_weights_only=True,
                                                    verbose=1)

    if notebook:
        tqdm_callback = TqdmNotebookCallback(metrics=['loss'],
                                             leave_inner=False)
    else:
        tqdm_callback = TqdmCallback()

    train_ds = create_train_dataset(train_vid_paths, frame_size, batch_size)
    val_ds = create_train_dataset(val_vid_paths, frame_size, batch_size, True)

    image_callback = ValFigCallback(val_ds, logdir)

    mymodel.fit(
        x=train_ds,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        callbacks=[
            tensorboard_callback,
            lr_callback,
            save_callback,
            tqdm_callback,
            image_callback,
        ],
        verbose=0,
        validation_data=val_ds,
        validation_steps=10,
    )

    delta = time.time() - st
    hours, remain = divmod(delta, 3600)
    minutes, seconds = divmod(remain, 60)
    print(
        f'Took {hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds')
Exemple #21
0
def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,
          customize):

    try:
        from comet_ml import Experiment
        experiment = Experiment(
            project_name="particleflow-tf",
            auto_metric_logging=True,
            auto_param_logging=True,
            auto_histogram_weight_logging=True,
            auto_histogram_gradient_logging=False,
            auto_histogram_activation_logging=False,
        )
    except Exception as e:
        print("Failed to initialize comet-ml dashboard")
        experiment = None
    """Train a model defined by config"""
    config_file_path = config
    config, config_file_stem = parse_config(config,
                                            nepochs=nepochs,
                                            weights=weights)

    if plot_freq:
        config["callbacks"]["plot_freq"] = plot_freq

    if customize:
        config = customization_functions[customize](config)

    if recreate or (weights is None):
        outdir = create_experiment_dir(prefix=prefix + config_file_stem + "_",
                                       suffix=platform.node())
    else:
        outdir = str(Path(weights).parent)

    # Decide tf.distribute.strategy depending on number of available GPUs
    strategy, num_gpus = get_strategy()
    #if "CPU" not in strategy.extended.worker_devices[0]:
    #    nvidia_smi_call = "nvidia-smi --query-gpu=timestamp,name,pci.bus_id,pstate,power.draw,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used --format=csv -l 1 -f {}/nvidia_smi_log.csv".format(outdir)
    #    p = subprocess.Popen(shlex.split(nvidia_smi_call))

    ds_train, num_train_steps = get_datasets(config["train_test_datasets"],
                                             config, num_gpus, "train")
    ds_test, num_test_steps = get_datasets(config["train_test_datasets"],
                                           config, num_gpus, "test")
    ds_val, ds_info = get_heptfds_dataset(
        config["validation_dataset"], config, num_gpus, "test",
        config["setup"]["num_events_validation"])
    ds_val = ds_val.batch(5)

    if ntrain:
        ds_train = ds_train.take(ntrain)
        num_train_steps = ntrain
    if ntest:
        ds_test = ds_test.take(ntest)
        num_test_steps = ntest

    print("num_train_steps", num_train_steps)
    print("num_test_steps", num_test_steps)
    total_steps = num_train_steps * config["setup"]["num_epochs"]
    print("total_steps", total_steps)

    if experiment:
        experiment.set_name(outdir)
        experiment.log_code("mlpf/tfmodel/model.py")
        experiment.log_code("mlpf/tfmodel/utils.py")
        experiment.log_code(config_file_path)

    shutil.copy(config_file_path, outdir + "/config.yaml"
                )  # Copy the config file to the train dir for later reference

    with strategy.scope():
        lr_schedule, optim_callbacks = get_lr_schedule(config,
                                                       steps=total_steps)
        opt = get_optimizer(config, lr_schedule)

        if config["setup"]["dtype"] == "float16":
            model_dtype = tf.dtypes.float16
            policy = mixed_precision.Policy("mixed_float16")
            mixed_precision.set_global_policy(policy)
            opt = mixed_precision.LossScaleOptimizer(opt)
        else:
            model_dtype = tf.dtypes.float32

        model = make_model(config, model_dtype)

        # Build the layers after the element and feature dimensions are specified
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        initial_epoch = 0
        if weights:
            # We need to load the weights in the same trainable configuration as the model was set up
            configure_model_weights(
                model, config["setup"].get("weights_config", "all"))
            model.load_weights(weights, by_name=True)
            initial_epoch = int(weights.split("/")[-1].split("-")[1])
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        config = set_config_loss(config, config["setup"]["trainable"])
        configure_model_weights(model, config["setup"]["trainable"])
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        print("model weights")
        tw_names = [m.name for m in model.trainable_weights]
        for w in model.weights:
            print("layer={} trainable={} shape={} num_weights={}".format(
                w.name, w.name in tw_names, w.shape, np.prod(w.shape)))

        loss_dict, loss_weights = get_loss_dict(config)
        model.compile(
            loss=loss_dict,
            optimizer=opt,
            sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ] + [
                    SingleClassRecall(
                        icls, name="rec_cls{}".format(icls), dtype=tf.float64)
                    for icls in range(config["dataset"]["num_output_classes"])
                ]
            },
        )
        model.summary()

    callbacks = prepare_callbacks(config["callbacks"],
                                  outdir,
                                  ds_val,
                                  ds_info,
                                  comet_experiment=experiment)
    callbacks.append(optim_callbacks)

    fit_result = model.fit(
        ds_train.repeat(),
        validation_data=ds_test.repeat(),
        epochs=initial_epoch + config["setup"]["num_epochs"],
        callbacks=callbacks,
        steps_per_epoch=num_train_steps,
        validation_steps=num_test_steps,
        initial_epoch=initial_epoch,
    )

    history_path = Path(outdir) / "history"
    history_path = str(history_path)
    with open("{}/history.json".format(history_path), "w") as fi:
        json.dump(fit_result.history, fi)

    weights = get_best_checkpoint(outdir)
    print("Loading best weights that could be found from {}".format(weights))
    model.load_weights(weights, by_name=True)

    model.save(outdir + "/model_full", save_format="tf")

    print("Training done.")
Exemple #22
0
def model_scope(config, total_steps, weights, horovod_enabled=False):
    lr_schedule, optim_callbacks, lr = get_lr_schedule(config,
                                                       steps=total_steps)
    opt = get_optimizer(config, lr_schedule)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
        opt = mixed_precision.LossScaleOptimizer(opt)
    else:
        model_dtype = tf.dtypes.float32

    model = make_model(config, model_dtype)

    # Build the layers after the element and feature dimensions are specified
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    initial_epoch = 0
    loaded_opt = None

    if weights:
        if lr_schedule:
            raise Exception(
                "Restoring the optimizer state with a learning rate schedule is currently not supported"
            )

        # We need to load the weights in the same trainable configuration as the model was set up
        configure_model_weights(model,
                                config["setup"].get("weights_config", "all"))
        model.load_weights(weights, by_name=True)
        opt_weight_file = weights.replace("hdf5",
                                          "pkl").replace("/weights-", "/opt-")
        if os.path.isfile(opt_weight_file):
            loaded_opt = pickle.load(open(opt_weight_file, "rb"))

        initial_epoch = int(weights.split("/")[-1].split("-")[1])
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    config = set_config_loss(config, config["setup"]["trainable"])
    configure_model_weights(model, config["setup"]["trainable"])
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    print("model weights")
    tw_names = [m.name for m in model.trainable_weights]
    for w in model.weights:
        print("layer={} trainable={} shape={} num_weights={}".format(
            w.name, w.name in tw_names, w.shape, np.prod(w.shape)))

    loss_dict, loss_weights = get_loss_dict(config)

    model.compile(
        loss=loss_dict,
        optimizer=opt,
        sample_weight_mode="temporal",
        loss_weights=loss_weights,
        metrics={
            "cls": [
                FlattenedCategoricalAccuracy(name="acc_unweighted",
                                             dtype=tf.float64),
                FlattenedCategoricalAccuracy(
                    use_weights=True, name="acc_weighted", dtype=tf.float64),
            ] + [
                SingleClassRecall(
                    icls, name="rec_cls{}".format(icls), dtype=tf.float64)
                for icls in range(config["dataset"]["num_output_classes"])
            ]
        },
    )

    model.summary()

    # Set the optimizer weights
    if loaded_opt:

        def model_weight_setting():
            grad_vars = model.trainable_weights
            zero_grads = [tf.zeros_like(w) for w in grad_vars]
            model.optimizer.apply_gradients(zip(zero_grads, grad_vars))
            if model.optimizer.__class__.__module__ == "keras.optimizers.optimizer_v1":
                model.optimizer.optimizer.optimizer.set_weights(
                    loaded_opt["weights"])
            else:
                model.optimizer.set_weights(loaded_opt["weights"])

        # FIXME: check that this still works with multiple GPUs
        strategy = tf.distribute.get_strategy()
        strategy.run(model_weight_setting)

    return model, optim_callbacks, initial_epoch
def run_training(
    backbone_f,
    lr_f,
    name,
    epochs,
    steps_per_epoch,
    # batch_size,
    intermediate_filters,
    kernel_size,
    stride,
    rfcn_window,
    anchor_ratios,
    anchor_scales,
    class_names,
    bbox_sizes,
    train_dir,
    val_dir,
    img_size,
    frozen_layers=[],
    mixed_float=True,
    notebook=True,
    load_model_path=None,
    profile=False,
):
    """
    img_size:
        (WIDTH, HEIGHT)
    frozen layers:
        one or more of ['rfcn','rpn','backbone']
    """

    if mixed_float:
        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_global_policy(policy)

    st = time.time()

    inputs = keras.Input((img_size[0], img_size[1], 3))
    if class_names is None:
        num_classes = 1
    else:
        num_classes = len(class_names) + 1
    mymodel = ObjectDetector(backbone_f, intermediate_filters, kernel_size,
                             stride, img_size, num_classes, rfcn_window,
                             anchor_ratios, anchor_scales)

    if load_model_path:
        mymodel.load_weights(load_model_path).expect_partial()
        print('loaded from : ' + load_model_path)

    #-------------------- Freeze R-FCN
    if 'rfcn' in frozen_layers:
        print('####################Freezing rfcn layers')
        mymodel.rfcn_cls_conv.trainable = False
        mymodel.rfcn_reg_conv.trainable = False
    else:
        if 'rfcn_cls' in frozen_layers:
            print('####################Freezing rfcn_cls layers')
            mymodel.rfcn_cls_conv.trainable = False
        if 'rfcn_reg' in frozen_layers:
            print('####################Freezing rfcn_reg layers')
            mymodel.rfcn_reg_conv.trainable = False
    #--------------------Freeze RPN
    if 'rpn' in frozen_layers:
        print('####################Freezing rpn layers')
        mymodel.rpn_inter_conv.trainable = False
        mymodel.rpn_cls_conv.trainable = False
        mymodel.rpn_reg_conv.trainable = False
    else:
        if 'rpn_inter' in frozen_layers:
            print('####################Freezing rpn_inter layers')
            mymodel.rpn_inter_conv.trainable = False
        if 'rpn_cls' in frozen_layers:
            print('####################Freezing rpn_cls layers')
            mymodel.rpn_cls_conv.trainable = False
        if 'rpn_reg' in frozen_layers:
            print('####################Freezing rpn_reg layers')
            mymodel.rpn_reg_conv.trainable = False
    #---------------------Freeze Backbone
    if 'backbone' in frozen_layers:
        print('####################Freezing backbone layers')
        mymodel.backbone_model.trainable = False

    mymodel.compile(optimizer='adam', )

    logdir = 'logs/fit/' + name
    if profile:
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=logdir,
            histogram_freq=1,
            profile_batch='7,9',
            update_freq='epoch')
    else:
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=logdir,
            histogram_freq=1,
            profile_batch=0,
            update_freq='epoch')
    lr_callback = keras.callbacks.LearningRateScheduler(lr_f, verbose=1)

    savedir = 'savedmodels/' + name + '/{epoch}'
    save_callback = keras.callbacks.ModelCheckpoint(savedir,
                                                    save_weights_only=True,
                                                    verbose=1)

    if notebook:
        tqdm_callback = TqdmNotebookCallback(metrics=['loss'],
                                             leave_inner=False)
    else:
        tqdm_callback = TqdmCallback()

    train_ds = create_train_dataset(
        train_dir,
        img_size,
        class_names,
        bbox_sizes,
        buffer_size=1000,
    )
    val_ds = create_train_dataset(
        val_dir,
        img_size,
        class_names,
        bbox_sizes,
        buffer_size=100,
        val_data=True,
    )

    image_callback = ValFigCallback(val_ds, logdir)

    mymodel.fit(
        x=train_ds,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        # steps_per_epoch=10,
        callbacks=[
            tensorboard_callback,
            lr_callback,
            save_callback,
            tqdm_callback,
            image_callback,
        ],
        verbose=0,
        # validation_data=val_ds,
        # validation_steps=100,
    )

    delta = time.time() - st
    hours, remain = divmod(delta, 3600)
    minutes, seconds = divmod(remain, 60)
    print(datetime.now().strftime('%Y/%m/%d %H:%M:%S'))
    print(
        f'Took {hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds')
Exemple #24
0
print(f"Tensorflow ver. {tf.__version__}")

# verify GPU devices are available and ready
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA
devices = tf.config.list_physical_devices('GPU')
assert len(devices) != 0, "No GPU devices found."

# ------------------------------------------------------------------
# System Configurations
# ------------------------------------------------------------------
if config.MIRROR_STRATEGY:
    strategy = tf.distribute.MirroredStrategy()
    print('Multi-GPU enabled')

if config.MIXED_PRECISION:
    policy = mixed_precision.Policy('mixed_float16')
    mixed_precision.set_global_policy(policy)
    print('Mixed precision enabled')

if config.XLA_ACCELERATE:
    tf.config.optimizer.set_jit(True)
    print('Accelerated Linear Algebra enabled')

# Disable AutoShard, data lives in memory, use in memory options
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = \
    tf.data.experimental.AutoShardPolicy.OFF


# ---------------------------------------------------------------------------
# script train.py
def use_mixed_precision():
    from tensorflow.keras import mixed_precision
    policy = mixed_precision.Policy('mixed_float16')
    mixed_precision.set_global_policy(policy)