def train_model(
        train_generator: Iterable[Tuple[np.ndarray, np.ndarray]],
        model: tf.keras.Model,
        optimizer: tf.keras.optimizers.Optimizer,
        loss: tf.keras.losses.Loss,
        epochs: int,
        val_generator: Iterable[Tuple[np.ndarray, np.ndarray]],
        metrics: List[tf.keras.metrics.Metric],
        callbacks: List[tf.keras.callbacks.Callback],
        path_to_save_results: str,
        class_weights: Optional[Dict[int, float]] = None,
        loss_weights: Optional[Dict['str', float]] = None) -> tf.keras.Model:
    # create directory for saving results
    if not os.path.exists(path_to_save_results):
        os.makedirs(path_to_save_results)
    # compile model
    model.compile(optimizer=optimizer,
                  loss=loss,
                  metrics=metrics,
                  loss_weights=loss_weights)
    model.fit(train_generator,
              epochs=epochs,
              callbacks=callbacks,
              validation_data=val_generator,
              verbose=2,
              class_weight=class_weights)
    return model
Ejemplo n.º 2
0
def compile_bert_classifier(
    model: tf.keras.Model,
    loss: tf.keras.losses = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True),
    learning_rate: float = 2e-5,
    metrics: List[Union[Text, tf.keras.metrics.Metric]] = None):
  """Compile the BERT classifier using suggested parameters.

  Args:
    model: A keras model. Most likely the output of build_bert_classifier.
    loss: tf.keras.losses. The suggested loss function expects integer labels
      (e.g. 0, 1, 2). If the labels are one-hot encoded, consider using
      tf.keras.lossesCategoricalCrossEntropy with from_logits set to true.
    learning_rate: Suggested learning rate to be used in
      tf.keras.optimizer.Adam. The three suggested learning_rates for
      fine-tuning are [2e-5, 3e-5, 5e-5].
    metrics: Default None will use ['sparse_categorical_accuracy']. An array of
      strings or tf.keras.metrics.

  Returns:
    None.
  """
  if metrics is None:
    metrics = ["sparse_categorical_accuracy"]

  model.compile(
      optimizer=tf.keras.optimizers.Adam(learning_rate),
      loss=loss,
      metrics=metrics)
Ejemplo n.º 3
0
def train(model_rs: tf.keras.Model, model_kge: tf.keras.Model, train_data: List[Tuple[int, int, int]],
          test_data: List[Tuple[int, int, int]], kg: List[Tuple[int, int, int]], topk_data: TopkData,
          optimizer_rs=None, optimizer_kge=None, kge_interval=3, epochs=100, batch=512):
    if optimizer_rs is None:
        optimizer_rs = tf.keras.optimizers.Adam()
    if optimizer_kge is None:
        optimizer_kge = tf.keras.optimizers.Adam()

    def xy(data):
        user_id = tf.constant([d[0] for d in data], dtype=tf.int32)
        item_id = tf.constant([d[1] for d in data], dtype=tf.int32)
        head_id = tf.constant([d[1] for d in data], dtype=tf.int32)
        label = tf.constant([d[2] for d in data], dtype=tf.float32)
        return {'user_id': user_id, 'item_id': item_id, 'head_id': head_id}, label

    def xy_kg(kg):
        item_id = tf.constant([d[0] for d in kg], dtype=tf.int32)
        head_id = tf.constant([d[0] for d in kg], dtype=tf.int32)
        relation_id = tf.constant([d[1] for d in kg], dtype=tf.int32)
        tail_id = tf.constant([d[2] for d in kg], dtype=tf.int32)
        label = tf.constant([0] * len(kg), dtype=tf.float32)
        return {'item_id': item_id, 'head_id': head_id, 'relation_id': relation_id, 'tail_id': tail_id}, label

    train_ds = tf.data.Dataset.from_tensor_slices(xy(train_data)).shuffle(len(train_data)).batch(batch)
    test_ds = tf.data.Dataset.from_tensor_slices(xy(test_data)).batch(batch)
    kg_ds = tf.data.Dataset.from_tensor_slices(xy_kg(kg)).shuffle(len(kg)).batch(batch)

    model_rs.compile(optimizer=optimizer_rs, loss='binary_crossentropy', metrics=['AUC', 'Precision', 'Recall'])
    model_kge.compile(optimizer=optimizer_kge, loss=lambda y_true, y_pre: y_pre)

    for epoch in range(epochs):
        model_rs.fit(train_ds, epochs=epoch + 1, verbose=0, validation_data=test_ds,
                     callbacks=[RsCallback(topk_data, _get_score_fn(model_rs))], initial_epoch=epoch)
        if epoch % kge_interval == 0:
            model_kge.fit(kg_ds, epochs=epoch + 1, verbose=0, callbacks=[_KgeCallback()], initial_epoch=epoch)
Ejemplo n.º 4
0
    def train_model(model: tf.keras.Model,
                    train_data,
                    validation_data,
                    optimizer,
                    loss='categorical_crossentropy',
                    epochs=3,
                    verbose=1,
                    batch_size=None,
                    callbacks=None):
        # init
        K.clear_session()
        tf.random.set_seed(51)
        np.random.seed(51)

        # optimizer
        opt = tf.keras.optimizers.Adam() if optimizer is None else optimizer

        # compile
        model.compile(opt, loss=loss, metrics=["acc"])

        # fit
        history = model.fit(train_data,
                            validation_data=validation_data,
                            epochs=epochs,
                            verbose=verbose,
                            callbacks=callbacks,
                            batch_size=batch_size)
        return history
Ejemplo n.º 5
0
def train_model(model: tf.keras.Model,
                tr_dataset: tf.data.Dataset,
                val_dataset: tf.data.Dataset,
                checkpoint_dir: str,
                epochs: int = 100,
                patience: int = 10,
                save_one: bool = False):
    def loss(labels, logits):
        return tf.keras.losses.sparse_categorical_crossentropy(
            labels, logits, from_logits=True)

    optimizer = tf.keras.optimizers.Adam()
    model.compile(optimizer=optimizer, loss=loss)
    early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                                  patience=patience)

    logger.info("Begin training... (this will take a while)")
    checkpoint_prefix = os.path.join(checkpoint_dir,
                                     "ckpt" if save_one else "ckpt_{epoch}")
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_prefix,
        save_best_only=save_one,
        save_weights_only=True)
    history = model.fit(tr_dataset,
                        epochs=epochs,
                        callbacks=[checkpoint_callback, early_stop],
                        validation_data=val_dataset)
    logger.info(
        "Training stopped, no improvement after {} epochs".format(patience))
