Exemplo n.º 1
0
 def testInitWithVariableReuse(self):
   with self.test_session():
     p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
     p_copy = pruning.Pruning(
         spec=self.pruning_hparams, sparsity=self.sparsity)
     variables.global_variables_initializer().run()
     sparsity = p._sparsity.eval()
     self.assertAlmostEqual(sparsity, 0.5)
     self.assertEqual(p._sparsity.eval(), p_copy._sparsity.eval())
Exemplo n.º 2
0
    def testWeightSpecificSparsity(self):
        param_list = [
            "begin_pruning_step=1", "pruning_frequency=1",
            "end_pruning_step=100", "target_sparsity=0.5",
            "weight_sparsity_map=[layer2/weights:0.75]", "threshold_decay=0.0"
        ]
        test_spec = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

        with variable_scope.variable_scope("layer1"):
            w1 = variables.Variable(math_ops.linspace(1.0, 100.0, 100),
                                    name="weights")
            _ = pruning.apply_mask(w1)
        with variable_scope.variable_scope("layer2"):
            w2 = variables.Variable(math_ops.linspace(1.0, 100.0, 100),
                                    name="weights")
            _ = pruning.apply_mask(w2)

        p = pruning.Pruning(pruning_hparams)
        mask_update_op = p.conditional_mask_update_op()
        increment_global_step = state_ops.assign_add(self.global_step, 1)

        with self.test_session() as session:
            variables.global_variables_initializer().run()
            for _ in range(110):
                session.run(mask_update_op)
                session.run(increment_global_step)

            self.assertAllEqual(session.run(pruning.get_weight_sparsity()),
                                [0.5, 0.75])
Exemplo n.º 3
0
 def testInitWithExternalSparsity(self):
     with self.test_session():
         p = pruning.Pruning(spec=self.pruning_hparams,
                             sparsity=self.sparsity)
         variables.global_variables_initializer().run()
         sparsity = p._sparsity.eval()
         self.assertAlmostEqual(sparsity, 0.5)
Exemplo n.º 4
0
 def testConditionalMaskUpdate(self):
     param_list = [
         "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6"
     ]
     test_spec = ",".join(param_list)
     pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
     weights = variables.Variable(math_ops.linspace(1.0, 100.0, 100),
                                  name="weights")
     masked_weights = pruning.apply_mask(weights)
     sparsity = variables.Variable(0.00, name="sparsity")
     # Set up pruning
     p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
     p._spec.threshold_decay = 0.0
     mask_update_op = p.conditional_mask_update_op()
     sparsity_val = math_ops.linspace(0.0, 0.9, 10)
     increment_global_step = state_ops.assign_add(self.global_step, 1)
     non_zero_count = []
     with self.test_session() as session:
         variables.global_variables_initializer().run()
         for i in range(10):
             session.run(state_ops.assign(sparsity, sparsity_val[i]))
             session.run(mask_update_op)
             session.run(increment_global_step)
             non_zero_count.append(np.count_nonzero(masked_weights.eval()))
     # Weights pruned at steps 0,2,4,and,6
     expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40]
     self.assertAllEqual(expected_non_zero_count, non_zero_count)
Exemplo n.º 5
0
    def _setup_graph(self):
        '''
        '''
        default_dict = {
            'name': 'model_pruining',
            'begin_pruning_step': 0,
            'end_pruning_step': 34400,
            'target_sparsity': 0.31,
            'pruning_frequency': 344,
            'sparsity_function_begin_step': 0,
            'sparsity_function_end_step': 34400,
            'sparsity_function_exponent': 2,
        }
        for k, v in self.param_dict.items():
            if k in default_dict:
                default_dict[k] = v

        param_list = ['{}={}'.format(k, v) for k, v in default_dict.items()]
        # param_list = [
        #         "name=cifar10_pruning",
        #         "begin_pruning_step=1000",
        #         "end_pruning_step=20000",
        #         "target_sparsity=0.9",
        #         "sparsity_function_begin_step=1000",
        #         "sparsity_function_end_step=20000"
        # ]

        PRUNE_HPARAMS = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(PRUNE_HPARAMS)
        self.p = pruning.Pruning(pruning_hparams,
                                 global_step=get_global_step_var())
        self.p.add_pruning_summaries()
        self.mask_update_op = self.p.conditional_mask_update_op()
Exemplo n.º 6
0
 def testInit(self):
     p = pruning.Pruning(self.pruning_hparams)
     self.assertEqual(p._spec.name, "test")
     self.assertAlmostEqual(p._spec.threshold_decay, 0.9)
     self.assertEqual(p._spec.pruning_frequency, 10)
     self.assertEqual(p._spec.sparsity_function_end_step, 100)
     self.assertAlmostEqual(p._spec.target_sparsity, 0.9)
Exemplo n.º 7
0
    def testPerLayerBlockSparsity(self):
        param_list = [
            "block_dims_map=[layer1/weights:1x1,layer2/weights:1x2]",
            "block_pooling_function=AVG", "threshold_decay=0.0"
        ]

        test_spec = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

        with variable_scope.variable_scope("layer1"):
            w1 = constant_op.constant([[-0.1, 0.1], [-0.2, 0.2]],
                                      name="weights")
            pruning.apply_mask(w1)

        with variable_scope.variable_scope("layer2"):
            w2 = constant_op.constant(
                [[0.1, 0.1, 0.3, 0.3], [0.2, 0.2, 0.4, 0.4]], name="weights")
            pruning.apply_mask(w2)

        sparsity = variables.VariableV1(0.5, name="sparsity")

        p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
        mask_update_op = p.mask_update_op()
        with self.cached_session() as session:
            variables.global_variables_initializer().run()
            session.run(mask_update_op)
            mask1_eval = session.run(pruning.get_masks()[0])
            mask2_eval = session.run(pruning.get_masks()[1])

            self.assertAllEqual(session.run(pruning.get_weight_sparsity()),
                                [0.5, 0.5])

            self.assertAllEqual(mask1_eval, [[0.0, 0.0], [1., 1.]])
            self.assertAllEqual(mask2_eval, [[0, 0, 1., 1.], [0, 0, 1., 1.]])
Exemplo n.º 8
0
 def __init__(self,
              input_size,
              output_size,
              model_path: str,
              momentum=0.9,
              reg_str=0.0005,
              scope='ConvNet',
              pruning_start=int(10e4),
              pruning_end=int(10e5),
              pruning_freq=int(10),
              sparsity_start=0,
              sparsity_end=int(10e5),
              target_sparsity=0.0,
              dropout=0.5,
              initial_sparsity=0,
              wd=0.0):
     super(ConvNet, self).__init__(input_size=input_size,
                                   output_size=output_size,
                                   model_path=model_path)
     self.scope = scope
     self.momentum = momentum
     self.reg_str = reg_str
     self.dropout = dropout
     self.logger = get_logger(scope)
     self.wd = wd
     self.logger.info("creating graph...")
     with self.graph.as_default():
         self.global_step = tf.Variable(0, trainable=False)
         self._build_placeholders()
         self.logits = self._build_model()
         self.weights_matrices = pruning.get_masked_weights()
         self.sparsity = pruning.get_weight_sparsity()
         self.loss = self._loss()
         self.train_op = self._optimizer()
         self._create_metrics()
         self.saver = tf.train.Saver(var_list=tf.global_variables())
         self.hparams = pruning.get_pruning_hparams()\
             .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},'
                    ' sparsity_function_begin_step={},sparsity_function_end_step={},'
                    'pruning_frequency={},initial_sparsity={},'
                    ' sparsity_function_exponent={}'.format(scope,
                                                            pruning_start,
                                                            pruning_end,
                                                            target_sparsity,
                                                            sparsity_start,
                                                            sparsity_end,
                                                            pruning_freq,
                                                            initial_sparsity,
                                                            3))
         # note that the global step plays an important part in the pruning mechanism,
         # the higher the global step the closer the sparsity is to sparsity end
         self.pruning_obj = pruning.Pruning(self.hparams,
                                            global_step=self.global_step)
         self.mask_update_op = self.pruning_obj.conditional_mask_update_op()
         # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned
         # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP
         self.init_variables(
             tf.global_variables())  # initialize variables in graph
Exemplo n.º 9
0
    def _prune_model(self, session):
        pruning_hparams = pruning.get_pruning_hparams().parse(
            self.pruning_spec)
        p = pruning.Pruning(pruning_hparams, sparsity=self.sparsity)
        self.mask_update_op = p.conditional_mask_update_op()

        variables.global_variables_initializer().run()
        for _ in range(20):
            session.run(self.mask_update_op)
            session.run(self.increment_global_step)
