Exemple #1
0
    def store_scalar_summaries(self, mode, path, record, n_global_experiences):
        if mode not in self.writers:
            self.writers[mode] = SummaryWriter(path)

        for k, v in AttrDict(record).flatten().items():
            self.writers[mode].add_scalar("all/" + k, float(v),
                                          n_global_experiences)
Exemple #2
0
def compute_required_resources(n_tasks, tasks_per_gpu, gpu_kind):
    host_name = os.uname().nodename
    aliases = dict(
        beluga='beluga blg'.split(),
        cedar='cedar cdr'.split(),
    )
    for _host_name, aliases in aliases.items():
        if any([host_name.startswith(a) for a in aliases]):
            host_name = _host_name
            break
    else:
        raise Exception(f"Unknown host: {host_name}")

    specs = AttrDict(
        beluga=dict(cpus_per_gpu=10, mem_per_gpu=47000, gpus_per_node=4),
        cedar=dict(cpus_per_gpu=6, mem_per_gpu=128000 // 4, gpus_per_node=4),
        cedar_p100=dict(cpus_per_gpu=6,
                        mem_per_gpu=128000 // 4,
                        gpus_per_node=4),
        cedar_p100l=dict(cpus_per_gpu=6,
                         mem_per_gpu=257000 // 4,
                         gpus_per_node=4),
        cedar_v100l=dict(cpus_per_gpu=8,
                         mem_per_gpu=192000 // 4,
                         gpus_per_node=4),
    )
    spec_key = host_name + ('' if gpu_kind is None else '_' + gpu_kind)
    spec = specs[spec_key]

    cpus_per_gpu = spec.cpus_per_gpu
    mem_per_gpu = spec.mem_per_gpu
    gpus_per_node = spec.gpus_per_node

    n_gpus = int(np.ceil(n_tasks / tasks_per_gpu))

    n_nodes = int(np.ceil(n_gpus / gpus_per_node))
    result = dict(
        n_nodes=n_nodes,
        cpus_per_task=cpus_per_gpu // tasks_per_gpu,
        mem_per_cpu=mem_per_gpu // cpus_per_gpu,
    )

    if n_nodes > 1:
        result.update(tasks_per_node=gpus_per_node * tasks_per_gpu,
                      gpu_set=",".join(str(i) for i in range(gpus_per_node)))
    else:
        result.update(tasks_per_node=n_tasks,
                      gpu_set=",".join(str(i) for i in range(n_gpus)))

    return result
Exemple #3
0
    def null_object_set(self, batch_size):
        n_prop_objects = self.n_prop_objects

        new_objects = AttrDict(
            normalized_box=tf.zeros((batch_size, n_prop_objects, 4)),
            attr=tf.zeros((batch_size, n_prop_objects, self.A)),
            z=tf.zeros((batch_size, n_prop_objects, 1)),
            obj=tf.zeros((batch_size, n_prop_objects, 1)),
        )

        yt, xt, ys, xs = tf.split(new_objects.normalized_box, 4, axis=-1)

        new_objects.update(
            abs_posn=new_objects.normalized_box[..., :2] + 0.0,
            yt=yt,
            xt=xt,
            ys=ys,
            xs=xs,
            ys_logit=ys + 0.0,
            xs_logit=xs + 0.0,

            # d_yt=yt + 0.0,
            # d_xt=xt + 0.0,
            # d_ys=ys + 0.0,
            # d_xs=xs + 0.0,
            # d_attr=new_objects.attr + 0.0,
            # d_z=new_objects.z + 0.0,
            z_logit=new_objects.z + 0.0,
        )

        prop_state = self.cell.initial_state(batch_size * n_prop_objects,
                                             tf.float32)
        trailing_shape = tf_shape(prop_state)[1:]
        new_objects.prop_state = tf.reshape(
            prop_state, (batch_size, n_prop_objects, *trailing_shape))
        new_objects.prior_prop_state = new_objects.prop_state

        return new_objects
Exemple #4
0
    def __call__(self, updater):
        fetched = self._fetch(updater)
        fetched = Config(fetched)
        self._prepare_fetched(updater, fetched)
        o = AttrDict(**fetched)

        N, T, image_height, image_width, _ = o.inp.shape

        # --- static ---

        fig_width = 2 * N
        fig_height = T
        figsize = self.fig_scale * np.asarray((fig_width, fig_height))
        fig, axes = plt.subplots(fig_height, fig_width, figsize=figsize)
        fig.suptitle("n_updates={}".format(updater.n_updates), fontsize=20, fontweight='bold')
        axes = axes.reshape((fig_height, fig_width))

        unique_ids = [int(i) for i in np.unique(o.obj_id)]
        if unique_ids[0] < 0:
            unique_ids = unique_ids[1:]

        color_by_id = {i: c for i, c in zip(unique_ids, itertools.cycle(self._BBOX_COLORS))}
        color_by_id[-1] = 'k'

        cmap = self._cmap(o.inp)
        for t, ax in enumerate(axes):
            for n in range(N):
                pres_time = o.presence[n, t, :]
                obj_id_time = o.obj_id[n, t, :]
                self.imshow(ax[2 * n], o.inp[n, t], cmap=cmap)

                n_obj = str(int(np.round(pres_time.sum())))
                id_string = ('{}{}'.format(color_by_id[int(i)], i) for i in o.obj_id[n, t] if i > -1)
                id_string = ', '.join(id_string)
                title = '{}: {}'.format(n_obj, id_string)

                ax[2 * n + 1].set_title(title, fontsize=6 * self.fig_scale)
                self.imshow(ax[2 * n + 1], o.canvas[n, t], cmap=cmap)
                for i, (p, o_id) in enumerate(zip(pres_time, obj_id_time)):
                    c = color_by_id[int(o_id)]
                    if p > .5:
                        r = patches.Rectangle(
                            (o.left[n, t, i], o.top[n, t, i]), o.width[n, t, i], o.height[n, t, i],
                            linewidth=self.linewidth, edgecolor=c, facecolor='none')
                        ax[2 * n + 1].add_patch(r)

        for n in range(N):
            axes[0, 2 * n].set_ylabel('gt #{:d}'.format(n))
            axes[0, 2 * n + 1].set_ylabel('rec #{:d}'.format(n))

        for ax in axes.flatten():
            # ax.grid(False)
            # ax.set_xticks([])
            # ax.set_yticks([])
            ax.set_axis_off()

        self.savefig('static', fig, updater)

        # --- moving ---

        fig_width = 2 * N
        n_objects = o.obj_id.shape[2]
        fig_height = n_objects + 2
        figsize = self.fig_scale * np.asarray((fig_width, fig_height))
        fig, axes = plt.subplots(fig_height, fig_width, figsize=figsize)
        title_text = fig.suptitle('', fontsize=10)
        axes = axes.reshape((fig_height, fig_width))

        def func(t):
            title_text.set_text("t={}, n_updates={}".format(t, updater.n_updates))

            for i in range(N):
                self.imshow(axes[0, 2*i], o.inp[i, t], cmap=cmap, vmin=0, vmax=1)
                self.imshow(axes[1, 2*i], o.canvas[i, t], cmap=cmap, vmin=0, vmax=1)

                for j in range(n_objects):
                    if o.presence[i, t, j] > .5:
                        c = color_by_id[int(o.obj_id[i, t, j])]
                        r = patches.Rectangle(
                            (o.left[i, t, j], o.top[i, t, j]), o.width[i, t, j], o.height[i, t, j],
                            linewidth=self.linewidth, edgecolor=c, facecolor='none')
                        axes[1, 2*i].add_patch(r)

                    ax = axes[2+j, 2*i]

                    self.imshow(ax, o.presence[i, t, j] * o.glimpse[i, t, j], cmap=cmap)
                    title = '{:d} with p({:d}) = {:.02f}, id = {}'.format(
                        int(o.presence[i, t, j]), i + 1, o.presence_prob[i, t, j], o.obj_id[i, t, j])
                    ax.set_title(title, fontsize=4 * self.fig_scale)

                    if o.presence[i, t, j] > .5:
                        c = color_by_id[int(o.obj_id[i, t, j])]
                        for spine in 'bottom top left right'.split():
                            ax.spines[spine].set_color(c)
                            ax.spines[spine].set_linewidth(2.)

            for ax in axes.flatten():
                ax.xaxis.set_ticks([])
                ax.yaxis.set_ticks([])

            axes[0, 0].set_ylabel('ground-truth')
            axes[1, 0].set_ylabel('reconstruction')

            for j in range(n_objects):
                axes[j+2, 0].set_ylabel('glimpse #{}'.format(j + 1))

        plt.subplots_adjust(left=0.02, right=.98, top=.95, bottom=0.02, wspace=0.1, hspace=0.15)

        anim = animation.FuncAnimation(fig, func, frames=T, interval=500)

        path = self.path_for('moving', updater, ext="mp4")
        anim.save(path, writer='ffmpeg', codec='hevc', extra_args=['-preset', 'ultrafast'])

        plt.close(fig)

        shutil.copyfile(
            path,
            os.path.join(
                os.path.dirname(path),
                'latest_stage{:0>4}.mp4'.format(updater.stage_idx)))
Exemple #5
0
class SQAIRUpdater(_Updater):
    VI_TARGETS = 'iwae reinforce'.split()
    TARGETS = VI_TARGETS

    lr_schedule = Param()
    l2_schedule = Param()
    output_std = Param()
    k_particles = Param()
    debug = Param()

    def __init__(self, env, scope=None, **kwargs):
        self.lr_schedule = build_scheduled_value(self.lr_schedule, "lr")
        self.l2_schedule = build_scheduled_value(self.l2_schedule, "l2_weight")

        super().__init__(env, scope=scope, **kwargs)

    def resample(self, *args, axis=-1):
        """ Just resample, but potentially applied to several args. """
        res = list(args)

        if self.k_particles > 1:
            for i, arg in enumerate(res):
                res[i] = self._resample(arg, axis)

        if len(res) == 1:
            res = res[0]

        return res

    def _resample(self, arg, axis=-1):
        iw_sample_idx = self.iw_resampling_idx + tf.range(self.batch_size) * self.k_particles

        resampled = index.gather_axis(arg, iw_sample_idx, axis)

        shape = arg.shape.as_list()
        shape[axis] = self.batch_size

        resampled.set_shape(shape)

        return resampled

    def _log_resampled(self, name):
        tensor = self.tensors[name + "_per_sample"]
        self.tensors['resampled_' + name] = self._resample(tensor)
        self.recorded_tensors[name] = self._imp_weighted_mean(tensor)
        self.tensors[name] = self.recorded_tensors[name]

    def _imp_weighted_mean(self, tensor):
        tensor = tf.reshape(tensor, (-1, self.batch_size, self.k_particles))
        tensor = tf.reduce_mean(tensor, 0)
        return tf.reduce_mean(self.importance_weights * tensor * self.k_particles)

    def compute_validation_pixelwise_mean(self, data):
        sess = tf.get_default_session()

        mean = None
        n_points = 0
        feed_dict = self.data_manager.do_val()

        while True:
            try:
                inp = sess.run(data["image"], feed_dict=feed_dict)
            except tf.errors.OutOfRangeError:
                break

            n_new = inp.shape[0] * inp.shape[1]
            if mean is None:
                mean = np.mean(inp, axis=(0, 1))
            else:
                mean = mean * (n_points / (n_points + n_new)) + np.sum(inp, axis=(0, 1)) / (n_points + n_new)
            n_points += n_new
        return mean

    def _build_graph(self):
        self.data_manager = DataManager(datasets=self.env.datasets)
        self.data_manager.build_graph()

        if self.k_particles <= 1:
            raise Exception("`k_particles` must be > 1.")

        data = self.data_manager.iterator.get_next()
        data['mean_img'] = self.compute_validation_pixelwise_mean(data)
        self.batch_size = cfg.batch_size

        self.tensors = AttrDict()

        network_outputs = self.network(data, self.data_manager.is_training)

        network_tensors = network_outputs["tensors"]
        network_recorded_tensors = network_outputs["recorded_tensors"]
        network_losses = network_outputs["losses"]
        assert not network_losses

        self.tensors.update(network_tensors)

        self.recorded_tensors = recorded_tensors = dict(global_step=tf.train.get_or_create_global_step())
        self.recorded_tensors.update(network_recorded_tensors)

        # --- values for training ---

        log_weights = tf.reduce_sum(self.tensors.log_weights_per_timestep, 0)
        self.log_weights = tf.reshape(log_weights, (self.batch_size, self.k_particles))

        self.elbo_vae = tf.reduce_mean(self.log_weights)
        self.elbo_iwae_per_example = targets.iwae(self.log_weights)
        self.elbo_iwae = tf.reduce_mean(self.elbo_iwae_per_example)

        self.normalised_elbo_vae = self.elbo_vae / tf.to_float(self.network.dynamic_n_frames)
        self.normalised_elbo_iwae = self.elbo_iwae / tf.to_float(self.network.dynamic_n_frames)

        self.importance_weights = tf.stop_gradient(tf.nn.softmax(self.log_weights, -1))
        self.ess = ops.ess(self.importance_weights, average=True)
        self.iw_distrib = tf.distributions.Categorical(probs=self.importance_weights)
        self.iw_resampling_idx = self.iw_distrib.sample()

        # --- count accuracy ---

        if "annotations" in data:
            gt_num_steps = tf.transpose(self.tensors.n_valid_annotations, (1, 0))[:, :, None]
            num_steps_per_sample = tf.reshape(
                self.tensors.num_steps_per_sample, (-1, self.batch_size, self.k_particles))
            count_1norm = tf.abs(num_steps_per_sample - tf.cast(gt_num_steps, tf.float32))
            count_error = tf.cast(count_1norm > 0.5, tf.float32)

            self.recorded_tensors.update(
                count_1norm=self._imp_weighted_mean(count_1norm),
                count_error=self._imp_weighted_mean(count_error),
            )

        # --- losses ---

        log_probs = tf.reduce_sum(self.tensors.discrete_log_prob, 0)
        target = targets.vimco(self.log_weights, log_probs, self.elbo_iwae_per_example)

        target /= tf.to_float(self.network.dynamic_n_frames)
        loss_l2 = targets.l2_reg(self.l2_schedule)
        target += loss_l2

        # --- train op ---

        tvars = tf.trainable_variables()
        pure_gradients = tf.gradients(target, tvars)

        clipped_gradients = pure_gradients
        if self.max_grad_norm is not None and self.max_grad_norm > 0.0:
            clipped_gradients, _ = tf.clip_by_global_norm(pure_gradients, self.max_grad_norm)

        grads_and_vars = list(zip(clipped_gradients, tvars))

        lr = self.lr_schedule
        valid_lr = tf.Assert(
            tf.logical_and(tf.less(lr, 1.0), tf.less(0.0, lr)),
            [lr], name="valid_learning_rate")

        opt = tf.train.RMSPropOptimizer(self.lr_schedule, momentum=.9)

        with tf.control_dependencies([valid_lr]):
            self.train_op = opt.apply_gradients(grads_and_vars, global_step=None)

        recorded_tensors.update(
            grad_norm_pure=tf.global_norm(pure_gradients),
            grad_norm_processed=tf.global_norm(clipped_gradients),
            grad_lr_norm=lr * tf.global_norm(clipped_gradients),
        )

        # gvs = opt.compute_gradients(target)
        # assert len(gvs) == len(tf.trainable_variables())
        # for g, v in gvs:
        #     assert g is not None, 'Gradient for variable {} is None'.format(v)

        # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        # with tf.control_dependencies(update_ops):
        #     self.train_op = opt.apply_gradients(gvs)

        # --- record ---

        self.tensors['mse_per_sample'] = tf.reduce_mean(
            (self.tensors['processed_image'] - self.tensors['canvas']) ** 2, (0, 2, 3, 4))
        self.raw_mse = tf.reduce_mean(self.tensors['mse_per_sample'])

        self._log_resampled('mse')
        self._log_resampled('data_ll')
        self._log_resampled('log_p_z')
        self._log_resampled('log_q_z_given_x')
        self._log_resampled('kl')

        try:
            self._log_resampled('num_steps')
            self._log_resampled('num_disc_steps')
            self._log_resampled('num_prop_steps')
        except AttributeError:
            pass

        recorded_tensors.update(
            raw_mse=self.raw_mse,
            elbo_vae=self.elbo_vae,
            elbo_iwae=self.elbo_iwae,
            normalised_elbo_vae=self.normalised_elbo_vae,
            normalised_elbo_iwae=self.normalised_elbo_iwae,
            ess=self.ess,
            loss=target,
            target=target,
            loss_l2=loss_l2,
        )
        self.train_records = {}

        # --- recorded values ---

        intersection = recorded_tensors.keys() & self.network.eval_funcs.keys()
        assert not intersection, "Key sets have non-zero intersection: {}".format(intersection)

        # --- for rendering and eval ---
        resampled_names = (
            'obj_id canvas glimpse presence_prob presence presence_logit '
            'where_coords num_steps_per_sample'.split())

        for name in resampled_names:
            try:
                resampled_tensor = self.resample(self.tensors[name], axis=1)
                permutation = [1, 0] + list(range(2, len(resampled_tensor.shape)))
                self.tensors['resampled_' + name] = tf.transpose(resampled_tensor, permutation)
            except AttributeError:
                pass

        # For running functions, during evaluation, that are not implemented in tensorflow
        self.evaluator = Evaluator(self.network.eval_funcs, self.tensors, self)
Exemple #6
0
    def _build_graph(self):
        self.data_manager = DataManager(datasets=self.env.datasets)
        self.data_manager.build_graph()

        if self.k_particles <= 1:
            raise Exception("`k_particles` must be > 1.")

        data = self.data_manager.iterator.get_next()
        data['mean_img'] = self.compute_validation_pixelwise_mean(data)
        self.batch_size = cfg.batch_size

        self.tensors = AttrDict()

        network_outputs = self.network(data, self.data_manager.is_training)

        network_tensors = network_outputs["tensors"]
        network_recorded_tensors = network_outputs["recorded_tensors"]
        network_losses = network_outputs["losses"]
        assert not network_losses

        self.tensors.update(network_tensors)

        self.recorded_tensors = recorded_tensors = dict(global_step=tf.train.get_or_create_global_step())
        self.recorded_tensors.update(network_recorded_tensors)

        # --- values for training ---

        log_weights = tf.reduce_sum(self.tensors.log_weights_per_timestep, 0)
        self.log_weights = tf.reshape(log_weights, (self.batch_size, self.k_particles))

        self.elbo_vae = tf.reduce_mean(self.log_weights)
        self.elbo_iwae_per_example = targets.iwae(self.log_weights)
        self.elbo_iwae = tf.reduce_mean(self.elbo_iwae_per_example)

        self.normalised_elbo_vae = self.elbo_vae / tf.to_float(self.network.dynamic_n_frames)
        self.normalised_elbo_iwae = self.elbo_iwae / tf.to_float(self.network.dynamic_n_frames)

        self.importance_weights = tf.stop_gradient(tf.nn.softmax(self.log_weights, -1))
        self.ess = ops.ess(self.importance_weights, average=True)
        self.iw_distrib = tf.distributions.Categorical(probs=self.importance_weights)
        self.iw_resampling_idx = self.iw_distrib.sample()

        # --- count accuracy ---

        if "annotations" in data:
            gt_num_steps = tf.transpose(self.tensors.n_valid_annotations, (1, 0))[:, :, None]
            num_steps_per_sample = tf.reshape(
                self.tensors.num_steps_per_sample, (-1, self.batch_size, self.k_particles))
            count_1norm = tf.abs(num_steps_per_sample - tf.cast(gt_num_steps, tf.float32))
            count_error = tf.cast(count_1norm > 0.5, tf.float32)

            self.recorded_tensors.update(
                count_1norm=self._imp_weighted_mean(count_1norm),
                count_error=self._imp_weighted_mean(count_error),
            )

        # --- losses ---

        log_probs = tf.reduce_sum(self.tensors.discrete_log_prob, 0)
        target = targets.vimco(self.log_weights, log_probs, self.elbo_iwae_per_example)

        target /= tf.to_float(self.network.dynamic_n_frames)
        loss_l2 = targets.l2_reg(self.l2_schedule)
        target += loss_l2

        # --- train op ---

        tvars = tf.trainable_variables()
        pure_gradients = tf.gradients(target, tvars)

        clipped_gradients = pure_gradients
        if self.max_grad_norm is not None and self.max_grad_norm > 0.0:
            clipped_gradients, _ = tf.clip_by_global_norm(pure_gradients, self.max_grad_norm)

        grads_and_vars = list(zip(clipped_gradients, tvars))

        lr = self.lr_schedule
        valid_lr = tf.Assert(
            tf.logical_and(tf.less(lr, 1.0), tf.less(0.0, lr)),
            [lr], name="valid_learning_rate")

        opt = tf.train.RMSPropOptimizer(self.lr_schedule, momentum=.9)

        with tf.control_dependencies([valid_lr]):
            self.train_op = opt.apply_gradients(grads_and_vars, global_step=None)

        recorded_tensors.update(
            grad_norm_pure=tf.global_norm(pure_gradients),
            grad_norm_processed=tf.global_norm(clipped_gradients),
            grad_lr_norm=lr * tf.global_norm(clipped_gradients),
        )

        # gvs = opt.compute_gradients(target)
        # assert len(gvs) == len(tf.trainable_variables())
        # for g, v in gvs:
        #     assert g is not None, 'Gradient for variable {} is None'.format(v)

        # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        # with tf.control_dependencies(update_ops):
        #     self.train_op = opt.apply_gradients(gvs)

        # --- record ---

        self.tensors['mse_per_sample'] = tf.reduce_mean(
            (self.tensors['processed_image'] - self.tensors['canvas']) ** 2, (0, 2, 3, 4))
        self.raw_mse = tf.reduce_mean(self.tensors['mse_per_sample'])

        self._log_resampled('mse')
        self._log_resampled('data_ll')
        self._log_resampled('log_p_z')
        self._log_resampled('log_q_z_given_x')
        self._log_resampled('kl')

        try:
            self._log_resampled('num_steps')
            self._log_resampled('num_disc_steps')
            self._log_resampled('num_prop_steps')
        except AttributeError:
            pass

        recorded_tensors.update(
            raw_mse=self.raw_mse,
            elbo_vae=self.elbo_vae,
            elbo_iwae=self.elbo_iwae,
            normalised_elbo_vae=self.normalised_elbo_vae,
            normalised_elbo_iwae=self.normalised_elbo_iwae,
            ess=self.ess,
            loss=target,
            target=target,
            loss_l2=loss_l2,
        )
        self.train_records = {}

        # --- recorded values ---

        intersection = recorded_tensors.keys() & self.network.eval_funcs.keys()
        assert not intersection, "Key sets have non-zero intersection: {}".format(intersection)

        # --- for rendering and eval ---
        resampled_names = (
            'obj_id canvas glimpse presence_prob presence presence_logit '
            'where_coords num_steps_per_sample'.split())

        for name in resampled_names:
            try:
                resampled_tensor = self.resample(self.tensors[name], axis=1)
                permutation = [1, 0] + list(range(2, len(resampled_tensor.shape)))
                self.tensors['resampled_' + name] = tf.transpose(resampled_tensor, permutation)
            except AttributeError:
                pass

        # For running functions, during evaluation, that are not implemented in tensorflow
        self.evaluator = Evaluator(self.network.eval_funcs, self.tensors, self)
Exemple #7
0
    def get_eval_tensors(self, step, mode="val", data_exclude=None, tensors_exclude=None):
        """ Run `self.model` on either val or test dataset, return data and tensors. """

        assert mode in "val test".split()

        if tensors_exclude is None:
            tensors_exclude = []
        if isinstance(tensors_exclude, str):
            tensors_exclude = tensors_exclude.split()

        if data_exclude is None:
            data_exclude = []
        if isinstance(data_exclude, str):
            data_exclude = data_exclude.split()

        self.model.eval()
        if mode == 'val':
            data_iterator = self.data_manager.do_val()
        elif mode == 'test':
            data_iterator = self.data_manager.do_test()
        else:
            raise Exception("Unknown data mode: {}".format(mode))

        _tensors = []
        _data = []

        n_points = 0
        n_batches = 0
        record = None

        with torch.no_grad():
            for data in data_iterator:
                data = AttrDict(data)

                tensors, data, recorded_tensors, losses = self.model(data, step)

                losses = AttrDict(losses)

                loss = 0.0
                for name, tensor in losses.flatten().items():
                    loss += tensor
                    recorded_tensors['loss_' + name] = tensor
                recorded_tensors['loss'] = loss

                recorded_tensors = map_structure(
                    lambda t: to_np(t.mean()) if isinstance(t, (torch.Tensor, np.ndarray)) else t,
                    recorded_tensors, is_leaf=lambda rec: not isinstance(rec, dict))

                batch_size = recorded_tensors['batch_size']

                n_points += batch_size
                n_batches += 1

                if record is None:
                    record = recorded_tensors
                else:
                    record = map_structure(
                        lambda rec, rec_t: rec + batch_size * np.mean(rec_t), record, recorded_tensors,
                        is_leaf=lambda rec: not isinstance(rec, dict))

                data = AttrDict(data)
                for de in data_exclude:
                    try:
                        del data[de]
                    except (KeyError, AttributeError):
                        pass
                data = map_structure(
                    lambda t: to_np(t) if isinstance(t, torch.Tensor) else t,
                    data, is_leaf=lambda rec: not isinstance(rec, dict))

                tensors = AttrDict(tensors)
                for te in tensors_exclude:
                    try:
                        del tensors[te]
                    except (KeyError, AttributeError):
                        pass
                tensors = map_structure(
                    lambda t: to_np(t) if isinstance(t, torch.Tensor) else t,
                    tensors, is_leaf=lambda rec: not isinstance(rec, dict))

                _tensors.append(tensors)
                _data.append(data)

        def postprocess(*t):
            return pad_and_concatenate(t, axis=0)

        _tensors = map_structure(postprocess, *_tensors, is_leaf=lambda rec: not isinstance(rec, dict))
        _data = map_structure(postprocess, *_data, is_leaf=lambda rec: not isinstance(rec, dict))

        record = map_structure(
            lambda rec: rec / n_points, record, is_leaf=lambda rec: not isinstance(rec, dict))
        record = AttrDict(record)

        return _data, _tensors, record
Exemple #8
0
    def update(self, batch_size, step):
        print_time = step % 100 == 0

        self.model.train()

        data = AttrDict(next(self.train_iterator))

        self.model.update_global_step(step)

        detect_grad_anomalies = cfg.get('detect_grad_anomalies', False)
        with torch.autograd.set_detect_anomaly(detect_grad_anomalies):

            profile_step = cfg.get('pytorch_profile_step', 0)
            if profile_step > 0 and step % profile_step == 0:
                with torch.autograd.profiler.profile(use_cuda=True) as prof:
                    tensors, data, recorded_tensors, losses = self.model(data, step)
                print(prof)
            else:
                with timed_block('forward', print_time):
                    tensors, data, recorded_tensors, losses = self.model(data, step)

            # --- loss ---

            losses = AttrDict(losses)

            loss = 0.0
            for name, tensor in losses.flatten().items():
                loss += tensor
                recorded_tensors['loss_' + name] = tensor
            recorded_tensors['loss'] = loss

            with timed_block('zero_grad', print_time):
                # Apparently this is faster, according to https://www.youtube.com/watch?v=9mS1fIYj1So, 10:37
                for param in self.model.parameters():
                    param.grad = None
                # self.optimizer.zero_grad()

            with timed_block('loss backward', print_time):
                loss.backward()

        with timed_block('process grad', print_time):
            if self.grad_norm_recorder is not None:
                self.grad_norm_recorder.update()

                if step % self.print_grad_norm_step == 0:
                    self.grad_norm_recorder.display()

            parameters = list(self.model.parameters())
            pure_grad_norm = grad_norm(parameters)

            if self.max_grad_norm is not None and self.max_grad_norm > 0.0:
                torch.nn.utils.clip_grad_norm_(parameters, self.max_grad_norm)

            clipped_grad_norm = grad_norm(parameters)

        with timed_block('optimizer step', print_time):
            self.optimizer.step()

        if self.scheduler is not None:
            self.scheduler.step()

        update_result = self._update(batch_size)

        if isinstance(update_result, dict):
            recorded_tensors.update(update_result)

        self._n_experiences += batch_size

        recorded_tensors.update(
            grad_norm_pure=pure_grad_norm,
            grad_norm_clipped=clipped_grad_norm
        )

        scheduled_values = self.model.get_scheduled_values()
        recorded_tensors.update(scheduled_values)

        recorded_tensors = map_structure(
            lambda t: t.mean() if isinstance(t, torch.Tensor) else t,
            recorded_tensors, is_leaf=lambda rec: not isinstance(rec, dict))

        return recorded_tensors
Exemple #9
0
    def _body(self, inp, features, objects, is_posterior):
        """
        Summary of how updates are done for the different variables:

        glimpse': glimpse_params = where + 0.1 * predicted_logit

        where_y/x: new_where_y/x = where_y/x + where_t_scale * tanh(predicted_logit)
        where_h/w: new_where_h/w_logit = where_h/w_logit + predicted_logit
        what: Hard to summarize here, taken from SQAIR. Kind of like an LSTM.
        depth: new_depth_logit = depth_logit + predicted_logit
        obj: new_obj = obj * sigmoid(predicted_logit)

        """
        batch_size, n_objects, _ = tf_shape(features)

        new_objects = AttrDict()

        is_posterior_tf = tf.ones_like(features[..., 0:2])
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        base_features = tf.concat([features, is_posterior_tf], axis=-1)

        cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1)

        # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up
        glimpse_dim = self.object_shape[0] * self.object_shape[1]
        glimpse_prime_params = apply_object_wise(self.glimpse_prime_network,
                                                 base_features,
                                                 output_size=4 +
                                                 2 * glimpse_dim,
                                                 is_training=self.is_training)

        glimpse_prime_params, glimpse_prime_mask_logit, glimpse_mask_logit = \
            tf.split(glimpse_prime_params, [4, glimpse_dim, glimpse_dim], axis=-1)

        if is_posterior:
            # --- obtain final parameters for glimpse prime by modifying current pose ---

            _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1)

            g_yt = cyt + 0.1 * _yt
            g_xt = cxt + 0.1 * _xt
            g_ys = ys + 0.1 * _ys
            g_xs = xs + 0.1 * _xs

            # --- extract glimpse prime ---

            _, image_height, image_width, _ = tf_shape(inp)
            g_yt, g_xt, g_ys, g_xs = coords_to_image_space(
                g_yt,
                g_xt,
                g_ys,
                g_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)
            glimpse_prime = extract_affine_glimpse(inp, self.object_shape,
                                                   g_yt, g_xt, g_ys, g_xs,
                                                   self.edge_resampler)
        else:
            g_yt = tf.zeros_like(cyt)
            g_xt = tf.zeros_like(cxt)
            g_ys = tf.zeros_like(ys)
            g_xs = tf.zeros_like(xs)
            glimpse_prime = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        glimpse_prime_mask = tf.nn.sigmoid(glimpse_prime_mask_logit + 1.)
        leading_mask_shape = tf_shape(glimpse_prime)[:-1]
        glimpse_prime_mask = tf.reshape(glimpse_prime_mask,
                                        (*leading_mask_shape, 1))

        new_objects.update(
            glimpse_prime_box=tf.concat([g_yt, g_xt, g_ys, g_xs], axis=-1),
            glimpse_prime=glimpse_prime,
            glimpse_prime_mask=glimpse_prime_mask,
        )

        glimpse_prime *= glimpse_prime_mask

        # --- encode glimpse ---

        encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder,
                                                  glimpse_prime,
                                                  n_trailing_dims=3,
                                                  output_size=self.A,
                                                  is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A),
                                             dtype=tf.float32)

        # --- position and scale ---

        # roughly:
        # base_features == temporal_state, encoded_glimpse_prime == hidden_output
        # hidden_output conditions on encoded_glimpse, and that's the only place encoded_glimpse_prime is used.

        # Here SQAIR conditions on the actual location values from the previous timestep, but we leave that out for now.
        d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1)
        d_box_params = apply_object_wise(self.d_box_network,
                                         d_box_inp,
                                         output_size=8,
                                         is_training=self.is_training)

        d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1)

        d_box_std = self.std_nonlinearity(d_box_log_std)

        d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + (
            1 - self.training_wheels) * d_box_mean
        d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + (
            1 - self.training_wheels) * d_box_std

        d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1)
        d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1)

        # --- position ---

        # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of
        # the scale, whereas for position we want to put a prior on the difference in position over timesteps.

        d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample()
        d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample()

        d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit)
        d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit)

        new_cyt = cyt + d_yt
        new_cxt = cxt + d_xt

        new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1)

        # --- scale ---

        new_ys_mean = objects.ys_logit + d_ys
        new_xs_mean = objects.xs_logit + d_xs

        new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample()
        new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample()

        new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw
        new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw

        if self.use_abs_posn:
            box_params = tf.concat([
                new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit,
                new_xs_logit
            ],
                                   axis=-1)
        else:
            box_params = tf.concat(
                [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1)

        new_objects.update(
            abs_posn=new_abs_posn,
            yt=new_cyt,
            xt=new_cxt,
            ys=new_ys,
            xs=new_xs,
            normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs],
                                     axis=-1),
            d_yt_logit=d_yt_logit,
            d_xt_logit=d_xt_logit,
            ys_logit=new_ys_logit,
            xs_logit=new_xs_logit,
            d_yt_logit_mean=d_yt_mean,
            d_xt_logit_mean=d_xt_mean,
            ys_logit_mean=new_ys_mean,
            xs_logit_mean=new_xs_mean,
            d_yt_logit_std=d_yt_std,
            d_xt_logit_std=d_xt_std,
            ys_logit_std=ys_std,
            xs_logit_std=xs_std,
        )

        # --- attributes ---

        # --- extract a glimpse using new box ---

        if is_posterior:
            _, image_height, image_width, _ = tf_shape(inp)
            _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space(
                new_cyt,
                new_cxt,
                new_ys,
                new_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)

            glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt,
                                             _new_cxt, _new_ys, _new_xs,
                                             self.edge_resampler)

        else:
            glimpse = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        glimpse_mask = tf.nn.sigmoid(glimpse_mask_logit + 1.)
        leading_mask_shape = tf_shape(glimpse)[:-1]
        glimpse_mask = tf.reshape(glimpse_mask, (*leading_mask_shape, 1))

        glimpse *= glimpse_mask

        encoded_glimpse = apply_object_wise(self.glimpse_encoder,
                                            glimpse,
                                            n_trailing_dims=3,
                                            output_size=self.A,
                                            is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse = tf.zeros((batch_size, n_objects, self.A),
                                       dtype=tf.float32)

        # --- predict change in attributes ---

        # so under sqair we mix between three different values for the attributes:
        # 1. value from previous timestep
        # 2. value predicted directly from glimpse
        # 3. value predicted based on update of temporal cell...this update conditions on hidden_output,
        # the prediction in #2., and the where values.

        # How to do this given that we are predicting the change in attr? We could just directly predict
        # the attr instead, but call it d_attr. After all, it is in this function that we control
        # whether d_attr is added to attr.

        # So, make a prediction based on just the input:

        attr_from_inp = apply_object_wise(self.predict_attr_inp,
                                          encoded_glimpse,
                                          output_size=2 * self.A,
                                          is_training=self.is_training)
        attr_from_inp_mean, attr_from_inp_log_std = tf.split(attr_from_inp,
                                                             [self.A, self.A],
                                                             axis=-1)

        attr_from_inp_std = self.std_nonlinearity(attr_from_inp_log_std)

        # And then a prediction which takes the past into account (predicting gate values at the same time):

        attr_from_temp_inp = tf.concat(
            [base_features, box_params, encoded_glimpse], axis=-1)
        attr_from_temp = apply_object_wise(self.predict_attr_temp,
                                           attr_from_temp_inp,
                                           output_size=5 * self.A,
                                           is_training=self.is_training)

        (attr_from_temp_mean, attr_from_temp_log_std, f_gate_logit,
         i_gate_logit, t_gate_logit) = tf.split(attr_from_temp, 5, axis=-1)

        attr_from_temp_std = self.std_nonlinearity(attr_from_temp_log_std)

        # bias the gates
        f_gate = tf.nn.sigmoid(f_gate_logit + 1) * .9999
        i_gate = tf.nn.sigmoid(i_gate_logit + 1) * .9999
        t_gate = tf.nn.sigmoid(t_gate_logit + 1) * .9999

        attr_mean = f_gate * objects.attr + (
            1 - i_gate) * attr_from_inp_mean + (1 -
                                                t_gate) * attr_from_temp_mean
        attr_std = (1 - i_gate) * attr_from_inp_std + (
            1 - t_gate) * attr_from_temp_std

        new_attr = Normal(loc=attr_mean, scale=attr_std).sample()

        # --- apply change in attributes ---

        new_objects.update(
            attr=new_attr,
            d_attr=new_attr - objects.attr,
            d_attr_mean=attr_mean - objects.attr,
            d_attr_std=attr_std,
            f_gate=f_gate,
            i_gate=i_gate,
            t_gate=t_gate,
            glimpse=glimpse,
            glimpse_mask=glimpse_mask,
        )

        # --- z ---

        d_z_inp = tf.concat(
            [base_features, box_params, new_attr, encoded_glimpse], axis=-1)
        d_z_params = apply_object_wise(self.d_z_network,
                                       d_z_inp,
                                       output_size=2,
                                       is_training=self.is_training)

        d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1)
        d_z_std = self.std_nonlinearity(d_z_log_std)

        d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + (
            1 - self.training_wheels) * d_z_mean
        d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + (
            1 - self.training_wheels) * d_z_std

        d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample()

        new_z_logit = objects.z_logit + d_z_logit
        new_z = self.z_nonlinearity(new_z_logit)

        new_objects.update(
            z=new_z,
            z_logit=new_z_logit,
            d_z_logit=d_z_logit,
            d_z_logit_mean=d_z_mean,
            d_z_logit_std=d_z_std,
        )

        # --- obj ---

        d_obj_inp = tf.concat(
            [base_features, box_params, new_attr, new_z, encoded_glimpse],
            axis=-1)
        d_obj_logit = apply_object_wise(self.d_obj_network,
                                        d_obj_inp,
                                        output_size=1,
                                        is_training=self.is_training)

        d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + (
            1 - self.training_wheels) * d_obj_logit
        d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10.,
                                          10.)

        d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample(
            d_obj_log_odds, self.obj_concrete_temp) +
                             (1 - self._noisy) * d_obj_log_odds)

        d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid)

        new_obj = objects.obj * d_obj

        new_objects.update(
            d_obj_log_odds=d_obj_log_odds,
            d_obj_prob=tf.nn.sigmoid(d_obj_log_odds),
            d_obj_pre_sigmoid=d_obj_pre_sigmoid,
            d_obj=d_obj,
            obj=new_obj,
        )

        # --- update each object's hidden state --

        cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1)

        if is_posterior:
            _, new_objects.prop_state = apply_object_wise(
                self.cell, cell_input, objects.prop_state)
            new_objects.prior_prop_state = new_objects.prop_state
        else:
            _, new_objects.prior_prop_state = apply_object_wise(
                self.cell, cell_input, objects.prior_prop_state)
            new_objects.prop_state = new_objects.prior_prop_state

        return new_objects