Ejemplo n.º 6
0
  def compile_model(self,
                    model: tf.keras.Model,
                    optimizer: tf.keras.optimizers.Optimizer,
                    loss=None,
                    train_step: Optional[Callable[..., Any]] = None,
                    validation_step: Optional[Callable[..., Any]] = None,
                    **kwargs) -> tf.keras.Model:
    """Compiles the model with objects created by the task.

    The method should not be used in any customized training implementation.

    Args:
      model: a keras.Model.
      optimizer: the keras optimizer.
      loss: a callable/list of losses.
      train_step: optional train step function defined by the task.
      validation_step: optional validation_step step function defined by the
        task.
      **kwargs: other kwargs consumed by keras.Model compile().

    Returns:
      a compiled keras.Model.
    """
    if bool(loss is None) == bool(train_step is None):
      raise ValueError("`loss` and `train_step` should be exclusive to "
                       "each other.")
    model.compile(optimizer=optimizer, loss=loss, **kwargs)

    if train_step:
      model.train_step = functools.partial(
          train_step, model=model, optimizer=model.optimizer)
    if validation_step:
      model.test_step = functools.partial(validation_step, model=model)
    return model
Ejemplo n.º 7
0
def compile_model(model: tf.keras.Model) -> tf.keras.Model:
    loss = {classifier.name: classifier.loss for classifier in classifiers}
    metrics = {
        classifier.name: classifier.metrics
        for classifier in classifiers
    }
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
    model.compile(loss=loss, metrics=metrics, optimizer=optimizer)
    return model
def compile_and_fit(model: tf.keras.Model, window: WindowGenerator,
                    model_name: str,
                    patience=cfg.EARLY_STOPPING['patience'],
                    ):
    """
    Train model
    @param model_name:
    @param model:
    @param window:
    @param patience:
    @return:
    """

    checkpoint_dir = os.path.join(cfg.CHECKPOINT_PATH, model_name)
    checkpoint_path = os.path.join(checkpoint_dir, '{epoch:04d}.ckpt')

    make_dir(checkpoint_dir)

    callbacks = []

    if not is_dir_empty(checkpoint_dir):
        load_weight(model, checkpoint_dir)

    cp_callback = tf.keras.callbacks.ModelCheckpoint(
        checkpoint_path,
        save_weights_only=True,
        verbose=1,
    )
    callbacks.append(cp_callback)

    if cfg.EARLY_STOPPING['enabled'] is True:
        early_stopping = tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=patience,
            mode='min')
        callbacks.append(early_stopping)

    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.losses.MeanSquaredError(),
        metrics=[tf.metrics.MeanAbsoluteError()]
    )

    history = model.fit(
        window.train,
        epochs=cfg.MAX_EPOCH,
        validation_data=window.val,
        callbacks=callbacks,
        verbose=2,
    )

    return history
Ejemplo n.º 9
0
def _compile(hp: kt.HyperParameters, model: tf.keras.Model):
    del hp
    # lr = hp.Float("learning_rate", 1e-3, 1e-1, sampling="log")
    lr = 1e-2
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=True,
                                                reduction="sum"),
        weighted_metrics=[
            tf.keras.metrics.AUC(curve="ROC", name="auc_roc",
                                 from_logits=True),
            tf.keras.metrics.AUC(curve="PR", name="auc_pr", from_logits=True),
        ],
    )
Ejemplo n.º 10
0
def restore_initial_weights(
    model: tf.keras.Model, initial_weights: dict
) -> tf.keras.Model:
    """Helper function to reinitialize the model to its initial weight values, while
    keeping the masking configuration

    Args:
        model (tf.keras.Model): Differentiable model defined in Keras and initialized.
        initial_weights (dict): dictionary of initial weights of the network. The
        structure of this dictionary must be
            key -> name of the parameter tensor in the tensorflow graph.
            value -> parameter matrix in numpy.

            The following snipped may be helpful:
            `initial_weights_backup = {w.name:w.numpy() for w in model.variables}`

    Returns:
        tf.keras.Model: Model with new weights. There is no need to use this output as
        the model is updated by reference.
    """
    # Get current weights
    weights = {w.name: w for w in model.variables}

    # Filter out bias
    initial_weights = filter(lambda x: "_b" not in x[0], initial_weights.items())

    # Filter out masks
    initial_weights = filter(lambda x: "_m" not in x[0], initial_weights)

    # Get final dict
    initial_weights = dict(initial_weights)

    for name in initial_weights:
        w = initial_weights[name]
        m = weights[name + "_m:0"]
        # Set the masked weights to zero. Not necessary, but cleaner
        w = tf.multiply(w, m)
        weights[name].assign(w)

    # Important to recompile the model to get rid of the optimizer state
    model.compile(
        loss=model.loss,
        optimizer=model.optimizer._name,
        metrics=[m for m in model.metrics_names if m != "loss"],
    )

    # The weights are set by reference, there is no need to return
    # the model. We return it for potential further compatibility reasons
    return model
Ejemplo n.º 11
0
def train_and_evaluate_model(model: tf.keras.Model) -> None:
    """Train and test the transfer learning model

    Parameters
    ----------
    model : tf.keras.Model
        The transfer learning model
    """
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)
    loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    metric = tf.keras.metrics.BinaryAccuracy()
    model.compile(loss=loss, optimizer=optimizer, metrics=[metric])
    train_dataset, validation_dataset, test_dataset = get_dataset()
    model.fit(train_dataset, epochs=20, validation_data=validation_dataset)
    model.evaluate(x=test_dataset)
Ejemplo n.º 12
0
def train(model: tf.keras.Model) -> tf.keras.Model:
    """Trains the classification model on the chess piece image dataset.
    """
    model.compile(
        optimizer="adam",
        loss=losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )

    train_dataset = _get_dataset("training")
    validation_dataset = _get_dataset("validation")

    model.fit(train_dataset,
              validation_data=validation_dataset,
              epochs=TRAINING_EPOCHS)
    return model