Exemplo n.º 10
0
    def __init__(self, model, data_handle, hyperparams):
        self.model = model
        self.data_handle = data_handle
        self.hyperparams = hyperparams

        # get defined tensor
        self.X = self.model.X
        self.Y = self.model.Y
        self.result = self.model.result
        self.train  = self.model.Utils.is_train
        self.update = self.model.Utils.tensor_updated
        self.learning_rate = tf.placeholder(tf.float32)
        self.global_step = tf.Variable(0, dtype = tf.int32, trainable = False)
        self.weights_decay = self.hyperparams['weights_decay']
        self.global_step_update = tf.assign_add(self.global_step, tf.constant(2, dtype = tf.int32))

        # optimizer
        self.cross_entropy     = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = self.Y, logits = self.result))
        self.l2_loss           = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()])
        self.loss              = self.l2_loss * self.weights_decay + self.cross_entropy
        # train_step        = tf.train.AdamOptimizer(learning_rate = 0.001).minimize(loss)
        self.train_step        = tf.train.MomentumOptimizer(self.learning_rate, 0.9, use_nesterov = True).minimize(self.loss)
        self.top1              = tf.equal(tf.argmax(self.result, 1), tf.argmax(self.Y, 1))
        self.top1_acc          = tf.reduce_mean(tf.cast(self.top1, "float"))
        self.top5              = tf.nn.in_top_k(predictions = self.result, targets = tf.argmax(self.Y, 1), k = 5) 
        self.top5_acc          = tf.reduce_mean(tf.cast(self.top5, "float"))


        # prune
        if self.hyperparams['enable_prune']:
            pruning_hparams = pruning.get_pruning_hparams()
            pruning_hparams.begin_pruning_step = self.hyperparams['begin_pruning_step']
            pruning_hparams.end_pruning_step   = self.hyperparams['end_pruning_step']
            pruning_hparams.pruning_frequency  = self.hyperparams['pruning_frequency']
            pruning_hparams.target_sparsity    = self.hyperparams['target_sparsity']
            p = pruning.Pruning(pruning_hparams, global_step = self.global_step)
            self.prune_op = p.conditional_mask_update_op()

        # log
        log_prefix = "log" + "_quant_{}".format(self.hyperparams['quant_bits']) + "_prune_{}".format(str(self.hyperparams["enable_prune"])) + "/"
        if not os.path.exists(log_prefix):
            os.mkdir(log_prefix)
        self.fd = open(log_prefix + self.hyperparams['model_name'], "a")
        print("model_name = {}, quant_bits = {}, enable_prune = {}".format(self.hyperparams['model_name'], self.hyperparams['quant_bits'], self.hyperparams['target_sparsity']), file = self.fd)
        print(time.asctime(time.localtime(time.time())) + "   train started", file = self.fd)


        # init_variable
        # config = tf.ConfigProto()
        # config.gpu_options.allow_growth = True
        # config.gpu_options.per_process_gpu_memory_fraction = 0.6
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
Exemplo n.º 11
0
    def testTrainingStep(self, training_method):

        tf.reset_default_graph()
        g = tf.Graph()
        with g.as_default():

            images, labels = self.get_next()

            global_step, _, _, logits = resnet_train_eval.build_model(
                mode='train',
                images=images,
                labels=labels,
                training_method=training_method,
                num_classes=FLAGS.num_classes,
                depth=FLAGS.resnet_depth,
                width=FLAGS.resnet_width)

            tf.losses.sparse_softmax_cross_entropy(labels=labels,
                                                   logits=logits)

            total_loss = tf.losses.get_total_loss(
                add_regularization_losses=True)

            learning_rate = 0.1

            opt = tf.train.MomentumOptimizer(learning_rate,
                                             momentum=FLAGS.momentum,
                                             use_nesterov=True)

            if training_method in ['threshold']:
                # Create a pruning object using the pruning hyperparameters
                pruning_obj = pruning.Pruning()

                logging.info('starting mask update op')
                mask_update_op = pruning_obj.conditional_mask_update_op()

            # Create the training op
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = opt.minimize(total_loss, global_step)

            init_op = tf.global_variables_initializer()

            with self.test_session() as sess:
                # test that we can train successfully for 1 step
                sess.run(init_op)
                for _ in range(1):
                    sess.run(train_op)
                    if training_method in ['threshold']:
                        sess.run(mask_update_op)
Exemplo n.º 12
0
 def __init__(self,
              actor_input_dim,
              actor_output_dim,
              model_path,
              redundancy=None,
              last_measure=10e4,
              tau=0.01):
     super(StudentActor, self).__init__(model_path=model_path)
     self.actor_input_dim = (None, actor_input_dim)
     self.actor_output_dim = (None, actor_output_dim)
     self.tau = tau
     self.redundancy = redundancy
     self.last_measure = last_measure
     with self.graph.as_default():
         self.actor_global_step = tf.Variable(0, trainable=False)
         self._build_placeholders()
         self.actor_logits = self._build_actor()
         # self.gumbel_dist = self._build_gumbel(self.actor_logits)
         self.loss = self._build_loss()
         self.actor_parameters = tf.get_collection(
             tf.GraphKeys.TRAINABLE_VARIABLES, scope='actor')
         self.actor_pruned_weight_matrices = pruning.get_masked_weights()
         self.actor_train_op = self._build_actor_train_op()
         self.actor_saver = tf.train.Saver(var_list=self.actor_parameters,
                                           max_to_keep=100)
         self.init_variables(tf.global_variables())
         self.sparsity = pruning.get_weight_sparsity()
         self.hparams = pruning.get_pruning_hparams() \
             .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},'
                    ' sparsity_function_begin_step={},sparsity_function_end_step={},'
                    'pruning_frequency={},initial_sparsity={},'
                    ' sparsity_function_exponent={}'.format('Actor',
                                                            cfg.pruning_start,
                                                            cfg.pruning_end,
                                                            cfg.target_sparsity,
                                                            cfg.sparsity_start,
                                                            cfg.sparsity_end,
                                                            cfg.pruning_freq,
                                                            cfg.initial_sparsity,
                                                            3))
         # note that the global step plays an important part in the pruning mechanism,
         # the higher the global step the closer the sparsity is to sparsity end
         self.pruning_obj = pruning.Pruning(
             self.hparams, global_step=self.actor_global_step)
         self.mask_update_op = self.pruning_obj.conditional_mask_update_op()
         # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned
         # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP
         self.init_variables(
             tf.global_variables())  # initialize variables in graph
Exemplo n.º 13
0
 def testUpdateSingleMask(self):
     with self.test_session() as session:
         weights = variables.Variable(math_ops.linspace(1.0, 100.0, 100),
                                      name="weights")
         masked_weights = pruning.apply_mask(weights)
         sparsity = variables.Variable(0.5, name="sparsity")
         p = pruning.Pruning(sparsity=sparsity)
         p._spec.threshold_decay = 0.0
         mask_update_op = p.mask_update_op()
         variables.global_variables_initializer().run()
         masked_weights_val = masked_weights.eval()
         self.assertAllEqual(np.count_nonzero(masked_weights_val), 100)
         session.run(mask_update_op)
         masked_weights_val = masked_weights.eval()
         self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
Exemplo n.º 14
0
    def _blockMasking(self, hparams, weights, expected_mask):

        threshold = variables.Variable(0.0, name="threshold")
        sparsity = variables.Variable(0.5, name="sparsity")
        test_spec = ",".join(hparams)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

        # Set up pruning
        p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
        with self.test_session():
            variables.global_variables_initializer().run()
            _, new_mask = p._maybe_update_block_mask(weights, threshold)
            # Check if the mask is the same size as the weights
            self.assertAllEqual(new_mask.get_shape(), weights.get_shape())
            mask_val = new_mask.eval()
            self.assertAllEqual(mask_val, expected_mask)
Exemplo n.º 15
0
def set_prune_params(s):
    # Get, Print, and Edit Pruning Hyperparameters
    pruning_hparams = pruning.get_pruning_hparams()
    print("Pruning Hyperparameters:", pruning_hparams)

    # Change hyperparameters to meet our needs
    pruning_hparams.begin_pruning_step = 0
    pruning_hparams.end_pruning_step = 250
    pruning_hparams.pruning_frequency = 1
    pruning_hparams.sparsity_function_end_step = 250
    pruning_hparams.target_sparsity = s

    # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam
    p = pruning.Pruning(pruning_hparams, global_step=global_step)
    prune_op = p.conditional_mask_update_op()
    return prune_op
Exemplo n.º 16
0
 def testPartitionedVariableMasking(self):
     partitioner = partitioned_variables.variable_axis_size_partitioner(40)
     with self.test_session() as session:
         with variable_scope.variable_scope("", partitioner=partitioner):
             sparsity = variables.Variable(0.5, name="Sparsity")
             weights = variable_scope.get_variable(
                 "weights", initializer=math_ops.linspace(1.0, 100.0, 100))
             masked_weights = pruning.apply_mask(
                 weights, scope=variable_scope.get_variable_scope())
         p = pruning.Pruning(sparsity=sparsity, partitioner=partitioner)
         p._spec.threshold_decay = 0.0
         mask_update_op = p.mask_update_op()
         variables.global_variables_initializer().run()
         masked_weights_val = masked_weights.eval()
         session.run(mask_update_op)
         masked_weights_val = masked_weights.eval()
         self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