Exemple #10
0
    def _body(self, inp, features, objects, is_posterior):
        batch_size, n_objects, _ = tf_shape(features)

        new_objects = AttrDict()

        is_posterior_tf = tf.ones_like(features[..., 0:2])
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        base_features = tf.concat([features, is_posterior_tf], axis=-1)

        cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1)

        if self.learn_glimpse_prime:
            # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up
            glimpse_prime_params = apply_object_wise(
                self.glimpse_prime_network,
                base_features,
                output_size=4,
                is_training=self.is_training)
        else:
            glimpse_prime_params = tf.zeros_like(base_features[..., :4])

        if is_posterior:

            if self.learn_glimpse_prime:
                # --- obtain final parameters for glimpse prime by modifying current pose ---
                _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1)

                # This is how it is done in SQAIR
                g_yt = cyt + 0.1 * _yt
                g_xt = cxt + 0.1 * _xt
                g_ys = ys + 0.1 * _ys
                g_xs = xs + 0.1 * _xs

                # g_yt = cyt + self.glimpse_prime_scale * tf.nn.tanh(_yt)
                # g_xt = cxt + self.glimpse_prime_scale * tf.nn.tanh(_xt)
                # g_ys = ys + self.glimpse_prime_scale * tf.nn.tanh(_ys)
                # g_xs = xs + self.glimpse_prime_scale * tf.nn.tanh(_xs)
            else:
                g_yt = cyt
                g_xt = cxt
                g_ys = self.glimpse_prime_scale * ys
                g_xs = self.glimpse_prime_scale * xs

            # --- extract glimpse prime ---

            _, image_height, image_width, _ = tf_shape(inp)
            g_yt, g_xt, g_ys, g_xs = coords_to_image_space(
                g_yt,
                g_xt,
                g_ys,
                g_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)
            glimpse_prime = extract_affine_glimpse(inp, self.object_shape,
                                                   g_yt, g_xt, g_ys, g_xs,
                                                   self.edge_resampler)
        else:
            g_yt = tf.zeros_like(cyt)
            g_xt = tf.zeros_like(cxt)
            g_ys = tf.zeros_like(ys)
            g_xs = tf.zeros_like(xs)
            glimpse_prime = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        new_objects.update(glimpse_prime_box=tf.concat(
            [g_yt, g_xt, g_ys, g_xs], axis=-1), )

        # --- encode glimpse ---

        encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder,
                                                  glimpse_prime,
                                                  n_trailing_dims=3,
                                                  output_size=self.A,
                                                  is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A),
                                             dtype=tf.float32)

        # --- position and scale ---

        d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1)
        d_box_params = apply_object_wise(self.d_box_network,
                                         d_box_inp,
                                         output_size=8,
                                         is_training=self.is_training)

        d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1)

        d_box_std = self.std_nonlinearity(d_box_log_std)

        d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + (
            1 - self.training_wheels) * d_box_mean
        d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + (
            1 - self.training_wheels) * d_box_std

        d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1)
        d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1)

        # --- position ---

        # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of
        # the scale, whereas for position we want to put a prior on the difference in position over timesteps.

        d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample()
        d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample()

        d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit)
        d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit)

        new_cyt = cyt + d_yt
        new_cxt = cxt + d_xt

        new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1)

        # --- scale ---

        new_ys_mean = objects.ys_logit + d_ys
        new_xs_mean = objects.xs_logit + d_xs

        new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample()
        new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample()

        new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw
        new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw

        # Used for conditioning
        if self.use_abs_posn:
            box_params = tf.concat([
                new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit,
                new_xs_logit
            ],
                                   axis=-1)
        else:
            box_params = tf.concat(
                [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1)

        new_objects.update(
            abs_posn=new_abs_posn,
            yt=new_cyt,
            xt=new_cxt,
            ys=new_ys,
            xs=new_xs,
            normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs],
                                     axis=-1),
            d_yt_logit=d_yt_logit,
            d_xt_logit=d_xt_logit,
            ys_logit=new_ys_logit,
            xs_logit=new_xs_logit,
            d_yt_logit_mean=d_yt_mean,
            d_xt_logit_mean=d_xt_mean,
            ys_logit_mean=new_ys_mean,
            xs_logit_mean=new_xs_mean,
            d_yt_logit_std=d_yt_std,
            d_xt_logit_std=d_xt_std,
            ys_logit_std=ys_std,
            xs_logit_std=xs_std,
            glimpse_prime=glimpse_prime,
        )

        # --- attributes ---

        # --- extract a glimpse using new box ---

        if is_posterior:
            _, image_height, image_width, _ = tf_shape(inp)
            _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space(
                new_cyt,
                new_cxt,
                new_ys,
                new_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)

            glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt,
                                             _new_cxt, _new_ys, _new_xs,
                                             self.edge_resampler)

        else:
            glimpse = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        encoded_glimpse = apply_object_wise(self.glimpse_encoder,
                                            glimpse,
                                            n_trailing_dims=3,
                                            output_size=self.A,
                                            is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse = tf.zeros((batch_size, n_objects, self.A),
                                       dtype=tf.float32)

        # --- predict change in attributes ---

        d_attr_inp = tf.concat([base_features, box_params, encoded_glimpse],
                               axis=-1)
        d_attr_params = apply_object_wise(self.d_attr_network,
                                          d_attr_inp,
                                          output_size=2 * self.A + 1,
                                          is_training=self.is_training)

        d_attr_mean, d_attr_log_std, gate_logit = tf.split(d_attr_params,
                                                           [self.A, self.A, 1],
                                                           axis=-1)
        d_attr_std = self.std_nonlinearity(d_attr_log_std)

        gate = tf.nn.sigmoid(gate_logit)

        if self.gate_d_attr:
            d_attr_mean *= gate

        d_attr = Normal(loc=d_attr_mean, scale=d_attr_std).sample()

        # --- apply change in attributes ---

        new_attr = objects.attr + d_attr

        new_objects.update(
            attr=new_attr,
            d_attr=d_attr,
            d_attr_mean=d_attr_mean,
            d_attr_std=d_attr_std,
            glimpse=glimpse,
            d_attr_gate=gate,
        )

        # --- z ---

        d_z_inp = tf.concat(
            [base_features, box_params, new_attr, encoded_glimpse], axis=-1)
        d_z_params = apply_object_wise(self.d_z_network,
                                       d_z_inp,
                                       output_size=2,
                                       is_training=self.is_training)

        d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1)
        d_z_std = self.std_nonlinearity(d_z_log_std)

        d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + (
            1 - self.training_wheels) * d_z_mean
        d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + (
            1 - self.training_wheels) * d_z_std

        d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample()

        new_z_logit = objects.z_logit + d_z_logit
        new_z = self.z_nonlinearity(new_z_logit)

        new_objects.update(
            z=new_z,
            z_logit=new_z_logit,
            d_z_logit=d_z_logit,
            d_z_logit_mean=d_z_mean,
            d_z_logit_std=d_z_std,
        )

        # --- obj ---

        d_obj_inp = tf.concat(
            [base_features, box_params, new_attr, new_z, encoded_glimpse],
            axis=-1)
        d_obj_logit = apply_object_wise(self.d_obj_network,
                                        d_obj_inp,
                                        output_size=1,
                                        is_training=self.is_training)

        d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + (
            1 - self.training_wheels) * d_obj_logit
        d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10.,
                                          10.)

        d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample(
            d_obj_log_odds, self.obj_concrete_temp) +
                             (1 - self._noisy) * d_obj_log_odds)

        d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid)

        new_obj = objects.obj * d_obj

        new_objects.update(
            d_obj_log_odds=d_obj_log_odds,
            d_obj_prob=tf.nn.sigmoid(d_obj_log_odds),
            d_obj_pre_sigmoid=d_obj_pre_sigmoid,
            d_obj=d_obj,
            obj=new_obj,
        )

        # --- update each object's hidden state --

        cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1)

        if is_posterior:
            _, new_objects.prop_state = apply_object_wise(
                self.cell, cell_input, objects.prop_state)
            new_objects.prior_prop_state = new_objects.prop_state
        else:
            _, new_objects.prior_prop_state = apply_object_wise(
                self.cell, cell_input, objects.prior_prop_state)
            new_objects.prop_state = new_objects.prior_prop_state

        return new_objects