Ejemplo n.º 13
0
    def train_image_gan(cls, generator:tf.keras.Model,  discriminator:tf.keras.Model, dataset:tf.data.Dataset, gan_input_shape=(32,), **kwargs):

        # discriminator コンパイル
        discriminator_optimizer = tf.keras.optimizers.RMSprop(lr=0.008,clipvalue=1.0,decay=1e-8)
        discriminator.compile(optimizer=discriminator_optimizer,loss='binary_crossentropy')

        # 識別器はいったん学習させない
        discriminator.trainable = False

        # 
        gan_input = tf.keras.Input(shape=gan_input_shape)
        gan_output = discriminator(generator(gan_input))

        gan = tf.keras.Model(gan_input,gan_output)
        gan_optimizer = tf.keras.optimizers.RMSprop(lr=0.0004,clipvalue=1.0,decay=1e-8)
        gan.compile()
Ejemplo n.º 14
0
def compiled(
    model: tf.keras.Model,
    loss=None,
    metrics=None,
    optimizer=None,
    run_eagerly: Optional[bool] = None,
    # steps_per_execution: Optional[int] = None,
) -> tf.keras.Model:
    """Mutate model in-place by compiling and return the model for convenience."""
    model.compile(
        loss=loss,
        metrics=metrics,
        optimizer=optimizer,
        run_eagerly=run_eagerly,
        # steps_per_execution=steps_per_execution,
    )
    return model
Ejemplo n.º 15
0
def test_can_run_the_lr_finder(model: tf.keras.Model,
                               dataset: tf.data.Dataset):
    min_lr = 1e-6
    max_lr = 1e-1

    model.compile(optimizer=tf.keras.optimizers.SGD(),
                  loss=tf.keras.losses.MeanSquaredError())

    lrfinder = LRFinder(min_lr, max_lr, num_steps=NUM_STEPS, figname=FIGNAME)

    model.fit(dataset, epochs=1, callbacks=[lrfinder])

    assert len(lrfinder.losses) == NUM_STEPS
    assert len(lrfinder.lrs) == NUM_STEPS
    assert lrfinder.lrs[0] == min_lr
    assert lrfinder.lrs[-1] == max_lr

    # by default should have saved a figure with the results
    assert os.path.exists(lrfinder.figname)
Ejemplo n.º 16
0
def save_model(model: tf.keras.Model, destination: str):
    """
    Remove any compiled options and save model under "model.tf"
    to a destination for standardization.  Custom losses/metricss
    require custom object resolution during load, so it's better
    to remove.

    https://github.com/tensorflow/tensorflow/issues/43478

    Args:
        model: tensorflow model
        destination: path to store "model.tf" under
    """
    # clear all the weights and optimizers settings
    model.compile()
    model_path = os.path.join(destination, "model.tf")
    logging.getLogger(__name__).debug(f"saving model to {model_path}")
    model.save(model_path, save_format="tf")
    return model_path
Ejemplo n.º 17
0
def jitter_reset(model: tf.keras.Model,
                 initial_weights: Dict[str, np.array],
                 sd: float = 0.01) -> tf.keras.Model:
    """ Function to add white noise to the weights.

    Parameters
    ----------
        model (tf.keras.Model): Pruned model.
        initial_weights(Dict[str, np.array]): Dictionary containing initial weights saved 
        with save_weights function.
        sd (float): Standard deviation of noise added.

    Returns
    -------
        tf.keras.Model: Jittered Keras model.
    """

    weights = {w.name: w for w in model.variables}
    # Filter kernel weights
    kernel_names = [
        w.name for w in model.variables
        if ("_bias" not in w.name) and ("_mask" not in w.name)
    ]

    k_init = tf.random_normal_initializer(stddev=sd)
    for w_name in kernel_names:
        w0 = initial_weights[w_name]
        noise = k_init(shape=w0.shape, dtype="float32")
        weights[w_name].assign(w0 + noise)
    # Filter bias weights
    bias_names = [w.name for w in model.variables if "_bias" in w.name]

    for w_name in bias_names:
        weights[w_name].assign(np.zeros(weights[w_name].shape,
                                        dtype="float32"))

    # Compile model to reset optimizer
    model.compile(optimizer=model.optimizer._name,
                  loss=model.loss,
                  metrics=[m for m in model.metrics_names if m != "loss"])

    return model
Ejemplo n.º 18
0
def train(train_x_y, model:tf.keras.Model, val_xy=None, epochs=10, input_dim=None):
    '''
    TF高阶API
    :param train_x_y:   训练集
    :param model:       模型实例
    :param epochs:      迭代次数
    :return:
    '''
    if input_dim:
        model.build(input_shape=(None, input_dim))
        print(model.summary())

    # 为训练选择优化器和损失函数
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3),
                  loss=tf.losses.BinaryCrossentropy(from_logits=False),
                  metrics=['binary_accuracy'])

    # 训练
    # model.fit(train_x_y, epochs=epochs, validation_data=val_xy, validation_freq=5)
    model.fit(train_x_y, epochs=epochs)
    return model
Ejemplo n.º 19
0
def train(args: Settings, model: tf.keras.Model, dataset: tf.data.Dataset,
          path):
    callbacks = [
        # tf.keras.callbacks.EarlyStopping(patience=2),
        tf.keras.callbacks.ModelCheckpoint(filepath=str(path['ckpt'] /
                                                        '{epoch:03d}.ckpt')),
        tf.keras.callbacks.TensorBoard(
            update_freq=args.update_freq,  # every 128 batches
            log_dir=path['log'],
            profile_batch='2,66'),
        ShowImageCallback(args,
                          log_dir=path['log'] / 'flow',
                          log_period=args.update_freq)
    ]
    optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)

    losses = [AutoResizeMseLoss() for _ in model.outputs]
    model.compile(
        optimizer=optimizer,
        loss=losses,
        # metrics=[tf.keras.metrics.MeanSquaredError()]
    )

    if args.load_ckpt:
        weight_path = Path(args.load_ckpt) / 'variables/variables'
        model.load_weights(weight_path, by_name=False)

    try:
        model.fit(dataset, epochs=args.num_epoch, callbacks=callbacks)
    except KeyboardInterrupt as e:
        pass
    finally:
        for ext, fmt in zip(['pb', 'h5'], ['tf', 'h5']):
            out_file = str(path['run'] / 'model.{}'.format(ext))
            logging.info(
                'saving weights to {} prior to termination ...'.format(
                    out_file))
            model.save_weights(out_file, save_format=fmt)