Exemplo n.º 17
0
def pruning_params(global_step, begin_step=0, end_step=-1, pruning_freq=10,
                   sparsity_function=2000, target_sparsity=.50, sparsity_exponent=1.0):
    """
    Creates the pruning op
    :param global_step: the global step, needed for pruning
    :param begin_step: the global step at which to begin pruning
    :param end_step: the global step at which to end pruning
    :param pruning_freq: the frequency of global step for when to prune
    :param sparsity_function: the global step used as the end point for the gradual sparsity function
    :param target_sparsity: the target sparsity
    :param sparsity_exponent: the exponent for the sparsity function
    :return: Pruning op
    """
    pruning_hparams = pruning.get_pruning_hparams()
    pruning_hparams.begin_pruning_step = begin_step
    pruning_hparams.end_pruning_step = end_step
    pruning_hparams.pruning_frequency = pruning_freq
    pruning_hparams.sparsity_function_end_step = sparsity_function
    pruning_hparams.target_sparsity = target_sparsity
    pruning_hparams.sparsity_function_exponent = sparsity_exponent
    p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=target_sparsity)
    p_op = p.conditional_mask_update_op()
    p.add_pruning_summaries()
    return p_op
Exemplo n.º 18
0
def train():
    with tf.Graph().as_default():
        with tf.device('/gpu:' + str(GPU_INDEX)):
            pointclouds_pl = MODEL.placeholder_input(BATCH_SIZE, NUM_POINT)
            labels_pl = MODEL.placeholder_label(BATCH_SIZE)
            if not FLAGS.quantize_delay:
                is_training = tf.placeholder(tf.bool, shape=(), name="is_training")
            else:
                is_training = True

            # Note the global_step=batch parameter to minimize.
            # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains.
            batch = tf.Variable(0)
            # bn_decay = BN_INIT_DECAY
            bn_decay = get_bn_decay(batch)
            tf.summary.scalar('bn_decay', bn_decay)

            # Get model
            pred, end_points = MODEL.get_network(pointclouds_pl, is_training,
                                                 bn_decay=bn_decay,
                                                 dynamic=DYNAMIC,
                                                 STN=STN,
                                                 scale=SCALE,
                                                 concat_fea=CONCAT)

            # Parse pruning hyperparameters
            pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)

            # Create a pruning object using the pruning specification
            p = pruning.Pruning(pruning_hparams, global_step=batch)

            # Add conditional mask update op. Executing this op will update all
            # the masks in the graph if the current global step is in the range
            # [begin_pruning_step, end_pruning_step] as specified by the pruning spec
            mask_update_op = p.conditional_mask_update_op()

            # Add summaries to keep track of the sparsity in different layers during training
            p.add_pruning_summaries()

            if FLAGS.quantize_delay and FLAGS.quantize_delay > 0:
                quant_scopes = ["DGCNN/get_edge_feature", "DGCNN/get_edge_feature_1", "DGCNN/get_edge_feature_2",
                                "DGCNN/get_edge_feature_3", "DGCNN/get_edge_feature_4", "DGCNN/agg",
                                "DGCNN/transform_net", "DGCNN/Transform", "DGCNN/dgcnn1", "DGCNN/dgcnn2",
                                "DGCNN/dgcnn3", "DGCNN/dgcnn4",
                                "PointNet"]
                tf.contrib.quantize.create_training_graph(
                    quant_delay=FLAGS.quantize_delay)
                for scope in quant_scopes:
                    my_quantization.experimental_create_training_graph(quant_delay=FLAGS.quantize_delay,
                                                                       scope=scope)

            # Get loss
            loss = MODEL.get_loss(pred, labels_pl, end_points)
            regularization_losses = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
            all_losses = []
            all_losses.append(loss)
            all_losses.append(tf.add_n(regularization_losses))
            total_loss = tf.add_n(all_losses)

            # tf.summary.scalar('loss', loss)
            tf.summary.scalar('loss', total_loss)

            correct = tf.equal(tf.argmax(pred, 1), tf.cast(labels_pl, tf.int64))
            accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE)
            tf.summary.scalar('accuracy', accuracy)

            # if update_ops:
            #     print("BN parameters: ", update_ops)
            #     updates = tf.group(*update_ops)
            #     train_step = control_flow_ops.with_dependencies([updates], batch)

            # Get training operator
            learning_rate = get_learning_rate(batch)
            tf.summary.scalar('learning_rate', learning_rate)
            if OPTIMIZER == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM)
            elif OPTIMIZER == 'adam':
                optimizer = tf.train.AdamOptimizer(learning_rate)

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies([tf.group(*update_ops)]):
                train_op = optimizer.minimize(total_loss, global_step=batch)
                # train_op = slim.learning.create_train_op(total_loss, optimizer)

            # Add ops to save and restore all the variables.
            saver = tf.train.Saver(max_to_keep=51)

        # Create a session
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)

        # Add summary writers
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'),
                                             sess.graph)
        test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test'))

        # Init variables
        init = tf.global_variables_initializer()
        # To fix the bug introduced in TF 0.12.1 as in
        # http://stackoverflow.com/questions/41543774/invalidargumenterror-for-tensor-bool-tensorflow-0-12-1
        sess.run(init)
        # sess.run(init, {is_training_pl: True})
        if FLAGS.quantize_delay and FLAGS.quantize_delay > 0:
            ops = {'pointclouds_pl': pointclouds_pl,
                   'labels_pl': labels_pl,
                   # 'is_training_pl': is_training,
                   'pred': pred,
                   'loss': loss,
                   'train_op': train_op,
                   'merged': merged,
                   'step': batch,
                   # 'mask_update_op': mask_update_op
                   }
        else:
            ops = {'pointclouds_pl': pointclouds_pl,
                   'labels_pl': labels_pl,
                   'is_training_pl': is_training,
                   'pred': pred,
                   'loss': loss,
                   'train_op': train_op,
                   'merged': merged,
                   'step': batch,
                   # 'mask_update_op': mask_update_op
                   }

        ever_best = 0
        if CHECKPOINT:
            saver.restore(sess, CHECKPOINT)
        for epoch in range(MAX_EPOCH):
            log_string(('**** EPOCH %03d ****' % (epoch))
                       + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + '****')
            sys.stdout.flush()

            ma = train_one_epoch(sess, ops, train_writer)
            if not FLAGS.quantize_delay:
                ma = eval_one_epoch(sess, ops, test_writer)

                # Save the variables to disk.

                if ma > ever_best:
                    save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"))
                    log_string("Model saved in file: %s" % save_path)
                    ever_best = ma
                log_string("Current model mean accuracy: {}".format(ma))
                log_string("Best model mean accuracy: {}".format(ever_best))
            else:
                if epoch % 5 == 0:
                    if CHECKPOINT:
                        save_path = saver.save(sess, os.path.join(LOG_DIR, "model-r-{}.ckpt".format(str(epoch))))
                    else:
                        save_path = saver.save(sess, os.path.join(LOG_DIR, "model-{}.ckpt".format(str(epoch))))
                    log_string("Model saved in file: %s" % save_path)