Exemple #11
0
    def _call(self, inp, features, objects, is_training, is_posterior):
        print("\n" + "-" * 10 +
              " PropagationLayer(is_posterior={}) ".format(is_posterior) +
              "-" * 10)

        self._build_networks()

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.image_height = int(inp.shape[-3])
            self.image_width = int(inp.shape[-2])
            self.image_depth = int(inp.shape[-1])
            self.batch_size = tf.shape(inp)[0]
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        if self.do_lateral:
            # hasn't been updated to make use of abs_posn
            raise Exception("NotImplemented.")

            batch_size, n_objects, _ = tf_shape(features)

            new_objects = []

            for i in range(n_objects):
                # apply lateral to running objects with the feature vector for
                # the current object

                _features = features[:, i:i + 1, :]

                if i > 0:
                    normalized_box = tf.concat(
                        [o.normalized_box for o in new_objects], axis=1)
                    attr = tf.concat([o.attr for o in new_objects], axis=1)
                    z = tf.concat([o.z for o in new_objects], axis=1)
                    obj = tf.concat([o.obj for o in new_objects], axis=1)
                    completed_features = tf.concat(
                        [normalized_box[:, :, 2:], attr, z, obj], axis=2)
                    completed_locs = normalized_box[:, :, :2]

                    current_features = tf.concat([
                        objects.normalized_box[:, i:i + 1,
                                               2:], objects.attr[:, i:i + 1],
                        objects.z[:, i:i + 1], objects.obj[:, i:i + 1]
                    ],
                                                 axis=2)
                    current_locs = objects.normalized_box[:, i:i + 1, :2]

                    # if i > max_completed_objects:
                    #     # top_k_indices
                    #     # squared_distances = tf.reduce_sum((completed_locs - current_locs)**2, axis=2)
                    #     # _, top_k_indices = tf.nn.top_k(squared_distances, k=max_completed_objects, sorted=False)

                    _features = self.lateral_network(completed_locs,
                                                     completed_features,
                                                     current_locs,
                                                     current_features,
                                                     is_training)

                _objects = AttrDict(
                    normalized_box=objects.normalized_box[:, i:i + 1],
                    attr=objects.attr[:, i:i + 1],
                    z=objects.z[:, i:i + 1],
                    obj=objects.obj[:, i:i + 1],
                )

                _new_objects = self._body(inp, _features, _objects,
                                          is_posterior)
                new_objects.append(_new_objects)

            _new_objects = AttrDict()
            for k in new_objects[0]:
                _new_objects[k] = tf.concat([no[k] for no in new_objects],
                                            axis=1)
            return _new_objects

        else:
            return self._body(inp, features, objects, is_posterior)
