Example #1
0
    def create_training_op(self, loss: tf_compat.Tensor,
                           params: Dict[str, Any]) -> tf_compat.Operation:
        """
        Create training op for optimization

        :param loss: the loss tensor
        :param params: the model function params
        :return: an Operation minimizing loss
        """
        global_step = tf_compat.train.get_or_create_global_step()

        optimizer_const = {}
        for opt_name in dir(tf_compat.train):
            opt_cls = getattr(tf_compat.train, opt_name)
            if inspect.isclass(opt_cls) and issubclass(
                    opt_cls, tf_compat.train.Optimizer):
                optimizer_const[opt_name] = opt_cls

        optimizer_name = params.get("optimizer", "AdamOptimizer")
        if optimizer_name not in optimizer_const:
            raise ValueError(
                "Unsupported optimizer: {}".format(optimizer_name))
        optimizer_params = params.get("optimizer_params", {})
        optimizer = optimizer_const[optimizer_name](**optimizer_params)

        with tf_compat.name_scope("train"):
            # We are using tf.layers.batch_normalization to support previous versions
            # of TF, which requires us explicite model the dependency between the
            # update of moving average and variance with training op
            update_ops = tf_compat.get_collection(
                tf_compat.GraphKeys.UPDATE_OPS)
            with tf_compat.control_dependencies(update_ops):
                training_op = optimizer.minimize(loss, global_step=global_step)
        return training_op
Example #2
0
def get_scheduled_update_op(
    pruning_op_vars: List[PruningOpVars],
    ks_group: str,
):
    """
    Creates model pruning (kernel sparsity) ops and vars in the graph
    to be applied over a specific schedule.
    Creates them for the ops in the graph such that they follow the given schedule.

    :param pruning_op_vars: List of tuples of operation tensors and masks.
    :param ks_group: the group identifier the scope should be created under
    :return: the update operation to run in a session
    """
    update_op = tf_compat.get_collection(
        PruningScope.collection_name(ks_group, PruningScope.OP_COND_UPDATE))
    update_op = update_op[0] if len(update_op) > 0 else None

    if update_op is None:
        update_op = tf_compat.group(
            *[op_var.update for op_var in pruning_op_vars])

        # add return state to collections
        tf_compat.add_to_collection(
            PruningScope.collection_name(ks_group,
                                         PruningScope.OP_COND_UPDATE),
            update_op,
        )

    return update_op
Example #3
0
    def saver(key: str, remove_dynamic_tl_vars: bool = False) -> tf_compat.train.Saver:
        """
        Get a tf compat saver that contains only the variables for the desired
        architecture specified by key.
        Note, the architecture must have been created in the current graph already
        to work.

        :param key: the model key (name) to get a saver instance for
        :param remove_dynamic_tl_vars: True to remove the vars that are used for
            transfer learning (have a different shape and should not be restored),
            False to keep all vars in the Saver
        :return: a Saver object with the appropriate vars for the model to restore
        """
        if key not in ModelRegistry._CONSTRUCTORS:
            raise ValueError(
                "key {} is not in the model registry; available: {}".format(
                    key, ModelRegistry._CONSTRUCTORS
                )
            )
        base_name = ModelRegistry._ATTRIBUTES[key].base_name_scope
        saver_vars = [
            var
            for var in tf_compat.get_collection(tf_compat.GraphKeys.TRAINABLE_VARIABLES)
            if base_name in var.name
        ]
        saver_vars.extend(
            [
                var
                for var in tf_compat.global_variables()
                if ("moving_mean" in var.name or "moving_variance" in var.name)
                and base_name in var.name
            ]
        )

        if remove_dynamic_tl_vars:
            tl_ignore_tens = ModelRegistry._ATTRIBUTES[key].tl_ignore_tens

            def _check_ignore(var: tf_compat.Variable) -> bool:
                for ignore in tl_ignore_tens:
                    if re.match(ignore, var.name):
                        return True

                return False

            saver_vars = [var for var in saver_vars if not _check_ignore(var)]

        saver = tf_compat.train.Saver(saver_vars)

        return saver