Exemplo n.º 19
0
def main(unused_args):
    tf.set_random_seed(FLAGS.seed)
    tf.get_variable_scope().set_use_resource(True)
    np.random.seed(FLAGS.seed)

    # Load the MNIST data and set up an iterator.
    mnist_data = input_data.read_data_sets(FLAGS.mnist,
                                           one_hot=False,
                                           validation_size=0)
    train_images = mnist_data.train.images
    test_images = mnist_data.test.images
    if FLAGS.input_mask_path:
        reader = tf.train.load_checkpoint(FLAGS.input_mask_path)
        input_mask = reader.get_tensor('layer1/mask')
        indices = np.sum(input_mask, axis=1) != 0
        train_images = train_images[:, indices]
        test_images = test_images[:, indices]
    dataset = tf.data.Dataset.from_tensor_slices(
        (train_images, mnist_data.train.labels.astype(np.int32)))
    num_batches = mnist_data.train.images.shape[0] // FLAGS.batch_size
    dataset = dataset.shuffle(buffer_size=mnist_data.train.images.shape[0])
    batched_dataset = dataset.repeat(FLAGS.num_epochs).batch(FLAGS.batch_size)
    iterator = batched_dataset.make_one_shot_iterator()

    test_dataset = tf.data.Dataset.from_tensor_slices(
        (test_images, mnist_data.test.labels.astype(np.int32)))
    num_test_images = mnist_data.test.images.shape[0]
    test_dataset = test_dataset.repeat(FLAGS.num_epochs).batch(num_test_images)
    test_iterator = test_dataset.make_one_shot_iterator()

    # Set up loss function.
    use_model_pruning = FLAGS.training_method != 'baseline'

    if FLAGS.network_type == 'fc':
        cross_entropy_train, _ = mnist_network_fc(
            iterator.get_next(), model_pruning=use_model_pruning)
        cross_entropy_test, accuracy_test = mnist_network_fc(
            test_iterator.get_next(),
            reuse=True,
            model_pruning=use_model_pruning)
    else:
        raise RuntimeError(FLAGS.network + ' is an unknown network type.')

    # Remove extra added ones. Current implementation adds the variables twice
    # to the collection. Improve this hacky thing.
    # TODO test the following with the convnet or any other network.
    if use_model_pruning:
        for k in ('masks', 'masked_weights', 'thresholds', 'kernel'):
            # del tf.get_collection_ref(k)[2]
            # del tf.get_collection_ref(k)[2]
            collection = tf.get_collection_ref(k)
            del collection[len(collection) // 2:]
            print(tf.get_collection_ref(k))

    # Set up optimizer and update ops.
    global_step = tf.train.get_or_create_global_step()
    batch_per_epoch = mnist_data.train.images.shape[0] // FLAGS.batch_size

    if FLAGS.optimizer != 'adam':
        if not use_model_pruning:
            boundaries = [
                int(round(s * batch_per_epoch)) for s in [60, 70, 80]
            ]
        else:
            boundaries = [
                int(round(s * batch_per_epoch))
                for s in [FLAGS.lr_drop_epoch, FLAGS.lr_drop_epoch + 20]
            ]
        learning_rate = tf.train.piecewise_constant(
            global_step,
            boundaries,
            values=[
                FLAGS.learning_rate / (3.**i)
                for i in range(len(boundaries) + 1)
            ])
    else:
        learning_rate = FLAGS.learning_rate

    if FLAGS.optimizer == 'adam':
        opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
    elif FLAGS.optimizer == 'momentum':
        opt = tf.train.MomentumOptimizer(learning_rate,
                                         FLAGS.momentum,
                                         use_nesterov=FLAGS.use_nesterov)
    elif FLAGS.optimizer == 'sgd':
        opt = tf.train.GradientDescentOptimizer(learning_rate)
    else:
        raise RuntimeError(FLAGS.optimizer + ' is unknown optimizer type')
    custom_sparsities = {
        'layer2': FLAGS.end_sparsity * FLAGS.sparsity_scale,
        'layer3': FLAGS.end_sparsity * 0
    }

    if FLAGS.training_method == 'set':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseSETOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal)
    elif FLAGS.training_method == 'static':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseStaticOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal)
    elif FLAGS.training_method == 'momentum':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseMomentumOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            momentum=FLAGS.s_momentum,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            grow_init=FLAGS.grow_init,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            use_tpu=False)
    elif FLAGS.training_method == 'rigl':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseRigLOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            initial_acc_scale=FLAGS.rigl_acc_scale,
            use_tpu=False)
    elif FLAGS.training_method == 'snip':
        opt = sparse_optimizers.SparseSnipOptimizer(
            opt,
            mask_init_method=FLAGS.mask_init_method,
            default_sparsity=FLAGS.end_sparsity,
            custom_sparsity_map=custom_sparsities,
            use_tpu=False)
    elif FLAGS.training_method in ('scratch', 'baseline', 'prune'):
        pass
    else:
        raise ValueError('Unsupported pruning method: %s' %
                         FLAGS.training_method)

    train_op = opt.minimize(cross_entropy_train, global_step=global_step)

    if FLAGS.training_method == 'prune':
        hparams_string = (
            'begin_pruning_step={0},sparsity_function_begin_step={0},'
            'end_pruning_step={1},sparsity_function_end_step={1},'
            'target_sparsity={2},pruning_frequency={3},'
            'threshold_decay={4}'.format(FLAGS.prune_begin_step,
                                         FLAGS.prune_end_step,
                                         FLAGS.end_sparsity,
                                         FLAGS.pruning_frequency,
                                         FLAGS.threshold_decay))
        pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)
        pruning_hparams.set_hparam(
            'weight_sparsity_map',
            ['{0}:{1}'.format(k, v) for k, v in custom_sparsities.items()])
        print(pruning_hparams)
        pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)
        with tf.control_dependencies([train_op]):
            train_op = pruning_obj.conditional_mask_update_op()
    weight_sparsity_levels = pruning.get_weight_sparsity()
    global_sparsity = sparse_utils.calculate_sparsity(pruning.get_masks())
    tf.summary.scalar('test_accuracy', accuracy_test)
    tf.summary.scalar('global_sparsity', global_sparsity)
    for k, v in zip(pruning.get_masks(), weight_sparsity_levels):
        tf.summary.scalar('sparsity/%s' % k.name, v)
    if FLAGS.training_method in ('prune', 'snip', 'baseline'):
        mask_init_op = tf.no_op()
        tf.logging.info('No mask is set, starting dense.')
    else:
        all_masks = pruning.get_masks()
        mask_init_op = sparse_utils.get_mask_init_fn(all_masks,
                                                     FLAGS.mask_init_method,
                                                     FLAGS.end_sparsity,
                                                     custom_sparsities)

    if FLAGS.save_model:
        saver = tf.train.Saver()
    init_op = tf.global_variables_initializer()
    hyper_params_string = '_'.join([
        FLAGS.network_type,
        str(FLAGS.batch_size),
        str(FLAGS.learning_rate),
        str(FLAGS.momentum), FLAGS.optimizer,
        str(FLAGS.l2_scale), FLAGS.training_method,
        str(FLAGS.prune_begin_step),
        str(FLAGS.prune_end_step),
        str(FLAGS.end_sparsity),
        str(FLAGS.pruning_frequency),
        str(FLAGS.seed)
    ])
    tf.io.gfile.makedirs(FLAGS.save_path)
    filename = os.path.join(FLAGS.save_path, hyper_params_string + '.txt')
    merged_summary_op = tf.summary.merge_all()

    # Run session.
    if not use_model_pruning:
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(
                FLAGS.save_path, graph=tf.get_default_graph())
            print('Epoch', 'Epoch time', 'Test loss', 'Test accuracy')
            sess.run([init_op])
            tic = time.time()
            with tf.io.gfile.GFile(filename, 'w') as outputfile:
                for i in range(FLAGS.num_epochs * num_batches):
                    sess.run([train_op])

                    if (i % num_batches) == (-1 % num_batches):
                        epoch_time = time.time() - tic
                        loss, accuracy, summary = sess.run([
                            cross_entropy_test, accuracy_test,
                            merged_summary_op
                        ])
                        # Write logs at every test iteration.
                        summary_writer.add_summary(summary, i)
                        log_str = '%d, %.4f, %.4f, %.4f' % (
                            i // num_batches, epoch_time, loss, accuracy)
                        print(log_str)
                        print(log_str, file=outputfile)
                        tic = time.time()
            if FLAGS.save_model:
                saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt'))
    else:
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(
                FLAGS.save_path, graph=tf.get_default_graph())
            log_str = ','.join([
                'Epoch', 'Iteration', 'Test loss', 'Test accuracy',
                'G_Sparsity', 'Sparsity Layer 0', 'Sparsity Layer 1'
            ])
            sess.run(init_op)
            sess.run(mask_init_op)
            tic = time.time()
            mask_records = {}
            with tf.io.gfile.GFile(filename, 'w') as outputfile:
                print(log_str)
                print(log_str, file=outputfile)
                for i in range(FLAGS.num_epochs * num_batches):
                    if (FLAGS.mask_record_frequency > 0
                            and i % FLAGS.mask_record_frequency == 0):
                        mask_vals = sess.run(pruning.get_masks())
                        # Cast into bool to save space.
                        mask_records[i] = [
                            a.astype(np.bool) for a in mask_vals
                        ]
                    sess.run([train_op])
                    weight_sparsity, global_sparsity_val = sess.run(
                        [weight_sparsity_levels, global_sparsity])
                    if (i % num_batches) == (-1 % num_batches):
                        epoch_time = time.time() - tic
                        loss, accuracy, summary = sess.run([
                            cross_entropy_test, accuracy_test,
                            merged_summary_op
                        ])
                        # Write logs at every test iteration.
                        summary_writer.add_summary(summary, i)
                        log_str = '%d, %d, %.4f, %.4f, %.4f, %.4f, %.4f' % (
                            i // num_batches, i, loss, accuracy,
                            global_sparsity_val, weight_sparsity[0],
                            weight_sparsity[1])
                        print(log_str)
                        print(log_str, file=outputfile)
                        mask_vals = sess.run(pruning.get_masks())
                        if FLAGS.network_type == 'fc':
                            sparsities, sizes = get_compressed_fc(mask_vals)
                            print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' %
                                  (sparsities, sizes))
                            print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' %
                                  (sparsities, sizes),
                                  file=outputfile)
                        tic = time.time()
            if FLAGS.save_model:
                saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt'))
            if mask_records:
                np.save(os.path.join(FLAGS.save_path, 'mask_records'),
                        mask_records)
