Ejemplo n.º 1
0
    def _build_graph(self):
        self.data_manager = DataManager(self.env.datasets['train'],
                                        self.env.datasets['val'],
                                        self.env.datasets['test'],
                                        cfg.batch_size)
        self.data_manager.build_graph()

        data = self.data_manager.iterator.get_next()
        self.inp = data["image"]
        network_outputs = self.network(data, self.data_manager.is_training)

        network_tensors = network_outputs["tensors"]
        network_recorded_tensors = network_outputs["recorded_tensors"]
        network_losses = network_outputs["losses"]

        self.tensors = network_tensors

        self.recorded_tensors = recorded_tensors = dict(global_step=tf.train.get_or_create_global_step())

        # --- loss ---

        self.loss = tf.constant(0., tf.float32)
        for name, tensor in network_losses.items():
            self.loss += tensor
            recorded_tensors['loss_' + name] = tensor
        recorded_tensors['loss'] = self.loss

        # --- train op ---

        if cfg.do_train and not cfg.get('no_gradient', False):
            tvars = self.trainable_variables(for_opt=True)

            self.train_op, self.train_records = build_gradient_train_op(
                self.loss, tvars, self.optimizer_spec, self.lr_schedule,
                self.max_grad_norm, self.noise_schedule, grad_n_record_groups=self.grad_n_record_groups)

        sess = tf.get_default_session()
        for k, v in getattr(sess, 'scheduled_values', None).items():
            if k in recorded_tensors:
                recorded_tensors['scheduled_' + k] = v
            else:
                recorded_tensors[k] = v

        # --- recorded values ---

        intersection = recorded_tensors.keys() & network_recorded_tensors.keys()
        assert not intersection, "Key sets have non-zero intersection: {}".format(intersection)
        recorded_tensors.update(network_recorded_tensors)

        intersection = recorded_tensors.keys() & self.network.eval_funcs.keys()
        assert not intersection, "Key sets have non-zero intersection: {}".format(intersection)

        if self.network.eval_funcs:
            eval_funcs = self.network.eval_funcs
        else:
            eval_funcs = {}

        # For running functions, during evaluation, that are not implemented in tensorflow
        self.evaluator = Evaluator(eval_funcs, network_tensors, self)
Ejemplo n.º 2
0
    def _build_graph(self):
        self.data_manager = DataManager(self.env.datasets['train'],
                                        self.env.datasets['val'],
                                        self.env.datasets['test'],
                                        cfg.batch_size)
        self.data_manager.build_graph()

        data = self.data_manager.iterator.get_next()
        self.inp = data["image"]
        network_outputs = self.network(data, self.data_manager.is_training)

        network_tensors = network_outputs["tensors"]
        network_recorded_tensors = network_outputs["recorded_tensors"]
        network_losses = network_outputs["losses"]

        self.tensors = network_tensors

        self.recorded_tensors = recorded_tensors = dict(
            global_step=tf.train.get_or_create_global_step())

        # --- loss ---

        self.loss = 0.0
        for name, tensor in network_losses.items():
            self.loss += tensor
            recorded_tensors['loss_' + name] = tensor
        recorded_tensors['loss'] = self.loss

        # --- train op ---

        tvars = self.trainable_variables(for_opt=True)

        self.train_op, self.train_records = build_gradient_train_op(
            self.loss, tvars, self.optimizer_spec, self.lr_schedule,
            self.max_grad_norm, self.noise_schedule)

        # --- recorded values ---

        intersection = recorded_tensors.keys() & network_recorded_tensors.keys(
        )
        assert not intersection, "Key sets have non-zero intersection: {}".format(
            intersection)
        recorded_tensors.update(network_recorded_tensors)

        intersection = recorded_tensors.keys() & self.network.eval_funcs.keys()
        assert not intersection, "Key sets have non-zero intersection: {}".format(
            intersection)

        # For running functions, during evaluation, that are not implemented in tensorflow
        self.evaluator = Evaluator(self.network.eval_funcs, network_tensors,
                                   self)