Example #4
0
    def create_metrics(
        self,
        net_outputs: Union[tf_compat.Tensor, Dict[str, tf_compat.Tensor]],
        labels: Union[tf_compat.Tensor, Dict[str, tf_compat.Tensor]],
        params: Dict[str, Any],
    ) -> (
            Dict[str, Tuple[tf_compat.Tensor, tf_compat.Operation]],
            Dict[str, tf_compat.Operation],
    ):
        """
        Create metrics for evaluation

        :param net_outputs: output tensors of the model graph
        :param labels: ground truth labels
        :param params: the model function params
        :return: dictionary of metrics and their reset operations
        """
        metrics = params.get("metrics", [])

        metrics_dict = {}
        metrics_initializers_dict = {}
        with tf_compat.name_scope("metrics"):
            for metric in metrics:
                if metric == "accuracy":
                    labels_argmax = tf_compat.argmax(labels, 1)
                    net_outputs_argmax = tf_compat.argmax(net_outputs, 1)
                    metrics_dict["accuracy"] = tf_compat.metrics.accuracy(
                        labels_argmax,
                        net_outputs_argmax,
                        name="accuracy_metric",
                    )
                    # The total and count variables created to support accuracy
                    running_vars = tf_compat.get_collection(
                        tf_compat.GraphKeys.LOCAL_VARIABLES,
                        scope="metrics/accuracy_metric",
                    )
                    running_vars_initializer = tf_compat.variables_initializer(
                        var_list=running_vars)
                    metrics_initializers_dict[
                        metric] = running_vars_initializer
                else:
                    raise ValueError("Unsupported metric: {}".format(metric))

        return (metrics_dict, metrics_initializers_dict)
Example #5
0
def get_or_create_ks_schedule_ops(
    global_step: tf_compat.Tensor,
    begin_step: int,
    end_step: int,
    update_step_freq: int,
    init_sparsity: float,
    final_sparsity: float,
    exponent: float,
    ks_group: str,
) -> Tuple[tf_compat.Tensor, tf_compat.Tensor]:
    """
    Creates or retrieves (if previously created) a gradual schedule
    for model pruning (kernel sparsity).
    Creates a sparsity tensor that goes from init_sparsity til final_sparsity
    starting at begin_step and ending at end_step.
    Uses the global_step to map those.
    Additionally creates an update_ready tensor that is True if an update
    to the sparsity tensor should be run, False otherwise.

    :param global_step: the global optimizer step for the training graph
    :param begin_step: the global step to begin pruning at
    :param end_step: the global step to end pruning at
    :param update_step_freq: the number of global steps between each weight update
    :param init_sparsity: the starting value for sparsity of a
        weight tensor to be enforce
    :param final_sparsity: the end value for sparsity for a weight tensor to be enforce
    :param exponent: the exponent to use for interpolating between
        init_sparsity and final_sparsity higher values will lead to larger sparsity
        steps at the beginning vs the end ie: linear (1) vs cubic (3)
    :param ks_group: the group identifier the scope should be created under
    :return: a tuple containing the signal for update_ready and the target sparsity
    """
    update_ready = tf_compat.get_collection(
        PruningScope.collection_name(ks_group, PruningScope.OP_UPDATE_READY))
    sparsity = tf_compat.get_collection(
        PruningScope.collection_name(ks_group, PruningScope.OP_SPARSITY))

    update_ready = update_ready[0] if len(update_ready) > 0 else None
    sparsity = sparsity[0] if len(sparsity) > 0 else None

    if update_ready is None or sparsity is None:
        update_ready, sparsity = create_ks_schedule_ops(
            global_step,
            begin_step,
            end_step,
            update_step_freq,
            init_sparsity,
            final_sparsity,
            exponent,
            ks_group,
        )
        # add return state to collections
        tf_compat.add_to_collection(
            PruningScope.collection_name(ks_group,
                                         PruningScope.OP_UPDATE_READY),
            update_ready,
        )
        tf_compat.add_to_collection(
            PruningScope.collection_name(ks_group, PruningScope.OP_SPARSITY),
            sparsity)

    return update_ready, sparsity
