def store_scalar_summaries(self, mode, path, record, n_global_experiences): if mode not in self.writers: self.writers[mode] = SummaryWriter(path) for k, v in AttrDict(record).flatten().items(): self.writers[mode].add_scalar("all/" + k, float(v), n_global_experiences)
def compute_required_resources(n_tasks, tasks_per_gpu, gpu_kind): host_name = os.uname().nodename aliases = dict( beluga='beluga blg'.split(), cedar='cedar cdr'.split(), ) for _host_name, aliases in aliases.items(): if any([host_name.startswith(a) for a in aliases]): host_name = _host_name break else: raise Exception(f"Unknown host: {host_name}") specs = AttrDict( beluga=dict(cpus_per_gpu=10, mem_per_gpu=47000, gpus_per_node=4), cedar=dict(cpus_per_gpu=6, mem_per_gpu=128000 // 4, gpus_per_node=4), cedar_p100=dict(cpus_per_gpu=6, mem_per_gpu=128000 // 4, gpus_per_node=4), cedar_p100l=dict(cpus_per_gpu=6, mem_per_gpu=257000 // 4, gpus_per_node=4), cedar_v100l=dict(cpus_per_gpu=8, mem_per_gpu=192000 // 4, gpus_per_node=4), ) spec_key = host_name + ('' if gpu_kind is None else '_' + gpu_kind) spec = specs[spec_key] cpus_per_gpu = spec.cpus_per_gpu mem_per_gpu = spec.mem_per_gpu gpus_per_node = spec.gpus_per_node n_gpus = int(np.ceil(n_tasks / tasks_per_gpu)) n_nodes = int(np.ceil(n_gpus / gpus_per_node)) result = dict( n_nodes=n_nodes, cpus_per_task=cpus_per_gpu // tasks_per_gpu, mem_per_cpu=mem_per_gpu // cpus_per_gpu, ) if n_nodes > 1: result.update(tasks_per_node=gpus_per_node * tasks_per_gpu, gpu_set=",".join(str(i) for i in range(gpus_per_node))) else: result.update(tasks_per_node=n_tasks, gpu_set=",".join(str(i) for i in range(n_gpus))) return result
def null_object_set(self, batch_size): n_prop_objects = self.n_prop_objects new_objects = AttrDict( normalized_box=tf.zeros((batch_size, n_prop_objects, 4)), attr=tf.zeros((batch_size, n_prop_objects, self.A)), z=tf.zeros((batch_size, n_prop_objects, 1)), obj=tf.zeros((batch_size, n_prop_objects, 1)), ) yt, xt, ys, xs = tf.split(new_objects.normalized_box, 4, axis=-1) new_objects.update( abs_posn=new_objects.normalized_box[..., :2] + 0.0, yt=yt, xt=xt, ys=ys, xs=xs, ys_logit=ys + 0.0, xs_logit=xs + 0.0, # d_yt=yt + 0.0, # d_xt=xt + 0.0, # d_ys=ys + 0.0, # d_xs=xs + 0.0, # d_attr=new_objects.attr + 0.0, # d_z=new_objects.z + 0.0, z_logit=new_objects.z + 0.0, ) prop_state = self.cell.initial_state(batch_size * n_prop_objects, tf.float32) trailing_shape = tf_shape(prop_state)[1:] new_objects.prop_state = tf.reshape( prop_state, (batch_size, n_prop_objects, *trailing_shape)) new_objects.prior_prop_state = new_objects.prop_state return new_objects
def __call__(self, updater): fetched = self._fetch(updater) fetched = Config(fetched) self._prepare_fetched(updater, fetched) o = AttrDict(**fetched) N, T, image_height, image_width, _ = o.inp.shape # --- static --- fig_width = 2 * N fig_height = T figsize = self.fig_scale * np.asarray((fig_width, fig_height)) fig, axes = plt.subplots(fig_height, fig_width, figsize=figsize) fig.suptitle("n_updates={}".format(updater.n_updates), fontsize=20, fontweight='bold') axes = axes.reshape((fig_height, fig_width)) unique_ids = [int(i) for i in np.unique(o.obj_id)] if unique_ids[0] < 0: unique_ids = unique_ids[1:] color_by_id = {i: c for i, c in zip(unique_ids, itertools.cycle(self._BBOX_COLORS))} color_by_id[-1] = 'k' cmap = self._cmap(o.inp) for t, ax in enumerate(axes): for n in range(N): pres_time = o.presence[n, t, :] obj_id_time = o.obj_id[n, t, :] self.imshow(ax[2 * n], o.inp[n, t], cmap=cmap) n_obj = str(int(np.round(pres_time.sum()))) id_string = ('{}{}'.format(color_by_id[int(i)], i) for i in o.obj_id[n, t] if i > -1) id_string = ', '.join(id_string) title = '{}: {}'.format(n_obj, id_string) ax[2 * n + 1].set_title(title, fontsize=6 * self.fig_scale) self.imshow(ax[2 * n + 1], o.canvas[n, t], cmap=cmap) for i, (p, o_id) in enumerate(zip(pres_time, obj_id_time)): c = color_by_id[int(o_id)] if p > .5: r = patches.Rectangle( (o.left[n, t, i], o.top[n, t, i]), o.width[n, t, i], o.height[n, t, i], linewidth=self.linewidth, edgecolor=c, facecolor='none') ax[2 * n + 1].add_patch(r) for n in range(N): axes[0, 2 * n].set_ylabel('gt #{:d}'.format(n)) axes[0, 2 * n + 1].set_ylabel('rec #{:d}'.format(n)) for ax in axes.flatten(): # ax.grid(False) # ax.set_xticks([]) # ax.set_yticks([]) ax.set_axis_off() self.savefig('static', fig, updater) # --- moving --- fig_width = 2 * N n_objects = o.obj_id.shape[2] fig_height = n_objects + 2 figsize = self.fig_scale * np.asarray((fig_width, fig_height)) fig, axes = plt.subplots(fig_height, fig_width, figsize=figsize) title_text = fig.suptitle('', fontsize=10) axes = axes.reshape((fig_height, fig_width)) def func(t): title_text.set_text("t={}, n_updates={}".format(t, updater.n_updates)) for i in range(N): self.imshow(axes[0, 2*i], o.inp[i, t], cmap=cmap, vmin=0, vmax=1) self.imshow(axes[1, 2*i], o.canvas[i, t], cmap=cmap, vmin=0, vmax=1) for j in range(n_objects): if o.presence[i, t, j] > .5: c = color_by_id[int(o.obj_id[i, t, j])] r = patches.Rectangle( (o.left[i, t, j], o.top[i, t, j]), o.width[i, t, j], o.height[i, t, j], linewidth=self.linewidth, edgecolor=c, facecolor='none') axes[1, 2*i].add_patch(r) ax = axes[2+j, 2*i] self.imshow(ax, o.presence[i, t, j] * o.glimpse[i, t, j], cmap=cmap) title = '{:d} with p({:d}) = {:.02f}, id = {}'.format( int(o.presence[i, t, j]), i + 1, o.presence_prob[i, t, j], o.obj_id[i, t, j]) ax.set_title(title, fontsize=4 * self.fig_scale) if o.presence[i, t, j] > .5: c = color_by_id[int(o.obj_id[i, t, j])] for spine in 'bottom top left right'.split(): ax.spines[spine].set_color(c) ax.spines[spine].set_linewidth(2.) for ax in axes.flatten(): ax.xaxis.set_ticks([]) ax.yaxis.set_ticks([]) axes[0, 0].set_ylabel('ground-truth') axes[1, 0].set_ylabel('reconstruction') for j in range(n_objects): axes[j+2, 0].set_ylabel('glimpse #{}'.format(j + 1)) plt.subplots_adjust(left=0.02, right=.98, top=.95, bottom=0.02, wspace=0.1, hspace=0.15) anim = animation.FuncAnimation(fig, func, frames=T, interval=500) path = self.path_for('moving', updater, ext="mp4") anim.save(path, writer='ffmpeg', codec='hevc', extra_args=['-preset', 'ultrafast']) plt.close(fig) shutil.copyfile( path, os.path.join( os.path.dirname(path), 'latest_stage{:0>4}.mp4'.format(updater.stage_idx)))
class SQAIRUpdater(_Updater): VI_TARGETS = 'iwae reinforce'.split() TARGETS = VI_TARGETS lr_schedule = Param() l2_schedule = Param() output_std = Param() k_particles = Param() debug = Param() def __init__(self, env, scope=None, **kwargs): self.lr_schedule = build_scheduled_value(self.lr_schedule, "lr") self.l2_schedule = build_scheduled_value(self.l2_schedule, "l2_weight") super().__init__(env, scope=scope, **kwargs) def resample(self, *args, axis=-1): """ Just resample, but potentially applied to several args. """ res = list(args) if self.k_particles > 1: for i, arg in enumerate(res): res[i] = self._resample(arg, axis) if len(res) == 1: res = res[0] return res def _resample(self, arg, axis=-1): iw_sample_idx = self.iw_resampling_idx + tf.range(self.batch_size) * self.k_particles resampled = index.gather_axis(arg, iw_sample_idx, axis) shape = arg.shape.as_list() shape[axis] = self.batch_size resampled.set_shape(shape) return resampled def _log_resampled(self, name): tensor = self.tensors[name + "_per_sample"] self.tensors['resampled_' + name] = self._resample(tensor) self.recorded_tensors[name] = self._imp_weighted_mean(tensor) self.tensors[name] = self.recorded_tensors[name] def _imp_weighted_mean(self, tensor): tensor = tf.reshape(tensor, (-1, self.batch_size, self.k_particles)) tensor = tf.reduce_mean(tensor, 0) return tf.reduce_mean(self.importance_weights * tensor * self.k_particles) def compute_validation_pixelwise_mean(self, data): sess = tf.get_default_session() mean = None n_points = 0 feed_dict = self.data_manager.do_val() while True: try: inp = sess.run(data["image"], feed_dict=feed_dict) except tf.errors.OutOfRangeError: break n_new = inp.shape[0] * inp.shape[1] if mean is None: mean = np.mean(inp, axis=(0, 1)) else: mean = mean * (n_points / (n_points + n_new)) + np.sum(inp, axis=(0, 1)) / (n_points + n_new) n_points += n_new return mean def _build_graph(self): self.data_manager = DataManager(datasets=self.env.datasets) self.data_manager.build_graph() if self.k_particles <= 1: raise Exception("`k_particles` must be > 1.") data = self.data_manager.iterator.get_next() data['mean_img'] = self.compute_validation_pixelwise_mean(data) self.batch_size = cfg.batch_size self.tensors = AttrDict() network_outputs = self.network(data, self.data_manager.is_training) network_tensors = network_outputs["tensors"] network_recorded_tensors = network_outputs["recorded_tensors"] network_losses = network_outputs["losses"] assert not network_losses self.tensors.update(network_tensors) self.recorded_tensors = recorded_tensors = dict(global_step=tf.train.get_or_create_global_step()) self.recorded_tensors.update(network_recorded_tensors) # --- values for training --- log_weights = tf.reduce_sum(self.tensors.log_weights_per_timestep, 0) self.log_weights = tf.reshape(log_weights, (self.batch_size, self.k_particles)) self.elbo_vae = tf.reduce_mean(self.log_weights) self.elbo_iwae_per_example = targets.iwae(self.log_weights) self.elbo_iwae = tf.reduce_mean(self.elbo_iwae_per_example) self.normalised_elbo_vae = self.elbo_vae / tf.to_float(self.network.dynamic_n_frames) self.normalised_elbo_iwae = self.elbo_iwae / tf.to_float(self.network.dynamic_n_frames) self.importance_weights = tf.stop_gradient(tf.nn.softmax(self.log_weights, -1)) self.ess = ops.ess(self.importance_weights, average=True) self.iw_distrib = tf.distributions.Categorical(probs=self.importance_weights) self.iw_resampling_idx = self.iw_distrib.sample() # --- count accuracy --- if "annotations" in data: gt_num_steps = tf.transpose(self.tensors.n_valid_annotations, (1, 0))[:, :, None] num_steps_per_sample = tf.reshape( self.tensors.num_steps_per_sample, (-1, self.batch_size, self.k_particles)) count_1norm = tf.abs(num_steps_per_sample - tf.cast(gt_num_steps, tf.float32)) count_error = tf.cast(count_1norm > 0.5, tf.float32) self.recorded_tensors.update( count_1norm=self._imp_weighted_mean(count_1norm), count_error=self._imp_weighted_mean(count_error), ) # --- losses --- log_probs = tf.reduce_sum(self.tensors.discrete_log_prob, 0) target = targets.vimco(self.log_weights, log_probs, self.elbo_iwae_per_example) target /= tf.to_float(self.network.dynamic_n_frames) loss_l2 = targets.l2_reg(self.l2_schedule) target += loss_l2 # --- train op --- tvars = tf.trainable_variables() pure_gradients = tf.gradients(target, tvars) clipped_gradients = pure_gradients if self.max_grad_norm is not None and self.max_grad_norm > 0.0: clipped_gradients, _ = tf.clip_by_global_norm(pure_gradients, self.max_grad_norm) grads_and_vars = list(zip(clipped_gradients, tvars)) lr = self.lr_schedule valid_lr = tf.Assert( tf.logical_and(tf.less(lr, 1.0), tf.less(0.0, lr)), [lr], name="valid_learning_rate") opt = tf.train.RMSPropOptimizer(self.lr_schedule, momentum=.9) with tf.control_dependencies([valid_lr]): self.train_op = opt.apply_gradients(grads_and_vars, global_step=None) recorded_tensors.update( grad_norm_pure=tf.global_norm(pure_gradients), grad_norm_processed=tf.global_norm(clipped_gradients), grad_lr_norm=lr * tf.global_norm(clipped_gradients), ) # gvs = opt.compute_gradients(target) # assert len(gvs) == len(tf.trainable_variables()) # for g, v in gvs: # assert g is not None, 'Gradient for variable {} is None'.format(v) # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # with tf.control_dependencies(update_ops): # self.train_op = opt.apply_gradients(gvs) # --- record --- self.tensors['mse_per_sample'] = tf.reduce_mean( (self.tensors['processed_image'] - self.tensors['canvas']) ** 2, (0, 2, 3, 4)) self.raw_mse = tf.reduce_mean(self.tensors['mse_per_sample']) self._log_resampled('mse') self._log_resampled('data_ll') self._log_resampled('log_p_z') self._log_resampled('log_q_z_given_x') self._log_resampled('kl') try: self._log_resampled('num_steps') self._log_resampled('num_disc_steps') self._log_resampled('num_prop_steps') except AttributeError: pass recorded_tensors.update( raw_mse=self.raw_mse, elbo_vae=self.elbo_vae, elbo_iwae=self.elbo_iwae, normalised_elbo_vae=self.normalised_elbo_vae, normalised_elbo_iwae=self.normalised_elbo_iwae, ess=self.ess, loss=target, target=target, loss_l2=loss_l2, ) self.train_records = {} # --- recorded values --- intersection = recorded_tensors.keys() & self.network.eval_funcs.keys() assert not intersection, "Key sets have non-zero intersection: {}".format(intersection) # --- for rendering and eval --- resampled_names = ( 'obj_id canvas glimpse presence_prob presence presence_logit ' 'where_coords num_steps_per_sample'.split()) for name in resampled_names: try: resampled_tensor = self.resample(self.tensors[name], axis=1) permutation = [1, 0] + list(range(2, len(resampled_tensor.shape))) self.tensors['resampled_' + name] = tf.transpose(resampled_tensor, permutation) except AttributeError: pass # For running functions, during evaluation, that are not implemented in tensorflow self.evaluator = Evaluator(self.network.eval_funcs, self.tensors, self)
def _build_graph(self): self.data_manager = DataManager(datasets=self.env.datasets) self.data_manager.build_graph() if self.k_particles <= 1: raise Exception("`k_particles` must be > 1.") data = self.data_manager.iterator.get_next() data['mean_img'] = self.compute_validation_pixelwise_mean(data) self.batch_size = cfg.batch_size self.tensors = AttrDict() network_outputs = self.network(data, self.data_manager.is_training) network_tensors = network_outputs["tensors"] network_recorded_tensors = network_outputs["recorded_tensors"] network_losses = network_outputs["losses"] assert not network_losses self.tensors.update(network_tensors) self.recorded_tensors = recorded_tensors = dict(global_step=tf.train.get_or_create_global_step()) self.recorded_tensors.update(network_recorded_tensors) # --- values for training --- log_weights = tf.reduce_sum(self.tensors.log_weights_per_timestep, 0) self.log_weights = tf.reshape(log_weights, (self.batch_size, self.k_particles)) self.elbo_vae = tf.reduce_mean(self.log_weights) self.elbo_iwae_per_example = targets.iwae(self.log_weights) self.elbo_iwae = tf.reduce_mean(self.elbo_iwae_per_example) self.normalised_elbo_vae = self.elbo_vae / tf.to_float(self.network.dynamic_n_frames) self.normalised_elbo_iwae = self.elbo_iwae / tf.to_float(self.network.dynamic_n_frames) self.importance_weights = tf.stop_gradient(tf.nn.softmax(self.log_weights, -1)) self.ess = ops.ess(self.importance_weights, average=True) self.iw_distrib = tf.distributions.Categorical(probs=self.importance_weights) self.iw_resampling_idx = self.iw_distrib.sample() # --- count accuracy --- if "annotations" in data: gt_num_steps = tf.transpose(self.tensors.n_valid_annotations, (1, 0))[:, :, None] num_steps_per_sample = tf.reshape( self.tensors.num_steps_per_sample, (-1, self.batch_size, self.k_particles)) count_1norm = tf.abs(num_steps_per_sample - tf.cast(gt_num_steps, tf.float32)) count_error = tf.cast(count_1norm > 0.5, tf.float32) self.recorded_tensors.update( count_1norm=self._imp_weighted_mean(count_1norm), count_error=self._imp_weighted_mean(count_error), ) # --- losses --- log_probs = tf.reduce_sum(self.tensors.discrete_log_prob, 0) target = targets.vimco(self.log_weights, log_probs, self.elbo_iwae_per_example) target /= tf.to_float(self.network.dynamic_n_frames) loss_l2 = targets.l2_reg(self.l2_schedule) target += loss_l2 # --- train op --- tvars = tf.trainable_variables() pure_gradients = tf.gradients(target, tvars) clipped_gradients = pure_gradients if self.max_grad_norm is not None and self.max_grad_norm > 0.0: clipped_gradients, _ = tf.clip_by_global_norm(pure_gradients, self.max_grad_norm) grads_and_vars = list(zip(clipped_gradients, tvars)) lr = self.lr_schedule valid_lr = tf.Assert( tf.logical_and(tf.less(lr, 1.0), tf.less(0.0, lr)), [lr], name="valid_learning_rate") opt = tf.train.RMSPropOptimizer(self.lr_schedule, momentum=.9) with tf.control_dependencies([valid_lr]): self.train_op = opt.apply_gradients(grads_and_vars, global_step=None) recorded_tensors.update( grad_norm_pure=tf.global_norm(pure_gradients), grad_norm_processed=tf.global_norm(clipped_gradients), grad_lr_norm=lr * tf.global_norm(clipped_gradients), ) # gvs = opt.compute_gradients(target) # assert len(gvs) == len(tf.trainable_variables()) # for g, v in gvs: # assert g is not None, 'Gradient for variable {} is None'.format(v) # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # with tf.control_dependencies(update_ops): # self.train_op = opt.apply_gradients(gvs) # --- record --- self.tensors['mse_per_sample'] = tf.reduce_mean( (self.tensors['processed_image'] - self.tensors['canvas']) ** 2, (0, 2, 3, 4)) self.raw_mse = tf.reduce_mean(self.tensors['mse_per_sample']) self._log_resampled('mse') self._log_resampled('data_ll') self._log_resampled('log_p_z') self._log_resampled('log_q_z_given_x') self._log_resampled('kl') try: self._log_resampled('num_steps') self._log_resampled('num_disc_steps') self._log_resampled('num_prop_steps') except AttributeError: pass recorded_tensors.update( raw_mse=self.raw_mse, elbo_vae=self.elbo_vae, elbo_iwae=self.elbo_iwae, normalised_elbo_vae=self.normalised_elbo_vae, normalised_elbo_iwae=self.normalised_elbo_iwae, ess=self.ess, loss=target, target=target, loss_l2=loss_l2, ) self.train_records = {} # --- recorded values --- intersection = recorded_tensors.keys() & self.network.eval_funcs.keys() assert not intersection, "Key sets have non-zero intersection: {}".format(intersection) # --- for rendering and eval --- resampled_names = ( 'obj_id canvas glimpse presence_prob presence presence_logit ' 'where_coords num_steps_per_sample'.split()) for name in resampled_names: try: resampled_tensor = self.resample(self.tensors[name], axis=1) permutation = [1, 0] + list(range(2, len(resampled_tensor.shape))) self.tensors['resampled_' + name] = tf.transpose(resampled_tensor, permutation) except AttributeError: pass # For running functions, during evaluation, that are not implemented in tensorflow self.evaluator = Evaluator(self.network.eval_funcs, self.tensors, self)
def get_eval_tensors(self, step, mode="val", data_exclude=None, tensors_exclude=None): """ Run `self.model` on either val or test dataset, return data and tensors. """ assert mode in "val test".split() if tensors_exclude is None: tensors_exclude = [] if isinstance(tensors_exclude, str): tensors_exclude = tensors_exclude.split() if data_exclude is None: data_exclude = [] if isinstance(data_exclude, str): data_exclude = data_exclude.split() self.model.eval() if mode == 'val': data_iterator = self.data_manager.do_val() elif mode == 'test': data_iterator = self.data_manager.do_test() else: raise Exception("Unknown data mode: {}".format(mode)) _tensors = [] _data = [] n_points = 0 n_batches = 0 record = None with torch.no_grad(): for data in data_iterator: data = AttrDict(data) tensors, data, recorded_tensors, losses = self.model(data, step) losses = AttrDict(losses) loss = 0.0 for name, tensor in losses.flatten().items(): loss += tensor recorded_tensors['loss_' + name] = tensor recorded_tensors['loss'] = loss recorded_tensors = map_structure( lambda t: to_np(t.mean()) if isinstance(t, (torch.Tensor, np.ndarray)) else t, recorded_tensors, is_leaf=lambda rec: not isinstance(rec, dict)) batch_size = recorded_tensors['batch_size'] n_points += batch_size n_batches += 1 if record is None: record = recorded_tensors else: record = map_structure( lambda rec, rec_t: rec + batch_size * np.mean(rec_t), record, recorded_tensors, is_leaf=lambda rec: not isinstance(rec, dict)) data = AttrDict(data) for de in data_exclude: try: del data[de] except (KeyError, AttributeError): pass data = map_structure( lambda t: to_np(t) if isinstance(t, torch.Tensor) else t, data, is_leaf=lambda rec: not isinstance(rec, dict)) tensors = AttrDict(tensors) for te in tensors_exclude: try: del tensors[te] except (KeyError, AttributeError): pass tensors = map_structure( lambda t: to_np(t) if isinstance(t, torch.Tensor) else t, tensors, is_leaf=lambda rec: not isinstance(rec, dict)) _tensors.append(tensors) _data.append(data) def postprocess(*t): return pad_and_concatenate(t, axis=0) _tensors = map_structure(postprocess, *_tensors, is_leaf=lambda rec: not isinstance(rec, dict)) _data = map_structure(postprocess, *_data, is_leaf=lambda rec: not isinstance(rec, dict)) record = map_structure( lambda rec: rec / n_points, record, is_leaf=lambda rec: not isinstance(rec, dict)) record = AttrDict(record) return _data, _tensors, record
def update(self, batch_size, step): print_time = step % 100 == 0 self.model.train() data = AttrDict(next(self.train_iterator)) self.model.update_global_step(step) detect_grad_anomalies = cfg.get('detect_grad_anomalies', False) with torch.autograd.set_detect_anomaly(detect_grad_anomalies): profile_step = cfg.get('pytorch_profile_step', 0) if profile_step > 0 and step % profile_step == 0: with torch.autograd.profiler.profile(use_cuda=True) as prof: tensors, data, recorded_tensors, losses = self.model(data, step) print(prof) else: with timed_block('forward', print_time): tensors, data, recorded_tensors, losses = self.model(data, step) # --- loss --- losses = AttrDict(losses) loss = 0.0 for name, tensor in losses.flatten().items(): loss += tensor recorded_tensors['loss_' + name] = tensor recorded_tensors['loss'] = loss with timed_block('zero_grad', print_time): # Apparently this is faster, according to https://www.youtube.com/watch?v=9mS1fIYj1So, 10:37 for param in self.model.parameters(): param.grad = None # self.optimizer.zero_grad() with timed_block('loss backward', print_time): loss.backward() with timed_block('process grad', print_time): if self.grad_norm_recorder is not None: self.grad_norm_recorder.update() if step % self.print_grad_norm_step == 0: self.grad_norm_recorder.display() parameters = list(self.model.parameters()) pure_grad_norm = grad_norm(parameters) if self.max_grad_norm is not None and self.max_grad_norm > 0.0: torch.nn.utils.clip_grad_norm_(parameters, self.max_grad_norm) clipped_grad_norm = grad_norm(parameters) with timed_block('optimizer step', print_time): self.optimizer.step() if self.scheduler is not None: self.scheduler.step() update_result = self._update(batch_size) if isinstance(update_result, dict): recorded_tensors.update(update_result) self._n_experiences += batch_size recorded_tensors.update( grad_norm_pure=pure_grad_norm, grad_norm_clipped=clipped_grad_norm ) scheduled_values = self.model.get_scheduled_values() recorded_tensors.update(scheduled_values) recorded_tensors = map_structure( lambda t: t.mean() if isinstance(t, torch.Tensor) else t, recorded_tensors, is_leaf=lambda rec: not isinstance(rec, dict)) return recorded_tensors
def _body(self, inp, features, objects, is_posterior): """ Summary of how updates are done for the different variables: glimpse': glimpse_params = where + 0.1 * predicted_logit where_y/x: new_where_y/x = where_y/x + where_t_scale * tanh(predicted_logit) where_h/w: new_where_h/w_logit = where_h/w_logit + predicted_logit what: Hard to summarize here, taken from SQAIR. Kind of like an LSTM. depth: new_depth_logit = depth_logit + predicted_logit obj: new_obj = obj * sigmoid(predicted_logit) """ batch_size, n_objects, _ = tf_shape(features) new_objects = AttrDict() is_posterior_tf = tf.ones_like(features[..., 0:2]) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] base_features = tf.concat([features, is_posterior_tf], axis=-1) cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1) # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up glimpse_dim = self.object_shape[0] * self.object_shape[1] glimpse_prime_params = apply_object_wise(self.glimpse_prime_network, base_features, output_size=4 + 2 * glimpse_dim, is_training=self.is_training) glimpse_prime_params, glimpse_prime_mask_logit, glimpse_mask_logit = \ tf.split(glimpse_prime_params, [4, glimpse_dim, glimpse_dim], axis=-1) if is_posterior: # --- obtain final parameters for glimpse prime by modifying current pose --- _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1) g_yt = cyt + 0.1 * _yt g_xt = cxt + 0.1 * _xt g_ys = ys + 0.1 * _ys g_xs = xs + 0.1 * _xs # --- extract glimpse prime --- _, image_height, image_width, _ = tf_shape(inp) g_yt, g_xt, g_ys, g_xs = coords_to_image_space( g_yt, g_xt, g_ys, g_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse_prime = extract_affine_glimpse(inp, self.object_shape, g_yt, g_xt, g_ys, g_xs, self.edge_resampler) else: g_yt = tf.zeros_like(cyt) g_xt = tf.zeros_like(cxt) g_ys = tf.zeros_like(ys) g_xs = tf.zeros_like(xs) glimpse_prime = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) glimpse_prime_mask = tf.nn.sigmoid(glimpse_prime_mask_logit + 1.) leading_mask_shape = tf_shape(glimpse_prime)[:-1] glimpse_prime_mask = tf.reshape(glimpse_prime_mask, (*leading_mask_shape, 1)) new_objects.update( glimpse_prime_box=tf.concat([g_yt, g_xt, g_ys, g_xs], axis=-1), glimpse_prime=glimpse_prime, glimpse_prime_mask=glimpse_prime_mask, ) glimpse_prime *= glimpse_prime_mask # --- encode glimpse --- encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder, glimpse_prime, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- position and scale --- # roughly: # base_features == temporal_state, encoded_glimpse_prime == hidden_output # hidden_output conditions on encoded_glimpse, and that's the only place encoded_glimpse_prime is used. # Here SQAIR conditions on the actual location values from the previous timestep, but we leave that out for now. d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1) d_box_params = apply_object_wise(self.d_box_network, d_box_inp, output_size=8, is_training=self.is_training) d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1) d_box_std = self.std_nonlinearity(d_box_log_std) d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + ( 1 - self.training_wheels) * d_box_mean d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + ( 1 - self.training_wheels) * d_box_std d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1) d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1) # --- position --- # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of # the scale, whereas for position we want to put a prior on the difference in position over timesteps. d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample() d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample() d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit) d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit) new_cyt = cyt + d_yt new_cxt = cxt + d_xt new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1) # --- scale --- new_ys_mean = objects.ys_logit + d_ys new_xs_mean = objects.xs_logit + d_xs new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample() new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample() new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw if self.use_abs_posn: box_params = tf.concat([ new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit ], axis=-1) else: box_params = tf.concat( [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1) new_objects.update( abs_posn=new_abs_posn, yt=new_cyt, xt=new_cxt, ys=new_ys, xs=new_xs, normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs], axis=-1), d_yt_logit=d_yt_logit, d_xt_logit=d_xt_logit, ys_logit=new_ys_logit, xs_logit=new_xs_logit, d_yt_logit_mean=d_yt_mean, d_xt_logit_mean=d_xt_mean, ys_logit_mean=new_ys_mean, xs_logit_mean=new_xs_mean, d_yt_logit_std=d_yt_std, d_xt_logit_std=d_xt_std, ys_logit_std=ys_std, xs_logit_std=xs_std, ) # --- attributes --- # --- extract a glimpse using new box --- if is_posterior: _, image_height, image_width, _ = tf_shape(inp) _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space( new_cyt, new_cxt, new_ys, new_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt, _new_cxt, _new_ys, _new_xs, self.edge_resampler) else: glimpse = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) glimpse_mask = tf.nn.sigmoid(glimpse_mask_logit + 1.) leading_mask_shape = tf_shape(glimpse)[:-1] glimpse_mask = tf.reshape(glimpse_mask, (*leading_mask_shape, 1)) glimpse *= glimpse_mask encoded_glimpse = apply_object_wise(self.glimpse_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- predict change in attributes --- # so under sqair we mix between three different values for the attributes: # 1. value from previous timestep # 2. value predicted directly from glimpse # 3. value predicted based on update of temporal cell...this update conditions on hidden_output, # the prediction in #2., and the where values. # How to do this given that we are predicting the change in attr? We could just directly predict # the attr instead, but call it d_attr. After all, it is in this function that we control # whether d_attr is added to attr. # So, make a prediction based on just the input: attr_from_inp = apply_object_wise(self.predict_attr_inp, encoded_glimpse, output_size=2 * self.A, is_training=self.is_training) attr_from_inp_mean, attr_from_inp_log_std = tf.split(attr_from_inp, [self.A, self.A], axis=-1) attr_from_inp_std = self.std_nonlinearity(attr_from_inp_log_std) # And then a prediction which takes the past into account (predicting gate values at the same time): attr_from_temp_inp = tf.concat( [base_features, box_params, encoded_glimpse], axis=-1) attr_from_temp = apply_object_wise(self.predict_attr_temp, attr_from_temp_inp, output_size=5 * self.A, is_training=self.is_training) (attr_from_temp_mean, attr_from_temp_log_std, f_gate_logit, i_gate_logit, t_gate_logit) = tf.split(attr_from_temp, 5, axis=-1) attr_from_temp_std = self.std_nonlinearity(attr_from_temp_log_std) # bias the gates f_gate = tf.nn.sigmoid(f_gate_logit + 1) * .9999 i_gate = tf.nn.sigmoid(i_gate_logit + 1) * .9999 t_gate = tf.nn.sigmoid(t_gate_logit + 1) * .9999 attr_mean = f_gate * objects.attr + ( 1 - i_gate) * attr_from_inp_mean + (1 - t_gate) * attr_from_temp_mean attr_std = (1 - i_gate) * attr_from_inp_std + ( 1 - t_gate) * attr_from_temp_std new_attr = Normal(loc=attr_mean, scale=attr_std).sample() # --- apply change in attributes --- new_objects.update( attr=new_attr, d_attr=new_attr - objects.attr, d_attr_mean=attr_mean - objects.attr, d_attr_std=attr_std, f_gate=f_gate, i_gate=i_gate, t_gate=t_gate, glimpse=glimpse, glimpse_mask=glimpse_mask, ) # --- z --- d_z_inp = tf.concat( [base_features, box_params, new_attr, encoded_glimpse], axis=-1) d_z_params = apply_object_wise(self.d_z_network, d_z_inp, output_size=2, is_training=self.is_training) d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1) d_z_std = self.std_nonlinearity(d_z_log_std) d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + ( 1 - self.training_wheels) * d_z_mean d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + ( 1 - self.training_wheels) * d_z_std d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample() new_z_logit = objects.z_logit + d_z_logit new_z = self.z_nonlinearity(new_z_logit) new_objects.update( z=new_z, z_logit=new_z_logit, d_z_logit=d_z_logit, d_z_logit_mean=d_z_mean, d_z_logit_std=d_z_std, ) # --- obj --- d_obj_inp = tf.concat( [base_features, box_params, new_attr, new_z, encoded_glimpse], axis=-1) d_obj_logit = apply_object_wise(self.d_obj_network, d_obj_inp, output_size=1, is_training=self.is_training) d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + ( 1 - self.training_wheels) * d_obj_logit d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10., 10.) d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample( d_obj_log_odds, self.obj_concrete_temp) + (1 - self._noisy) * d_obj_log_odds) d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid) new_obj = objects.obj * d_obj new_objects.update( d_obj_log_odds=d_obj_log_odds, d_obj_prob=tf.nn.sigmoid(d_obj_log_odds), d_obj_pre_sigmoid=d_obj_pre_sigmoid, d_obj=d_obj, obj=new_obj, ) # --- update each object's hidden state -- cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1) if is_posterior: _, new_objects.prop_state = apply_object_wise( self.cell, cell_input, objects.prop_state) new_objects.prior_prop_state = new_objects.prop_state else: _, new_objects.prior_prop_state = apply_object_wise( self.cell, cell_input, objects.prior_prop_state) new_objects.prop_state = new_objects.prior_prop_state return new_objects
def _body(self, inp, features, objects, is_posterior): batch_size, n_objects, _ = tf_shape(features) new_objects = AttrDict() is_posterior_tf = tf.ones_like(features[..., 0:2]) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] base_features = tf.concat([features, is_posterior_tf], axis=-1) cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1) if self.learn_glimpse_prime: # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up glimpse_prime_params = apply_object_wise( self.glimpse_prime_network, base_features, output_size=4, is_training=self.is_training) else: glimpse_prime_params = tf.zeros_like(base_features[..., :4]) if is_posterior: if self.learn_glimpse_prime: # --- obtain final parameters for glimpse prime by modifying current pose --- _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1) # This is how it is done in SQAIR g_yt = cyt + 0.1 * _yt g_xt = cxt + 0.1 * _xt g_ys = ys + 0.1 * _ys g_xs = xs + 0.1 * _xs # g_yt = cyt + self.glimpse_prime_scale * tf.nn.tanh(_yt) # g_xt = cxt + self.glimpse_prime_scale * tf.nn.tanh(_xt) # g_ys = ys + self.glimpse_prime_scale * tf.nn.tanh(_ys) # g_xs = xs + self.glimpse_prime_scale * tf.nn.tanh(_xs) else: g_yt = cyt g_xt = cxt g_ys = self.glimpse_prime_scale * ys g_xs = self.glimpse_prime_scale * xs # --- extract glimpse prime --- _, image_height, image_width, _ = tf_shape(inp) g_yt, g_xt, g_ys, g_xs = coords_to_image_space( g_yt, g_xt, g_ys, g_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse_prime = extract_affine_glimpse(inp, self.object_shape, g_yt, g_xt, g_ys, g_xs, self.edge_resampler) else: g_yt = tf.zeros_like(cyt) g_xt = tf.zeros_like(cxt) g_ys = tf.zeros_like(ys) g_xs = tf.zeros_like(xs) glimpse_prime = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) new_objects.update(glimpse_prime_box=tf.concat( [g_yt, g_xt, g_ys, g_xs], axis=-1), ) # --- encode glimpse --- encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder, glimpse_prime, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- position and scale --- d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1) d_box_params = apply_object_wise(self.d_box_network, d_box_inp, output_size=8, is_training=self.is_training) d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1) d_box_std = self.std_nonlinearity(d_box_log_std) d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + ( 1 - self.training_wheels) * d_box_mean d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + ( 1 - self.training_wheels) * d_box_std d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1) d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1) # --- position --- # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of # the scale, whereas for position we want to put a prior on the difference in position over timesteps. d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample() d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample() d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit) d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit) new_cyt = cyt + d_yt new_cxt = cxt + d_xt new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1) # --- scale --- new_ys_mean = objects.ys_logit + d_ys new_xs_mean = objects.xs_logit + d_xs new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample() new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample() new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw # Used for conditioning if self.use_abs_posn: box_params = tf.concat([ new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit ], axis=-1) else: box_params = tf.concat( [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1) new_objects.update( abs_posn=new_abs_posn, yt=new_cyt, xt=new_cxt, ys=new_ys, xs=new_xs, normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs], axis=-1), d_yt_logit=d_yt_logit, d_xt_logit=d_xt_logit, ys_logit=new_ys_logit, xs_logit=new_xs_logit, d_yt_logit_mean=d_yt_mean, d_xt_logit_mean=d_xt_mean, ys_logit_mean=new_ys_mean, xs_logit_mean=new_xs_mean, d_yt_logit_std=d_yt_std, d_xt_logit_std=d_xt_std, ys_logit_std=ys_std, xs_logit_std=xs_std, glimpse_prime=glimpse_prime, ) # --- attributes --- # --- extract a glimpse using new box --- if is_posterior: _, image_height, image_width, _ = tf_shape(inp) _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space( new_cyt, new_cxt, new_ys, new_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt, _new_cxt, _new_ys, _new_xs, self.edge_resampler) else: glimpse = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) encoded_glimpse = apply_object_wise(self.glimpse_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- predict change in attributes --- d_attr_inp = tf.concat([base_features, box_params, encoded_glimpse], axis=-1) d_attr_params = apply_object_wise(self.d_attr_network, d_attr_inp, output_size=2 * self.A + 1, is_training=self.is_training) d_attr_mean, d_attr_log_std, gate_logit = tf.split(d_attr_params, [self.A, self.A, 1], axis=-1) d_attr_std = self.std_nonlinearity(d_attr_log_std) gate = tf.nn.sigmoid(gate_logit) if self.gate_d_attr: d_attr_mean *= gate d_attr = Normal(loc=d_attr_mean, scale=d_attr_std).sample() # --- apply change in attributes --- new_attr = objects.attr + d_attr new_objects.update( attr=new_attr, d_attr=d_attr, d_attr_mean=d_attr_mean, d_attr_std=d_attr_std, glimpse=glimpse, d_attr_gate=gate, ) # --- z --- d_z_inp = tf.concat( [base_features, box_params, new_attr, encoded_glimpse], axis=-1) d_z_params = apply_object_wise(self.d_z_network, d_z_inp, output_size=2, is_training=self.is_training) d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1) d_z_std = self.std_nonlinearity(d_z_log_std) d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + ( 1 - self.training_wheels) * d_z_mean d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + ( 1 - self.training_wheels) * d_z_std d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample() new_z_logit = objects.z_logit + d_z_logit new_z = self.z_nonlinearity(new_z_logit) new_objects.update( z=new_z, z_logit=new_z_logit, d_z_logit=d_z_logit, d_z_logit_mean=d_z_mean, d_z_logit_std=d_z_std, ) # --- obj --- d_obj_inp = tf.concat( [base_features, box_params, new_attr, new_z, encoded_glimpse], axis=-1) d_obj_logit = apply_object_wise(self.d_obj_network, d_obj_inp, output_size=1, is_training=self.is_training) d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + ( 1 - self.training_wheels) * d_obj_logit d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10., 10.) d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample( d_obj_log_odds, self.obj_concrete_temp) + (1 - self._noisy) * d_obj_log_odds) d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid) new_obj = objects.obj * d_obj new_objects.update( d_obj_log_odds=d_obj_log_odds, d_obj_prob=tf.nn.sigmoid(d_obj_log_odds), d_obj_pre_sigmoid=d_obj_pre_sigmoid, d_obj=d_obj, obj=new_obj, ) # --- update each object's hidden state -- cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1) if is_posterior: _, new_objects.prop_state = apply_object_wise( self.cell, cell_input, objects.prop_state) new_objects.prior_prop_state = new_objects.prop_state else: _, new_objects.prior_prop_state = apply_object_wise( self.cell, cell_input, objects.prior_prop_state) new_objects.prop_state = new_objects.prior_prop_state return new_objects
def _call(self, inp, features, objects, is_training, is_posterior): print("\n" + "-" * 10 + " PropagationLayer(is_posterior={}) ".format(is_posterior) + "-" * 10) self._build_networks() if not self.initialized: # Note this limits the re-usability of this module to images # with a fixed shape (the shape of the first image it is used on) self.image_height = int(inp.shape[-3]) self.image_width = int(inp.shape[-2]) self.image_depth = int(inp.shape[-1]) self.batch_size = tf.shape(inp)[0] self.is_training = is_training self.float_is_training = tf.to_float(is_training) if self.do_lateral: # hasn't been updated to make use of abs_posn raise Exception("NotImplemented.") batch_size, n_objects, _ = tf_shape(features) new_objects = [] for i in range(n_objects): # apply lateral to running objects with the feature vector for # the current object _features = features[:, i:i + 1, :] if i > 0: normalized_box = tf.concat( [o.normalized_box for o in new_objects], axis=1) attr = tf.concat([o.attr for o in new_objects], axis=1) z = tf.concat([o.z for o in new_objects], axis=1) obj = tf.concat([o.obj for o in new_objects], axis=1) completed_features = tf.concat( [normalized_box[:, :, 2:], attr, z, obj], axis=2) completed_locs = normalized_box[:, :, :2] current_features = tf.concat([ objects.normalized_box[:, i:i + 1, 2:], objects.attr[:, i:i + 1], objects.z[:, i:i + 1], objects.obj[:, i:i + 1] ], axis=2) current_locs = objects.normalized_box[:, i:i + 1, :2] # if i > max_completed_objects: # # top_k_indices # # squared_distances = tf.reduce_sum((completed_locs - current_locs)**2, axis=2) # # _, top_k_indices = tf.nn.top_k(squared_distances, k=max_completed_objects, sorted=False) _features = self.lateral_network(completed_locs, completed_features, current_locs, current_features, is_training) _objects = AttrDict( normalized_box=objects.normalized_box[:, i:i + 1], attr=objects.attr[:, i:i + 1], z=objects.z[:, i:i + 1], obj=objects.obj[:, i:i + 1], ) _new_objects = self._body(inp, _features, _objects, is_posterior) new_objects.append(_new_objects) _new_objects = AttrDict() for k in new_objects[0]: _new_objects[k] = tf.concat([no[k] for no in new_objects], axis=1) return _new_objects else: return self._body(inp, features, objects, is_posterior)
def extract_stage_data(self, fields=None, bare=False): """ Extract stage-by-stage data about the training runs. Parameters ---------- bare: boolean If True, only returns the data. Otherwise, additionally returns the stage-by-stage config and meta-data. Returns ------- A nested data structure containing the requested data. {param-setting-key: {(repeat, seed): (df, sc, md) where: df is a pandas DataFrame sc is a list giving the config for each stage md is a dictionary storing metadata """ stage_data = defaultdict(dict) if isinstance(fields, str): fields = fields.split() config_keys = self.dist_keys() KeyTuple = namedtuple(self.__class__.__name__ + "Key", config_keys) for exp_path in self.experiment_paths: try: exp_data = FrozenTrainingLoopData(exp_path) md = {} md['host'] = exp_data.host for k in config_keys: md[k] = exp_data.get_config_value(k) sc = [] records = [] for stage in exp_data.history: record = stage.copy() stage_config = record['stage_config'].copy() sc.append(stage_config) del record['stage_config'] record = AttrDict(record).flatten() if 'best_path' in record: del record['best_path'] if 'final_path' in record: del record['final_path'] # Fix and filter keys _record = {} for k, v in record.items(): if k.startswith("best_"): k = k[5:] if (fields and k in fields) or not fields: _record[k] = v records.append(_record) key = KeyTuple(*(exp_data.get_config_value(k) for k in config_keys)) repeat = exp_data.get_config_value("repeat") seed = exp_data.get_config_value("seed") if bare: stage_data[key][( repeat, seed)] = pd.DataFrame.from_records(records) else: stage_data[key][(repeat, seed)] = ( pd.DataFrame.from_records(records), sc, md) except Exception: print( "Exception raised while extracting stage data for path: {}" .format(exp_path)) traceback.print_exc() return stage_data
def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None): print("\n" + "-" * 10 + " ConvGridObjectLayer({}, is_posterior={}) ".format(self.name, is_posterior) + "-" * 10) # --- set up sub networks and attributes --- self.maybe_build_subnet("box_network", builder=cfg.build_conv_lateral, key="box") self.maybe_build_subnet("attr_network", builder=cfg.build_conv_lateral, key="attr") self.maybe_build_subnet("z_network", builder=cfg.build_conv_lateral, key="z") self.maybe_build_subnet("obj_network", builder=cfg.build_conv_lateral, key="obj") self.maybe_build_subnet("object_encoder") _, H, W, n_channels = tf_shape(inp_features) if self.B != 1: raise Exception("NotImplemented") if not self.initialized: # Note this limits the re-usability of this module to images # with a fixed shape (the shape of the first image it is used on) self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp) self.H = H self.W = W self.HWB = H*W self.batch_size = tf.shape(inp)[0] self.is_training = is_training self.float_is_training = tf.to_float(is_training) is_posterior_tf = tf.ones_like(inp_features[..., :2]) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] objects = AttrDict() base_features = tf.concat([inp_features, is_posterior_tf], axis=-1) # --- box --- layer_inp = base_features n_features = self.n_passthrough_features output_size = 8 network_output = self.box_network(layer_inp, output_size + n_features, self.is_training) rep_input, features = tf.split(network_output, (output_size, n_features), axis=-1) _objects = self._build_box(rep_input, self.is_training) objects.update(_objects) # --- attr --- if is_posterior: # --- Get object attributes using object encoder --- yt, xt, ys, xs = tf.split(objects['normalized_box'], 4, axis=-1) yt, xt, ys, xs = coords_to_image_space( yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper( (self.image_height, self.image_width), self.object_shape, transform_constraints) _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1) _boxes = tf.reshape(_boxes, (self.batch_size*H*W, 4)) grid_coords = warper(_boxes) grid_coords = tf.reshape(grid_coords, (self.batch_size, H, W, *self.object_shape, 2,)) if self.edge_resampler: glimpse = resampler_edge.resampler_edge(inp, grid_coords) else: glimpse = tf.contrib.resampler.resampler(inp, grid_coords) else: glimpse = tf.zeros((self.batch_size, H, W, *self.object_shape, self.image_depth)) # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction encoded_glimpse = apply_object_wise( self.object_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse = tf.zeros_like(encoded_glimpse) layer_inp = tf.concat([base_features, features, encoded_glimpse, objects['local_box']], axis=-1) network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self.is_training) attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=-1) attr_std = self.std_nonlinearity(attr_log_std) attr = Normal(loc=attr_mean, scale=attr_std).sample() objects.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse) # --- z --- layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr']], axis=-1) n_features = self.n_passthrough_features network_output = self.z_network(layer_inp, 2 + n_features, self.is_training) z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=-1) z_std = self.std_nonlinearity(z_log_std) z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std z_logit = Normal(loc=z_mean, scale=z_std).sample() z = self.z_nonlinearity(z_logit) objects.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z) # --- obj --- layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr'], objects['z']], axis=-1) rep_input = self.obj_network(layer_inp, 1, self.is_training) _objects = self._build_obj(rep_input, self.is_training) objects.update(_objects) # --- final --- if prop_state is not None: objects.prop_state = tf.tile(prop_state[0:1, None, None, :], (self.batch_size, H, W, 1)) objects.prior_prop_state = tf.tile(prop_state[0:1, None, None, :], (self.batch_size, H, W, 1)) if self.flatten: _objects = AttrDict() for k, v in objects.items(): _, _, _, *trailing_dims = tf_shape(v) _objects[k] = tf.reshape(v, (self.batch_size, self.HWB, *trailing_dims)) objects = _objects # --- misc --- flat_objects = tf.reshape(objects.obj, (self.batch_size, -1)) objects.pred_n_objects = tf.reduce_sum(flat_objects, axis=1) objects.pred_n_objects_hard = tf.reduce_sum(tf.round(flat_objects), axis=1) return objects
def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None): print("\n" + "-" * 10 + " GridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10) # --- set up sub networks and attributes --- self.maybe_build_subnet("box_network", builder=cfg.build_lateral, key="box") self.maybe_build_subnet("attr_network", builder=cfg.build_lateral, key="attr") self.maybe_build_subnet("z_network", builder=cfg.build_lateral, key="z") self.maybe_build_subnet("obj_network", builder=cfg.build_lateral, key="obj") self.maybe_build_subnet("object_encoder") _, H, W, _ = tf_shape(inp_features) H = int(H) W = int(W) if not self.initialized: # Note this limits the re-usability of this module to images # with a fixed shape (the shape of the first image it is used on) self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp) self.H = H self.W = W self.HWB = H*W*self.B self.is_training = is_training self.float_is_training = tf.to_float(is_training) # --- set up the edge element --- sizes = [4, self.A, 1, 1] sigmoids = [True, False, False, True] total_sample_size = sum(sizes) if self.edge_weights is None: self.edge_weights = tf.get_variable("edge_weights", shape=total_sample_size, dtype=tf.float32) if "edge" in self.fixed_weights: tf.add_to_collection(FIXED_COLLECTION, self.edge_weights) _edge_weights = tf.split(self.edge_weights, sizes, axis=0) _edge_weights = [ (tf.nn.sigmoid(ew) if sigmoid else ew) for ew, sigmoid in zip(_edge_weights, sigmoids)] edge_element = tf.concat(_edge_weights, axis=0) edge_element = tf.tile(edge_element[None, :], (self.batch_size, 1)) # --- containers for storing built program --- program = np.empty((H, W, self.B), dtype=np.object) # --- build the program --- is_posterior_tf = tf.ones((self.batch_size, 2)) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] results = [] for h, w, b in itertools.product(range(H), range(W), range(self.B)): built = dict() partial_program, features = None, None context = self._get_sequential_context(program, h, w, b, edge_element) base_features = tf.concat([inp_features[:, h, w, :], context, is_posterior_tf], axis=1) # --- box --- layer_inp = base_features n_features = self.n_passthrough_features output_size = 8 network_output = self.box_network(layer_inp, output_size + n_features, self. is_training) rep_input, features = tf.split(network_output, (output_size, n_features), axis=1) _built = self._build_box(rep_input, self.is_training, hw=(h, w)) built.update(_built) partial_program = built['local_box'] # --- attr --- if is_posterior: # --- Get object attributes using object encoder --- yt, xt, ys, xs = tf.split(built['normalized_box'], 4, axis=-1) yt, xt, ys, xs = coords_to_image_space( yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper( (self.image_height, self.image_width), self.object_shape, transform_constraints) _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1) grid_coords = warper(_boxes) grid_coords = tf.reshape(grid_coords, (self.batch_size, 1, *self.object_shape, 2,)) if self.edge_resampler: glimpse = resampler_edge.resampler_edge(inp, grid_coords) else: glimpse = tf.contrib.resampler.resampler(inp, grid_coords) glimpse = tf.reshape(glimpse, (self.batch_size, *self.object_shape, self.image_depth)) else: glimpse = tf.zeros((self.batch_size, *self.object_shape, self.image_depth)) # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction encoded_glimpse = self.object_encoder(glimpse, (1, 1, self.A), self.is_training) encoded_glimpse = tf.reshape(encoded_glimpse, (self.batch_size, self.A)) if not is_posterior: encoded_glimpse = tf.zeros_like(encoded_glimpse) layer_inp = tf.concat( [base_features, features, encoded_glimpse, partial_program], axis=1) network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self. is_training) attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=1) attr_std = self.std_nonlinearity(attr_log_std) attr = Normal(loc=attr_mean, scale=attr_std).sample() built.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse) partial_program = tf.concat([partial_program, built['attr']], axis=1) # --- z --- layer_inp = tf.concat([base_features, features, partial_program], axis=1) n_features = self.n_passthrough_features network_output = self.z_network(layer_inp, 2 + n_features, self.is_training) z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=1) z_std = self.std_nonlinearity(z_log_std) z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std z_logit = Normal(loc=z_mean, scale=z_std).sample() z = self.z_nonlinearity(z_logit) built.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z) partial_program = tf.concat([partial_program, built['z']], axis=1) # --- obj --- layer_inp = tf.concat([base_features, features, partial_program], axis=1) rep_input = self.obj_network(layer_inp, 1, self.is_training) _built = self._build_obj(rep_input, self.is_training) built.update(_built) partial_program = tf.concat([partial_program, built['obj']], axis=1) # --- final --- results.append(built) program[h, w, b] = partial_program assert program[h, w, b].shape[1] == total_sample_size objects = AttrDict() for k in results[0]: objects[k] = tf.stack([r[k] for r in results], axis=1) if prop_state is not None: objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1)) objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1)) # --- misc --- objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2)) objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2)) return objects
def _call(self, data, is_training): self.data = data inp = data["image"] self._tensors = AttrDict( inp=inp, is_training=is_training, float_is_training=tf.to_float(is_training), batch_size=tf.shape(inp)[0], ) if "annotations" in data: self._tensors.update( annotations=data["annotations"]["data"], n_annotations=data["annotations"]["shapes"][:, 1], n_valid_annotations=tf.to_int32( tf.reduce_sum( data["annotations"]["data"][:, :, :, 0] * tf.to_float(data["annotations"]["mask"][:, :, :, 0]), axis=2 ) ) ) if "label" in data: self._tensors.update( targets=data["label"], ) if "background" in data: self._tensors.update( ground_truth_background=data["background"], ) if "offset" in data: self._tensors.update( offset=data["offset"], ) max_n_frames = tf_shape(inp)[1] if self.stage_steps is None: self.current_stage = tf.constant(0, tf.int32) dynamic_n_frames = max_n_frames else: self.current_stage = tf.cast(tf.train.get_or_create_global_step(), tf.int32) // self.stage_steps dynamic_n_frames = tf.minimum( self.initial_n_frames + self.n_frames_scale * self.current_stage, max_n_frames) dynamic_n_frames = tf.cast(dynamic_n_frames, tf.float32) dynamic_n_frames = ( self.float_is_training * tf.cast(dynamic_n_frames, tf.float32) + (1-self.float_is_training) * tf.cast(max_n_frames, tf.float32) ) self.dynamic_n_frames = tf.cast(dynamic_n_frames, tf.int32) self._tensors.current_stage = self.current_stage self._tensors.dynamic_n_frames = self.dynamic_n_frames self._tensors.inp = self._tensors.inp[:, :self.dynamic_n_frames] if 'annotations' in self._tensors: self._tensors.annotations = self._tensors.annotations[:, :self.dynamic_n_frames] # self._tensors.n_annotations = self._tensors.n_annotations[:, :self.dynamic_n_frames] self._tensors.n_valid_annotations = self._tensors.n_valid_annotations[:, :self.dynamic_n_frames] self.record_tensors( batch_size=tf.to_float(self.batch_size), float_is_training=self.float_is_training, current_stage=self.current_stage, dynamic_n_frames=self.dynamic_n_frames, ) self.losses = dict() with tf.variable_scope("representation", reuse=self.initialized): if self.needs_background: self.build_background() self.build_representation() return dict( tensors=self._tensors, recorded_tensors=self.recorded_tensors, losses=self.losses, )
class VideoNetwork(TensorRecorder): attr_prior_mean = Param() attr_prior_std = Param() noisy = Param() stage_steps = Param() initial_n_frames = Param() n_frames_scale = Param() background_encoder = None background_decoder = None needs_background = True eval_funcs = dict() def __init__(self, env, updater, scope=None, **kwargs): self.updater = updater self.obs_shape = env.datasets['train'].obs_shape self.n_frames, self.image_height, self.image_width, self.image_depth = self.obs_shape super(VideoNetwork, self).__init__(scope=scope, **kwargs) def std_nonlinearity(self, std_logit): std = 2 * tf.nn.sigmoid(tf.clip_by_value(std_logit, -10, 10)) if not self.noisy: std = tf.zeros_like(std) return std @property def inp(self): return self._tensors["inp"] @property def batch_size(self): return self._tensors["batch_size"] @property def is_training(self): return self._tensors["is_training"] @property def float_is_training(self): return self._tensors["float_is_training"] def _call(self, data, is_training): self.data = data inp = data["image"] self._tensors = AttrDict( inp=inp, is_training=is_training, float_is_training=tf.to_float(is_training), batch_size=tf.shape(inp)[0], ) if "annotations" in data: self._tensors.update( annotations=data["annotations"]["data"], n_annotations=data["annotations"]["shapes"][:, 1], n_valid_annotations=tf.to_int32( tf.reduce_sum( data["annotations"]["data"][:, :, :, 0] * tf.to_float(data["annotations"]["mask"][:, :, :, 0]), axis=2 ) ) ) if "label" in data: self._tensors.update( targets=data["label"], ) if "background" in data: self._tensors.update( ground_truth_background=data["background"], ) if "offset" in data: self._tensors.update( offset=data["offset"], ) max_n_frames = tf_shape(inp)[1] if self.stage_steps is None: self.current_stage = tf.constant(0, tf.int32) dynamic_n_frames = max_n_frames else: self.current_stage = tf.cast(tf.train.get_or_create_global_step(), tf.int32) // self.stage_steps dynamic_n_frames = tf.minimum( self.initial_n_frames + self.n_frames_scale * self.current_stage, max_n_frames) dynamic_n_frames = tf.cast(dynamic_n_frames, tf.float32) dynamic_n_frames = ( self.float_is_training * tf.cast(dynamic_n_frames, tf.float32) + (1-self.float_is_training) * tf.cast(max_n_frames, tf.float32) ) self.dynamic_n_frames = tf.cast(dynamic_n_frames, tf.int32) self._tensors.current_stage = self.current_stage self._tensors.dynamic_n_frames = self.dynamic_n_frames self._tensors.inp = self._tensors.inp[:, :self.dynamic_n_frames] if 'annotations' in self._tensors: self._tensors.annotations = self._tensors.annotations[:, :self.dynamic_n_frames] # self._tensors.n_annotations = self._tensors.n_annotations[:, :self.dynamic_n_frames] self._tensors.n_valid_annotations = self._tensors.n_valid_annotations[:, :self.dynamic_n_frames] self.record_tensors( batch_size=tf.to_float(self.batch_size), float_is_training=self.float_is_training, current_stage=self.current_stage, dynamic_n_frames=self.dynamic_n_frames, ) self.losses = dict() with tf.variable_scope("representation", reuse=self.initialized): if self.needs_background: self.build_background() self.build_representation() return dict( tensors=self._tensors, recorded_tensors=self.recorded_tensors, losses=self.losses, ) def build_background(self): if cfg.background_cfg.mode == "colour": rgb = np.array(to_rgb(cfg.background_cfg.colour))[None, None, None, :] background = rgb * tf.ones_like(self.inp) elif cfg.background_cfg.mode == "learn_solid": # Learn a solid colour for the background self.solid_background_logits = tf.get_variable("solid_background", initializer=[0.0, 0.0, 0.0]) if "background" in self.fixed_weights: tf.add_to_collection(FIXED_COLLECTION, self.solid_background_logits) solid_background = tf.nn.sigmoid(10 * self.solid_background_logits) background = solid_background[None, None, None, :] * tf.ones_like(self.inp) elif cfg.background_cfg.mode == "scalor": pass elif cfg.background_cfg.mode == "learn": self.maybe_build_subnet("background_encoder") self.maybe_build_subnet("background_decoder") # Here I'm just encoding the first frame... bg_attr = self.background_encoder(self.inp[:, 0], 2 * cfg.background_cfg.A, self.is_training) bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1) bg_attr_std = tf.exp(bg_attr_log_std) if not self.noisy: bg_attr_std = tf.zeros_like(bg_attr_std) bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std) self._tensors.update( bg_attr_mean=bg_attr_mean, bg_attr_std=bg_attr_std, bg_attr_kl=bg_attr_kl, bg_attr=bg_attr) # --- decode --- _, T, H, W, _ = tf_shape(self.inp) background = self.background_decoder(bg_attr, 3, self.is_training) assert len(background.shape) == 2 and background.shape[1] == 3 background = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10)) background = tf.tile(background[:, None, None, None, :], (1, T, H, W, 1)) elif cfg.background_cfg.mode == "learn_and_transform": self.maybe_build_subnet("background_encoder") self.maybe_build_subnet("background_decoder") # --- encode --- n_transform_latents = 4 n_latents = (2 * cfg.background_cfg.A, 2 * n_transform_latents) bg_attr, bg_transform_params = self.background_encoder(self.inp, n_latents, self.is_training) # --- bg attributes --- bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1) bg_attr_std = self.std_nonlinearity(bg_attr_log_std) bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std) # --- bg location --- bg_transform_params = tf.reshape( bg_transform_params, (self.batch_size, self.dynamic_n_frames, 2*n_transform_latents)) mean, log_std = tf.split(bg_transform_params, 2, axis=2) std = self.std_nonlinearity(log_std) logits, kl = normal_vae(mean, std, 0.0, 1.0) # integrate across timesteps logits = tf.cumsum(logits, axis=1) logits = tf.reshape(logits, (self.batch_size*self.dynamic_n_frames, n_transform_latents)) y, x, h, w = tf.split(logits, n_transform_latents, axis=1) h = (0.9 - 0.5) * tf.nn.sigmoid(h) + 0.5 w = (0.9 - 0.5) * tf.nn.sigmoid(w) + 0.5 y = (1 - h) * tf.nn.tanh(y) x = (1 - w) * tf.nn.tanh(x) # --- decode --- background = self.background_decoder(bg_attr, self.image_depth, self.is_training) bg_shape = cfg.background_cfg.bg_shape background = background[:, :bg_shape[0], :bg_shape[1], :] assert background.shape[1:3] == bg_shape background_raw = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10)) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper( bg_shape, (self.image_height, self.image_width), transform_constraints) transforms = tf.concat([w, x, h, y], axis=-1) grid_coords = warper(transforms) grid_coords = tf.reshape( grid_coords, (self.batch_size, self.dynamic_n_frames, *tf_shape(grid_coords)[1:])) background = tf.contrib.resampler.resampler(background_raw, grid_coords) self._tensors.update( bg_attr_mean=bg_attr_mean, bg_attr_std=bg_attr_std, bg_attr_kl=bg_attr_kl, bg_attr=bg_attr, bg_y=tf.reshape(y, (self.batch_size, self.dynamic_n_frames, 1)), bg_x=tf.reshape(x, (self.batch_size, self.dynamic_n_frames, 1)), bg_h=tf.reshape(h, (self.batch_size, self.dynamic_n_frames, 1)), bg_w=tf.reshape(w, (self.batch_size, self.dynamic_n_frames, 1)), bg_transform_kl=kl, bg_raw=background_raw, ) elif cfg.background_cfg.mode == "data": background = self._tensors["ground_truth_background"] else: raise Exception("Unrecognized background mode: {}.".format(cfg.background_cfg.mode)) self._tensors["background"] = background[:, :self.dynamic_n_frames]