Beispiel #1
0
def shake_shake_block(x, conv_filters, stride, hparams):
  with tf.variable_scope('branch_1'):
    branch1 = shake_shake_block_branch(x, conv_filters, stride)
  with tf.variable_scope('branch_2'):
    branch2 = shake_shake_block_branch(x, conv_filters, stride)
  if x.shape[-1] == conv_filters:
    skip = tf.identity(x)
  else:
    skip = downsampling_residual_branch(x, conv_filters)

  # TODO(rshin): Use different alpha for each image in batch.
  if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN:
    if hparams.shakeshake_type == 'batch':
      shaken = common_layers.shakeshake2(branch1, branch2)
    elif hparams.shakeshake_type == 'image':
      shaken = common_layers.shakeshake2_indiv(branch1, branch2)
    elif hparams.shakeshake_type == 'equal':
      shaken = common_layers.shakeshake2_py(branch1, branch2, equal=True)
    else:
      raise ValueError('Invalid shakeshake_type: {!r}'.format(shaken))
  else:
    shaken = common_layers.shakeshake2_py(branch1, branch2, equal=True)
  shaken.set_shape(branch1.get_shape())

  return skip + shaken
Beispiel #2
0
def shakeshake_binary_module(x, y, hparams):
    del hparams  # Unused.
    return common_layers.shakeshake2(x, y)