Пример #1
0
    def add_init_params(self, init_net):
        '''
        Adds layer initialization operators to passed net.
        '''
        for param in self.params:
            # TODO(amalevich): Either return back to lambdas, that add
            # all params (looks a bit safer and breaking less
            # abstractions) or extend Net interface to this type of
            # operations better
            # TODO(xlwang) init_net._net.op has type google.protobuf.\
            # internal.containers.RepeatedCompositeFieldContainer, but
            # the version of protobuf in fbcode does not support append
            # so extend is used
            init_op = param.initializer
            current_device_scope = scope.CurrentDeviceScope()
            if not init_op:
                continue

            if not init_op.HasField('device_option') and\
                    current_device_scope:
                init_op = caffe2_pb2.OperatorDef()
                init_op.CopyFrom(param.initializer)
                init_op.device_option.CopyFrom(current_device_scope)

            # do not add duplicated init ops
            if any(
                    utils.OpAlmostEqual(op, init_op, 'debug_info')
                    for op in init_net._net.op):
                continue

            init_net._net.op.extend([init_op])
Пример #2
0
    def maybe_add_global_constant(self, name, *args, **kwargs):
        # To ad hoc add new global constants without duplication
        # if the name was already registered in global_constants, it will not be
        # added even if the intended value is different from its original value

        if name in self.global_constants:
            blob_name = self.global_constants[name]
            initializer_op = \
                LayerModelHelper._get_global_constant_initializer_op(
                    blob_name, *args, **kwargs
                )
            # check if the original initializer is the same as the one intended
            # now
            assert utils.OpAlmostEqual(
                initializer_op,
                self.global_constant_initializers[blob_name],
                'debug_info'
            ), \
                "conflict initializers for global constant %s, " \
                "previous %s, now %s" % (
                    blob_name, str(initializer_op),
                    str(self.global_constant_initializers[blob_name]))
            return blob_name
        return self.add_global_constant(name, *args, **kwargs)