Ejemplo n.º 3
0
    def start_stage(self, training_loop, updater, stage_idx):
        # similar to `build_graph`

        self.network = updater.network

        dataset = self.dataset_class(**self.dataset_kwargs)
        self.data_manager = DataManager(val_dataset=dataset,
                                        batch_size=cfg.batch_size)
        self.data_manager.build_graph()

        data = self.data_manager.iterator.get_next(
        )  # a dict mapping from names to tensors
        self.inp = data["image"]
        network_outputs = self.network(data, self.data_manager.is_training)

        network_tensors = network_outputs["tensors"]
        network_recorded_tensors = network_outputs["recorded_tensors"]
        network_losses = network_outputs["losses"]

        self.recorded_tensors = recorded_tensors = dict(
            global_step=tf.train.get_or_create_global_step())

        # --- loss ---

        recorded_tensors['loss'] = 0
        for name, tensor in network_losses.items():
            recorded_tensors['loss'] += tensor
            recorded_tensors['loss_' + name] = tensor
        self.loss = recorded_tensors['loss']

        intersection = recorded_tensors.keys() & network_recorded_tensors.keys(
        )
        assert not intersection, "Key sets have non-zero intersection: {}".format(
            intersection)
        recorded_tensors.update(network_recorded_tensors)

        intersection = recorded_tensors.keys() & self.network.eval_funcs.keys()
        assert not intersection, "Key sets have non-zero intersection: {}".format(
            intersection)

        # For running functions, during evaluation, that are not implemented in tensorflow
        self.evaluator = Evaluator(self.network.eval_funcs, network_tensors,
                                   self)
Ejemplo n.º 4
0
class EvalHook(Hook):
    def __init__(self,
                 dataset_class,
                 plot_step=None,
                 dataset_kwargs=None,
                 **kwargs):
        self.dataset_class = dataset_class
        self.dataset_kwargs = dataset_kwargs or {}
        self.dataset_kwargs['n_examples'] = cfg.n_val
        kwarg_string = "_".join("{}={}".format(k, v)
                                for k, v in self.dataset_kwargs.items())
        name = dataset_class.__name__ + ("_" +
                                         kwarg_string if kwarg_string else "")
        self.name = name.replace(" ", "_")
        self.plot_step = plot_step
        super(EvalHook, self).__init__(final=True, **kwargs)

    def start_stage(self, training_loop, updater, stage_idx):
        # similar to `build_graph`

        self.network = updater.network

        dataset = self.dataset_class(**self.dataset_kwargs)
        self.data_manager = DataManager(val_dataset=dataset,
                                        batch_size=cfg.batch_size)
        self.data_manager.build_graph()

        data = self.data_manager.iterator.get_next(
        )  # a dict mapping from names to tensors
        self.inp = data["image"]
        network_outputs = self.network(data, self.data_manager.is_training)

        network_tensors = network_outputs["tensors"]
        network_recorded_tensors = network_outputs["recorded_tensors"]
        network_losses = network_outputs["losses"]

        self.recorded_tensors = recorded_tensors = dict(
            global_step=tf.train.get_or_create_global_step())

        # --- loss ---

        recorded_tensors['loss'] = 0
        for name, tensor in network_losses.items():
            recorded_tensors['loss'] += tensor
            recorded_tensors['loss_' + name] = tensor
        self.loss = recorded_tensors['loss']

        intersection = recorded_tensors.keys() & network_recorded_tensors.keys(
        )
        assert not intersection, "Key sets have non-zero intersection: {}".format(
            intersection)
        recorded_tensors.update(network_recorded_tensors)

        intersection = recorded_tensors.keys() & self.network.eval_funcs.keys()
        assert not intersection, "Key sets have non-zero intersection: {}".format(
            intersection)

        # For running functions, during evaluation, that are not implemented in tensorflow
        self.evaluator = Evaluator(self.network.eval_funcs, network_tensors,
                                   self)

    def step(self, training_loop, updater, step_idx=None):
        feed_dict = self.data_manager.do_val()
        record = collections.defaultdict(float)
        n_points = 0

        sess = tf.get_default_session()

        while True:
            try:
                _record, eval_fetched = sess.run(
                    [self.recorded_tensors, self.evaluator.fetches],
                    feed_dict=feed_dict)
            except tf.errors.OutOfRangeError:
                break

            eval_record = self.evaluator.eval(eval_fetched)
            _record.update(eval_record)

            batch_size = _record['batch_size']

            for k, v in _record.items():
                record[k] += batch_size * v

            n_points += batch_size

        return {self.name: {k: v / n_points for k, v in record.items()}}

    def _plot(self, updater, rollouts):
        plt.ion()

        if updater.dataset.gym_dataset.image_obs:
            obs = rollouts.obs
        else:
            obs = rollouts.image

        fig, axes = square_subplots(rollouts.batch_size, figsize=(5, 5))
        plt.subplots_adjust(top=0.95,
                            bottom=0,
                            left=0,
                            right=1,
                            wspace=0.1,
                            hspace=0.1)

        images = []
        for i, ax in enumerate(axes.flatten()):
            ax.set_aspect("equal")
            ax.set_axis_off()
            image = ax.imshow(np.zeros(obs.shape[2:]))
            images.append(image)

        def animate(t):
            for i in range(rollouts.batch_size):
                images[i].set_array(obs[t, i, :, :, :])

        anim = animation.FuncAnimation(fig,
                                       animate,
                                       frames=len(rollouts),
                                       interval=500)

        path = updater.exp_dir.path_for('plots',
                                        '{}_animation.gif'.format(self.name))
        anim.save(path, writer='imagemagick')

        plt.close(fig)