Ejemplo n.º 20
0
def add_compile(model: tf.keras.Model, **params) -> tf.keras.Model:
    """Compiles the model with the given parameters.

    Parameters
    ----------
    model: tf.keras.Model
        The model to compile.

    Returns
    -------
    tf.keras.Model
        Compiles the model with loss and optimizer.

    """

    optimizer = Adam(lr=params["learning_rate"],
                     clipnorm=params.get("clipnorm", 1))
    model.compile(
        loss=tf.keras.losses.BinaryCrossentropy(),
        optimizer=optimizer,
        metrics=["accuracy",
                 tf.keras.metrics.AUC(multi_label=True)],
    )
Ejemplo n.º 21
0
def train(model: tf.keras.Model, model_path: str, n_epoch: int = 10):

    x_train, x_test, y_train, y_test = load_and_split()
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.005)
    model.compile(optimizer=optimizer,
                  loss=simnet_loss)

    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                                  verbose=1, patience=5, min_lr=0.0001)

    mcp_save = ModelCheckpoint(model_path,  # {epoch:02d}-{val_loss:.2f}.hdf5
                               save_best_only=True, monitor='val_loss', mode='min')

    # fit and check validation data
    history = model.fit(x_train, y_train,
                        batch_size=2, epochs=n_epoch, workers=8,
                        callbacks=[reduce_lr, mcp_save],
                        validation_data=(x_test, y_test))

    # model.save('./benchmarks/encoder_' + str(np.around(history.history['val_loss'][-1], 3)))
    model.summary()
    plot_loss(history)

    tf.keras.utils.plot_model(model, 'simnet_model.png', show_shapes=True, expand_nested=True)
Ejemplo n.º 22
0
def _train_bert_multitask_keras_model(
        train_dataset: tf.data.Dataset,
        eval_dataset: tf.data.Dataset,
        model: tf.keras.Model,
        params: BaseParams,
        mirrored_strategy: tf.distribute.MirroredStrategy = None):
    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(params.ckpt_dir, 'model'),
        save_weights_only=True,
        monitor='val_acc',
        mode='auto',
        save_best_only=False)

    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=params.ckpt_dir)

    with mirrored_strategy.scope():
        model.compile()
        model.fit(x=train_dataset,
                  validation_data=eval_dataset,
                  epochs=params.train_epoch,
                  callbacks=[model_checkpoint_callback, tensorboard_callback],
                  steps_per_epoch=params.train_steps_per_epoch)
    model.summary()
Ejemplo n.º 23
0
def prune(
    model: tf.keras.Model,
    prune_proportion: float,
    criterion: Callable = magnitude_saliency_criterion,
) -> tf.keras.Model:
    """Helper function to prune neural networks defined in Keras using a specified
    criterion.

    Args:
        model (tf.keras.Model): Differentiable model defined in Keras and initialized.
        prune_proportion (float): Proportion of weights to be pruned.
        criterion (callable, optional): Criterion to prune the weights. For now it is
        defined as a function that takes a vector of weights as a parameter.
        Defaults to magnitude_saliency_criterion.

    Returns:
        tf.keras.Model: Model with new weights. There is no need to use this output as
        the model is updated by reference.
    """
    weights = {w.name: w for w in model.variables}

    # Filter out bias and mask terms
    w_mat_names = map(lambda x: x.name, model.variables)

    # Filter out bias
    prunable_w_mat_names = filter(lambda x: "_b" not in x, w_mat_names)

    # Filter out masks
    prunable_w_mat_names = filter(lambda x: "_m" not in x, prunable_w_mat_names)

    for w_name in prunable_w_mat_names:
        # Get the weights of the layer and the corresponding mask
        w = weights[w_name].numpy()
        m = weights[w_name + "_m:0"].numpy()

        # Store the original matrix shape
        shape = w.shape

        # Reshape the matrices into vectors
        w = w.reshape(-1)
        m = m.reshape(-1)

        # Calculate the number of pruned weights
        n_pruned_weights = np.round(w.size * prune_proportion).astype(int)

        # Apply the saliency criterion and sort in increasing order.
        # Get the indices of the less salient weights
        connections_to_prune = criterion(w).argsort()[:n_pruned_weights]

        # Set the weights to prune to zero (not necessary)
        w[connections_to_prune] = 0

        # Set the mask values corresponding to the pruned weights to zero
        # In order to prevent the gradient to back-propagate. Equivalent
        # to freeze individual connections
        m[connections_to_prune] = 0

        # Set the weights back to the network
        weights[w_name].assign(w.reshape(shape))
        weights[w_name + "_m:0"].assign(m.reshape(shape))

    # Important to recompile the model to get rid of the optimizer state
    model.compile(
        loss=model.loss,
        optimizer=model.optimizer._name,
        metrics=[m for m in model.metrics_names if m != "loss"],
    )
    # The weights are set by reference, there is no need to return
    # the model. We return it for potential further compatibility reasons
    return model