Exemplo n.º 20
0
def train_alexnet(
        dataset_name='imagenet',
        prune=False,
        prune_params='',
        learning_rate=conf.learning_rate,
        num_epochs=conf.num_epochs,
        batch_size=conf.batch_size,
        learning_rate_decay_factor=conf.learning_rate_decay_factor,
        num_epochs_per_decay=conf.num_epochs_per_decay,
        dropout_rate=conf.dropout_rate,
        log_step=conf.log_step,
        checkpoint_step=conf.checkpoint_step,
        summary_path=conf.root_path + 'alexnet' + conf.summary_path,
        checkpoint_path=conf.root_path + 'alexnet' + conf.checkpoint_path,
        highest_accuracy_path=conf.root_path + 'alexnet' +
    conf.highest_accuracy_path,
        default_image_size=227,  #224 in the paper
):
    """prune_params: Comma separated list of pruning-related hyperparameters
       ex:'begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000'
    """
    if dataset_name is 'imagenet':
        num_class = conf.imagenet['num_class']
        train_set_size = conf.imagenet['train_set_size']
        validation_set_size = conf.imagenet['validation_set_size']
        label_offset = conf.imagenet['label_offset']
        label_path = conf.imagenet['label_path']
        dataset_path = conf.imagenet['dataset_path']

        x = tf.placeholder(
            tf.float32,
            [batch_size, default_image_size, default_image_size, 3])
        y = tf.placeholder(tf.float32, [batch_size, num_class - label_offset])
        keep_prob = tf.placeholder(tf.float32)  #placeholder for dropout rate
        # prepare to train the model
        model = AlexNet.AlexNet(x,
                                keep_prob,
                                num_class - label_offset, [],
                                prune=prune)
        # Link variable to model output
        score = model.fc8

        # List of trainable variables of the layers we want to train
        var_list = [v for v in tf.trainable_variables()]

        # Op for calculating the loss
        with tf.name_scope("cross_ent"):
            loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(logits=score,
                                                        labels=y))

        global_step = tf.Variable(0, False)
        with tf.name_scope("train"):
            # Get gradients of all trainable variables
            decay_steps = int(train_set_size / batch_size *
                              num_epochs_per_decay)
            learning_rate = tf.train.exponential_decay(
                learning_rate,
                global_step,
                decay_steps,
                learning_rate_decay_factor,
                staircase=True)
            # Create optimizer and apply gradient descent to the trainable variables
            train_op = tf.train.GradientDescentOptimizer(
                learning_rate).minimize(loss, global_step)

        # Evaluation op: Accuracy of the model
        with tf.name_scope("accuracy"):
            correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

        if prune:
            # Parse pruning hyperparameters
            prune_params = pruning.get_pruning_hparams().parse(prune_params)
            # Create a pruning object using the pruning specification
            p = pruning.Pruning(prune_params, global_step=global_step)
            # Add conditional mask update op. Executing this op will update all
            # the masks in the graph if the current global step is in the range
            # [begin_pruning_step, end_pruning_step] as specified by the pruning spec
            mask_update_op = p.conditional_mask_update_op()
            # Add summaries to keep track of the sparsity in different layers during training
            p.add_pruning_summaries()

        # Add the variables we train to the summary
        for var in var_list:
            tf.summary.histogram(var.name, var)
        # Add the loss to summary
        tf.summary.scalar('cross_entropy', loss)
        # Add the accuracy to the summary
        tf.summary.scalar('accuracy', accuracy)
        # Merge all summaries together
        merged_summary = tf.summary.merge_all()
        # Initialize the FileWriter
        writer = tf.summary.FileWriter(summary_path)

        # prepare the data
        img_train, label_train, labels_text_train = read_tfrecord(
            'train', dataset_path, default_image_size=default_image_size)
        img_validation, label_validation, labels_text_validation = read_tfrecord(
            'validation', dataset_path, default_image_size=default_image_size)
        coord = tf.train.Coordinator()

        # Initialize an saver for store model checkpoints
        saver = tf.train.Saver()

        with tf.Session() as sess:

            # Initialize all variables
            sess.run(tf.global_variables_initializer())

            # Add the model graph to TensorBoard
            writer.add_graph(sess.graph)

            # Load the pretrained weights into the non-trainable layer
            model.load_initial_weights(sess)

            #start the input pipeline queue
            threads = tf.train.start_queue_runners(sess, coord=coord)

            # load the weights from checkpoint if there exists one
            model_saved = tf.train.get_checkpoint_state(checkpoint_path)
            if model_saved and model_saved.model_checkpoint_path:
                saver.restore(sess, model_saved.model_checkpoint_path)
                print('load model from ' + model_saved.model_checkpoint_path)

            print("{} Start training...".format(datetime.now()))
            print("{} Open Tensorboard at --logdir {}".format(
                datetime.now(), summary_path))

            # Loop over number of epochs
            for epoch in range(num_epochs):
                print("{} Epoch number: {}".format(datetime.now(), epoch + 1))

                highest_accuracy = 0  #highest accuracy by far
                if os.path.exists(highest_accuracy_path):
                    f = open(highest_accuracy_path, 'r')
                    highest_accuracy = float(f.read())
                    f.close()
                    print('highest accuracy from previous training is %f' %
                          highest_accuracy)

                train_batches_per_epoch = int(
                    np.floor(train_set_size / batch_size))
                for step in range(train_batches_per_epoch):
                    # train the model
                    img, l, l_text = sess.run(
                        [img_train, label_train, labels_text_train])
                    _, sc, gl_step, lr = sess.run(
                        [train_op, score, global_step, learning_rate],
                        feed_dict={
                            x: img,
                            y: l,
                            keep_prob: dropout_rate
                        })
                    if prune:
                        # Update the masks by running the mask_update_op
                        sess.run(mask_update_op)

                    # Generate summary with the current batch of data and write to file
                    if step % log_step == 0:
                        s, aq = sess.run([merged_summary, accuracy],
                                         feed_dict={
                                             x: img,
                                             y: l,
                                             keep_prob: 1.
                                         })
                        writer.add_summary(
                            s, epoch * train_batches_per_epoch + step)
                        print(
                            "global_step:" + str(gl_step) + ';learning_rate:' +
                            str(lr) + ';accuracy:', aq)

                    #validate the model and write checkpoint if the accuracy is higher
                    if step % checkpoint_step == 0 and step != 0:
                        val_batches_per_epoch = int(
                            np.floor(validation_set_size / batch_size))
                        print("{} Start validation".format(datetime.now()))
                        test_acc = 0.
                        test_count = 0
                        for _ in range(val_batches_per_epoch
                                       ):  # val_batches_per_epoch
                            #validate the model
                            img, l, l_text = sess.run([
                                img_validation, label_validation,
                                labels_text_validation
                            ])
                            acc = sess.run(accuracy,
                                           feed_dict={
                                               x: img,
                                               y: l,
                                               keep_prob: 1.
                                           })
                            test_acc += acc
                            test_count += 1
                        test_acc /= test_count
                        print("{} Validation Accuracy = {:.4f}".format(
                            datetime.now(), test_acc))
                        # save the model if it is better than the previous best model
                        if test_acc > highest_accuracy:
                            print("{} Saving checkpoint of model...".format(
                                datetime.now()))
                            highest_accuracy = test_acc
                            # save checkpoint of the model
                            checkpoint_name = os.path.join(
                                checkpoint_path, 'model_epoch' + '.ckpt')
                            # save_path = saver.save(sess, checkpoint_name, global_step=global_step)
                            f = open(highest_accuracy_path, 'w')
                            f.write(str(highest_accuracy))
                            f.close()
                            print("{} Model checkpoint saved at {}".format(
                                datetime.now(), checkpoint_name))
            coord.request_stop()
            coord.join(threads)