Ejemplo n.º 5
0
class Updater(_Updater):
    optimizer_spec = Param()
    lr_schedule = Param()
    noise_schedule = Param()
    max_grad_norm = Param()
    grad_n_record_groups = Param(None)

    def __init__(self, env, scope=None, **kwargs):
        self.obs_shape = env.obs_shape
        *other, self.image_height, self.image_width, self.image_depth = self.obs_shape
        self.n_frames = other[0] if other else 0
        self.network = cfg.build_network(env, self, scope="network")

        super(Updater, self).__init__(env, scope=scope, **kwargs)

    def trainable_variables(self, for_opt):
        return self.network.trainable_variables(for_opt)

    def _update(self, batch_size):
        if cfg.get('no_gradient', False):
            return dict(train=dict())

        feed_dict = self.data_manager.do_train()

        sess = tf.get_default_session()
        _, record, train_record = sess.run(
            [self.train_op, self.recorded_tensors, self.train_records],
            feed_dict=feed_dict)
        record.update(train_record)

        return dict(train=record)

    def _evaluate(self, _batch_size, mode):
        result = self.evaluator.eval(self.recorded_tensors, self.data_manager,
                                     mode)

        if "MOT:mota" in result and "prior_MOT:mota" in result:
            result['mota_post_prior_sum'] = result["MOT:mota"] + result[
                "prior_MOT:mota"]

        return result

    def _build_graph(self):
        self.data_manager = DataManager(self.env.datasets['train'],
                                        self.env.datasets['val'],
                                        self.env.datasets['test'],
                                        cfg.batch_size)
        self.data_manager.build_graph()

        data = self.data_manager.iterator.get_next()
        self.inp = data["image"]
        network_outputs = self.network(data, self.data_manager.is_training)

        network_tensors = network_outputs["tensors"]
        network_recorded_tensors = network_outputs["recorded_tensors"]
        network_losses = network_outputs["losses"]

        self.tensors = network_tensors

        self.recorded_tensors = recorded_tensors = dict(
            global_step=tf.train.get_or_create_global_step())

        # --- loss ---

        self.loss = tf.constant(0., tf.float32)
        for name, tensor in network_losses.items():
            self.loss += tensor
            recorded_tensors['loss_' + name] = tensor
        recorded_tensors['loss'] = self.loss

        # --- train op ---

        if cfg.do_train and not cfg.get('no_gradient', False):
            tvars = self.trainable_variables(for_opt=True)

            self.train_op, self.train_records = build_gradient_train_op(
                self.loss,
                tvars,
                self.optimizer_spec,
                self.lr_schedule,
                self.max_grad_norm,
                self.noise_schedule,
                grad_n_record_groups=self.grad_n_record_groups)

        sess = tf.get_default_session()
        for k, v in getattr(sess, 'scheduled_values', None).items():
            if k in recorded_tensors:
                recorded_tensors['scheduled_' + k] = v
            else:
                recorded_tensors[k] = v

        # --- recorded values ---

        intersection = recorded_tensors.keys() & network_recorded_tensors.keys(
        )
        assert not intersection, "Key sets have non-zero intersection: {}".format(
            intersection)
        recorded_tensors.update(network_recorded_tensors)

        intersection = recorded_tensors.keys() & self.network.eval_funcs.keys()
        assert not intersection, "Key sets have non-zero intersection: {}".format(
            intersection)

        if self.network.eval_funcs:
            eval_funcs = self.network.eval_funcs
        else:
            eval_funcs = {}

        # For running functions, during evaluation, that are not implemented in tensorflow
        self.evaluator = Evaluator(eval_funcs, network_tensors, self)
