示例#1
0
def _get_loss(event_file):
    with open(event_file, 'rb') as stream:
        while True:
            header = stream.read(8)
            if not header:
                break

            header_crc, = struct.unpack('I', stream.read(4))

            assert masked_crc32c(header) == header_crc

            str_len, = struct.unpack('Q', header)
            event_str = stream.read(str_len)
            body_crc, = struct.unpack('I', stream.read(4))

            assert masked_crc32c(event_str) == body_crc

            event = event_pb2.Event()
            event.ParseFromString(event_str)

            if len(event.summary.value) > 0:
                value = event.summary.value[0]
                if value.tag == 'best_loss':
                    return event.step, value.simple_value

    raise AssertionError('could not find "loss" value in event file: ' + event_file)
示例#2
0
 def test_event_file_writer_roundtrip(self):
     _TAGNAME = 'dummy'
     _DUMMY_VALUE = 42
     logdir = self.get_temp_dir()
     w = EventFileWriter(logdir)
     summary = Summary(
         value=[Summary.Value(tag=_TAGNAME, simple_value=_DUMMY_VALUE)])
     fakeevent = event_pb2.Event(summary=summary)
     w.add_event(fakeevent)
     w.close()
     event_files = sorted(glob.glob(os.path.join(logdir, '*')))
     self.assertEqual(len(event_files), 1)
     r = PyRecordReader_New(event_files[0])
     r.GetNext()  # meta data, so skip
     r.GetNext()
     self.assertEqual(fakeevent.SerializeToString(), r.record())
示例#3
0
    def _read_event(self):
        # Read the header which tells the length of the event string
        header_str = self._buf[self._index:self._index + 8]
        header = struct.unpack('Q', header_str)
        self._index += 12

        # Read the event string
        header_len = int(header[0])
        event_str = self._buf[self._index:self._index + header_len]
        self._index += (header_len + 4)

        # Parse event string
        ev = event_pb2.Event()
        try:
            ev.ParseFromString(event_str)
        except:
            raise

        return ev
示例#4
0
    def test_writer(self):
        TEST_LEN = 100
        N_PROC = 4
        writer = SummaryWriter()
        event_filename = writer.file_writer.event_writer._ev_writer._file_name

        predifined_values = list(range(TEST_LEN))

        def train3():
            for i in range(TEST_LEN):
                writer.add_scalar('many_write_in_func', predifined_values[i])
                time.sleep(0.01 * np.random.randint(0, 10))

        processes = []
        for i in range(N_PROC):
            p1 = mp.Process(target=train3)
            processes.append(p1)
            p1.start()

        for p in processes:
            p.join()
        writer.close()

        collected_values = []
        r = PyRecordReader_New(event_filename)
        r.GetNext()  # meta data, so skip
        for _ in range(TEST_LEN * N_PROC):  # all of the data should be flushed
            r.GetNext()
            ev = event_pb2.Event()
            value = ev.FromString(r.record()).summary.value
            collected_values.append(value[0].simple_value)

        collected_values = sorted(collected_values)
        for i in range(TEST_LEN):
            for j in range(N_PROC):
                assert collected_values[i * N_PROC + j] == i