def build_model():
    """Builds graph for model to train with rewrites for quantization.
  """
    g = tf.Graph()
    with g.as_default(), tf.device(
            tf.train.replica_device_setter(FLAGS.ps_tasks)):
        inputs, labels = hcl_input(is_training=True)
        #with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=True)):
        logits, _ = mobilenet_v1_prune.mobilenet_v1(
            inputs,
            is_training=True,
            depth_multiplier=FLAGS.depth_multiplier,
            num_classes=FLAGS.num_classes)

        tf.losses.softmax_cross_entropy(labels, logits)

        # Call rewriter to produce graph with fake quant ops and folded batch norms
        # quant_delay delays start of quantization till quant_delay steps, allowing
        # for better model accuracy.
        if FLAGS.quantize:
            tf.contrib.quantize.create_training_graph(
                quant_delay=get_quant_delay())

        total_loss = tf.losses.get_total_loss(name='total_loss')
        # Configure the learning rate using an exponential decay.
        num_epochs_per_decay = 2.5
        hcl_size = 4650035  #3523535
        decay_steps = int(hcl_size / FLAGS.batch_size * num_epochs_per_decay)
        global_step = tf.train.get_or_create_global_step()
        learning_rate = tf.train.exponential_decay(
            get_learning_rate(),
            global_step,  #t1f.train.get_or_create_global_step(),
            decay_steps,
            _LEARNING_RATE_DECAY_FACTOR,
            staircase=True)
        opt = tf.train.GradientDescentOptimizer(learning_rate)

        # Get, Print, and Edit Pruning Hyperparameters
        pruning_hparams = pruning.get_pruning_hparams()
        #print("Pruning Hyperparameters:", pruning_hparams)

        # Change hyperparameters to meet our needs
        pruning_hparams.begin_pruning_step = 200000
        #pruning_hparams.end_pruning_step = 250
        #pruning_hparams.pruning_frequency = 1
        #pruning_hparams.sparsity_function_end_step = 250
        pruning_hparams.target_sparsity = .5
        print("Pruning Hyperparameters:", pruning_hparams)

        # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam
        p = pruning.Pruning(pruning_hparams,
                            global_step=global_step,
                            sparsity=.5)
        prune_op = p.conditional_mask_update_op()

        train_tensor = slim.learning.create_train_op(total_loss, optimizer=opt)

    slim.summaries.add_scalar_summary(total_loss, 'total_loss', 'losses')
    slim.summaries.add_scalar_summary(learning_rate, 'learning_rate',
                                      'training')
    return g, [train_tensor, prune_op]
Exemplo n.º 22
0
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    # Parse pruning hyperparameters
    pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)

    # Create a pruning object using the pruning hyperparameters
    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

    # Use the pruning_obj to add ops to the training graph to update the masks
    # The conditional_mask_update_op will update the masks only when the
    # training step is in [begin_pruning_step, end_pruning_step] specified in
    # the pruning spec proto
    mask_update_op = pruning_obj.conditional_mask_update_op()

    # Use the pruning_obj to add summaries to the graph to track the sparsity
    # of each of the layers
    pruning_obj.add_pruning_summaries()

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1

      def before_run(self, run_context):
        self._step += 1
        self._start_time = time.time()
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        duration = time.time() - self._start_time
        loss_value = run_values.results
        if self._step % 10 == 0:
          num_examples_per_step = 128
          examples_per_sec = num_examples_per_step / duration
          sec_per_batch = float(duration)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print(format_str % (datetime.datetime.now(), self._step, loss_value,
                              examples_per_sec, sec_per_batch))

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)
        # Update the masks
        mon_sess.run(mask_update_op)
allEventFiles = os.listdir('./logs/')
for file in allEventFiles:
    os.remove('./logs/' + file)

################ PRUNING #####################
PARAM_LIST = [
    "name=FFN_Pruning_Test", "pruning_frequency=10", "target_sparsity=0.5"
]
TEST_HPARAMS = ",".join(PARAM_LIST)

# Parse pruning hyperparameters
pruning_hparams = pruning.get_pruning_hparams().parse(TEST_HPARAMS)
#pruning_hparams = model_pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)

# Create a pruning object using the pruning specification
p = pruning.Pruning(pruning_hparams, global_step=global_step)

# Add conditional mask update op. Executing this op will update all
# the masks in the graph if the current global step is in the range
# [begin_pruning_step, end_pruning_step] as specified by the pruning spec
mask_update_op = p.conditional_mask_update_op()

# Add summaries to keep track of the sparsity in different layers during training
p.add_pruning_summaries()

### Data statistics ###
tic = time.time()
featMean, featStd = dataStatistics.calcMeanAndStd(label_root_train, NFFT,
                                                  STFT_OVERLAP, BIN_WIZE)
toc = time.time()
print(toc - tic)
Exemplo n.º 24
0
def train_function(pruning_method, loss, output_dir, use_tpu):
    """Training script for resnet model.

  Args:
   pruning_method: string indicating pruning method used to compress model.
   loss: tensor float32 of the cross entropy + regularization losses.
   output_dir: string tensor indicating the directory to save summaries.
   use_tpu: boolean indicating whether to run script on a tpu.

  Returns:
    host_call: summary tensors to be computed at each training step.
    train_op: the optimization term.
  """

    global_step = tf.train.get_global_step()

    steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
    current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)
    learning_rate = lr_schedule(current_epoch)
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                           momentum=FLAGS.momentum,
                                           use_nesterov=True)

    if use_tpu:
        # use CrossShardOptimizer when using TPU.
        optimizer = contrib_tpu.CrossShardOptimizer(optimizer)

    # UPDATE_OPS needs to be added as a dependency due to batch norm
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops), tf.name_scope('train'):
        train_op = optimizer.minimize(loss, global_step)

    if not use_tpu:
        if FLAGS.num_workers > 0:
            optimizer = tf.train.SyncReplicasOptimizer(
                optimizer,
                replicas_to_aggregate=FLAGS.num_workers,
                total_num_replicas=FLAGS.num_workers)
            optimizer.make_session_run_hook(True)

    metrics = {
        'global_step': tf.train.get_or_create_global_step(),
        'loss': loss,
        'learning_rate': learning_rate,
        'current_epoch': current_epoch
    }

    if pruning_method == 'threshold':
        # construct the necessary hparams string from the FLAGS
        hparams_string = ('begin_pruning_step={0},'
                          'sparsity_function_begin_step={0},'
                          'end_pruning_step={1},'
                          'sparsity_function_end_step={1},'
                          'target_sparsity={2},'
                          'pruning_frequency={3},'
                          'threshold_decay=0,'
                          'use_tpu={4}'.format(
                              FLAGS.sparsity_begin_step,
                              FLAGS.sparsity_end_step,
                              FLAGS.end_sparsity,
                              FLAGS.pruning_frequency,
                              FLAGS.use_tpu,
                          ))

        # Parse pruning hyperparameters
        pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)

        # The first layer has so few parameters, we don't need to prune it, and
        # pruning it a higher sparsity levels has very negative effects.
        if FLAGS.prune_first_layer and FLAGS.first_layer_sparsity >= 0.:
            pruning_hparams.set_hparam(
                'weight_sparsity_map',
                ['resnet_model/initial_conv:%f' % FLAGS.first_layer_sparsity])
        if FLAGS.prune_last_layer and FLAGS.last_layer_sparsity >= 0:
            pruning_hparams.set_hparam(
                'weight_sparsity_map',
                ['resnet_model/final_dense:%f' % FLAGS.last_layer_sparsity])

        # Create a pruning object using the pruning hyperparameters
        pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

        # We override the train op to also update the mask.
        with tf.control_dependencies([train_op]):
            train_op = pruning_obj.conditional_mask_update_op()

        masks = pruning.get_masks()
        metrics.update(utils.mask_summaries(masks))
    elif pruning_method == 'scratch':
        masks = pruning.get_masks()
        # make sure the masks have the sparsity we expect and that it doesn't change
        metrics.update(utils.mask_summaries(masks))
    elif pruning_method == 'variational_dropout':
        masks = utils.add_vd_pruning_summaries(
            threshold=FLAGS.log_alpha_threshold)
        metrics.update(masks)
    elif pruning_method == 'l0_regularization':
        summaries = utils.add_l0_summaries()
        metrics.update(summaries)
    elif pruning_method == 'baseline':
        pass
    else:
        raise ValueError('Unsupported pruning method', FLAGS.pruning_method)

    host_call = (functools.partial(utils.host_call_fn,
                                   output_dir), utils.format_tensors(metrics))

    return host_call, train_op