Ejemplo n.º 24
0
    def train_and_eval(self,
                       model: tf.keras.Model,
                       epochs: Optional[int] = None,
                       sparsity: Optional[float] = None):
        """
        Trains a Keras model and returns its validation set error (1.0 - accuracy).
        :param model: A Keras model.
        :param epochs: Overrides the duration of training.
        :param sparsity: Desired sparsity level (for unstructured sparsity)
        :returns Smallest error on validation set seen during training, the error on the test set,
        pruned weights (if pruning was used)
        """
        dataset = self.config.dataset
        batch_size = self.config.batch_size
        sparsity = sparsity or 0.0

        train = dataset.train_dataset() \
            .shuffle(batch_size * 8) \
            .batch(batch_size) \
            .prefetch(tf.data.experimental.AUTOTUNE)

        val = dataset.validation_dataset() \
            .batch(batch_size) \
            .prefetch(tf.data.experimental.AUTOTUNE)

        # TODO: check if this works, make sure we're excluding the last layer from the student
        if self.pruning and self.distillation:
            raise NotImplementedError()

        if self.distillation:
            teacher = tf.keras.models.load_model(
                self.distillation.distill_from)
            teacher._name = "teacher_"
            teacher.trainable = False

            t, a = self.distillation.temperature, self.distillation.alpha

            # Assemble a parallel model with the teacher and student
            i = tf.keras.Input(shape=dataset.input_shape)
            cxent = tf.keras.losses.CategoricalCrossentropy()

            stud_logits = model(i)
            tchr_logits = teacher(i)

            o_stud = tf.keras.layers.Softmax()(stud_logits / t)
            o_tchr = tf.keras.layers.Softmax()(tchr_logits / t)
            teaching_loss = (a * t * t) * cxent(o_tchr, o_stud)

            model = tf.keras.Model(inputs=i, outputs=stud_logits)
            model.add_loss(teaching_loss, inputs=True)

        if self.dataset.num_classes == 2:
            loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
            accuracy = tf.keras.metrics.BinaryAccuracy(name="accuracy")
        else:
            loss = tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True)
            accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
                name="accuracy")
        model.compile(optimizer=self.config.optimizer(),
                      loss=loss,
                      metrics=[accuracy])

        # TODO: adjust metrics by class weight?
        class_weight = {k: v for k, v in enumerate(self.dataset.class_weight())} \
            if self.config.use_class_weight else None
        epochs = epochs or self.config.epochs
        callbacks = self.config.callbacks()
        check_logs_from_epoch = 0

        pruning_cb = None
        if self.pruning and sparsity > 0.0:
            assert 0.0 < sparsity <= 1.0
            self.log.info(f"Target sparsity: {sparsity:.4f}")
            pruning_cb = DPFPruning(
                target_sparsity=sparsity,
                structured=self.pruning.structured,
                start_pruning_at_epoch=self.pruning.start_pruning_at_epoch,
                finish_pruning_by_epoch=self.pruning.finish_pruning_by_epoch)
            check_logs_from_epoch = self.pruning.finish_pruning_by_epoch
            callbacks.append(pruning_cb)

        log = model.fit(train,
                        epochs=epochs,
                        validation_data=val,
                        verbose=1 if debug_mode() else 2,
                        callbacks=callbacks,
                        class_weight=class_weight)

        test = dataset.test_dataset() \
            .batch(batch_size) \
            .prefetch(tf.data.experimental.AUTOTUNE)
        _, test_acc = model.evaluate(test, verbose=0)

        return {
            "val_error":
            1.0 - max(log.history["val_accuracy"][check_logs_from_epoch:]),
            "test_error":
            1.0 - test_acc,
            "pruned_weights":
            pruning_cb.weights if pruning_cb else None
        }
Ejemplo n.º 25
0
def train_neural_network(X_train: Tuple[np.ndarray,
                                        np.ndarray], y_train: np.ndarray,
                         X_val: Tuple[np.ndarray,
                                      np.ndarray], y_val: np.ndarray,
                         neural_network: tf.keras.Model, batch_size: int,
                         is_bayesian: bool, training_cycle_length: int,
                         min_learning_rate: float, max_learning_rate: float,
                         max_epochs: int, **kwargs) -> tf.keras.Model:
    assert training_cycle_length > 0
    assert training_cycle_length < max_epochs
    patience = training_cycle_length * 2
    if is_bayesian:
        assert 'num_monte_carlo' in kwargs
        assert isinstance(kwargs['num_monte_carlo'], int)
        assert kwargs['num_monte_carlo'] > 1
    print('Structure of neural network:')
    print('')
    neural_network.summary()
    print('')
    neural_network.compile(
        optimizer=tf.keras.optimizers.Adam(min_learning_rate),
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=["binary_accuracy"],
        experimental_run_tf_function=not is_bayesian)
    temp_file_name = None
    try:
        with tempfile.NamedTemporaryFile(mode='w', delete=False) as fp:
            temp_file_name = fp.name
        best_auc = None
        epochs_without_improving = 0
        n_batches = int(np.ceil(y_train.shape[0]) / float(batch_size))
        bounds_of_batches = [(idx * batch_size,
                              min((idx + 1) * batch_size, y_train.shape[0]))
                             for idx in range(n_batches)]
        for epoch in range(max_epochs):
            random.shuffle(bounds_of_batches)
            epoch_accuracy = 0.0
            epoch_loss = 0.0
            neural_network.reset_metrics()
            start_time = time.time()
            tf.keras.backend.set_value(
                neural_network.optimizer.lr,
                calculate_learning_rate(epoch_index=epoch,
                                        cycle_length=training_cycle_length,
                                        max_lr=max_learning_rate,
                                        min_lr=min_learning_rate))
            for iter_idx, (batch_start,
                           batch_end) in enumerate(bounds_of_batches):
                batch_x = (X_train[0][batch_start:batch_end],
                           X_train[1][batch_start:batch_end])
                batch_y = y_train[batch_start:batch_end]
                epoch_loss, epoch_accuracy = neural_network.train_on_batch(
                    batch_x, batch_y, reset_metrics=False)
            training_duration = time.time() - start_time
            print("Epoch {0}".format(epoch + 1))
            print("  Training time is {0:.3f} secs".format(training_duration))
            print("  Learning rate is {0:.7f}".format(
                float(tf.keras.backend.get_value(
                    neural_network.optimizer.lr))))
            print("  Training measures:")
            print("    loss = {0:.6f}, accuracy = {1:8.6f}".format(
                epoch_loss, epoch_accuracy))
            start_time = time.time()
            if is_bayesian:
                probabilities = tf.reduce_mean([
                    neural_network.predict(X_val, batch_size=batch_size)
                    for _ in range(kwargs['num_monte_carlo'])
                ],
                                               axis=0)
            else:
                probabilities = neural_network.predict(X_val,
                                                       batch_size=batch_size)
            probabilities = np.reshape(probabilities,
                                       newshape=(y_val.shape[0], ))
            validation_duration = time.time() - start_time
            roc_auc = roc_auc_score(y_val, probabilities)
            y_pred = np.asarray(probabilities >= 0.5, dtype=y_val.dtype)
            print("  Validation time is {0:.3f} secs".format(
                validation_duration))
            print("  Validation measures:")
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore",
                                        category=UndefinedMetricWarning)
                print(
                    "    accuracy = {0:8.6f}, AUC = {1:8.6f}, F1 = {2:.8f}, P = {3:.8f}, R = {4:.8f}"
                    .format(accuracy_score(y_val, y_pred), roc_auc,
                            f1_score(y_val, y_pred),
                            precision_score(y_val, y_pred),
                            recall_score(y_val, y_pred)))
            if best_auc is None:
                best_auc = roc_auc
                neural_network.save_weights(temp_file_name, overwrite=True)
            else:
                if roc_auc > best_auc:
                    best_auc = roc_auc
                    neural_network.save_weights(temp_file_name, overwrite=True)
                    epochs_without_improving = 0
                else:
                    epochs_without_improving += 1
                if epochs_without_improving > patience:
                    print("Early stopping!")
                    break
            del probabilities, y_pred
        if epochs_without_improving <= patience:
            print("Maximal number of epochs is reached!")
        print('')
        neural_network.load_weights(temp_file_name)
    finally:
        if temp_file_name is not None:
            if os.path.isfile(temp_file_name):
                os.remove(temp_file_name)
    print('')
    return neural_network