Exemple #12
0
    def extract_stage_data(self, fields=None, bare=False):
        """ Extract stage-by-stage data about the training runs.

        Parameters
        ----------
        bare: boolean
            If True, only returns the data. Otherwise, additionally returns the stage-by-stage config and meta-data.

        Returns
        -------
        A nested data structure containing the requested data.

        {param-setting-key: {(repeat, seed): (df, sc, md)

        where:
            df is a pandas DataFrame
            sc is a list giving the config for each stage
            md is a dictionary storing metadata

        """
        stage_data = defaultdict(dict)
        if isinstance(fields, str):
            fields = fields.split()

        config_keys = self.dist_keys()

        KeyTuple = namedtuple(self.__class__.__name__ + "Key", config_keys)

        for exp_path in self.experiment_paths:
            try:
                exp_data = FrozenTrainingLoopData(exp_path)

                md = {}
                md['host'] = exp_data.host
                for k in config_keys:
                    md[k] = exp_data.get_config_value(k)

                sc = []
                records = []
                for stage in exp_data.history:
                    record = stage.copy()

                    stage_config = record['stage_config'].copy()
                    sc.append(stage_config)
                    del record['stage_config']

                    record = AttrDict(record).flatten()

                    if 'best_path' in record:
                        del record['best_path']
                    if 'final_path' in record:
                        del record['final_path']

                    # Fix and filter keys
                    _record = {}
                    for k, v in record.items():
                        if k.startswith("best_"):
                            k = k[5:]

                        if (fields and k in fields) or not fields:
                            _record[k] = v

                    records.append(_record)

                key = KeyTuple(*(exp_data.get_config_value(k)
                                 for k in config_keys))

                repeat = exp_data.get_config_value("repeat")
                seed = exp_data.get_config_value("seed")

                if bare:
                    stage_data[key][(
                        repeat, seed)] = pd.DataFrame.from_records(records)
                else:
                    stage_data[key][(repeat, seed)] = (
                        pd.DataFrame.from_records(records), sc, md)

            except Exception:
                print(
                    "Exception raised while extracting stage data for path: {}"
                    .format(exp_path))
                traceback.print_exc()

        return stage_data
