コード例 #1
0
    def __init__(self, component, num_op_records=None, args=None, kwargs=None):
        """
        Args:
            component (Component): The Component to which this column belongs.
        """
        self.id = self.get_id()

        if num_op_records is None:
            self.op_records = []
            if args is not None:
                for i in range(len(args)):
                    if args[i] is None:
                        continue
                    op_rec = DataOpRecord(op=None, column=self, position=i)
                    # If incoming is an op-rec -> Link them.
                    if isinstance(args[i], DataOpRecord):
                        op_rec.previous = args[i]
                        op = args[i].op
                        if op is not None:
                            op_rec.op = op
                            op_rec.space = get_space_from_op(op)
                        args[i].next.add(op_rec)
                    # Do constant value assignment here.
                    elif args[i] is not None:
                        op = args[i]
                        if is_constant(op) and not isinstance(op, np.ndarray):
                            op = np.array(op, dtype=convert_dtype(type(op), "np"))
                        op_rec.op = op
                        op_rec.space = get_space_from_op(op)
                        component.constant_op_records.add(op_rec)
                    self.op_records.append(op_rec)

            if kwargs is not None:
                for key in sorted(kwargs.keys()):
                    value = kwargs[key]
                    if value is None:
                        continue
                    op_rec = DataOpRecord(op=None, column=self, kwarg=key)
                    # If incoming is an op-rec -> Link them.
                    if isinstance(value, DataOpRecord):
                        op_rec.previous = value
                        op_rec.op = value.op  # assign op if any
                        value.next.add(op_rec)
                    # Do constant value assignment here.
                    elif value is not None:
                        op = value
                        if is_constant(op):
                            op = np.array(op, dtype=convert_dtype(type(op), "np"))
                        op_rec.op = op
                        op_rec.space = get_space_from_op(op)
                        component.constant_op_records.add(op_rec)
                    self.op_records.append(op_rec)
        else:
            self.op_records = [DataOpRecord(op=None, column=self, position=i) for i in range(num_op_records)]

        # For __str__ purposes.
        self.op_id_list = [o.id for o in self.op_records]

        # The component this column belongs to.
        self.component = component
コード例 #2
0
    def setup_graph(self):
        """
        Generates the tf-Graph object and enters its scope as default graph.
        Also creates the global time step variable.
        """
        self.graph = tf.Graph()
        self.graph_default_context = self.graph.as_default()
        self.graph_default_context.__enter__()

        # Create global training (update) timestep. Gets increased once per update.
        # Do not include this in GLOBAL_STEP collection as only one variable (`global_timestep`) should be in there.
        self.global_training_timestep = tf.get_variable(
            name="global-training-timestep",
            dtype=util.convert_dtype("int"),
            trainable=False,
            initializer=0,
            collections=["global-training-timestep"])
        # Create global (env-stepping) timestep. Gets increased once per environment step.
        # For vector-envs, gets increased each action by the number of parallel environments.
        self.global_timestep = tf.get_variable(
            name="global-timestep",
            dtype=util.convert_dtype("int"),
            trainable=False,
            initializer=0,
            collections=["global-timestep", tf.GraphKeys.GLOBAL_STEP])

        # Set the random seed graph-wide.
        if self.seed is not None:
            self.logger.info(
                "Initializing TensorFlow graph with seed {}".format(self.seed))
            tf.set_random_seed(self.seed)
コード例 #3
0
    def setup_graph(self):
        """
        Generates the tf-Graph object and enters its scope as default graph.
        Also creates the global time step variable.
        """
        self.graph = tf.Graph()
        self.graph_default_context = self.graph.as_default()
        self.graph_default_context.__enter__()

        self.global_training_timestep = tf.get_variable(
            name="global-timestep", dtype=util.convert_dtype("int"), trainable=False, initializer=0,
            collections=["global-timestep", tf.GraphKeys.GLOBAL_STEP])

        # Set the random seed graph-wide.
        if self.seed is not None:
            self.logger.info("Initializing TensorFlow graph with seed {}".format(self.seed))
            tf.set_random_seed(self.seed)
