コード例 #1
0
    def test_cross_nets_no_change(self):
        net = core.Net("test")
        init_net = core.Net("init")
        device_option = caffe2_pb2.DeviceOption()
        device_option.device_type = caffe2_pb2.CUDA
        device_option.cuda_gpu_id = 1

        with core.DeviceScope(device_option):
            weight = init_net.XavierFill([], 'fc_w', shape=[10, 100])
            bias = init_net.ConstantFill([], 'fc_b', shape=[
                10,
            ])
            net.FC(["data", weight, bias], "fc1")

        data_remap = {'data': device_option}
        nets = core.InjectDeviceCopiesAmongNetsWithoutB2D(
            [init_net, net], blob_to_device_init=data_remap)
        op = nets[1]._net.op[0]
        self.assertEqual(op.type, "FC")
        self.assertEqual(op.input[0], "data")
        self.assertEqual(op.input[1], "fc_w")
        self.assertEqual(op.input[2], "fc_b")
        self.assertEqual(op.device_option.device_type, 1)
        self.assertEqual(op.device_option.cuda_gpu_id, 1)
        """
コード例 #2
0
ファイル: core_test.py プロジェクト: simmon2014/caffe2
    def test_cross_nets_no_change(self):
        net = core.Net("test")
        init_net = core.Net("init")
        device_option = caffe2_pb2.DeviceOption()
        device_option.device_type = caffe2_pb2.CUDA
        device_option.cuda_gpu_id = 1

        with core.DeviceScope(device_option):
            weight = init_net.XavierFill([], 'fc_w', shape=[10, 100])
            bias = init_net.ConstantFill([], 'fc_b', shape=[
                10,
            ])
            net.FC(["data", weight, bias], "fc1")

        data_remap = {'data': device_option}
        nets = core.InjectDeviceCopiesAmongNetsWithoutB2D(
            [init_net, net], blob_to_device_init=data_remap)
        ref_str = """
name: ""
op {
  input: "data"
  input: "fc_w"
  input: "fc_b"
  output: "fc1"
  name: ""
  type: "FC"
  device_option {
    device_type: 1
    cuda_gpu_id: 1
  }
}
external_input: "data"
external_input: "fc_w"
external_input: "fc_b"
"""
        nets[1].Proto().name = ''  # Ignore the name
        self.assertEqual(str(nets[1].Proto()).strip(), ref_str.strip())