Exemple #13
0
    def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None):
        print("\n" + "-" * 10 + " ConvGridObjectLayer({}, is_posterior={}) ".format(self.name, is_posterior) + "-" * 10)

        # --- set up sub networks and attributes ---

        self.maybe_build_subnet("box_network", builder=cfg.build_conv_lateral, key="box")
        self.maybe_build_subnet("attr_network", builder=cfg.build_conv_lateral, key="attr")
        self.maybe_build_subnet("z_network", builder=cfg.build_conv_lateral, key="z")
        self.maybe_build_subnet("obj_network", builder=cfg.build_conv_lateral, key="obj")

        self.maybe_build_subnet("object_encoder")

        _, H, W, n_channels = tf_shape(inp_features)

        if self.B != 1:
            raise Exception("NotImplemented")

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp)
            self.H = H
            self.W = W
            self.HWB = H*W
            self.batch_size = tf.shape(inp)[0]
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        is_posterior_tf = tf.ones_like(inp_features[..., :2])
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        objects = AttrDict()

        base_features = tf.concat([inp_features, is_posterior_tf], axis=-1)

        # --- box ---

        layer_inp = base_features
        n_features = self.n_passthrough_features
        output_size = 8

        network_output = self.box_network(layer_inp, output_size + n_features, self.is_training)
        rep_input, features = tf.split(network_output, (output_size, n_features), axis=-1)

        _objects = self._build_box(rep_input, self.is_training)
        objects.update(_objects)

        # --- attr ---

        if is_posterior:
            # --- Get object attributes using object encoder ---

            yt, xt, ys, xs = tf.split(objects['normalized_box'], 4, axis=-1)

            yt, xt, ys, xs = coords_to_image_space(
                yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False)

            transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
            warper = snt.AffineGridWarper(
                (self.image_height, self.image_width), self.object_shape, transform_constraints)

            _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1)
            _boxes = tf.reshape(_boxes, (self.batch_size*H*W, 4))
            grid_coords = warper(_boxes)
            grid_coords = tf.reshape(grid_coords, (self.batch_size, H, W, *self.object_shape, 2,))

            if self.edge_resampler:
                glimpse = resampler_edge.resampler_edge(inp, grid_coords)
            else:
                glimpse = tf.contrib.resampler.resampler(inp, grid_coords)
        else:
            glimpse = tf.zeros((self.batch_size, H, W, *self.object_shape, self.image_depth))

        # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction
        encoded_glimpse = apply_object_wise(
            self.object_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse = tf.zeros_like(encoded_glimpse)

        layer_inp = tf.concat([base_features, features, encoded_glimpse, objects['local_box']], axis=-1)
        network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self.is_training)
        attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=-1)

        attr_std = self.std_nonlinearity(attr_log_std)

        attr = Normal(loc=attr_mean, scale=attr_std).sample()

        objects.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse)

        # --- z ---

        layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr']], axis=-1)
        n_features = self.n_passthrough_features

        network_output = self.z_network(layer_inp, 2 + n_features, self.is_training)
        z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=-1)
        z_std = self.std_nonlinearity(z_log_std)

        z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean
        z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std
        z_logit = Normal(loc=z_mean, scale=z_std).sample()
        z = self.z_nonlinearity(z_logit)

        objects.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z)

        # --- obj ---

        layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr'], objects['z']], axis=-1)
        rep_input = self.obj_network(layer_inp, 1, self.is_training)

        _objects = self._build_obj(rep_input, self.is_training)
        objects.update(_objects)

        # --- final ---

        if prop_state is not None:
            objects.prop_state = tf.tile(prop_state[0:1, None, None, :], (self.batch_size, H, W, 1))
            objects.prior_prop_state = tf.tile(prop_state[0:1, None, None, :], (self.batch_size, H, W, 1))

        if self.flatten:
            _objects = AttrDict()
            for k, v in objects.items():
                _, _, _, *trailing_dims = tf_shape(v)
                _objects[k] = tf.reshape(v, (self.batch_size, self.HWB, *trailing_dims))
            objects = _objects

        # --- misc ---

        flat_objects = tf.reshape(objects.obj, (self.batch_size, -1))

        objects.pred_n_objects = tf.reduce_sum(flat_objects, axis=1)
        objects.pred_n_objects_hard = tf.reduce_sum(tf.round(flat_objects), axis=1)

        return objects