Ejemplo n.º 26
0
def compile_model(
    model: tf.keras.Model,
    learning_rate: float,
    dataloader: AbstractDataloader,
    loss: str,
    optimizer: str,
    config: dict,
    metrics: List[str] = None
) -> (tf.keras.models.Model, List[tf.keras.callbacks.Callback]):
    """
        Helper function to compile a new model at each variation of the experiment
    :param learning_rate:
    :param dataloader: dataloader
    :param loss: loss function name
    :param optimizer: optimizer function name
    :param model: model to be compiled
    :param metrics: list of metrics
    :param config: configuration dictionary
    :return: compiled model and additional callbacks (for metrics which are too slow to run on training set)
    """

    mapping_metrics = {
        "perplexity": perplexity,  # For language model task
        "perplexity_mlm": perplexity_mlm,
        #        "bleu": bleu,  # For translation task but it's too slow to run during training
        "sparse_accuracy": tf.keras.metrics.SparseCategoricalAccuracy(
        ),  # Generic for classification
    }

    mapping_loss = {
        "sparse_categorical_cross_entropy":
        tf.keras.losses.SparseCategoricalCrossentropy(),
        "mlm_loss":
        mlm_loss
    }

    if "d_model" in config['model']['hyper_params']:  # From Blaise model
        d_model = config['model']['hyper_params']['d_model']
    else:
        if "hidden_size" in config['model'][
                'hyper_params']:  # From François model
            d_model = config['model']['hyper_params']['hidden_size']
        else:
            if optimizer == "adam-transformer":
                raise Exception(
                    "adam-transformer requires d_model or hidden size in config"
                )
            d_model = 1  # Never executed but prevents a warning in PyCharm
    mapping_optimizer = {
        "adam":
        tf.keras.optimizers.Adam(learning_rate=learning_rate),
        "rmsprop":
        tf.keras.optimizers.RMSprop(learning_rate=learning_rate),
        "adam-transformer":
        tf.keras.optimizers.Adam(CustomSchedule(d_model),
                                 beta_1=0.9,
                                 beta_2=0.98,
                                 epsilon=1e-9),
        "mini-batch-gradient-descent":
        tf.keras.optimizers.SGD(learning_rate=learning_rate)
    }

    metric_funcs, additional_callbacks = [], []
    for metric in metrics:
        if metric == "bleu":
            # Special case
            #  To be called only on end of epoch because it's too slow
            additional_callbacks += [
                BleuIntervalEvaluation(dataloader=dataloader)
            ]
        else:
            metric_funcs += [mapping_metrics[metric]]

    optimizer = mapping_optimizer[optimizer]

    model.compile(optimizer=optimizer,
                  loss=mapping_loss[loss],
                  metrics=metric_funcs)

    return model, additional_callbacks
Ejemplo n.º 27
0
def reset_weights(model: tf.keras.Model,
                  initial_weights: Dict[str, np.array],
                  reset_mode: str = "rewind",
                  same_sign: bool = True,
                  jitter_sd: float = 0.01,
                  constant: float = None) -> tf.keras.Model:
    """ Function to reinitialize weights (kernel+bias) of the pruned network in order
    to perform LTH.

    Parameters
    ----------
        model (tf.keras.Model): Pruned model to reinitialize.
        initial_weights (Dict[str, np.array]): Dictionary of initial weights saved with 
        save_weights function.
        reset_mode (str): How to initialise the unpruned weights. There are five available
        modes:
            "rewind": Rewind weights to the initial ones. Default mode.
            "jitter": Rewind and add noise to the initial weights.
            "random": Reinitialise with random weights based on the original init
            distribution.
            "reshuffle": Reinitialise by reshuffling the kept weights initial values.
            "constant": Set the kept weights to a positive or negative constant.
        same_sign (bool): Specify whether same sign as original initialization is kept
        when resetting weights.
        jitter_sd (float): Standard deviation of added noise in "jitter" mode.
        constant (float): Constant of reinitialization when selecting "constant" mode.
        By default, this constant is set to the standard deviation of the original 
        distribution in each layer.
    
    Returns
    -------
        tf.keras.Model: Keras pruned model reinitialized.
    """

    if reset_mode == "rewind":

        weights = {w.name: w for w in model.variables}
        # Filter kernel and bias weights
        kandb_names = [
            w.name for w in model.variables if ("_mask" not in w.name)
        ]

        for w_name in kandb_names:
            w0 = initial_weights[w_name]
            weights[w_name].assign(w0)
        # Compile model to reset optimizer
        model.compile(optimizer=model.optimizer._name,
                      loss=model.loss,
                      metrics=[m for m in model.metrics_names if m != "loss"])

        return model

    elif reset_mode == "jitter":

        return jitter_reset(model, initial_weights, jitter_sd)

    elif reset_mode == "random":

        weights = {w.name: w for w in model.variables}
        # Filter kernel weights
        kernel_names = [
            w.name for w in model.variables
            if ("_bias" not in w.name) and ("_mask" not in w.name)
        ]
        # Dictionary containing initialization sd for each layer
        init_sd = {layer.name: layer.init_sd for layer in model.layers}

        for w_name in kernel_names:
            w0 = initial_weights[w_name]
            # Extract layer name
            l_name = os.path.split(w_name)[0]
            # Select custom standard deviation for layer
            sd = init_sd[l_name]
            k_init = tf.random_normal_initializer(stddev=sd)
            w = k_init(shape=weights[w_name].shape, dtype="float32").numpy()
            if same_sign:
                # Due to the probabilistic symmetry, we can simply reassign the signs while
                # preserving the distribution.
                correct_signs = (w0 >= 0) * (w >= 0) + (w0 < 0) * (w < 0)
                signs_mask = correct_signs + (correct_signs - 1)
                w = w * signs_mask

            weights[w_name].assign(w)
        # Filter bias weights
        bias_names = [w.name for w in model.variables if "_bias" in w.name]

        for w_name in bias_names:
            weights[w_name].assign(
                np.zeros(weights[w_name].shape, dtype="float32"))

        # Compile model to reset optimizer
        model.compile(optimizer=model.optimizer._name,
                      loss=model.loss,
                      metrics=[m for m in model.metrics_names if m != "loss"])

        return model

    elif reset_mode == "reshuffle":

        weights = {w.name: w for w in model.variables}
        # Filter kernel weights
        kernel_names = [
            w.name for w in model.variables
            if ("_bias" not in w.name) and ("_mask" not in w.name)
        ]

        for w_name in kernel_names:
            w0 = initial_weights[w_name].copy()
            mask = weights[w_name + "_mask:0"].numpy()
            sample = (w0 * mask)[(w0 * mask) != 0]
            np.random.shuffle(sample)
            inds = np.where(mask > 0.5)
            # Loop to assign shuffled weights
            if not same_sign:
                for n, index in enumerate(zip(*inds)):
                    w0[index] = sample[n]
            else:
                pos_sample = sample[sample >= 0]
                neg_sample = sample[sample < 0]
                pos_ind, neg_ind = 0, 0

                for index in zip(*inds):
                    if w0[index] >= 0:
                        w0[index] = pos_sample[pos_ind]
                        pos_ind += 1
                    else:
                        w0[index] = neg_sample[neg_ind]
                        neg_ind += 1

            weights[w_name].assign(w0)

        # Filter bias weights
        bias_names = [w.name for w in model.variables if "_bias" in w.name]

        for w_name in bias_names:
            weights[w_name].assign(
                np.zeros(weights[w_name].shape, dtype="float32"))

        # Compile model to reset optimizer
        model.compile(optimizer=model.optimizer._name,
                      loss=model.loss,
                      metrics=[m for m in model.metrics_names if m != "loss"])

        return model

    elif reset_mode == "constant":

        weights = {w.name: w for w in model.variables}
        # Filter kernel weights
        kernel_names = [
            w.name for w in model.variables
            if ("_bias" not in w.name) and ("_mask" not in w.name)
        ]

        init_sd = {layer.name: layer.init_sd for layer in model.layers}

        for w_name in kernel_names:
            l_name = os.path.split(w_name)[0]
            w0 = initial_weights[w_name]
            w = w0.copy()
            mask = weights[w_name + "_mask:0"].numpy()
            if not constant:
                constant = init_sd[l_name]
            inds = np.where(mask > 0.5)

            # Loop to assign shuffled weights
            for index in zip(*inds):
                w[index] = constant

            if same_sign:
                signs_mask = (w0 >= 0) + ((w0 >= 0) - 1)
                w = w * signs_mask
            else:
                random_mask = np.random.rand(*mask.shape) - 0.5 >= 0
                signs_mask = random_mask + (random_mask - 1)
                w = w * signs_mask

            weights[w_name].assign(w)

        # Filter bias weights
        bias_names = [w.name for w in model.variables if "_bias" in w.name]

        for w_name in bias_names:
            weights[w_name].assign(
                np.zeros(weights[w_name].shape, dtype="float32"))

        # Compile model to reset optimizer
        model.compile(optimizer=model.optimizer._name,
                      loss=model.loss,
                      metrics=[m for m in model.metrics_names if m != "loss"])

        return model

    else:
        raise Exception("Invalid mode!")