示例#5
0
    def from_variable(self, leaf, output_name="output"):
        def parse_variable(v, var_num):
            def add_variable(v, v_idx):
                v_name = parameters.get(v.data, None)
                exist = False
                if not v_name:
                    v_name, exist = get_variable_name(v, v_idx)
                if not exist:
                    shape_proto = TensorShapeProto(
                        dim=[TensorShapeProto.Dim(size=d) for d in v.shape])

                    if v.parent is None:
                        inputs = []
                    else:
                        inputs = [get_func_name(v.parent)]
                    # print("Variable: {}:{}".format(v_name, inputs))
                    nodes.append(NodeDef(
                        name=v_name.encode(encoding='utf-8'),
                        op='Variable',
                        input=inputs,
                        attr={
                            'shape': AttrValue(shape=shape_proto),
                            'dtype': AttrValue(type=DT_FLOAT)
                        }
                    ))
                return v_name

            def get_unique_variable_name(v_name_base):
                v_num = 0
                v_name = v_name_base + str(v_num)
                while v_name in unique_var_names:
                    v_num += 1
                    v_name = v_name_base + str(v_num)
                unique_var_names.add(v_name)
                return v_name

            def get_variable_name(v, v_idx):
                v_name = variables.get(v, None)
                if v_name:
                    return v_name, True
                else:
                    if v.parent is None:
                        v_name_base = "Input"
                        v_name = get_unique_variable_name(v_name_base)
                    elif not nodes:
                        v_name = output_name
                    else:
                        f_name_sections = get_func_name(v.parent).split("/")
                        f_name = f_name_sections[-1]
                        f_scope = f_name_sections[:-1]
                        base_name = "variable<-{}".format(f_name)
                        v_name_base = "/".join(f_scope + [base_name])
                        v_name = get_unique_variable_name(v_name_base)

                    variables[v] = v_name
                    return v_name, False

            def get_func_name(func):
                func_name = func_names.get(func, None)
                if func_name:
                    return func_name
                name_scope = loc_var['name_scope']
                for v in func.inputs:
                    v_name = self.parameters.get(v.data, None)
                    if v_name:
                        name_scope = '/'.join(v_name.split('/')[:-1])
                        break
                if name_scope:
                    func_name_base = '/'.join([name_scope, func.name])
                else:
                    func_name_base = func.name
                func_num = 0
                func_name = func_name_base + str(func_num)
                while func_name in unique_func_names:
                    func_num += 1
                    func_name = func_name_base + str(func_num)
                unique_func_names.add(func_name)
                func_names[func] = func_name
                return func_name

            def add_func(v):
                input_names = []
                for index, v_input in enumerate(v.parent.inputs):
                    v_name = add_variable(v_input, index)
                    input_names.append(v_name)
                # print("Function: {}:{}".format(get_func_name(v.parent), input_names))
                f_name = get_func_name(v.parent)
                if f_name in func_set:
                    return False
                attrs = []
                for k, a in v.parent.info.args.items():
                    attr = "{}={}".format(k, a)
                    attrs.append(attr)
                attr_str = ','.join(attrs).encode(encoding='utf-8')
                nodes.append(NodeDef(
                    name=f_name,
                    op=v.parent.info.type_name,
                    input=input_names,
                    attr={"parameters": AttrValue(s=attr_str)}
                ))
                func_set.add(f_name)
                return True

            name_scope = loc_var['name_scope']
            if not nodes:
                add_variable(v, var_num)
            if v.parent is None:
                add_variable(v, var_num)
            else:
                if not add_func(v):
                    return
                for idx, in_var in enumerate(v.parent.inputs):
                    name_scope_stack.append(name_scope)
                    parse_variable(in_var, idx)
                    name_scope = name_scope_stack.pop()

        nodes = []
        variables = {}
        loc_var = {}
        loc_var['name_scope'] = ''
        name_scope_stack = []
        func_names = {}
        func_set = set()
        unique_func_names = set()
        unique_var_names = set()
        parameters = {v.data: k for k,
                      v in get_parameters(grad_only=False).items()}
        parse_variable(leaf, 0)
        nodes = nodes[::-1]

        current_graph = GraphDef(node=nodes, versions=VersionDef(producer=22))
        event = event_pb2.Event(
            graph_def=current_graph.SerializeToString())
        self.file_writer.add_event(event)
示例#6
0
    def from_graph_def(self, graph_def):
        variables = graph_def.variables
        parameters = graph_def.parameters
        functions = graph_def.functions
        inputs = graph_def.inputs
        nodes = []
        scope = {}

        for n, v in parameters.items():
            shape_proto = TensorShapeProto(
                dim=[TensorShapeProto.Dim(size=d) for d in v.shape])
            node = NodeDef(
                name=n.encode(encoding='utf-8'),
                op='Parameter',
                input=[],
                attr={
                    'shape': AttrValue(shape=shape_proto),
                    'dtype': AttrValue(type=DT_FLOAT)
                }
            )
            nodes.append(node)
            scope[n] = node

        for n, v in inputs.items():
            shape_proto = TensorShapeProto(
                dim=[TensorShapeProto.Dim(size=d) for d in v.shape])
            nodes.append(NodeDef(
                name=n.encode(encoding='utf-8'),
                op='Variable',
                input=[],
                attr={
                    'shape': AttrValue(shape=shape_proto),
                    'dtype': AttrValue(type=DT_FLOAT)
                }
            ))

        for func_name, func in functions.items():
            for o in func['outputs']:
                if o in scope:
                    node = scope[o]
                    node.input.extend([func_name])
                else:
                    if o in variables:
                        v = variables[o]
                        shape_proto = TensorShapeProto(
                            dim=[TensorShapeProto.Dim(size=d) for d in v.shape])
                        node = NodeDef(
                            name=o.encode(encoding='utf-8'),
                            op='Variable',
                            input=[func_name],
                            attr={
                                'shape': AttrValue(shape=shape_proto),
                                'dtype': AttrValue(type=DT_FLOAT)
                            }
                        )
                        nodes.append(node)
            for i in func['inputs']:
                if i in variables:
                    v = variables[i]
                    shape_proto = TensorShapeProto(
                        dim=[TensorShapeProto.Dim(size=d) for d in v.shape])
                    node = NodeDef(
                        name=o.encode(encoding='utf-8'),
                        op='Variable',
                        input=[],
                        attr={
                            'shape': AttrValue(shape=shape_proto),
                            'dtype': AttrValue(type=DT_FLOAT)
                        }
                    )
                    nodes.append(node)
                    scope[o] = node
            nodes.append(NodeDef(
                name=func_name,
                op=func['type'],
                input=func['inputs'],
                attr={"arguments": AttrValue(s='a=1'.encode(encoding='utf-8'))}
            ))

        current_graph = GraphDef(node=nodes, versions=VersionDef(producer=22))
        event = event_pb2.Event(
            graph_def=current_graph.SerializeToString())
        self.file_writer.add_event(event)