Example #6
0
def get_or_create_graph_ops_pruning(
    graph: tf_compat.Graph,
    var_names: List[str],
    sparsity: tf_compat.Tensor,
    update_ready: tf_compat.Tensor,
    leave_enabled: bool,
    is_after_end_step: tf_compat.Tensor,
    ks_group: str,
    mask_creator: PruningMaskCreator,
) -> List[PruningOpVars]:
    """
    Creates or retrieves (if previously created) the necessary variables
    and operators to gradually apply sparsity to a given list of operators in a graph.

    Handles setting a mask on an operator to the given sparsity.
    Sets the mask based on pruning away the lowest absolute magnitude weights.

    :param graph: the tf graph to pull the operator out of for applying the pruning to
    :param var_names: the names or regex patterns of names of variables to prune in the
        graph to the given sparsity
    :param sparsity: the target sparsity to use for assigning the masks
    :param update_ready: the tensor where if true will update the mask from sparsity,
        if false will not update the mask
    :param leave_enabled: True to continue masking the weights after end_epoch,
        False to stop masking
    :param is_after_end_step: tensor that is true if the current global step
        is after end_epoch
    :param ks_group: the group identifier the scope should be created under
    :param mask_creator: optional object to define sparisty mask creation
    :return: a list of the created or retrieved named tuples each containing the
        assignment op, mask variable, threshold tensor, and masked tensor
    """
    ops = tf_compat.get_collection(
        PruningScope.collection_name(ks_group, PruningScope.OPS))
    ops_input = tf_compat.get_collection(
        PruningScope.collection_name(ks_group, PruningScope.OPS_INPUT))
    mask_updates = tf_compat.get_collection(
        PruningScope.collection_name(ks_group, PruningScope.OP_MASK_UPDATE))
    masks = tf_compat.get_collection(
        PruningScope.collection_name(ks_group, PruningScope.VAR_MASK))
    maskeds = tf_compat.get_collection(
        PruningScope.collection_name(ks_group, PruningScope.OP_MASKED_VAR))

    if (len(ops) < 1 or len(ops_input) < 1 or len(mask_updates) < 1
            or len(masks) < 1 or len(maskeds) < 1):  # create new pruning ops
        pruning_op_vars = create_graph_ops_pruning(
            graph,
            var_names,
            sparsity,
            update_ready,
            leave_enabled,
            is_after_end_step,
            ks_group,
            mask_creator,
        )
    else:  # use collection pruning ops
        pruning_op_vars = []
        for op, op_input, mask_update, mask, masked in zip(
                ops, ops_input, mask_updates, masks, maskeds):
            pruning_op_vars.append(
                PruningOpVars(op, op_input, mask_update, mask, masked))

    return pruning_op_vars
