def set_trainable_variables(self, variable_list=None):
        """Set the variables which we want to optimize.
        The optimizer will only optimize the variables which contain sub-string in the variable list.
        Basically, this is copied from the training path in `build`.

        The batchnorm statistics can always be updated?

        Args:
            variable_list: The model variable contains sub-string in the list will be optimized.
                           If None, all variables will be optimized.
        """
        add_train_summary = []
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        trainable_variables = []
        if variable_list is None:
            tf.logging.info(
                "[Info] Add all trainable variables to the optimizer.")
            trainable_variables = None
        else:
            for v in variables:
                if substring_in_list(v.name, variable_list):
                    trainable_variables.append(v)
                    tf.logging.info("[Info] Add %s to trainable list" % v.name)

        with tf.name_scope("train") as scope:
            grads = self.optimizer.compute_gradients(
                self.total_loss, var_list=trainable_variables)

        if self.params.clip_gradient:
            grads, vars = zip(
                *grads)  # compute gradients of variables with respect to loss
            grads_clip, _ = tf.clip_by_global_norm(
                grads, self.params.clip_gradient_norm)  # l2 norm clipping
            grads = zip(grads_clip, vars)

        # The values and gradients are added to summeries
        for grad, var in grads:
            if grad is not None:
                add_train_summary.append(
                    tf.summary.histogram(var.op.name + '/gradients', grad))
                add_train_summary.append(
                    tf.summary.scalar(var.op.name + '/gradients_norm',
                                      tf.norm(grad)))

        if variable_list is None:
            trainable_variables = tf.trainable_variables()
        for var in trainable_variables:
            add_train_summary.append(tf.summary.histogram(var.op.name, var))
        self.train_summary = tf.summary.merge(
            [self.train_summary,
             tf.summary.merge(add_train_summary)])

        batchnorm_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                 scope)
        with tf.control_dependencies(batchnorm_update_ops):
            self.train_op = self.optimizer.apply_gradients(grads)
    def get_finetune_model(self, excluded_list):
        """Start from a pre-trained model and other parameters are initialized using default initializer.
        Actually, this function is only called at the first epoch of the fine-tuning, because in succeeded epochs,
        we need to fully load the model rather than loading part of the graph.

        The pre-trained model is saved in the model directory as index 0.
        Backup the pre-trained model and save the new model (with random initialized parameters) as index 0 instead.

        Args:
            excluded_list: A list. Do NOT restore the parameters in the exclude_list. This is useful in fine-truning
                          an existing model. We load a part of the pre-trained model and leave the other part
                          randomly initialized.
        Deprecated:
            data: The training data directory.
            spklist: The spklist is a file map speaker name to the index.
            learning_rate: The learning rate is passed by the main program. The main program can easily tune the
                           learning rate according to the validation accuracy or anything else.
        """
        # initialize all variables
        self.sess.run(tf.global_variables_initializer())

        # Load parts of the model
        variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        restore_variables = []
        for v in variables:
            if not substring_in_list(v.name, excluded_list):
                restore_variables.append(v)
            else:
                tf.logging.info(
                    "[Info] Ignore %s when loading the checkpoint" % v.name)
        finetune_saver = tf.train.Saver(var_list=restore_variables)
        ckpt = tf.train.get_checkpoint_state(self.model)
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        finetune_saver.restore(self.sess, os.path.join(self.model, ckpt_name))

        # Backup the old files
        import glob, shutil
        model_checkpoint_path = ckpt.model_checkpoint_path
        for filename in glob.glob(model_checkpoint_path + "*"):
            shutil.copyfile(filename, filename + '.bak')

        # Save the new model. The new model is basically the same with the pre-trained one, while parameters
        # NOT in the pre-trained model are random initialized.
        # Set the step to 0.
        self.save(0)
        return
    def build(self,
              mode,
              dim,
              loss_type=None,
              num_speakers=None,
              noupdate_var_list=None):
        """ Build a network.

        Currently, I use placeholder in the graph and feed data during sess.run. So no need to parse
        features and labels.

        Args:
            mode: `train`, `valid` or `predict`.
            dim: The dimension of the feature.
            loss_type: Which loss function do we use. Could be None when mode == predict
            num_speakers: The total number of speakers. Used in softmax-like network
            noupdate_var_list: In the fine-tuning, some variables are fixed. The list contains their names (or part of their names).
                               We use `noupdate` rather than `notrain` because some variables are not trainable, e.g.
                               the mean and var in the batchnorm layers.
        """
        assert (mode == "train" or mode == "valid" or mode == "predict")
        is_training = (mode == "train")
        reuse_variables = True if self.is_built else None

        # Create a new path for prediction, since the training may build a tower the support multi-GPUs
        if mode == "predict":
            self.pred_features = tf.placeholder(tf.float32,
                                                shape=[None, None, dim],
                                                name="pred_features")
            with tf.name_scope("predict") as scope:
                tf.logging.info("Extract embedding from node %s" %
                                self.params.embedding_node)
                # There is no need to do L2 normalization in this function, because we can do the normalization outside,
                # or simply a cosine similarity can do it.
                # Note that the output node may be different if we use different loss function. For example, if the
                # softmax is used, the output of 2-last layer is used as the embedding. While if the end2end loss is
                # used, the output of the last layer may be a better choice. So it is impossible to specify the
                # embedding node inside the network structure. The configuration will tell the network to output the
                # correct activations as the embeddings.
                _, endpoints = self.entire_network(self.pred_features,
                                                   self.params, is_training,
                                                   reuse_variables)
                self.embeddings = endpoints[self.params.embedding_node]
                if self.saver is None:
                    self.saver = tf.train.Saver()
            return

        # global_step should be defined before loss function since some loss functions use this value to tune
        # some internal parameters.
        if self.global_step is None:
            self.global_step = tf.placeholder(tf.int32, name="global_step")
            self.params.dict["global_step"] = self.global_step

        # If new loss function is added, please modify the code.
        self.loss_type = loss_type
        if loss_type == "softmax":
            self.loss_network = softmax
        elif loss_type == "asoftmax":
            self.loss_network = asoftmax
        elif loss_type == "additive_margin_softmax":
            self.loss_network = additive_margin_softmax
        elif loss_type == "additive_angular_margin_softmax":
            self.loss_network = additive_angular_margin_softmax
        elif loss_type == "semihard_triplet_loss":
            self.loss_network = semihard_triplet_loss
        elif loss_type == "angular_triplet_loss":
            self.loss_network = angular_triplet_loss
        elif loss_type == "generalized_angular_triplet_loss":
            self.loss_network = generalized_angular_triplet_loss
        else:
            raise NotImplementedError("Not implement %s loss" % self.loss_type)

        if mode == "valid":
            tf.logging.info("Building valid network...")
            self.valid_features = tf.placeholder(tf.float32,
                                                 shape=[None, None, dim],
                                                 name="valid_features")
            self.valid_labels = tf.placeholder(tf.int32,
                                               shape=[
                                                   None,
                                               ],
                                               name="valid_labels")
            with tf.name_scope("valid") as scope:
                # We can adjust some parameters in the config when we do validation
                # TODO: I'm not sure whether it is necssary to change the margin for the valid set.
                # TODO: compare the performance!
                # Change the margin for the valid set.
                if loss_type == "softmax":
                    pass
                elif loss_type == "asoftmax":
                    train_margin = self.params.asoftmax_m
                    self.params.asoftmax_m = 1
                elif loss_type == "additive_margin_softmax":
                    train_margin = self.params.amsoftmax_m
                    self.params.amsoftmax_m = 0
                elif loss_type == "additive_angular_margin_softmax":
                    train_margin = self.params.arcsoftmax_m
                    self.params.arcsoftmax_m = 0
                elif loss_type == "angular_triplet_loss":
                    # Switch loss to e2e_valid_loss
                    train_loss_network = self.loss_network
                    self.loss_network = e2e_valid_loss
                else:
                    pass

                if "aux_loss_func" in self.params.dict:
                    # No auxiliary losses during validation.
                    train_aux_loss_func = self.params.aux_loss_func
                    self.params.aux_loss_func = []

                features, endpoints = self.entire_network(
                    self.valid_features, self.params, is_training,
                    reuse_variables)
                valid_loss, endpoints_loss = self.loss_network(
                    features, self.valid_labels, num_speakers, self.params,
                    is_training, reuse_variables)
                endpoints.update(endpoints_loss)

                if "aux_loss_func" in self.params.dict:
                    self.params.aux_loss_func = train_aux_loss_func

                # Change the margin back!!!
                if loss_type == "softmax":
                    pass
                elif loss_type == "asoftmax":
                    self.params.asoftmax_m = train_margin
                elif loss_type == "additive_margin_softmax":
                    self.params.amsoftmax_m = train_margin
                elif loss_type == "additive_angular_margin_softmax":
                    self.params.arcsoftmax_m = train_margin
                elif loss_type == "angular_triplet_loss":
                    self.loss_network = train_loss_network
                else:
                    pass

                # We can evaluate other stuff in the valid_ops. Just add the new values to the dict.
                # We may also need to check other values expect for the loss. Leave the task to other functions.
                # During validation, I compute the cosine EER for the final output of the network.
                self.embeddings = endpoints["output"]
                self.endpoints = endpoints

                self.valid_ops["raw_valid_loss"] = valid_loss
                mean_valid_loss, mean_valid_loss_op = tf.metrics.mean(
                    valid_loss)
                self.valid_ops["valid_loss"] = mean_valid_loss
                self.valid_ops["valid_loss_op"] = mean_valid_loss_op
                valid_loss_summary = tf.summary.scalar("loss", mean_valid_loss)
                self.valid_summary = tf.summary.merge([valid_loss_summary])
                if self.saver is None:
                    self.saver = tf.train.Saver(
                        max_to_keep=self.params.keep_checkpoint_max)
                if self.valid_summary_writer is None:
                    self.valid_summary_writer = tf.summary.FileWriter(
                        os.path.join(self.model, "eval"), self.sess.graph)
            return

        tf.logging.info("Building training network...")
        self.train_features = tf.placeholder(tf.float32,
                                             shape=[None, None, dim],
                                             name="train_features")
        self.train_labels = tf.placeholder(tf.int32,
                                           shape=[
                                               None,
                                           ],
                                           name="train_labels")
        self.learning_rate = tf.placeholder(tf.float32, name="learning_rate")

        if "optimizer" not in self.params.dict:
            # The default optimizer is sgd
            self.params.dict["optimizer"] = "sgd"

        if self.params.optimizer == "sgd":
            if "momentum" in self.params.dict:
                sys.exit(
                    "Using sgd as the optimizer and you should not specify the momentum."
                )
            tf.logging.info("***** Using SGD as the optimizer.")
            opt = tf.train.GradientDescentOptimizer(self.learning_rate,
                                                    name="optimizer")
        elif self.params.optimizer == "momentum":
            # SGD with momentum
            # It is also possible to use other optimizers, e.g. Adam.
            tf.logging.info("***** Using Momentum as the optimizer.")
            opt = tf.train.MomentumOptimizer(
                self.learning_rate,
                self.params.momentum,
                use_nesterov=self.params.use_nesterov,
                name="optimizer")
        elif self.params.optimizer == "adam":
            tf.logging.info("***** Using Adam as the optimizer.")
            opt = tf.train.AdamOptimizer(self.learning_rate, name="optimizer")
        else:
            sys.exit("Optimizer %s is not supported." % self.params.optimizer)
        self.optimizer = opt

        # Use name_space here. Create multiple name_spaces if multi-gpus
        # There is a copy in `set_trainable_variables`
        with tf.name_scope("train") as scope:
            features, endpoints = self.entire_network(self.train_features,
                                                      self.params, is_training,
                                                      reuse_variables)
            loss, endpoints_loss = self.loss_network(features,
                                                     self.train_labels,
                                                     num_speakers, self.params,
                                                     is_training,
                                                     reuse_variables)
            self.endpoints = endpoints

            endpoints.update(endpoints_loss)
            regularization_loss = tf.losses.get_regularization_loss()
            total_loss = loss + regularization_loss

            # train_summary contains all the summeries we want to inspect.
            # Get the summaries define in the network and loss function.
            # The summeries in the network and loss function are about the network variables.
            self.train_summary = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                   scope)
            self.train_summary.append(tf.summary.scalar("loss", loss))
            self.train_summary.append(
                tf.summary.scalar("regularization_loss", regularization_loss))

            # We may have other losses (i.e. penalty term in attention layer)
            penalty_loss = tf.get_collection("PENALTY")
            if len(penalty_loss) != 0:
                penalty_loss = tf.reduce_sum(penalty_loss)
                total_loss += penalty_loss
                self.train_summary.append(
                    tf.summary.scalar("penalty_term", penalty_loss))

            self.total_loss = total_loss
            self.train_summary.append(
                tf.summary.scalar("total_loss", total_loss))
            self.train_summary.append(
                tf.summary.scalar("learning_rate", self.learning_rate))

            # The gradient ops is inside the scope to support multi-gpus
            if noupdate_var_list is not None:
                old_batchnorm_update_ops = tf.get_collection(
                    tf.GraphKeys.UPDATE_OPS, scope)
                batchnorm_update_ops = []
                for op in old_batchnorm_update_ops:
                    if not substring_in_list(op.name, noupdate_var_list):
                        batchnorm_update_ops.append(op)
                        tf.logging.info("[Info] Update %s" % op.name)
                    else:
                        tf.logging.info("[Info] Op %s will not be executed" %
                                        op.name)
            else:
                batchnorm_update_ops = tf.get_collection(
                    tf.GraphKeys.UPDATE_OPS, scope)

            if noupdate_var_list is not None:
                variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
                train_var_list = []

                for v in variables:
                    if not substring_in_list(v.name, noupdate_var_list):
                        train_var_list.append(v)
                        tf.logging.info("[Info] Train %s" % v.name)
                    else:
                        tf.logging.info("[Info] Var %s will not be updated" %
                                        v.name)
                grads = opt.compute_gradients(total_loss,
                                              var_list=train_var_list)
            else:
                grads = opt.compute_gradients(total_loss)

            # Once the model has been built (even for a tower), we set the flag
            self.is_built = True

        if self.params.clip_gradient:
            grads, vars = zip(
                *grads)  # compute gradients of variables with respect to loss
            grads_clip, _ = tf.clip_by_global_norm(
                grads, self.params.clip_gradient_norm)  # l2 norm clipping

            # we follow the instruction in ge2e paper to scale the learning rate for w and b
            # Actually, I wonder that we can just simply set a large value for w (e.g. 20) and fix it.
            if self.loss_type == "ge2e":
                # The parameters w and b must be the last variables in the gradients
                grads_clip = grads_clip[:-2] + [
                    0.01 * grad for grad in grads_clip[-2:]
                ]
                # Simply check the position of w and b
                for var in vars[-2:]:
                    assert ("w" in var.name or "b" in var.name)
            grads = zip(grads_clip, vars)

        # There are some things we can do to the gradients, i.e. learning rate scaling.

        # # The values and gradients are added to summeries
        # for grad, var in grads:
        #     if grad is not None:
        #         self.train_summary.append(tf.summary.histogram(var.op.name + '/gradients', grad))
        #         self.train_summary.append(tf.summary.scalar(var.op.name + '/gradients_norm', tf.norm(grad)))

        self.train_summary.append(activation_summaries(endpoints))
        for var in tf.trainable_variables():
            self.train_summary.append(tf.summary.histogram(var.op.name, var))
        self.train_summary = tf.summary.merge(self.train_summary)

        with tf.control_dependencies(batchnorm_update_ops):
            self.train_op = opt.apply_gradients(grads)

        # We want to inspect other values during training?
        self.train_ops["loss"] = total_loss
        self.train_ops["raw_loss"] = loss

        # The model saver
        if self.saver is None:
            self.saver = tf.train.Saver(
                max_to_keep=self.params.keep_checkpoint_max)

        # The training summary writer
        if self.summary_writer is None:
            self.summary_writer = tf.summary.FileWriter(
                self.model, self.sess.graph)
        return