Esempio n. 1
0
    def test_optimizer_context(self):
        from caffe2.python import brew, optimizer
        from caffe2.python.model_helper import ModelHelper

        model = ModelHelper(name="test", arg_scope={'order': 'NCHW'})
        count = optimizer._optimizer_instance_count['SgdOptimizer']
        cnv_optim = SgdOptimizer(0.15)
        weight_optim = SgdOptimizer(0.2)
        bias_optim = SgdOptimizer(0.1)

        with UseOptimizer(cnv_optim):
            cnv = brew.conv(model, 'data', 'cnv', 32, 32, 4)
        with UseOptimizer({'WEIGHT': weight_optim, 'BIAS': bias_optim}):
            a = brew.fc(model, cnv, 'a', 100, 200)
        pred = brew.fc(model, a, 'b', 200, 5)
        (softmax, loss) = model.SoftmaxWithLoss(
            [pred, 'label'],
            ['softmax', 'loss'],
        )
        model.AddGradientOperators([loss])

        add_weight_decay(model, weight_decay=1e-4)
        # use the following optimizer if none specified in param_info
        build_sgd(model, 0.11)
        expected_weight_grad = {'b_w_grad', 'a_w_grad', 'cnv_w_grad'}
        expected_learning_rate = {
            "SgdOptimizer_{}_lr_cpu".format(count): -0.15,
            "SgdOptimizer_{}_lr_cpu".format(count + 1): -0.2,
            "SgdOptimizer_{}_lr_cpu".format(count + 2): -0.1,
            "SgdOptimizer_{}_lr_cpu".format(count + 3): -0.11
        }

        for op in model.net.Proto().op:
            # Check the proto that all weights are decayed and not non-weights
            # are decayed.
            if op.type == 'WeightedSum' and 'wd_0_0' in op.input:
                if op.output[0] not in expected_weight_grad:
                    print(
                        "Unexpected param for weight_decay: {}".
                        format(op.output[0])
                    )
                self.assertTrue(op.output[0] in expected_weight_grad)
                expected_weight_grad.remove(op.output[0])
            # Check the learning rate for each parameter
            if op.type == 'LearningRate':
                val = 0
                for arg in op.arg:
                    if arg.name == 'base_lr':
                        val = arg.f
                self.assertAlmostEqual(
                    val,
                    expected_learning_rate[op.output[0]]
                )

        self.assertEqual(
            expected_weight_grad,
            set(),
            "Not all weights were decayed: {}".format(expected_weight_grad)
        )
Esempio n. 2
0
    def test_regularizer_context(self, X):
        weight_reg_out = L1Norm(0.2)
        bias_reg_out = L1Norm(0)
        regularizers = {"WEIGHT": weight_reg_out, "BIAS": bias_reg_out}

        output_dims = 2
        input_record = self.new_record(schema.Scalar((np.float32, (5, ))))
        schema.FeedRecord(input_record, [X])

        with UseRegularizer(regularizers):
            weight_reg = RegularizerContext.current().get_regularizer("WEIGHT")
            bias_reg = RegularizerContext.current().get_regularizer("BIAS")
            optim = SgdOptimizer(0.15)

            assert (weight_reg == weight_reg_out
                    ), "fail to get correct weight reg from context"
            assert bias_reg == bias_reg_out, "fail to get correct bias reg from context"
            fc_output = self.model.FC(
                input_record,
                output_dims,
                weight_optim=optim,
                bias_optim=optim,
                weight_reg=weight_reg,
                bias_reg=bias_reg,
            )
            # model.output_schema has to a struct
            self.model.output_schema = schema.Struct(("fc_output", fc_output))

            self.assertEqual(schema.Scalar((np.float32, (output_dims, ))),
                             fc_output)

            _, train_net = layer_model_instantiator.generate_training_nets(
                self.model)
            ops = train_net.Proto().op
            ops_type_list = [ops[i].type for i in range(len(ops))]
            assert ops_type_list.count("LpNorm") == 2
            assert ops_type_list.count("Scale") == 4
            assert ops_type_list.count("LpNormGradient") == 2