def train(args, save_dir, logs_dir):
    # setup dataset
    with tf_compat.device("/cpu:0"):
        train_dataset, _ = _create_dataset(args, train=True)
        val_dataset, num_classes = _create_dataset(args, train=False)
        # calc steps
        train_steps = math.ceil(len(train_dataset) / args.train_batch_size)
        val_steps = math.ceil(len(val_dataset) / args.test_batch_size)
        # build datasets
        train_dataset = _build_dataset(args, train_dataset,
                                       args.train_batch_size)
        val_dataset = _build_dataset(args, val_dataset, args.test_batch_size)
    handle, iterator, (train_iter, val_iter) = create_split_iterators_handle(
        [train_dataset, val_dataset])

    # set up model graph
    images, labels = iterator.get_next()
    training = tf_compat.placeholder(dtype=tf_compat.bool, shape=[])
    outputs = _create_model(args, num_classes, images, training)

    # set up training objects
    loss = batch_cross_entropy_loss(outputs, labels)
    acc = accuracy(outputs, labels)
    global_step = tf_compat.train.get_or_create_global_step()
    train_op = tf_compat.train.AdamOptimizer(learning_rate=args.init_lr,
                                             **args.optim_args).minimize(
                                                 loss, global_step=global_step)
    update_ops = tf_compat.get_collection(tf_compat.GraphKeys.UPDATE_OPS)
    LOGGER.info("Created update ops for training")

    # set up sparseml modifier ops
    add_mods = (ConstantPruningModifier(
        params="__ALL__") if args.sparse_transfer_learn else None)
    manager = ScheduledModifierManager.from_yaml(file_path=args.recipe_path,
                                                 add_modifiers=add_mods)
    mod_ops, mod_extras = manager.create_ops(train_steps, global_step)

    with tf_compat.Session() as sess:
        # set up tensorboard logging
        summary_writer = tf_compat.summary.FileWriter(logs_dir, sess.graph)
        summaries = tf_compat.summary.merge_all()
        LOGGER.info("Logging to tensorboard at {}".format(logs_dir))

        # initialize variables, load pretrained weights, initialize modifiers
        train_iter_handle, val_iter_handle = sess.run(
            [train_iter.string_handle(),
             val_iter.string_handle()])
        LOGGER.info("Initialized graph variables")
        _load_model(args, sess)
        manager.initialize_session()
        LOGGER.info("Initialized SparseML modifiers")

        best_loss = None
        for epoch in range(manager.max_epochs):
            # train
            LOGGER.info("Training for epoch {}...".format(epoch))
            sess.run(train_iter.initializer)
            train_acc, train_loss = [], []
            for step in range(train_steps):
                _, __, meas_step, meas_loss, meas_acc, meas_summ = sess.run(
                    [train_op, update_ops, global_step, loss, acc, summaries],
                    feed_dict={
                        handle: train_iter_handle,
                        training: True
                    },
                )
                if step >= train_steps - 1:
                    # log the general summaries on the last training step
                    summary_writer.add_summary(meas_summ, meas_step)
                # run modifier ops
                sess.run(mod_ops)
                # summarize
                write_simple_summary(summary_writer, "Train/Loss", meas_loss,
                                     meas_step)
                write_simple_summary(summary_writer, "Train/Acc",
                                     meas_acc * 100.0, meas_step)
                train_acc.append(meas_acc)
                train_loss.append(meas_loss)
            LOGGER.info("Epoch {} - Train Loss: {}, Train Acc: {}".format(
                epoch,
                numpy.mean(train_loss).item(),
                numpy.mean(train_acc).item()))

            # val
            LOGGER.info("Validating for epoch {}...".format(epoch))
            sess.run(val_iter.initializer)
            val_acc, val_loss = [], []
            for step in range(val_steps):
                meas_loss, meas_acc = sess.run(
                    [loss, acc],
                    feed_dict={
                        handle: val_iter_handle,
                        training: False
                    },
                )
                val_acc.append(meas_acc)
                val_loss.append(meas_loss)
                write_simple_summary(summary_writer, "Val/Loss",
                                     numpy.mean(val_loss).item(), epoch)
                write_simple_summary(summary_writer, "Val/Acc",
                                     numpy.mean(val_acc).item(), epoch)
            val_loss = numpy.mean(val_loss).item()
            LOGGER.info("Epoch {} - Val Loss: {}, Val Acc: {}".format(
                epoch, val_loss,
                numpy.mean(train_acc).item()))
            if epoch >= args.save_best_after and (best_loss is None
                                                  or val_loss <= best_loss):
                _save_checkpoint(args, sess, save_dir, "checkpoint-best")
                best_loss = val_loss
            if args.save_epochs and epoch in args.save_epochs:
                _save_checkpoint(args, sess, save_dir,
                                 "checkpoint-epoch-{}".format(epoch))

        # cleanup graph and save final checkpoint
        manager.complete_graph()
        checkpoint_path = _save_checkpoint(args, sess, save_dir,
                                           "final-checkpoint")
    LOGGER.info("Running ONNX export flow")
    export(
        args,
        save_dir,
        checkpoint_path=checkpoint_path,
        skip_samples=True,
        num_classes=num_classes,
        opset=11,
    )