Ejemplo n.º 1
0
def initialize():
    tf.compat.v1.disable_eager_execution()

    options.initialize_with_logfiles(get_parser())
    logging.info(f'-- Starting --')
    logging.info(f'Host: {socket.gethostname()}')
    logging.info(f'Process id (pid): {os.getpid()}')

    if FLAGS.comment:
        logging.info(f'Comment: {FLAGS.comment}')
    logging.info(f'Raw command: {" ".join(map(shlex.quote, sys.argv))}')
    logging.info(f'Parsed flags: {FLAGS}')
    tfu.set_data_format(FLAGS.data_format)
    tfu.set_dtype(tf.float32 if FLAGS.dtype == 'float32' else tf.float16)

    if FLAGS.batch_size_test is None:
        FLAGS.batch_size_test = FLAGS.batch_size

    if FLAGS.checkpoint_dir is None:
        FLAGS.checkpoint_dir = FLAGS.logdir

    FLAGS.checkpoint_dir = util.ensure_absolute_path(
        FLAGS.checkpoint_dir, root=f'{paths.DATA_ROOT}/experiments')
    os.makedirs(FLAGS.checkpoint_dir, exist_ok=True)

    if not FLAGS.pred_path:
        FLAGS.pred_path = f'predictions_{FLAGS.dataset}.npz'
    base = os.path.dirname(
        FLAGS.load_path) if FLAGS.load_path else FLAGS.checkpoint_dir
    FLAGS.pred_path = util.ensure_absolute_path(FLAGS.pred_path, base)

    if FLAGS.bone_length_dataset is None:
        FLAGS.bone_length_dataset = FLAGS.dataset

    if FLAGS.load_path:
        if FLAGS.load_path.endswith('.index') or FLAGS.load_path.endswith(
                '.meta'):
            FLAGS.load_path = os.path.splitext(FLAGS.load_path)[0]
        FLAGS.load_path = util.ensure_absolute_path(FLAGS.load_path,
                                                    FLAGS.checkpoint_dir)

    # Override the default data format in slim layers
    enter_context(
        slim.arg_scope([
            slim.conv2d, slim.conv3d, slim.conv3d_transpose,
            slim.conv2d_transpose, slim.avg_pool2d, slim.separable_conv2d,
            slim.max_pool2d, slim.batch_norm, slim.spatial_softmax
        ],
                       data_format=tfu.data_format()))

    # Override default paddings to SAME
    enter_context(
        slim.arg_scope([slim.avg_pool2d, slim.max_pool2d], padding='SAME'))
    tf.compat.v2.random.set_seed(FLAGS.seed)
    if FLAGS.gui:
        plt.switch_backend('TkAgg')
Ejemplo n.º 2
0
def resnet_arg_scope():
    batch_norm_params = dict(decay=0.997,
                             epsilon=1e-5,
                             scale=True,
                             is_training=tfu.is_training(),
                             fused=True,
                             data_format=tfu.data_format())

    with slim.arg_scope(
        [slim.conv2d, slim.conv3d],
            weights_regularizer=slim.l2_regularizer(1e-4),
            weights_initializer=slim.variance_scaling_initializer(),
            activation_fn=tf.nn.relu,
            normalizer_fn=slim.batch_norm,
            normalizer_params=batch_norm_params):
        with slim.arg_scope([slim.batch_norm], **batch_norm_params):
            with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
                return arg_sc
Ejemplo n.º 3
0
def initialize():
    global FLAGS
    parse_and_set_global_flags()
    setup_logging()
    logging.info(f'-- Starting --')
    logging.info(f'Host: {socket.gethostname()}')
    logging.info(f'Process id (pid): {os.getpid()}')

    if FLAGS.comment:
        logging.info(f'Comment: {FLAGS.comment}')
    logging.info(f'Raw command: {" ".join(map(shlex.quote, sys.argv))}')
    logging.info(f'Parsed flags: {FLAGS}')
    tfu.set_data_format(FLAGS.data_format)

    if FLAGS.dtype == 'float32':
        tfu.set_dtype(tf.float32)
    elif FLAGS.dtype == 'float16':
        tfu.set_dtype(tf.float16)
    else:
        raise ValueError(f'Training dtype {FLAGS.dtype} not supported, only float16/32.')

    # We parallelize on a coarser level already, openmp just makes things slower

    # Override the default data format in slim layers
    enter_context(slim.arg_scope(
        [slim.conv2d, slim.conv3d, slim.conv3d_transpose, slim.conv2d_transpose, slim.avg_pool2d,
         slim.separable_conv2d, slim.max_pool2d, slim.batch_norm, slim.spatial_softmax],
        data_format=tfu.data_format()))

    # Override default paddings to SAME
    enter_context(slim.arg_scope([slim.avg_pool2d, slim.max_pool2d], padding='SAME'))

    if FLAGS.gui:
        plt.switch_backend('TkAgg')

    tf.set_random_seed(FLAGS.seed)
Ejemplo n.º 4
0
def conv2d_same(inputs,
                num_outputs,
                kernel_size,
                stride=1,
                rate=1,
                centered_stride=False,
                scope=None,
                **kwargs):
    """Strided 2-D convolution with 'SAME' padding.

    When stride > 1, then we do explicit zero-padding, followed by conv2d with
    'VALID' padding.

    Note that

       net = conv2d_same(inputs, num_outputs, 3, stride=stride)

    is equivalent to

       net = tf.contrib.layers.conv2d(inputs, num_outputs, 3, stride=1,
       padding='SAME')
       net = subsample(net, factor=stride)

    whereas

       net = tf.contrib.layers.conv2d(inputs, num_outputs, 3, stride=stride,
       padding='SAME')

    is different when the input's height or width is even, which is why we add the
    current function. For more details, see ResnetUtilsTest.testConv2DSameEven().

    Args:
      inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
      num_outputs: An integer, the number of output filters.
      kernel_size: An int with the kernel_size of the filters.
      stride: An integer, the output stride.
      rate: An integer, rate for atrous convolution.
      scope: Scope.

    Returns:
      output: A 4-D tensor of size [batch, height_out, width_out, channels] with
        the convolution output.
    """
    if stride == 1 or centered_stride:
        return layers_lib.conv2d(inputs,
                                 num_outputs,
                                 kernel_size,
                                 stride=stride,
                                 rate=rate,
                                 padding='SAME',
                                 scope=scope,
                                 **kwargs)
    else:
        kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
        pad_total = kernel_size_effective - 1
        pad_beg = pad_total // 2
        pad_end = pad_total - pad_beg
        if tfu.data_format() == 'NHWC':
            inputs = array_ops.pad(
                inputs,
                [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
        else:
            inputs = array_ops.pad(
                inputs,
                [[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]])
        return layers_lib.conv2d(inputs,
                                 num_outputs,
                                 kernel_size,
                                 stride=stride,
                                 rate=rate,
                                 padding='VALID',
                                 scope=scope,
                                 **kwargs)