Exemple #14
0
    def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None):
        print("\n" + "-" * 10 + " GridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10)

        # --- set up sub networks and attributes ---

        self.maybe_build_subnet("box_network", builder=cfg.build_lateral, key="box")
        self.maybe_build_subnet("attr_network", builder=cfg.build_lateral, key="attr")
        self.maybe_build_subnet("z_network", builder=cfg.build_lateral, key="z")
        self.maybe_build_subnet("obj_network", builder=cfg.build_lateral, key="obj")

        self.maybe_build_subnet("object_encoder")

        _, H, W, _ = tf_shape(inp_features)
        H = int(H)
        W = int(W)

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp)
            self.H = H
            self.W = W
            self.HWB = H*W*self.B
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        # --- set up the edge element ---

        sizes = [4, self.A, 1, 1]
        sigmoids = [True, False, False, True]
        total_sample_size = sum(sizes)

        if self.edge_weights is None:
            self.edge_weights = tf.get_variable("edge_weights", shape=total_sample_size, dtype=tf.float32)
            if "edge" in self.fixed_weights:
                tf.add_to_collection(FIXED_COLLECTION, self.edge_weights)

        _edge_weights = tf.split(self.edge_weights, sizes, axis=0)
        _edge_weights = [
            (tf.nn.sigmoid(ew) if sigmoid else ew)
            for ew, sigmoid in zip(_edge_weights, sigmoids)]
        edge_element = tf.concat(_edge_weights, axis=0)
        edge_element = tf.tile(edge_element[None, :], (self.batch_size, 1))

        # --- containers for storing built program ---

        program = np.empty((H, W, self.B), dtype=np.object)

        # --- build the program ---

        is_posterior_tf = tf.ones((self.batch_size, 2))
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        results = []
        for h, w, b in itertools.product(range(H), range(W), range(self.B)):
            built = dict()

            partial_program, features = None, None
            context = self._get_sequential_context(program, h, w, b, edge_element)
            base_features = tf.concat([inp_features[:, h, w, :], context, is_posterior_tf], axis=1)

            # --- box ---

            layer_inp = base_features
            n_features = self.n_passthrough_features
            output_size = 8

            network_output = self.box_network(layer_inp, output_size + n_features, self. is_training)
            rep_input, features = tf.split(network_output, (output_size, n_features), axis=1)

            _built = self._build_box(rep_input, self.is_training, hw=(h, w))
            built.update(_built)
            partial_program = built['local_box']

            # --- attr ---

            if is_posterior:
                # --- Get object attributes using object encoder ---

                yt, xt, ys, xs = tf.split(built['normalized_box'], 4, axis=-1)

                yt, xt, ys, xs = coords_to_image_space(
                    yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False)

                transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
                warper = snt.AffineGridWarper(
                    (self.image_height, self.image_width), self.object_shape, transform_constraints)

                _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1)

                grid_coords = warper(_boxes)
                grid_coords = tf.reshape(grid_coords, (self.batch_size, 1, *self.object_shape, 2,))
                if self.edge_resampler:
                    glimpse = resampler_edge.resampler_edge(inp, grid_coords)
                else:
                    glimpse = tf.contrib.resampler.resampler(inp, grid_coords)
                glimpse = tf.reshape(glimpse, (self.batch_size, *self.object_shape, self.image_depth))
            else:
                glimpse = tf.zeros((self.batch_size, *self.object_shape, self.image_depth))

            # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction
            encoded_glimpse = self.object_encoder(glimpse, (1, 1, self.A), self.is_training)
            encoded_glimpse = tf.reshape(encoded_glimpse, (self.batch_size, self.A))

            if not is_posterior:
                encoded_glimpse = tf.zeros_like(encoded_glimpse)

            layer_inp = tf.concat(
                [base_features, features, encoded_glimpse, partial_program], axis=1)
            network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self. is_training)
            attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=1)

            attr_std = self.std_nonlinearity(attr_log_std)

            attr = Normal(loc=attr_mean, scale=attr_std).sample()

            built.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse)
            partial_program = tf.concat([partial_program, built['attr']], axis=1)

            # --- z ---

            layer_inp = tf.concat([base_features, features, partial_program], axis=1)
            n_features = self.n_passthrough_features

            network_output = self.z_network(layer_inp, 2 + n_features, self.is_training)
            z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=1)
            z_std = self.std_nonlinearity(z_log_std)

            z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean
            z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std
            z_logit = Normal(loc=z_mean, scale=z_std).sample()
            z = self.z_nonlinearity(z_logit)

            built.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z)
            partial_program = tf.concat([partial_program, built['z']], axis=1)

            # --- obj ---

            layer_inp = tf.concat([base_features, features, partial_program], axis=1)
            rep_input = self.obj_network(layer_inp, 1, self.is_training)

            _built = self._build_obj(rep_input, self.is_training)
            built.update(_built)

            partial_program = tf.concat([partial_program, built['obj']], axis=1)

            # --- final ---

            results.append(built)

            program[h, w, b] = partial_program
            assert program[h, w, b].shape[1] == total_sample_size

        objects = AttrDict()
        for k in results[0]:
            objects[k] = tf.stack([r[k] for r in results], axis=1)

        if prop_state is not None:
            objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))
            objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))

        # --- misc ---

        objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2))
        objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2))

        return objects