Exemplo n.º 25
0
def train_fn(training_method, global_step, total_loss, train_dir, accuracy,
             top_5_accuracy):
  """Training script for resnet model.

  Args:
   training_method: specifies the method used to sparsify networks.
   global_step: the current step of training/eval.
   total_loss: tensor float32 of the cross entropy + regularization losses.
   train_dir: string specifying where directory where summaries are saved.
   accuracy: tensor float32 batch classification accuracy.
   top_5_accuracy: tensor float32 batch classification accuracy (top_5 classes).

  Returns:
    hooks: summary tensors to be computed at each training step.
    eval_metrics: set to None during training.
    train_op: the optimization term.
  """
  # Rougly drops at every 30k steps.
  boundaries = [30000, 60000, 90000]
  if FLAGS.training_steps_multiplier != 1.0:
    multiplier = FLAGS.training_steps_multiplier
    boundaries = [int(x * multiplier) for x in boundaries]
    tf.logging.info(
        'Learning Rate boundaries are updated with multiplier:%.2f', multiplier)

  learning_rate = tf.train.piecewise_constant(
      global_step,
      boundaries,
      values=[0.1 / (5.**i) for i in range(len(boundaries) + 1)],
      name='lr_schedule')

  optimizer = tf.train.MomentumOptimizer(
      learning_rate, momentum=FLAGS.momentum, use_nesterov=True)

  if training_method == 'set':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseSETOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal)
  elif training_method == 'static':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseStaticOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal)
  elif training_method == 'momentum':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseMomentumOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum,
        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
        grow_init=FLAGS.grow_init,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False)
  elif training_method == 'rigl':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseRigLOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
        frequency=FLAGS.maskupdate_frequency,
        drop_fraction=FLAGS.drop_fraction,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal,
        initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False)
  elif training_method == 'snip':
    optimizer = sparse_optimizers.SparseSnipOptimizer(
        optimizer, mask_init_method=FLAGS.mask_init_method,
        default_sparsity=FLAGS.end_sparsity, use_tpu=False)
  elif training_method in ('scratch', 'baseline', 'prune'):
    pass
  else:
    raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method)
  # Create the training op
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(total_loss, global_step)

  if training_method == 'prune':
    # construct the necessary hparams string from the FLAGS
    hparams_string = ('begin_pruning_step={0},'
                      'sparsity_function_begin_step={0},'
                      'end_pruning_step={1},'
                      'sparsity_function_end_step={1},'
                      'target_sparsity={2},'
                      'pruning_frequency={3},'
                      'threshold_decay=0,'
                      'use_tpu={4}'.format(
                          FLAGS.sparsity_begin_step,
                          FLAGS.sparsity_end_step,
                          FLAGS.end_sparsity,
                          FLAGS.pruning_frequency,
                          False,
                      ))
    # Parse pruning hyperparameters
    pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)

    # Create a pruning object using the pruning hyperparameters
    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

    tf.logging.info('starting mask update op')

    # We override the train op to also update the mask.
    with tf.control_dependencies([train_op]):
      train_op = pruning_obj.conditional_mask_update_op()

  masks = pruning.get_masks()
  mask_metrics = utils.mask_summaries(masks)
  for name, tensor in mask_metrics.items():
    tf.summary.scalar(name, tensor)

  tf.summary.scalar('learning_rate', learning_rate)
  tf.summary.scalar('accuracy', accuracy)
  tf.summary.scalar('total_loss', total_loss)
  tf.summary.scalar('top_5_accuracy', top_5_accuracy)
  # Logging drop_fraction if dynamic sparse training.
  if training_method in ('set', 'momentum', 'rigl', 'static'):
    tf.summary.scalar('drop_fraction', optimizer.drop_fraction)

  summary_op = tf.summary.merge_all()
  summary_hook = tf.train.SummarySaverHook(
      save_secs=300, output_dir=train_dir, summary_op=summary_op)
  hooks = [summary_hook]
  eval_metrics = None

  return hooks, eval_metrics, train_op
        pruning_hparams = pruning.get_pruning_hparams()
        print("Pruning Hyperparameters:", pruning_hparams)

        # Change hyperparameters to meet our needs
        pruning_hparams.begin_pruning_step = 0
        pruning_hparams.end_pruning_step = 250
        pruning_hparams.pruning_frequency = 1
        pruning_hparams.sparsity_function_end_step = 250
        pruning_hparams.target_sparsity = .9
        global_step = tf.train.get_or_create_global_step()

        #train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step)
        reset_global_step_op = tf.assign(global_step, 0)
        
        p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=.9)
        prune_op = p.conditional_mask_update_op()
        correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(nlabels, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        with tf.Session() as sess:
            sess.run(tf.initialize_all_variables())

            # Reset the global step counter and begin pruning
            sess.run(reset_global_step_op)
            for epoch in range(epochs):
                for batch in range(batches):
                    batch_xs, batch_ys = image_set.train.next_batch(batch_size)
                    # Prune and retrain
                    sess.run(prune_op)
                    #sess.run(train_op, feed_dict={images: batch_xs, label: batch_ys})
Exemplo n.º 27
0
def train_with_pruning():
    tf.compat.v1.reset_default_graph()

    # Inference
    network = Network(NUM_CLASSES)
    inputs = tf.compat.v1.placeholder(tf.float32, [None, INPUT_SIZE, INPUT_SIZE, INPUT_CHANNEL], 'inputs')
    logits = network.pruning_inference(inputs)

    # loss & accuracy
    labels = tf.compat.v1.placeholder(tf.int64, [None, ], 'labels')
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
    prediction = tf.argmax(tf.nn.softmax(logits), axis=1)
    acc = tf.reduce_mean(tf.cast(tf.equal(prediction, labels), dtype=tf.float32))

    # Create pruning operator
    global_step = tf.train.get_or_create_global_step()
    pruning_hparams = pruning.get_pruning_hparams()
    pruning_hparams.sparsity_function_end_step = 1000
    p = pruning.Pruning(pruning_hparams, global_step=global_step)
    mask_update_op = p.conditional_mask_update_op()
    p.add_pruning_summaries()

    # optimizer
    optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate=LEARNING_RATE, momentum=0.9)
    train_op = optimizer.minimize(loss, global_step)

    # loading data
    train_next = load_tfrecords('train')
    test_next = load_tfrecords('test')

    with tf.compat.v1.Session() as sess:
        sess.run(tf.compat.v1.global_variables_initializer())

        # summaries
        logs_dir = './logs/with_pruning'
        if not os.path.exists(logs_dir):
            os.makedirs(logs_dir)
        tf.compat.v1.summary.scalar('monitor/loss', loss)
        tf.compat.v1.summary.scalar('monitor/acc', acc)
        merged_summary_op = tf.compat.v1.summary.merge_all()
        train_summary_writer = tf.compat.v1.summary.FileWriter(os.path.join(logs_dir, 'train'), graph=sess.graph)
        test_summary_writer = tf.compat.v1.summary.FileWriter(os.path.join(logs_dir, 'test'), graph=sess.graph)

        best_acc = 0
        saver = tf.compat.v1.train.Saver()
        for epoch in range(NUM_EPOCHS):
            # training
            num_steps = TRAIN_SIZE // BATCH_SIZE
            train_acc = 0
            train_loss = 0
            for step in range(num_steps):
                x, y = sess.run(train_next)
                _, summary, train_acc_batch, train_loss_batch = sess.run([train_op, merged_summary_op, acc, loss],
                                                                         feed_dict={inputs: x, labels: y})
                sess.run(mask_update_op)
                train_acc += train_acc_batch
                train_loss += train_loss_batch
                sys.stdout.write("\r epoch %d, step %d, training accuracy %g, training loss %g" %
                                 (epoch + 1, step + 1, train_acc_batch, train_loss_batch))
                sys.stdout.flush()
                train_summary_writer.add_summary(summary, global_step=epoch * num_steps + step)
                train_summary_writer.flush()
            print("\n epoch %d, training accuracy %g, training loss %g" %
                  (epoch + 1, train_acc / num_steps, train_loss / num_steps))

            # testing
            num_steps = TEST_SIZE // BATCH_SIZE
            test_acc = 0
            test_loss = 0
            for step in range(num_steps):
                x, y = sess.run(test_next)
                summary, test_acc_batch, test_loss_batch = sess.run([merged_summary_op, acc, loss],
                                                                    feed_dict={inputs: x, labels: y})
                test_acc += test_acc_batch
                test_loss += test_loss_batch
                test_summary_writer.add_summary(summary, global_step=(epoch * num_steps + step) * (TRAIN_SIZE // TEST_SIZE))
                test_summary_writer.flush()
            print(" epoch %d, testing accuracy %g, testing loss %g" %
                  (epoch + 1, test_acc / num_steps, test_loss / num_steps))

            if test_acc / num_steps > best_acc:
                best_acc = test_acc / num_steps
                saver.save(sess, './ckpt_with_pruning/model')

        print(" Best Testing Accuracy %g" % best_acc)
Exemplo n.º 28
0
    def estimator_spec_train(self, loss, num_async_replicas=1, use_tpu=False):
        """Constructs `tf.estimator.EstimatorSpec` for TRAIN (training) mode."""
        train_op = self.optimize(loss,
                                 num_async_replicas=num_async_replicas,
                                 use_tpu=use_tpu)

        sparsity_technique = self._hparams.get("sparsity_technique")
        if "pruning" in sparsity_technique:
            if not self._hparams.load_masks_from:
                # If we are loading trained masks, don't add the mask update
                # step to the training process and keep the masks static
                with tf.control_dependencies([train_op]):
                    mp_hparams = pruning_hparams(
                        self._hparams, use_tpu,
                        sparsity_technique == "random_pruning")
                    p = magnitude_pruning.Pruning(
                        mp_hparams, global_step=tf.train.get_global_step())
                    mask_update_op = p.conditional_mask_update_op()
                    train_op = mask_update_op
            check_global_sparsity()

        if use_tpu:
            if self._hparams.warm_start_from:

                def scaffold_fn():
                    self.initialize_from_ckpt(self._hparams.warm_start_from)
                    return tf.train.Scaffold()
            elif self._hparams.load_masks_from and self._hparams.load_weights_from:

                def scaffold_fn():
                    self.initialize_masks_from_ckpt(
                        self._hparams.load_masks_from)
                    self.initialize_non_masks_from_ckpt(
                        self._hparams.load_weights_from)
                    return tf.train.Scaffold()
            elif self._hparams.load_masks_from:

                def scaffold_fn():
                    self.initialize_masks_from_ckpt(
                        self._hparams.load_masks_from)
                    return tf.train.Scaffold()
            else:
                scaffold_fn = None

            # Note: important to call this before remove_summaries()
            if self.hparams.tpu_enable_host_call:
                host_call = t2t_model.create_host_call(self.hparams.model_dir)
            else:
                host_call = None

            t2t_model.remove_summaries()

            return contrib_tpu.TPUEstimatorSpec(tf_estimator.ModeKeys.TRAIN,
                                                loss=loss,
                                                train_op=train_op,
                                                host_call=host_call,
                                                scaffold_fn=scaffold_fn)
        else:
            if self._hparams.warm_start_from:
                self.initialize_from_ckpt(self._hparams.warm_start_from)
            elif self._hparams.load_masks_from:
                self.initialize_masks_from_ckpt(self._hparams.load_masks_from)

            return tf_estimator.EstimatorSpec(tf_estimator.ModeKeys.TRAIN,
                                              loss=loss,
                                              train_op=train_op)
Exemplo n.º 29
0
    def build_graph(self, hparams, scope=None):
        """Subclass must implement this method.

        Creates a sequence-to-sequence model with dynamic RNN decoder API.
        Args:
          hparams: Hyperparameter configurations.
          scope: VariableScope for the created subgraph; default "dynamic_seq2seq".

        Returns:
          A tuple of the form (logits, loss_tuple, final_context_state, sample_id),
          where:
            logits: float32 Tensor [batch_size x num_decoder_symbols].
            loss: loss = the total loss / batch_size.
            final_context_state: the final state of decoder RNN.
            sample_id: sampling indices.

        Raises:
          ValueError: if encoder_type differs from mono and bi, or
            attention_option is not (luong | scaled_luong |
            bahdanau | normed_bahdanau).
        """
        utils.print_out("# Creating %s graph ..." % self.mode)

        # Projection
        if not self.extract_encoder_layers:
            with tf.variable_scope(scope or "build_network"):
                with tf.variable_scope("decoder/output_projection"):
                    if hparams.projection_type == 'sparse':
                        self.output_layer = core_layers.MaskedFullyConnected(
                            hparams.tgt_vocab_size,
                            use_bias=False,
                            name="output_projection")
                    elif hparams.projection_type == 'dense':
                        self.output_layer = tf.layers.Dense(
                            hparams.tgt_vocab_size,
                            use_bias=False,
                            name="output_projection")
                    else:
                        raise ValueError("Unknown projection type %s!" %
                                         hparams.projection_type)

        with tf.variable_scope(scope or "dynamic_seq2seq", dtype=self.dtype):
            # Encoder
            if hparams.language_model:  # no encoder for language modeling
                utils.print_out("  language modeling: no encoder")
                self.encoder_outputs = None
                encoder_state = None
            else:
                self.encoder_outputs, encoder_state = self._build_encoder(
                    hparams)

            # Skip decoder if extracting only encoder layers
            if self.extract_encoder_layers:
                return

            # Decoder
            logits, decoder_cell_outputs, sample_id, final_context_state = (
                self._build_decoder(self.encoder_outputs, encoder_state,
                                    hparams))

            # Loss
            if self.mode != tf.contrib.learn.ModeKeys.INFER:
                with tf.device(
                        model_helper.get_device_str(
                            self.num_encoder_layers - 1, self.num_gpus)):
                    loss = self._compute_loss(logits, decoder_cell_outputs)
            else:
                loss = tf.constant(0.0)

            # model pruning
            if hparams.pruning_hparams is not None:
                pruning_hparams = pruning.get_pruning_hparams().parse(
                    hparams.pruning_hparams)
                self.p = pruning.Pruning(pruning_hparams,
                                         global_step=self.global_step)
                self.mask_update_op = self.p.conditional_mask_update_op()
                masks = get_masks()
                thresholds = get_thresholds()
                masks_s = []
                for index, mask in enumerate(masks):
                    masks_s.append(
                        tf.summary.scalar(mask.name + '/sparsity',
                                          tf.nn.zero_fraction(mask)))
                    masks_s.append(
                        tf.summary.scalar(
                            thresholds[index].op.name + '/threshold',
                            thresholds[index]))
                    masks_s.append(
                        tf.summary.histogram(mask.name + '/mask_tensor', mask))
                self.pruning_summary = tf.summary.merge([
                    tf.summary.scalar('sparsity', self.p._sparsity),
                    tf.summary.scalar('last_mask_update_step',
                                      self.p._last_update_step)
                ] + masks_s)
            else:
                self.mask_update_op = tf.no_op()
                self.pruning_summary = tf.no_op()

            return logits, loss, final_context_state, sample_id
Exemplo n.º 30
0
def train():
    is_training = True
    # data pipeline
    imgs, true_boxes = gen_data_batch(re.sub(r'examples/', '', cfg.data_path),
                                      cfg.batch_size * cfg.train.num_gpus)
    imgs_split = tf.split(imgs, cfg.train.num_gpus)
    true_boxes_split = tf.split(true_boxes, cfg.train.num_gpus)

    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0.),
                                  trainable=False)
    lr = tf.train.piecewise_constant(global_step, cfg.train.lr_steps,
                                     cfg.train.learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate=lr)

    # Calculate the gradients for each model tower.
    tower_grads = []
    summaries_buf = []
    summaries = set()
    with tf.variable_scope(tf.get_variable_scope()):
        for i in range(cfg.train.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('%s_%d' % (cfg.train.tower, i)) as scope:
                    model = PDetNet(imgs_split[i], true_boxes_split[i],
                                    is_training)
                    loss = model.compute_loss()
                    tf.get_variable_scope().reuse_variables()
                    grads_and_vars = optimizer.compute_gradients(loss)
                    #
                    gradients_norm = summaries_gradients_norm(grads_and_vars)
                    gradients_hist = summaries_gradients_hist(grads_and_vars)
                    #summaries_buf.append(gradients_norm)
                    summaries_buf.append(gradients_hist)
                    ##sum_set = set()
                    ##sum_set.add(tf.summary.scalar("loss", loss))
                    ##summaries_buf.append(sum_set)
                    summaries_buf.append({tf.summary.scalar("loss", loss)})
                    #
                    tower_grads.append(grads_and_vars)
                    if i == 0:
                        current_loss = loss
                        update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                        vars_det = tf.get_collection(
                            tf.GraphKeys.TRAINABLE_VARIABLES, scope="PDetNet")
    grads = average_gradients(tower_grads)
    with tf.control_dependencies(update_op):
        #train_op = optimizer.minimize(loss, global_step=global_step, var_list=vars_det)
        apply_gradient_op = optimizer.apply_gradients(grads,
                                                      global_step=global_step)
        train_op = tf.group(apply_gradient_op, *update_op)

    # GPU config
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    ##pruning add by lzlu
    # Parse pruning hyperparameters
    pruning_hparams = pruning.get_pruning_hparams().parse(
        cfg.prune.pruning_hparams)

    # Create a pruning object using the pruning hyperparameters
    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

    # Use the pruning_obj to add ops to the training graph to update the masks
    # The conditional_mask_update_op will update the masks only when the
    # training step is in [begin_pruning_step, end_pruning_step] specified in
    # the pruning spec proto
    mask_update_op = pruning_obj.conditional_mask_update_op()

    # Use the pruning_obj to add summaries to the graph to track the sparsity
    # of each of the layers
    pruning_summaries = pruning_obj.add_pruning_summaries()

    summaries |= pruning_summaries
    for summ in summaries_buf:
        summaries |= summ

    summaries.add(tf.summary.scalar('lr', lr))

    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    if cfg.summary.summary_allowed:
        summary_writer = tf.summary.FileWriter(
            logdir=cfg.summary.logs_path,
            graph=sess.graph,
            flush_secs=cfg.summary.summary_secs)

    # Create a saver
    saver = tf.train.Saver()
    ckpt_dir = re.sub(r'examples/', '', cfg.ckpt_path_608)

    if cfg.train.fine_tune == 0:
        # init
        sess.run(tf.global_variables_initializer())
    else:
        saver.restore(sess, cfg.train.rstd_path)

    # running
    for i in range(0, cfg.train.max_batches):
        _, loss_, gstep, sval, _ = sess.run(
            [train_op, current_loss, global_step, summary_op, mask_update_op])
        if (i % 100 == 0):
            print(i, ': ', loss_)
        if i % 1000 == 0 and i < 10000:
            saver.save(sess,
                       ckpt_dir + str(i) + '_plate.ckpt',
                       global_step=global_step,
                       write_meta_graph=False)
        if i % 10000 == 0:
            saver.save(sess,
                       ckpt_dir + str(i) + '_plate.ckpt',
                       global_step=global_step,
                       write_meta_graph=False)
        if cfg.summary.summary_allowed and gstep % cfg.summary.summ_steps == 0:
            summary_writer.add_summary(sval, global_step=gstep)