Ejemplo n.º 6
0
class Updater(_Updater):
    optimizer_spec = Param()
    lr_schedule = Param()
    noise_schedule = Param()
    max_grad_norm = Param()

    def __init__(self, env, scope=None, **kwargs):
        self.obs_shape = env.obs_shape
        *other, self.image_height, self.image_width, self.image_depth = self.obs_shape
        self.n_frames = other[0] if other else 0
        self.network = cfg.build_network(env, self, scope="network")

        super(Updater, self).__init__(env, scope=scope, **kwargs)

    def trainable_variables(self, for_opt):
        return self.network.trainable_variables(for_opt)

    def _update(self, batch_size):
        feed_dict = self.data_manager.do_train()

        sess = tf.get_default_session()
        _, record, train_record = sess.run(
            [self.train_op, self.recorded_tensors, self.train_records],
            feed_dict=feed_dict)
        record.update(train_record)

        return dict(train=record)

    def _evaluate(self, _batch_size, mode):
        if mode == "val":
            feed_dict = self.data_manager.do_val()
        elif mode == "test":
            feed_dict = self.data_manager.do_test()
        else:
            raise Exception("Unknown evaluation mode: {}".format(mode))

        record = collections.defaultdict(float)
        n_points = 0

        sess = tf.get_default_session()

        while True:
            try:
                _record, eval_fetched = sess.run(
                    [self.recorded_tensors, self.evaluator.fetches],
                    feed_dict=feed_dict)
            except tf.errors.OutOfRangeError:
                break

            eval_record = self.evaluator.eval(eval_fetched)
            _record.update(eval_record)

            batch_size = _record['batch_size']

            for k, v in _record.items():
                record[k] += batch_size * v

            n_points += batch_size

        return {k: v / n_points for k, v in record.items()}

    def _build_graph(self):
        self.data_manager = DataManager(self.env.datasets['train'],
                                        self.env.datasets['val'],
                                        self.env.datasets['test'],
                                        cfg.batch_size)
        self.data_manager.build_graph()

        data = self.data_manager.iterator.get_next()
        self.inp = data["image"]
        network_outputs = self.network(data, self.data_manager.is_training)

        network_tensors = network_outputs["tensors"]
        network_recorded_tensors = network_outputs["recorded_tensors"]
        network_losses = network_outputs["losses"]

        self.tensors = network_tensors

        self.recorded_tensors = recorded_tensors = dict(
            global_step=tf.train.get_or_create_global_step())

        # --- loss ---

        self.loss = 0.0
        for name, tensor in network_losses.items():
            self.loss += tensor
            recorded_tensors['loss_' + name] = tensor
        recorded_tensors['loss'] = self.loss

        # --- train op ---

        tvars = self.trainable_variables(for_opt=True)

        self.train_op, self.train_records = build_gradient_train_op(
            self.loss, tvars, self.optimizer_spec, self.lr_schedule,
            self.max_grad_norm, self.noise_schedule)

        # --- recorded values ---

        intersection = recorded_tensors.keys() & network_recorded_tensors.keys(
        )
        assert not intersection, "Key sets have non-zero intersection: {}".format(
            intersection)
        recorded_tensors.update(network_recorded_tensors)

        intersection = recorded_tensors.keys() & self.network.eval_funcs.keys()
        assert not intersection, "Key sets have non-zero intersection: {}".format(
            intersection)

        # For running functions, during evaluation, that are not implemented in tensorflow
        self.evaluator = Evaluator(self.network.eval_funcs, network_tensors,
                                   self)