def export_saved_model(model: tf.keras.Model,
                       input_shape: Tuple[int, int, int, int, int],
                       export_path: str = '/tmp/movinet/',
                       causal: bool = False,
                       bundle_input_init_states_fn: bool = True,
                       checkpoint_path: Optional[str] = None) -> None:
    """Exports a MoViNet model to a saved model.

  Args:
    model: the tf.keras.Model to export.
    input_shape: The 5D spatiotemporal input shape of size
      [batch_size, num_frames, image_height, image_width, num_channels].
      Set the field or a shape position in the field to None for dynamic input.
    export_path: Export path to save the saved_model file.
    causal: Run the model in causal mode.
    bundle_input_init_states_fn: Add init_states as a function signature to the
      saved model. This is not necessary if the input shape is static (e.g.,
      for TF Lite).
    checkpoint_path: Checkpoint path to load. Leave blank to keep the model's
      initialization.
  """

    # Use dimensions of 1 except the channels to export faster,
    # since we only really need the last dimension to build and get the output
    # states. These dimensions can be set to `None` once the model is built.
    input_shape_concrete = [1 if s is None else s for s in input_shape]
    model.build(input_shape_concrete)

    # Compile model to generate some internal Keras variables.
    model.compile()

    if checkpoint_path:
        checkpoint = tf.train.Checkpoint(model=model)
        status = checkpoint.restore(checkpoint_path)
        status.assert_existing_objects_matched()

    if causal:
        # Call the model once to get the output states. Call again with `states`
        # input to ensure that the inputs with the `states` argument is built
        # with the full output state shapes.
        input_image = tf.ones(input_shape_concrete)
        _, states = model({
            **model.init_states(input_shape_concrete), 'image':
            input_image
        })
        _ = model({**states, 'image': input_image})

        # Create a function to explicitly set the names of the outputs
        def predict(inputs):
            outputs, states = model(inputs)
            return {**states, 'logits': outputs}

        specs = {
            name: tf.TensorSpec(spec.shape, name=name, dtype=spec.dtype)
            for name, spec in model.initial_state_specs(input_shape).items()
        }
        specs['image'] = tf.TensorSpec(input_shape,
                                       dtype=model.dtype,
                                       name='image')

        predict_fn = tf.function(predict, jit_compile=True)
        predict_fn = predict_fn.get_concrete_function(specs)

        init_states_fn = tf.function(model.init_states, jit_compile=True)
        init_states_fn = init_states_fn.get_concrete_function(
            tf.TensorSpec([5], dtype=tf.int32))

        if bundle_input_init_states_fn:
            signatures = {'call': predict_fn, 'init_states': init_states_fn}
        else:
            signatures = predict_fn

        tf.keras.models.save_model(model, export_path, signatures=signatures)
    else:
        _ = model(tf.ones(input_shape_concrete))
        tf.keras.models.save_model(model, export_path)
