def get_scheduled_sample_inputs(self, done_warm_start, groundtruth_items, generated_items, scheduled_sampling_func): """Scheduled sampling. Args: done_warm_start: whether we are done with warm start or not. groundtruth_items: list of ground truth items. generated_items: list of generated items. scheduled_sampling_func: scheduled sampling function to choose between groundtruth items and generated items. Returns: A mix list of ground truth and generated items. """ def sample(): """Calculate the scheduled sampling params based on iteration number.""" with tf.variable_scope("scheduled_sampling", reuse=tf.AUTO_REUSE): return [ scheduled_sampling_func(item_gt, item_gen) for item_gt, item_gen in zip(groundtruth_items, generated_items) ] cases = [ (tf.logical_not(done_warm_start), lambda: groundtruth_items), (tf.logical_not(self.is_training), lambda: generated_items), ] output_items = tf.case(cases, default=sample, strict=True) return output_items
def beta_schedule(schedule, global_step, final_beta, decay_start, decay_end): """Get KL multiplier (beta) based on the schedule.""" if decay_start > decay_end: raise ValueError("decay_end is smaller than decay_end.") # Since some of the TF schedules do not support incrementing a value, # in all of the schedules, we anneal the beta from final_beta to zero # and then reverse it at the bottom. if schedule == "constant": decayed_value = 0.0 elif schedule == "linear": decayed_value = tf.train.polynomial_decay( learning_rate=final_beta, global_step=global_step - decay_start, decay_steps=decay_end - decay_start, end_learning_rate=0.0) elif schedule == "noisy_linear_cosine_decay": decayed_value = tf.train.noisy_linear_cosine_decay( learning_rate=final_beta, global_step=global_step - decay_start, decay_steps=decay_end - decay_start) # TODO(mechcoder): Add log_annealing schedule. else: raise ValueError("Unknown beta schedule.") increased_value = final_beta - decayed_value increased_value = tf.maximum(0.0, increased_value) beta = tf.case(pred_fn_pairs=[(tf.less(global_step, decay_start), lambda: 0.0), (tf.greater(global_step, decay_end), lambda: final_beta)], default=lambda: increased_value) return beta
def blend(image1, image2, factor): """Blend image1 and image2 using 'factor'. Factor can be above 0.0. A value of 0.0 means only image1 is used. A value of 1.0 means only image2 is used. A value between 0.0 and 1.0 means we linearly interpolate the pixel values between the two images. A value greater than 1.0 "extrapolates" the difference between the two pixel values, and we clip the results to values between 0 and 255. Args: image1: An image Tensor of type uint8. image2: An image Tensor of type uint8. factor: A floating point value above 0.0. Returns: A blended image Tensor of type uint8. """ def _blend(): image_1 = tf.image.convert_image_dtype(image1, tf.float32) image_2 = tf.image.convert_image_dtype(image2, tf.float32) output = image_1 + factor * (image_2 - image_1) output = tf.where_v2( tf.logical_and(tf.less(0., factor), tf.less(factor, 1.)), x=output, y=tf.clip_by_value(output, 0., 255.)) return tf.image.convert_image_dtype(output, tf.uint8) pred_fn_pairs = [ (tf.equal(factor, 0.), lambda: image1), (tf.equal(factor, 1.), lambda: image2), ] return tf.case( pred_fn_pairs, default=_blend, exclusive=True, strict=True, name='blend')
def _drop_channels(data): image = data["image"] def _drop(keep_i): shape = image.get_shape().as_list() size, num_channels = shape[:-1], shape[-1] return tf.concat([ image[:, :, i:i + 1] if i == keep_i else tf.random_uniform( size + [1], noise_min, noise_max) for i in range(num_channels) ], axis=2) def _drop_random_channel(coin_channel): return tf.case({ tf.equal(coin_channel, 0): lambda: _drop(0), tf.equal(coin_channel, 1): lambda: _drop(1), tf.equal(coin_channel, 2): lambda: _drop(2), }) coin_keep_original = tf.random.uniform([], 0.0, 1.0, dtype=tf.float32) coin_channel = tf.random.uniform([], 0, 3, dtype=tf.int32) image = tf.case({ tf.less(coin_keep_original, keep_original): lambda: image, tf.greater_equal(coin_keep_original, keep_original): lambda: _drop_random_channel(coin_channel) }) data["image"] = image return data
def _produce_posterior_estimate(posterior_dist, posterior_estimate_mode, raw_var_name): """Create tensor representing estimate of posterior. Args: posterior_dist: An instance of `tfp.distributions.Distribution`. The variational posterior from which to produce an estimate of the variable in question. posterior_estimate_mode: A `Tensor` of dtype `tf.string`, which determines the inference mode. raw_var_name: The name of the variable over which inference is done. Returns: `z_sample`, a `Tensor` representing an estimate derived from the posterior distribution. """ conds = [ tf.equal(posterior_estimate_mode, tf.constant(EstimatorModes.sample), name="equal_sample_mode"), tf.equal(posterior_estimate_mode, tf.constant(EstimatorModes.mean), name="equal_mean_mode"), tf.equal(posterior_estimate_mode, tf.constant(EstimatorModes.last_sample), name="equal_last_sample_mode"), ] # pylint: disable=unnecessary-lambda results = [ lambda: posterior_dist.sample(), lambda: posterior_dist.mean(), lambda: posterior_dist.last_sample() ] def default_case_branch_raising_error(): err_msg = "Invalid posterior estimate mode." raise_err = tf.Assert(tf.constant(False), data=[tf.constant(err_msg)]) with tf.control_dependencies([raise_err]): return posterior_dist.mean() if hasattr(posterior_dist, "last_sample"): cases = [(conds[0], results[0]), (conds[1], results[1]), (conds[2], results[2])] else: cases = [(conds[0], results[0]), (conds[1], results[1])] z_sample = tf.case(cases, exclusive=True, default=default_case_branch_raising_error, name="{}_posterior_estimate".format(raw_var_name)) # pylint: enable=unnecessary-lambda return z_sample
def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, cov_ema_decay=0.0) # Construct an op that executes one covariance update per step. global_step = tf.train.get_or_create_global_step() (cov_variable_thunks, cov_update_op_thunks, _, _) = fisher_estimator.create_ops_and_vars_thunks() for thunk in cov_variable_thunks: thunk() cov_matrices = [ fisher_factor.cov for fisher_factor in self.layer_collection.get_factors() ] cov_update_op = tf.case([(tf.equal(global_step, i), thunk) for i, thunk in enumerate(cov_update_op_thunks)]) increment_global_step = global_step.assign_add(1) sess.run(tf.global_variables_initializer()) initial_cov_values = sess.run(cov_matrices) # Ensure there's one update per covariance matrix. self.assertEqual(len(cov_matrices), len(cov_update_op_thunks)) # Test is no-op if only 1 covariance matrix. assert len(cov_matrices) > 1 for i in range(len(cov_matrices)): # Compare new and old covariance values new_cov_values = sess.run(cov_matrices) is_cov_equal = [ np.allclose(initial_cov_value, new_cov_value) for (initial_cov_value, new_cov_value) in zip(initial_cov_values, new_cov_values) ] num_cov_equal = sum(is_cov_equal) # Ensure exactly one covariance matrix changes per step. self.assertEqual(num_cov_equal, len(cov_matrices) - i) # Run all covariance update ops. sess.run(cov_update_op) sess.run(increment_global_step)
def categorical_case(pmf, fns, rand=None): """Returns the outputs of fns[i] with probability pmf[i]. Args: pmf: A 1-D tensor of probabilities, the probability mass function. fns: A list of callables that return tensors, same length as pmf. rand: An optional scalar between 0.0 and 1.0, the output of an RNG. Returns: A tensor, the output of fns[i] with probability pmf[i]. """ rand = tf.random_uniform([]) if rand is None else rand cmf = tf.pad(tf.cumsum(pmf), [(1, 0)]) cmf = [cmf[i] for i in range(len(fns) + 1)] preds = [(rand >= a) & (rand < b) for a, b in zip(cmf[:-1], cmf[1:])] return tf.case(list(zip(preds, fns)), exclusive=True)
def apply_with_random_selector(x, func, num_cases): """Computes func(x, sel), with sel sampled from [0...num_cases-1]. Args: x: input Tensor. func: Python function to apply. num_cases: Python int32, number of cases to sample sel from. Returns: The result of func(x, sel), where func receives the value of the selector as a python integer, but sel is sampled dynamically. """ sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32) # Pass the real x only to one of the func calls. pairs = [] for i in range(num_cases): def _apply(i_value=i): return func(x, i_value) pairs.append((tf.equal(sel, i), _apply)) return tf.case(pairs)
def _resize_pp(data): im = data["image"] if randomize_resize_method: # pick random resizing method r = tf.random_uniform([], 0, 3, dtype=tf.int32) im = tf.case({ tf.equal(r, tf.cast(0, r.dtype)): _resize(im, tf.image.ResizeMethod.BILINEAR, True), tf.equal(r, tf.cast(1, r.dtype)): _resize(im, tf.image.ResizeMethod.NEAREST_NEIGHBOR, True), tf.equal(r, tf.cast(2, r.dtype)): _resize(im, tf.image.ResizeMethod.BICUBIC, True), # NOTE: use align_corners=False for AREA resize, but True for the # others. See https://github.com/tensorflow/tensorflow/issues/6720 tf.equal(r, tf.cast(3, r.dtype)): _resize(im, tf.image.ResizeMethod.AREA, False), }) else: im = tf.image.resize_images(im, im_size) data["image"] = im return data
def input_tensors_to_model_input(input_tensors, hparams, is_training, num_classes=constants.MIDI_PITCHES): """Processes an InputTensor into FeatureTensors and LabelTensors.""" length = tf.cast(input_tensors.length, tf.int32) labels = tf.reshape(input_tensors.labels, (-1, num_classes)) label_weights = tf.reshape(input_tensors.label_weights, (-1, num_classes)) onsets = tf.reshape(input_tensors.onsets, (-1, num_classes)) offsets = tf.reshape(input_tensors.offsets, (-1, num_classes)) velocities = tf.reshape(input_tensors.velocities, (-1, num_classes)) spec = tf.reshape(input_tensors.spec, (-1, hparams_frame_size(hparams))) # Slice specs and labels tensors so they are no longer than truncated_length. hparams_truncated_length = tf.cast( hparams.truncated_length_secs * hparams_frames_per_second(hparams), tf.int32) if hparams.truncated_length_secs: truncated_length = tf.reduce_min([hparams_truncated_length, length]) else: truncated_length = length if is_training: truncated_note_sequence = tf.constant(0) else: truncated_note_sequence = truncate_note_sequence_op( input_tensors.note_sequence, truncated_length, hparams) # If max_expected_train_example_len is set, ensure that all examples are # padded to this length. This results in a fixed shape that can work on TPUs. if hparams.max_expected_train_example_len and is_training: # In this case, final_length is a constant. if hparams.truncated_length_secs: assert_op = tf.assert_equal(hparams.max_expected_train_example_len, hparams_truncated_length) with tf.control_dependencies([assert_op]): final_length = hparams.max_expected_train_example_len else: final_length = hparams.max_expected_train_example_len else: # In this case, it is min(hparams.truncated_length, length) final_length = truncated_length spec_delta = tf.shape(spec)[0] - final_length spec = tf.case([(spec_delta < 0, lambda: tf.pad(spec, tf.stack([(0, -spec_delta), (0, 0)]))), (spec_delta > 0, lambda: spec[0:-spec_delta])], default=lambda: spec) labels_delta = tf.shape(labels)[0] - final_length labels = tf.case( [(labels_delta < 0, lambda: tf.pad(labels, tf.stack([(0, -labels_delta), (0, 0)]))), (labels_delta > 0, lambda: labels[0:-labels_delta])], default=lambda: labels) label_weights = tf.case( [(labels_delta < 0, lambda: tf.pad(label_weights, tf.stack([(0, -labels_delta), (0, 0)]))), (labels_delta > 0, lambda: label_weights[0:-labels_delta])], default=lambda: label_weights) onsets = tf.case( [(labels_delta < 0, lambda: tf.pad(onsets, tf.stack([(0, -labels_delta), (0, 0)]))), (labels_delta > 0, lambda: onsets[0:-labels_delta])], default=lambda: onsets) offsets = tf.case( [(labels_delta < 0, lambda: tf.pad(offsets, tf.stack([(0, -labels_delta), (0, 0)]))), (labels_delta > 0, lambda: offsets[0:-labels_delta])], default=lambda: offsets) velocities = tf.case( [(labels_delta < 0, lambda: tf.pad(velocities, tf.stack([(0, -labels_delta), (0, 0)]))), (labels_delta > 0, lambda: velocities[0:-labels_delta])], default=lambda: velocities) features = FeatureTensors(spec=tf.reshape( spec, (final_length, hparams_frame_size(hparams), 1)), length=truncated_length, sequence_id=tf.constant(0) if is_training else input_tensors.sequence_id) labels = LabelTensors( labels=tf.reshape(labels, (final_length, num_classes)), label_weights=tf.reshape(label_weights, (final_length, num_classes)), onsets=tf.reshape(onsets, (final_length, num_classes)), offsets=tf.reshape(offsets, (final_length, num_classes)), velocities=tf.reshape(velocities, (final_length, num_classes)), note_sequence=truncated_note_sequence) return features, labels
def parser(value): """Parse an Imagenet record from value.""" keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 'image/class/label': tf.FixedLenFeature([], dtype=tf.int64, default_value=-1), 'image/class/text': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), 'image/object/class/label': tf.VarLenFeature(dtype=tf.int64), } parsed = tf.parse_single_example(value, keys_to_features) encoded_image = tf.reshape(parsed['image/encoded'], shape=[], name='encoded_image') image_format = parsed['image/format'] xmin = tf.expand_dims(parsed['image/object/bbox/xmin'].values, 0) ymin = tf.expand_dims(parsed['image/object/bbox/ymin'].values, 0) xmax = tf.expand_dims(parsed['image/object/bbox/xmax'].values, 0) ymax = tf.expand_dims(parsed['image/object/bbox/ymax'].values, 0) # Note that we impose an ordering of (y, x) just to make life difficult. bbox = tf.concat([ymin, xmin, ymax, xmax], 0) # Force the variable number of bounding boxes into the shape # [1, num_boxes, coords]. bbox = tf.expand_dims(bbox, 0) bbox = tf.transpose(bbox, [0, 2, 1]) def decode_png(): return tf.image.decode_png(encoded_image, 3) def decode_jpg(): return tf.image.decode_jpeg(encoded_image, 3) # If image format is PNG, use decode_png, default to jpg. pred_fn_pairs = { tf.logical_or(tf.equal(image_format, 'png'), tf.equal(image_format, 'PNG')): decode_png } image = tf.case(pred_fn_pairs, default=decode_jpg, exclusive=True) image.set_shape([None, None, 3]) image = preprocess(image, bbox) label = tf.cast(tf.reshape(parsed['image/class/label'], shape=[]), dtype=tf.int32, name='cast_label') label = tf.reshape(label, [1]) return tf.cast(image, tf.float32), label
def input_tensors_to_model_input(input_tensors, hparams, is_training, num_classes=constants.MIDI_PITCHES): """Processes an InputTensor into FeatureTensors and LabelTensors.""" length = tf.cast(input_tensors.length, tf.int32) labels = tf.reshape(input_tensors.labels, (-1, num_classes)) label_weights = tf.reshape(input_tensors.label_weights, (-1, num_classes)) onsets = tf.reshape(input_tensors.onsets, (-1, num_classes)) offsets = tf.reshape(input_tensors.offsets, (-1, num_classes)) velocities = tf.reshape(input_tensors.velocities, (-1, num_classes)) spec = tf.reshape(input_tensors.spec, (-1, hparams_frame_size(hparams))) # Slice specs and labels tensors so they are no longer than truncated_length. hparams_truncated_length = tf.cast( hparams.truncated_length_secs * hparams_frames_per_second(hparams), tf.int32) if hparams.truncated_length_secs: truncated_length = tf.reduce_min([hparams_truncated_length, length]) else: truncated_length = length if is_training: truncated_note_sequence = tf.constant(0) else: truncated_note_sequence = truncate_note_sequence_op( input_tensors.note_sequence, truncated_length, hparams) # If max_expected_train_example_len is set, ensure that all examples are # padded to this length. This results in a fixed shape that can work on TPUs. if hparams.max_expected_train_example_len and is_training: # In this case, final_length is a constant. if hparams.truncated_length_secs: assert_op = tf.assert_equal(hparams.max_expected_train_example_len, hparams_truncated_length) with tf.control_dependencies([assert_op]): final_length = hparams.max_expected_train_example_len else: final_length = hparams.max_expected_train_example_len else: # In this case, it is min(hparams.truncated_length, length) final_length = truncated_length spec_delta = tf.shape(spec)[0] - final_length spec = tf.case([(spec_delta < 0, lambda: tf.pad(spec, tf.stack([(0, -spec_delta), (0, 0)]))), (spec_delta > 0, lambda: spec[0:-spec_delta])], default=lambda: spec) labels_delta = tf.shape(labels)[0] - final_length labels = tf.case( [(labels_delta < 0, lambda: tf.pad(labels, tf.stack([(0, -labels_delta), (0, 0)]))), (labels_delta > 0, lambda: labels[0:-labels_delta])], default=lambda: labels) label_weights = tf.case( [(labels_delta < 0, lambda: tf.pad(label_weights, tf.stack([(0, -labels_delta), (0, 0)]))), (labels_delta > 0, lambda: label_weights[0:-labels_delta])], default=lambda: label_weights) onsets = tf.case( [(labels_delta < 0, lambda: tf.pad(onsets, tf.stack([(0, -labels_delta), (0, 0)]))), (labels_delta > 0, lambda: onsets[0:-labels_delta])], default=lambda: onsets) offsets = tf.case( [(labels_delta < 0, lambda: tf.pad(offsets, tf.stack([(0, -labels_delta), (0, 0)]))), (labels_delta > 0, lambda: offsets[0:-labels_delta])], default=lambda: offsets) velocities = tf.case( [(labels_delta < 0, lambda: tf.pad(velocities, tf.stack([(0, -labels_delta), (0, 0)]))), (labels_delta > 0, lambda: velocities[0:-labels_delta])], default=lambda: velocities) features = FeatureTensors(spec=tf.reshape( spec, (final_length, hparams_frame_size(hparams), 1)), length=truncated_length, sequence_id=tf.constant(0) if is_training else input_tensors.sequence_id) labels = LabelTensors( labels=tf.reshape(labels, (final_length, num_classes)), label_weights=tf.reshape(label_weights, (final_length, num_classes)), onsets=tf.reshape(onsets, (final_length, num_classes)), offsets=tf.reshape(offsets, (final_length, num_classes)), velocities=tf.reshape(velocities, (final_length, num_classes)), note_sequence=truncated_note_sequence) if hparams.drum_data_map: labels_dict = labels._asdict() for k in ('labels', 'onsets', 'offsets'): labels_dict[k] = drum_mappings.map_pianoroll( labels_dict[k], mapping_name=hparams.drum_data_map, reduce_mode='any', min_pitch=constants.MIN_MIDI_PITCH) for k in ('label_weights', 'velocities'): labels_dict[k] = drum_mappings.map_pianoroll( labels_dict[k], mapping_name=hparams.drum_data_map, reduce_mode='max', min_pitch=constants.MIN_MIDI_PITCH) if labels_dict['note_sequence'].dtype == tf.string: labels_dict['note_sequence'] = tf.py_func( functools.partial(drum_mappings.map_sequences, mapping_name=hparams.drum_data_map), [labels_dict['note_sequence']], tf.string, name='get_drum_sequences', stateful=False) labels_dict['note_sequence'].set_shape(()) labels = LabelTensors(**labels_dict) return features, labels
def _enas_layer(self, layer_id, prev_layers, start_idx, out_filters, is_training): """ Args: layer_id: current layer prev_layers: cache of previous layers. for skip connections start_idx: where to start looking at. technically, we can infer this from layer_id, but why bother... is_training: for batch_norm """ inputs = prev_layers[-1] if self.whole_channels: if self.data_format == "NHWC": inp_h = inputs.get_shape()[1].value inp_w = inputs.get_shape()[2].value inp_c = inputs.get_shape()[3].value elif self.data_format == "NCHW": inp_c = inputs.get_shape()[1].value inp_h = inputs.get_shape()[2].value inp_w = inputs.get_shape()[3].value count = self.sample_arc[start_idx] branches = {} with tf.variable_scope("branch_0"): y = self._conv_branch(inputs, 3, is_training, out_filters, out_filters, start_idx=0) branches[tf.equal(count, 0)] = lambda: y with tf.variable_scope("branch_1"): y = self._conv_branch(inputs, 3, is_training, out_filters, out_filters, start_idx=0, separable=True) branches[tf.equal(count, 1)] = lambda: y with tf.variable_scope("branch_2"): y = self._conv_branch(inputs, 5, is_training, out_filters, out_filters, start_idx=0) branches[tf.equal(count, 2)] = lambda: y with tf.variable_scope("branch_3"): y = self._conv_branch(inputs, 5, is_training, out_filters, out_filters, start_idx=0, separable=True) branches[tf.equal(count, 3)] = lambda: y if self.num_branches >= 5: with tf.variable_scope("branch_4"): y = self._pool_branch(inputs, is_training, out_filters, "avg", start_idx=0) branches[tf.equal(count, 4)] = lambda: y if self.num_branches >= 6: with tf.variable_scope("branch_5"): y = self._pool_branch(inputs, is_training, out_filters, "max", start_idx=0) branches[tf.equal(count, 5)] = lambda: y #out = tf.case(branches, default=lambda: tf.constant(0, tf.float32), # exclusive=True) out = tf.case( branches, default=lambda: tf.constant( 0, tf.float32, shape=[self.batch_size, out_filters, inp_h, inp_w]), exclusive=True) if self.data_format == "NHWC": out.set_shape([None, inp_h, inp_w, out_filters]) elif self.data_format == "NCHW": out.set_shape([None, out_filters, inp_h, inp_w]) else: count = self.sample_arc[start_idx:start_idx + 2 * self.num_branches] branches = [] with tf.variable_scope("branch_0"): branches.append( self._conv_branch(inputs, 3, is_training, count[1], out_filters, start_idx=count[0])) with tf.variable_scope("branch_1"): branches.append( self._conv_branch(inputs, 3, is_training, count[3], out_filters, start_idx=count[2], separable=True)) with tf.variable_scope("branch_2"): branches.append( self._conv_branch(inputs, 5, is_training, count[5], out_filters, start_idx=count[4])) with tf.variable_scope("branch_3"): branches.append( self._conv_branch(inputs, 5, is_training, count[7], out_filters, start_idx=count[6], separable=True)) if self.num_branches >= 5: with tf.variable_scope("branch_4"): branches.append( self._pool_branch(inputs, is_training, count[9], "avg", start_idx=count[8])) if self.num_branches >= 6: with tf.variable_scope("branch_5"): branches.append( self._pool_branch(inputs, is_training, count[11], "max", start_idx=count[10])) with tf.variable_scope("final_conv"): w = create_weight( "w", [self.num_branches * out_filters, out_filters]) w_mask = tf.constant( [False] * (self.num_branches * out_filters), tf.bool) new_range = tf.range(0, self.num_branches * out_filters, dtype=tf.int32) for i in range(self.num_branches): start = out_filters * i + count[2 * i] new_mask = tf.logical_and( start <= new_range, new_range < start + count[2 * i + 1]) w_mask = tf.logical_or(w_mask, new_mask) w = tf.boolean_mask(w, w_mask) w = tf.reshape(w, [1, 1, -1, out_filters]) inp = prev_layers[-1] if self.data_format == "NHWC": branches = tf.concat(branches, axis=3) elif self.data_format == "NCHW": branches = tf.concat(branches, axis=1) N = tf.shape(inp)[0] H = inp.get_shape()[2].value W = inp.get_shape()[3].value branches = tf.reshape(branches, [N, -1, H, W]) out = tf.nn.conv2d(branches, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) out = batch_norm(out, is_training, data_format=self.data_format) out = tf.nn.relu(out) if layer_id > 0: if self.whole_channels: skip_start = start_idx + 1 else: skip_start = start_idx + 2 * self.num_branches skip = self.sample_arc[skip_start:skip_start + layer_id] with tf.variable_scope("skip"): res_layers = [] for i in range(layer_id): res_layers.append( tf.cond(tf.equal(skip[i], 1), lambda: prev_layers[i], lambda: tf.zeros_like(prev_layers[i]))) res_layers.append(out) out = tf.add_n(res_layers) out = batch_norm(out, is_training, data_format=self.data_format) return out
def _drop_random_channel(coin_channel): return tf.case({ tf.equal(coin_channel, 0): lambda: _drop(0), tf.equal(coin_channel, 1): lambda: _drop(1), tf.equal(coin_channel, 2): lambda: _drop(2), })
def test_inv_update_thunks(self): """Ensures inverse update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, cov_ema_decay=0.0) # Construct op that updates one inverse per global step. global_step = tf.train.get_or_create_global_step() (cov_variable_thunks, _, inv_variable_thunks, inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks() for thunk in cov_variable_thunks: thunk() for thunk in inv_variable_thunks: thunk() inv_matrices = [ matrix for fisher_factor in self.layer_collection.get_factors() for matrix in fisher_factor._matpower_by_exp_and_damping.values() ] inv_update_op = tf.case([(tf.equal(global_step, i), thunk) for i, thunk in enumerate(inv_update_op_thunks)]) increment_global_step = global_step.assign_add(1) sess.run(tf.global_variables_initializer()) initial_inv_values = sess.run(inv_matrices) # Ensure there's one update per inverse matrix. This is true as long as # there's no fan-in/fan-out or parameter re-use. self.assertEqual(len(inv_matrices), len(inv_update_op_thunks)) # Test is no-op if only 1 invariance matrix. assert len(inv_matrices) > 1 # Assign each covariance matrix a value other than the identity. This # ensures that the inverse matrices are updated to something different as # well. sess.run([ fisher_factor._cov.add_to_average( 2 * tf.eye(int(fisher_factor._cov_shape[0]))) for fisher_factor in self.layer_collection.get_factors() ]) for i in range(len(inv_matrices)): # Compare new and old inverse values new_inv_values = sess.run(inv_matrices) is_inv_equal = [ np.allclose(initial_inv_value, new_inv_value) for (initial_inv_value, new_inv_value) in zip(initial_inv_values, new_inv_values) ] num_inv_equal = sum(is_inv_equal) # Ensure exactly one inverse matrix changes per step. self.assertEqual(num_inv_equal, len(inv_matrices) - i) # Run all inverse update ops. sess.run(inv_update_op) sess.run(increment_global_step)