コード例 #4
0
    def setup_scaffold(self):
        """
        Creates a tf.train.Scaffold object to be used by the session to initialize variables and to save models
        and summaries.
        Assigns the scaffold object to `self.scaffold`.
        """
        # Determine init_op and ready_op.
        var_list = list(self.graph_builder.root_component.variables.values())

        self.global_training_timestep = tf.get_variable(
            name="global-timestep", dtype=util.convert_dtype("int"), trainable=False, initializer=0,
            collections=["global-timestep", tf.GraphKeys.GLOBAL_STEP])
        var_list.append(self.global_training_timestep)

        # We can not fetch optimizer vars.
        # TODO let graph builder do this
        if self.optimizer is not None:
            var_list.extend(self.optimizer.get_optimizer_variables())
            # If the VF has a separate optimizer (non-shared network), we need to fetch its vars here as well.
            if self.vf_optimizer is not None:
                var_list.extend(self.vf_optimizer.get_optimizer_variables())

        if self.execution_mode == "single":
            self.init_op = tf.variables_initializer(var_list=var_list)
            self.ready_op = tf.report_uninitialized_variables(var_list=var_list)
        else:
            assert self.execution_mode == "distributed",\
                "ERROR: execution_mode can only be 'single' or 'distributed'! Is '{}'.".format(self.execution_mode)
            local_job_and_task = "/job:{}/task:{}/".format(self.execution_spec["distributed_spec"]["job"],
                                                          self.execution_spec["distributed_spec"]["task_index"])
            var_list_local = [var for var in var_list if not var.device or local_job_and_task in var.device]
            var_list_remote = [var for var in var_list if var.device and local_job_and_task not in var.device]
            self.init_op = tf.variables_initializer(var_list=var_list_remote)
            self.ready_for_local_init_op = tf.report_uninitialized_variables(var_list=var_list_remote)
            self.local_init_op = tf.variables_initializer(var_list=var_list_local)
            self.ready_op = tf.report_uninitialized_variables(var_list=var_list)

        def init_fn(scaffold, session):
            # NOTE: `self.load_from_file` is either True or a string value.
            # - No specific file given -> Use latest checkpoint.
            saver_dir = self.saver_spec.get("directory", "") if self.saver_spec else ""
            if self.load_from_file is True:
                assert self.saver_spec is not None,\
                    "ERROR: load_from_file is True but no saver_spec with 'directory' provided"
                file = tf.train.latest_checkpoint(
                    checkpoint_dir=saver_dir,
                    latest_filename=None
                )
            # - File given -> Look for it in cwd, then in our checkpoint directory.
            else:
                assert isinstance(self.load_from_file, str)
                file = self.load_from_file
                if not os.path.isfile(file):
                    file = os.path.join(saver_dir, self.load_from_file)

            if file is not None:
                scaffold.saver.restore(sess=session, save_path=file)

        # Create the tf.train.Scaffold object. Monitoring cannot be disabled for this.
        if not self.disable_monitoring:
            self.scaffold = tf.train.Scaffold(
                init_op=self.init_op,
                init_feed_dict=None,
                init_fn=init_fn if self.load_from_file else None,
                ready_op=self.ready_op,
                ready_for_local_init_op=self.ready_for_local_init_op,
                local_init_op=self.local_init_op,
                summary_op=self.summary_op,
                saver=self.saver,
                copy_from_scaffold=None
            )
コード例 #5
0
    def __init__(self, component, num_op_records=None, args=None, kwargs=None):
        """
        Args:
            component (Component): The Component to which this column belongs.
        """
        self.id = self.get_id()

        if num_op_records is None:
            self.op_records = []
            if args is not None:
                args = list(args)
                for i in range(len(args)):
                    if args[i] is None:
                        continue
                    op_rec = DataOpRecord(op=None, column=self, position=i)

                    # Dict instead of a DataOpRecord -> Translate on the fly into a DataOpRec held by a
                    # ContainerMerger Component.
                    if isinstance(args[i], dict):
                        items = args[i].items()
                        keys = [k for k, _ in items]
                        values = [v for _, v in items]
                        if isinstance(values[0], DataOpRecord):
                            merger_component = values[0].column.component.get_helper_component(
                                "container-merger", _args=list(keys)
                            )
                            args[i] = merger_component.merge(*list(values))
                    # Tuple instead of a DataOpRecord -> Translate on the fly into a DataOpRec held by a
                    # ContainerMerger Component.
                    elif isinstance(args[i], tuple) and isinstance(args[i][0], DataOpRecord):
                        merger_component = args[i][0].column.component.get_helper_component(
                            "container-merger", _args=len(args[i])
                        )
                        args[i] = merger_component.merge(*args[i])

                    # If incoming is an op-rec -> Link them.
                    if isinstance(args[i], DataOpRecord):
                        args[i].connect_to(op_rec)
                    # Do constant value assignment here.
                    elif args[i] is not None:
                        op = args[i]
                        if is_constant(op) and not isinstance(op, np.ndarray):
                            op = np.array(op, dtype=convert_dtype(type(op), "np"))
                        op_rec.op = op
                        op_rec.space = get_space_from_op(op)
                        component.constant_op_records.add(op_rec)

                    self.op_records.append(op_rec)

            if kwargs is not None:
                for key in sorted(kwargs.keys()):
                    value = kwargs[key]
                    if value is None:
                        continue
                    op_rec = DataOpRecord(op=None, column=self, kwarg=key)
                    # If incoming is an op-rec -> Link them.
                    if isinstance(value, DataOpRecord):
                        op_rec.previous = value
                        op_rec.op = value.op  # assign op if any
                        value.next.add(op_rec)
                    # Do constant value assignment here.
                    elif value is not None:
                        op = value
                        if is_constant(op):
                            op = np.array(op, dtype=convert_dtype(type(op), "np"))
                        op_rec.op = op
                        op_rec.space = get_space_from_op(op)
                        component.constant_op_records.add(op_rec)
                    self.op_records.append(op_rec)
        else:
            self.op_records = [DataOpRecord(op=None, column=self, position=i) for i in range(num_op_records)]

        # For __str__ purposes.
        self.op_id_list = [o.id for o in self.op_records]

        # The component this column belongs to.
        self.component = component