Exemple #15
0
    def _call(self, data, is_training):
        self.data = data

        inp = data["image"]
        self._tensors = AttrDict(
            inp=inp,
            is_training=is_training,
            float_is_training=tf.to_float(is_training),
            batch_size=tf.shape(inp)[0],
        )

        if "annotations" in data:
            self._tensors.update(
                annotations=data["annotations"]["data"],
                n_annotations=data["annotations"]["shapes"][:, 1],
                n_valid_annotations=tf.to_int32(
                    tf.reduce_sum(
                        data["annotations"]["data"][:, :, :, 0]
                        * tf.to_float(data["annotations"]["mask"][:, :, :, 0]),
                        axis=2
                    )
                )
            )

        if "label" in data:
            self._tensors.update(
                targets=data["label"],
            )

        if "background" in data:
            self._tensors.update(
                ground_truth_background=data["background"],
            )

        if "offset" in data:
            self._tensors.update(
                offset=data["offset"],
            )

        max_n_frames = tf_shape(inp)[1]

        if self.stage_steps is None:
            self.current_stage = tf.constant(0, tf.int32)
            dynamic_n_frames = max_n_frames
        else:
            self.current_stage = tf.cast(tf.train.get_or_create_global_step(), tf.int32) // self.stage_steps
            dynamic_n_frames = tf.minimum(
                self.initial_n_frames + self.n_frames_scale * self.current_stage, max_n_frames)

        dynamic_n_frames = tf.cast(dynamic_n_frames, tf.float32)
        dynamic_n_frames = (
            self.float_is_training * tf.cast(dynamic_n_frames, tf.float32)
            + (1-self.float_is_training) * tf.cast(max_n_frames, tf.float32)
        )
        self.dynamic_n_frames = tf.cast(dynamic_n_frames, tf.int32)

        self._tensors.current_stage = self.current_stage
        self._tensors.dynamic_n_frames = self.dynamic_n_frames

        self._tensors.inp = self._tensors.inp[:, :self.dynamic_n_frames]

        if 'annotations' in self._tensors:
            self._tensors.annotations = self._tensors.annotations[:, :self.dynamic_n_frames]
            # self._tensors.n_annotations = self._tensors.n_annotations[:, :self.dynamic_n_frames]
            self._tensors.n_valid_annotations = self._tensors.n_valid_annotations[:, :self.dynamic_n_frames]

        self.record_tensors(
            batch_size=tf.to_float(self.batch_size),
            float_is_training=self.float_is_training,
            current_stage=self.current_stage,
            dynamic_n_frames=self.dynamic_n_frames,
        )

        self.losses = dict()

        with tf.variable_scope("representation", reuse=self.initialized):
            if self.needs_background:
                self.build_background()

            self.build_representation()

        return dict(
            tensors=self._tensors,
            recorded_tensors=self.recorded_tensors,
            losses=self.losses,
        )