Ejemplo n.º 29
0
def _train_model(model: tf.keras.Model, database: Dict[str, float], num_epochs: int, test_set: Optional[List[str]],
                 batch_size: int = 32, validation_split: float = 0.1, bootstrap: bool = False,
                 random_state: int = 1, learning_rate: float = 1e-3, patience: int = None,
                 timeout: float = None) -> Union[Tuple[List, dict], Tuple[List, dict, List[float]]]:
    """Train a model

    Args:
        model: Model to be trained
        database: Training dataset of molecule mapped to a property
        test_set: Hold-out set. If provided, this function will return predictions on this set
        num_epochs: Maximum number of epochs to run
        batch_size: Number of molecules per training batch
        validation_split: Fraction of molecules used for the training/validation split
        bootstrap: Whether to perform a bootstrap sample of the dataset
        random_state: Seed to the random number generator. Ensures entries do not move between train
            and validation set as the database becomes larger
        learning_rate: Learning rate for the Adam optimizer
        patience: Number of epochs without improvement before terminating training.
        timeout: Maximum training time in seconds
    Returns:
        model: Updated weights
        history: Training history
    """
    # Compile the model with a new optimizer
    #  We find that it is best to reset the optimizer before updating
    model.compile(tf.keras.optimizers.Adam(lr=learning_rate), 'mean_squared_error')

    # Separate the database into molecules and properties
    smiles, y = zip(*database.items())
    smiles = np.array(smiles)
    y = np.array(y)

    # Make the training and validation splits
    rng = np.random.RandomState(random_state)
    train_split = rng.rand(len(smiles)) > validation_split
    train_X = smiles[train_split]
    train_y = y[train_split]
    valid_X = smiles[~train_split]
    valid_y = y[~train_split]

    # Perform a bootstrap sample of the training data
    if bootstrap:
        sample = rng.choice(len(train_X), size=(len(train_X),), replace=True)
        train_X = train_X[sample]
        train_y = train_y[sample]

    # Make the training data loaders
    train_loader = GraphLoader(train_X, train_y, batch_size=batch_size, shuffle=True)
    val_loader = GraphLoader(valid_X, valid_y, batch_size=batch_size, shuffle=False)

    # Make the callbacks
    final_learn_rate = 1e-6
    init_learn_rate = learning_rate
    decay_rate = (final_learn_rate / init_learn_rate) ** (1. / (num_epochs - 1))

    def lr_schedule(epoch, lr):
        return lr * decay_rate

    if patience is None:
        patience = num_epochs // 8

    early_stopping = cb.EarlyStopping(patience=patience, restore_best_weights=True)
    my_callbacks = [
        LRLogger(),
        EpochTimeLogger(),
        cb.LearningRateScheduler(lr_schedule),
        early_stopping,
        cb.TerminateOnNaN(),
        train_loader  # So the shuffling gets called
    ]
    if timeout is not None:
        my_callbacks += [
            TimeLimitCallback(timeout)
        ]

    # Run the desired number of epochs
    history = model.fit(train_loader, epochs=num_epochs, validation_data=val_loader,
                        verbose=False, shuffle=False, callbacks=my_callbacks)

    # If a timeout is used, make sure we are using the best weights
    #  The training may have exited without storing the best weights
    if timeout is not None:
        model.set_weights(early_stopping.best_weights)

    # Check if there is a NaN loss
    if np.isnan(history.history['loss']).any():
        raise ValueError('Training failed due to a NaN loss.')

    # If provided, evaluate model on test set
    test_pred = None
    if test_set is not None:
        test_pred = evaluate_mpnn([model], test_set, batch_size, cache=False)

    # Convert weights to numpy arrays (avoids mmap issues)
    weights = []
    for v in model.get_weights():
        v = np.array(v)
        if np.isnan(v).any():
            raise ValueError('Found some NaN weights.')
        weights.append(v)

    # Once we are finished training call "clear_session"
    tf.keras.backend.clear_session()
    if test_pred is None:
        return weights, history.history
    else:
        return weights, history.history, test_pred[:, 0].tolist()
Ejemplo n.º 30
0
    def train_image_classification(
        cls,
        train_data:tf.data.Dataset, train_size:int, batch_size:int,
        validation_data:tf.data.Dataset, validation_size:int,
        shuffle_size:int,
        model:tf.keras.Model,
        callbacks:List[tf.keras.callbacks.Callback],
        optimizer:tf.keras.optimizers.Optimizer,
        loss:tf.keras.losses.Loss,
        max_epoch:int = 5, resume:bool = True):
        """画像分類の学習を実施します。
        
        Parameters:
            train_data{tf.data.Dataset}: 学習に使用するトレーニングデータ
            train_size{int}: トレーニングデータのデータ数
            batch_size{int} : 学習時のバッチサイズ
            shuffle_size : 学習時のデータシャッフルサイズ
            model{tf.keras.} : 学習モデル

        Example:
            import tftk


            tftk.Context.init_context(
                TRAINING_NAME = "example_traninig1"
                TRAINING_BASE_DIR = "./tmp"
            )
            tftk.ENABLE_SUSPEND_RESUME_TRAINING()
            tftk.USE_MIXED_PRECISION()
            

        """
        # dataset = dataset.shuffle(1024).batch(32).prefetch(tf.data.experimental.AUTOTUNE)

        train_data = train_data.map(ImageDatasetUtil.dict_to_classification_tuple(),num_parallel_calls=tf.data.experimental.AUTOTUNE).repeat()
        if shuffle_size != 0:
            train_data = train_data.shuffle(shuffle_size)
        train_data = train_data.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

        validation_data = validation_data.map(ImageDatasetUtil.dict_to_classification_tuple(),num_parallel_calls=tf.data.experimental.AUTOTUNE)
        validation_data = validation_data.repeat().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

        model.compile(optimizer=optimizer, loss=loss, metrics=["acc"])
        model.summary()

        initial_epoch = 0
        exe = ResumeExecutor.get_instance()

        if IS_ON_COLABOLATORY_WITH_GOOGLE_DRIVE():
            Colaboratory.copy_resume_data_from_google_drive()
        else:
            print("google drive is not found.")

        if exe.is_resumable_training()==True:
            print("This is resume training!!")
            exe.resume_model(model)
            resume_val = exe.resume_values()
            initial_epoch, _, _,_  = resume_val
            initial_epoch = initial_epoch + 1
            print("resuming epoch", initial_epoch, "max_epoch", max_epoch)
        else:
            if exe.is_train_ended()==True:
                print("Training is completed.")
                exit()
            else:
                # print("Not resume training")
                pass


        steps_per_epoch = train_size//batch_size
        validation_steps = validation_size//batch_size
        history = model.fit(
            train_data,
            callbacks=callbacks,
            validation_data=validation_data,
            steps_per_epoch=steps_per_epoch,
            validation_steps=validation_steps,
            epochs=max_epoch, initial_epoch=initial_epoch)


        tf.keras.backend.clear_session()
        del optimizer,callbacks,model,train_data,validation_data
        return history