Exemple #16
0
class VideoNetwork(TensorRecorder):
    attr_prior_mean = Param()
    attr_prior_std = Param()
    noisy = Param()
    stage_steps = Param()
    initial_n_frames = Param()
    n_frames_scale = Param()

    background_encoder = None
    background_decoder = None

    needs_background = True

    eval_funcs = dict()

    def __init__(self, env, updater, scope=None, **kwargs):
        self.updater = updater

        self.obs_shape = env.datasets['train'].obs_shape
        self.n_frames, self.image_height, self.image_width, self.image_depth = self.obs_shape

        super(VideoNetwork, self).__init__(scope=scope, **kwargs)

    def std_nonlinearity(self, std_logit):
        std = 2 * tf.nn.sigmoid(tf.clip_by_value(std_logit, -10, 10))
        if not self.noisy:
            std = tf.zeros_like(std)
        return std

    @property
    def inp(self):
        return self._tensors["inp"]

    @property
    def batch_size(self):
        return self._tensors["batch_size"]

    @property
    def is_training(self):
        return self._tensors["is_training"]

    @property
    def float_is_training(self):
        return self._tensors["float_is_training"]

    def _call(self, data, is_training):
        self.data = data

        inp = data["image"]
        self._tensors = AttrDict(
            inp=inp,
            is_training=is_training,
            float_is_training=tf.to_float(is_training),
            batch_size=tf.shape(inp)[0],
        )

        if "annotations" in data:
            self._tensors.update(
                annotations=data["annotations"]["data"],
                n_annotations=data["annotations"]["shapes"][:, 1],
                n_valid_annotations=tf.to_int32(
                    tf.reduce_sum(
                        data["annotations"]["data"][:, :, :, 0]
                        * tf.to_float(data["annotations"]["mask"][:, :, :, 0]),
                        axis=2
                    )
                )
            )

        if "label" in data:
            self._tensors.update(
                targets=data["label"],
            )

        if "background" in data:
            self._tensors.update(
                ground_truth_background=data["background"],
            )

        if "offset" in data:
            self._tensors.update(
                offset=data["offset"],
            )

        max_n_frames = tf_shape(inp)[1]

        if self.stage_steps is None:
            self.current_stage = tf.constant(0, tf.int32)
            dynamic_n_frames = max_n_frames
        else:
            self.current_stage = tf.cast(tf.train.get_or_create_global_step(), tf.int32) // self.stage_steps
            dynamic_n_frames = tf.minimum(
                self.initial_n_frames + self.n_frames_scale * self.current_stage, max_n_frames)

        dynamic_n_frames = tf.cast(dynamic_n_frames, tf.float32)
        dynamic_n_frames = (
            self.float_is_training * tf.cast(dynamic_n_frames, tf.float32)
            + (1-self.float_is_training) * tf.cast(max_n_frames, tf.float32)
        )
        self.dynamic_n_frames = tf.cast(dynamic_n_frames, tf.int32)

        self._tensors.current_stage = self.current_stage
        self._tensors.dynamic_n_frames = self.dynamic_n_frames

        self._tensors.inp = self._tensors.inp[:, :self.dynamic_n_frames]

        if 'annotations' in self._tensors:
            self._tensors.annotations = self._tensors.annotations[:, :self.dynamic_n_frames]
            # self._tensors.n_annotations = self._tensors.n_annotations[:, :self.dynamic_n_frames]
            self._tensors.n_valid_annotations = self._tensors.n_valid_annotations[:, :self.dynamic_n_frames]

        self.record_tensors(
            batch_size=tf.to_float(self.batch_size),
            float_is_training=self.float_is_training,
            current_stage=self.current_stage,
            dynamic_n_frames=self.dynamic_n_frames,
        )

        self.losses = dict()

        with tf.variable_scope("representation", reuse=self.initialized):
            if self.needs_background:
                self.build_background()

            self.build_representation()

        return dict(
            tensors=self._tensors,
            recorded_tensors=self.recorded_tensors,
            losses=self.losses,
        )

    def build_background(self):
        if cfg.background_cfg.mode == "colour":
            rgb = np.array(to_rgb(cfg.background_cfg.colour))[None, None, None, :]
            background = rgb * tf.ones_like(self.inp)

        elif cfg.background_cfg.mode == "learn_solid":
            # Learn a solid colour for the background
            self.solid_background_logits = tf.get_variable("solid_background", initializer=[0.0, 0.0, 0.0])
            if "background" in self.fixed_weights:
                tf.add_to_collection(FIXED_COLLECTION, self.solid_background_logits)
            solid_background = tf.nn.sigmoid(10 * self.solid_background_logits)
            background = solid_background[None, None, None, :] * tf.ones_like(self.inp)

        elif cfg.background_cfg.mode == "scalor":
            pass

        elif cfg.background_cfg.mode == "learn":
            self.maybe_build_subnet("background_encoder")
            self.maybe_build_subnet("background_decoder")

            # Here I'm just encoding the first frame...
            bg_attr = self.background_encoder(self.inp[:, 0], 2 * cfg.background_cfg.A, self.is_training)
            bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1)
            bg_attr_std = tf.exp(bg_attr_log_std)
            if not self.noisy:
                bg_attr_std = tf.zeros_like(bg_attr_std)

            bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std)

            self._tensors.update(
                bg_attr_mean=bg_attr_mean,
                bg_attr_std=bg_attr_std,
                bg_attr_kl=bg_attr_kl,
                bg_attr=bg_attr)

            # --- decode ---

            _, T, H, W, _ = tf_shape(self.inp)

            background = self.background_decoder(bg_attr, 3, self.is_training)
            assert len(background.shape) == 2 and background.shape[1] == 3
            background = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10))
            background = tf.tile(background[:, None, None, None, :], (1, T, H, W, 1))

        elif cfg.background_cfg.mode == "learn_and_transform":
            self.maybe_build_subnet("background_encoder")
            self.maybe_build_subnet("background_decoder")

            # --- encode ---

            n_transform_latents = 4
            n_latents = (2 * cfg.background_cfg.A, 2 * n_transform_latents)

            bg_attr, bg_transform_params = self.background_encoder(self.inp, n_latents, self.is_training)

            # --- bg attributes ---

            bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1)
            bg_attr_std = self.std_nonlinearity(bg_attr_log_std)

            bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std)

            # --- bg location ---

            bg_transform_params = tf.reshape(
                bg_transform_params,
                (self.batch_size, self.dynamic_n_frames, 2*n_transform_latents))

            mean, log_std = tf.split(bg_transform_params, 2, axis=2)
            std = self.std_nonlinearity(log_std)

            logits, kl = normal_vae(mean, std, 0.0, 1.0)

            # integrate across timesteps
            logits = tf.cumsum(logits, axis=1)
            logits = tf.reshape(logits, (self.batch_size*self.dynamic_n_frames, n_transform_latents))

            y, x, h, w = tf.split(logits, n_transform_latents, axis=1)
            h = (0.9 - 0.5) * tf.nn.sigmoid(h) + 0.5
            w = (0.9 - 0.5) * tf.nn.sigmoid(w) + 0.5
            y = (1 - h) * tf.nn.tanh(y)
            x = (1 - w) * tf.nn.tanh(x)

            # --- decode ---

            background = self.background_decoder(bg_attr, self.image_depth, self.is_training)
            bg_shape = cfg.background_cfg.bg_shape
            background = background[:, :bg_shape[0], :bg_shape[1], :]
            assert background.shape[1:3] == bg_shape
            background_raw = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10))

            transform_constraints = snt.AffineWarpConstraints.no_shear_2d()

            warper = snt.AffineGridWarper(
                bg_shape, (self.image_height, self.image_width), transform_constraints)

            transforms = tf.concat([w, x, h, y], axis=-1)
            grid_coords = warper(transforms)

            grid_coords = tf.reshape(
                grid_coords,
                (self.batch_size, self.dynamic_n_frames, *tf_shape(grid_coords)[1:]))

            background = tf.contrib.resampler.resampler(background_raw, grid_coords)

            self._tensors.update(
                bg_attr_mean=bg_attr_mean,
                bg_attr_std=bg_attr_std,
                bg_attr_kl=bg_attr_kl,
                bg_attr=bg_attr,
                bg_y=tf.reshape(y, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_x=tf.reshape(x, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_h=tf.reshape(h, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_w=tf.reshape(w, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_transform_kl=kl,
                bg_raw=background_raw,
            )

        elif cfg.background_cfg.mode == "data":
            background = self._tensors["ground_truth_background"]

        else:
            raise Exception("Unrecognized background mode: {}.".format(cfg.background_cfg.mode))

        self._tensors["background"] = background[:, :self.dynamic_n_frames]