Exemplo n.º 1
0
def route_block_pred(x, nexpert, route, expert):
    blocks = [x]
    for (c, f, p) in route:
        new_block = block(blocks[-1], c, f, p)
        blocks.append(new_block)

    flat   = tf.reshape(blocks[-1], [1, f])
    route  = tf.squeeze(dense_block(flat, [f, nexpert]))

    pi = tf.distributions.Categorical(logits=route)
    train_idx = tf.squeeze(pi.sample(1))
    test_idx = tf.cast(tf.argmax(route), dtype=tf.int32)

    experts = []
    for e in range(nexpert):
        expert_blocks = [x]
        for (c, f, p) in expert:
            new_block = block(expert_blocks[-1], c, f, p)
            expert_blocks.append(new_block)

        flat  = tf.reshape(expert_blocks[-1], [1, f])
        pred  = dense_block(flat, [f, 100])
        experts.append(pred)

    branch_fns = {}
    for e in range(nexpert):
        branch_fns[e] = lambda: experts[e]

    train_out = tf.switch_case(branch_index=train_idx, branch_fns=branch_fns)
    test_out = tf.switch_case(branch_index=test_idx, branch_fns=branch_fns)

    return train_out, test_out, train_idx, test_idx
Exemplo n.º 2
0
def switch_case(index, values, default=None, name='switch_case'):
    """Switch/Case over integer indices.

    Parameters
    ----------
    index : () tensor_like
        Index to switch.
    values : dict(index: callable)
        Pairs of index/function to execute.
    default : callable, optional
        Default function to call if all branches fail.
    name : str, default='switch_case'
        Name for the operation.

    Returns
    -------
    value : tensor or array

    """
    if tf.is_tensor(index):
        return tf.switch_case(index, values, default, name)
    else:
        for key, item in values.items():
            if index == key:
                return item()
        raise IndexError('{} is not a key from `values`.'.format(index))
Exemplo n.º 3
0
  def _apply_one_layer(self, image):
    """Applies one level of augmentation to the image."""
    level = self._get_level()
    branch_fns = []
    for augment_op_name in IMAGENET_AUG_OPS:
      augment_fn = augment_ops.NAME_TO_FUNC[augment_op_name]
      level_to_args_fn = LEVEL_TO_ARG[augment_op_name]
      def _branch_fn(image=image,
                     augment_fn=augment_fn,
                     level_to_args_fn=level_to_args_fn):
        args = [image] + list(level_to_args_fn(level))
        return augment_fn(*args)

      branch_fns.append(_branch_fn)

    branch_index = tf.random.uniform(
        shape=[], maxval=len(branch_fns), dtype=tf.int32)
    aug_image = tf.switch_case(branch_index, branch_fns, default=lambda: image)
    if self.prob_to_apply is not None:
      return tf.cond(
          tf.random.uniform(shape=[], dtype=tf.float32) < self.prob_to_apply,
          lambda: aug_image,
          lambda: image)
    else:
      return aug_image
Exemplo n.º 4
0
    def _apply_one_layer(self, data):
        """Applies one level of augmentation to the data."""
        level = self._get_level()
        branch_fns = []
        for augment_op_name in IMAGENET_AUG_OPS:
            augment_fn = NAME_TO_FUNC[augment_op_name]
            level_to_args_fn = LEVEL_TO_ARG[augment_op_name]

            def _branch_fn(data=data,
                           augment_fn=augment_fn,
                           level_to_args_fn=level_to_args_fn):

                args = [data] + list(level_to_args_fn(level))
                fuc_args = inspect.getfullargspec(augment_fn).args
                if 'replace' in fuc_args and 'replace' == fuc_args[-1]:
                    # Make sure replace is the final argument
                    args.append(self.replace)
                return augment_fn(*args)

            branch_fns.append(_branch_fn)

        branch_index = tf.random.uniform(shape=[],
                                         maxval=len(branch_fns),
                                         dtype=tf.int32)
        aug_data = tf.switch_case(branch_index,
                                  branch_fns,
                                  default=lambda: data)
        if self.prob_to_apply is not None:
            return tf.cond(
                tf.random.uniform(shape=[], dtype=tf.float32) <
                self.prob_to_apply, lambda: aug_data, lambda: data)
        else:
            return aug_data
Exemplo n.º 5
0
        def _rotate():
            """Rotation.

      These will be rotated:
        image,
        rbox,
        entity_id_mask,
      TODO(longshangbang): rotate vertices.

      Returns:
        The rotated tensors of the above fields.
      """
            k = tf.random.uniform([], 1, 4, dtype=tf.int32)
            h, w, _ = utilities.resolve_shape(data['image'])
            # Image
            rotated_img = tf.image.rot90(data['image'],
                                         k=k,
                                         name='image_rot90k')
            # Box
            rotate_box_op = functools.partial(utilities.rotate_rboxes90,
                                              rboxes=data['groundtruth_boxes'],
                                              image_width=w,
                                              image_height=h)
            rotated_boxes = tf.switch_case(
                k - 1,  # Indices start with 1.
                branch_fns=[
                    lambda: rotate_box_op(rotation_count=1),
                    lambda: rotate_box_op(rotation_count=2),
                    lambda: rotate_box_op(rotation_count=3)
                ])
            # Mask
            rotated_mask = tf.image.rot90(data['entity_id_mask'],
                                          k=k,
                                          name='mask_rot90k')
            return rotated_img, rotated_boxes, rotated_mask
Exemplo n.º 6
0
def random_resize_pad(images, height, width):
    methods = {
        0:
        lambda: tf.image.resize_with_pad(images,
                                         height,
                                         width,
                                         method=tf.image.ResizeMethod.
                                         NEAREST_NEIGHBOR),
        1:
        lambda: tf.image.resize_with_pad(
            images, height, width, method=tf.image.ResizeMethod.BICUBIC),
        2:
        lambda: tf.image.resize_with_pad(
            images, height, width, method=tf.image.ResizeMethod.AREA),
        3:
        lambda: tf.image.resize_with_pad(
            images, height, width, method=tf.image.ResizeMethod.LANCZOS3),
        4:
        lambda: tf.image.resize_with_pad(
            images, height, width, method=tf.image.ResizeMethod.LANCZOS5),
        5:
        lambda: tf.image.resize_with_pad(
            images, height, width, method=tf.image.ResizeMethod.MITCHELLCUBIC),
        6:
        lambda: tf.image.resize_with_pad(
            images, height, width, method=tf.image.ResizeMethod.GAUSSIAN),
    }
    return tf.switch_case(tf.cast(
        tf.random.uniform([], 0, 1.0) * len(methods), tf.int32),
                          branch_fns=methods)
Exemplo n.º 7
0
def application():
    Kyc()
    b = generate_cibil_score()
    upload_docs = upload_bank_statement(b)
    """
    650<cibil<750 :- 10 to 15k loan
    750<cibil<850 :- 25 to 40k
    850 <cibil :- 40k to 80 k 
    """
    tf.switch_case(b,
                   branch_fns={
                       0: case0(b, upload_docs),
                       1: case1(b, upload_docs),
                       2: case2(b, upload_docs),
                       3: loan_not_applicable(b)
                   })
Exemplo n.º 8
0
 def switch_case(self, branch_selector, branch_callables, name=None):
     """Implements a switch (branch_selector) { case ... } construct."""
     with tf.compat.v2.name_scope('VM.switch_case'):
         with _control_flow_v2():
             return tf.switch_case(branch_selector,
                                   branch_callables,
                                   name=name)
def apply_with_random_selector(x: tf.Tensor,
                               func: Callable[[tf.Tensor, tf.Tensor],
                                              tf.Tensor],
                               num_cases: int,
                               seed: tf.Tensor,
                               selected: Optional[int] = None) -> tf.Tensor:
    """Computes func(x, sel), with sel sampled from [0...num_cases-1].

  Args:
    x: input Tensor.
    func: Python function to apply.
    num_cases: Python int32, number of cases to sample sel from.
    seed: the random seed to use.
    selected: Python int32, optional value to use as the selected index.

  Returns:
    The result of func(x, sel), where func receives the value of the
    selector as a python integer, but sel is sampled dynamically.
  """
    if selected is None:
        selected = tf.random.stateless_uniform([],
                                               maxval=num_cases,
                                               dtype=tf.int32,
                                               seed=seed)
    branches = [lambda i=case: func(x, i) for case in range(num_cases)]
    return tf.switch_case(selected, branches)
Exemplo n.º 10
0
    def _apply_ops(self, data, op_indices, op_args, prob_to_apply=None):
        for idx in range(self.num_layers):
            op_index, op_level = op_indices[idx], op_args[idx]
            """Applies one augmentation op to the data."""
            branch_fns = []
            for augment_op_name in IMAGENET_AUG_OPS:
                augment_fn = NAME_TO_FUNC[augment_op_name]
                level_to_args_fn = LEVEL_TO_ARG[augment_op_name]

                def _branch_fn(data=data,
                               augment_fn=augment_fn,
                               level_to_args_fn=level_to_args_fn):
                    # Add image shape kwargs for translate augment
                    args = [data] + list(
                        level_to_args_fn(op_level, image_size=data.shape[0]))
                    fuc_args = inspect.getfullargspec(augment_fn).args
                    if 'replace' in fuc_args and 'replace' == fuc_args[-1]:
                        # Make sure replace is the final argument
                        args.append(self.replace)
                    return augment_fn(*args)

                branch_fns.append(_branch_fn)
            aug_data = tf.switch_case(op_index,
                                      branch_fns,
                                      default=lambda: data)
            if prob_to_apply is not None:
                data = tf.cond(
                    tf.random.uniform(shape=[], dtype=tf.float32) <
                    prob_to_apply, lambda: aug_data, lambda: data)
            else:
                data = aug_data

        return data
Exemplo n.º 11
0
 def call(self, x):
     hardwts, index = gumbel_softmax(self.alpha,
                                     tau=1.0,
                                     hard=True,
                                     return_index=True)
     branch_fns = [self._create_branch_fn(x, i, hardwts) for i in range(3)]
     return tf.switch_case(index, branch_fns)
Exemplo n.º 12
0
 def to_tensor(self, elem: tfds.features.FeaturesDict) -> tf.Tensor:
     index = tf.dtypes.cast(elem[self.label], tf.int32)
     fns = {key: lambda input_=input_: input_.to_tensor(elem) for key, input_ in self.mapping.items()}
     tensor = tf.switch_case(branch_index=index,
                             branch_fns=fns,
                             default=lambda: 'NONE')
     return tensor
Exemplo n.º 13
0
 def distort_each_ops(data):
     op_to_select = tf.random.uniform([], maxval=len(available_ops), dtype=tf.int32)
     # tf.print("ops ",op_to_select)
     ret = tf.switch_case(
         op_to_select,
         branch_fns={
             0: lambda: cls.autocontrast(data["image"]),
             1: lambda: cls.equalize(data["image"]),
             2: lambda: cls.invert(data["image"]),
             3: lambda: cls.rotate(data["image"], *(cls._rotate_level_to_arg(level))),
             4: lambda: cls.posterize(data["image"] , int((level/cls._MAX_LEVEL)*4)),
             5: lambda: cls.solarize(data["image"], int((level/cls._MAX_LEVEL)* 256)),
             6: lambda: cls.solarize_add(data["image"],int((level/cls._MAX_LEVEL)* 110)),
             7: lambda: cls.color(data["image"],*(cls._enhance_level_to_arg(level))),
             8: lambda: cls.contrast(data["image"],*(cls._enhance_level_to_arg(level))),
             9: lambda: cls.brightness(data["image"], *(cls._enhance_level_to_arg(level))),
             10: lambda: cls.sharpness(data["image"], *(cls._enhance_level_to_arg(level))),
             11: lambda: cls.shear_x(data["image"] , *(cls._shear_level_to_arg(level))),
             12: lambda: cls.shear_y(data["image"] , *(cls._shear_level_to_arg(level))),
             13: lambda: cls.translate_x(data["image"], *(cls._translate_level_to_arg(level, translate_const))),
             14: lambda: cls.translate_y(data["image"], *(cls._translate_level_to_arg(level, translate_const))),
             15: lambda: cls.cutout(data["image"], int((level/cls._MAX_LEVEL) * cutout_const)),
         },
         default= lambda: cls.autocontrast(data["image"], *())
     )
     return ret
Exemplo n.º 14
0
  def inference(actor_ids, run_ids, env_outputs, raw_rewards):
    # Reset the actors that had their first run or crashed.
    previous_run_ids = actor_run_ids.read(actor_ids)
    actor_run_ids.replace(actor_ids, run_ids)
    reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
    actors_needing_reset = tf.gather(actor_ids, reset_indices)
    if tf.not_equal(tf.shape(actors_needing_reset)[0], 0):
      tf.print('Actor ids needing reset:', actors_needing_reset)
    actor_infos.reset(actors_needing_reset)
    store.reset(actors_needing_reset)
    initial_agent_states = agent.initial_state(
        tf.shape(actors_needing_reset)[0])
    first_agent_states.replace(actors_needing_reset, initial_agent_states)
    agent_states.replace(actors_needing_reset, initial_agent_states)
    actions.reset(actors_needing_reset)

    # Update steps and return.
    actor_infos.add(actor_ids, (0, env_outputs.reward, raw_rewards))
    done_ids = tf.gather(actor_ids, tf.where(env_outputs.done)[:, 0])
    info_queue.enqueue_many(actor_infos.read(done_ids))
    actor_infos.reset(done_ids)
    actor_infos.add(actor_ids, (FLAGS.num_action_repeats, 0., 0.))

    # Inference.
    prev_actions = actions.read(actor_ids)
    input_ = encode((prev_actions, env_outputs))
    prev_agent_states = agent_states.read(actor_ids)
    def make_inference_fn(inference_device):
      def device_specific_inference_fn():
        with tf.device(inference_device):
          @tf.function
          def agent_inference(*args):
            return agent(*decode(args))

          return agent_inference(input_, prev_agent_states)

      return device_specific_inference_fn

    # Distribute the inference calls among the inference cores.
    branch_index = inference_iteration.assign_add(1) % len(inference_devices)
    agent_outputs, curr_agent_states = tf.switch_case(branch_index, {
        i: make_inference_fn(inference_device)
        for i, inference_device in enumerate(inference_devices)
    })

    # Append the latest outputs to the unroll and insert completed unrolls in
    # queue.
    completed_ids, unrolls = store.append(
        actor_ids, (prev_actions, env_outputs, agent_outputs))
    unrolls = Unroll(first_agent_states.read(completed_ids), *unrolls)
    unroll_queue.enqueue_many(unrolls)
    first_agent_states.replace(completed_ids,
                               agent_states.read(completed_ids))

    # Update current state.
    agent_states.replace(actor_ids, curr_agent_states)
    actions.replace(actor_ids, agent_outputs.action)

    # Return environment actions to actors.
    return agent_outputs.action
Exemplo n.º 15
0
    def color_jitter_rand(self, x):
        def change_brightness():
            return x if self.brightness == 0 else tf.image.random_brightness(
                x, self.brightness)

        def change_contrast():
            return x if self.contrast[
                0] + self.contrast[1] == 0 else tf.image.random_contrast(
                    x, self.contrast[0], self.contrast[1])

        def change_saturation():
            return x if self.saturation[0] + self.saturation[
                1] == 0 else tf.image.random_saturation(
                    x, self.saturation[0], self.saturation[1])

        def change_hue():
            return x if self.hue == 0 else tf.image.random_hue(x, self.hue)

        perm = tf.random.shuffle(tf.range(4))
        for i in range(4):
            index = tf.gather(perm, i)
            x = tf.switch_case(index,
                               branch_fns={
                                   0: change_brightness,
                                   1: change_contrast,
                                   2: change_saturation,
                                   3: change_hue
                               })
            if self.clip[0] + self.clip[1] != 0:
                x = tf.clip_by_value(x, self.clip[0], self.clip[1])
        return x
Exemplo n.º 16
0
def tuneLabel(label):
    def f0(): return tf.constant(0.0, dtype=tf.float16)
    def f1(): return tf.constant(1.0, dtype=tf.float16)
    def f2(): return tf.constant(2.0, dtype=tf.float16)
    def f3(): return tf.constant(3.0, dtype=tf.float16)
    def f4(): return tf.constant(4.0, dtype=tf.float16)
    def f5(): return tf.constant(5.5, dtype=tf.float16)
    return tf.switch_case(tf.cast(label,dtype=tf.int32), branch_fns=[f0, f1, f2, f3, f4, f5])
Exemplo n.º 17
0
def _cond(index: TfVal, *operands: TfVal,
          branches: Sequence[core.TypedJaxpr],
          linear: Sequence[bool]):
  del linear
  # tf.cond needs lambdas with no arguments.
  tf_branches = [functools.partial(_interpret_jaxpr, jaxpr, *operands)
                 for jaxpr in branches]
  return tf.switch_case(index, tf_branches)
Exemplo n.º 18
0
def _cond(index: TfVal, *operands: TfValOrUnit,
          branches: Sequence[core.TypedJaxpr],
          linear: Sequence[bool]) -> Sequence[TfValOrUnit]:
  del linear
  # tf.cond needs lambdas with no arguments.
  branches_tf = [functools.partial(_interpret_jaxpr, jaxpr, *operands)
                 for jaxpr in branches]
  res_tf: Sequence[TfVal] = tf.switch_case(index, branches_tf)
  return _tfval_add_unit(res_tf, branches[0].out_avals)
Exemplo n.º 19
0
    def distort(self, image: tf.Tensor) -> tf.Tensor:
        """Applies the RandAugment policy to `image`.

    Args:
      image: `Tensor` of shape [height, width, 3] representing an image.

    Returns:
      The augmented version of `image`.
    """
        input_image_type = image.dtype

        if input_image_type != tf.uint8:
            image = tf.clip_by_value(image, 0.0, 255.0)
            image = tf.cast(image, dtype=tf.uint8)

        replace_value = [128] * 3
        min_prob, max_prob = 0.2, 0.8

        aug_image = image

        for _ in range(self.num_layers):
            op_to_select = tf.random.uniform([],
                                             maxval=len(self.available_ops) +
                                             1,
                                             dtype=tf.int32)

            branch_fns = []
            for (i, op_name) in enumerate(self.available_ops):
                prob = tf.random.uniform([],
                                         minval=min_prob,
                                         maxval=max_prob,
                                         dtype=tf.float32)
                func, _, args = _parse_policy_info(op_name, prob,
                                                   self.magnitude,
                                                   replace_value,
                                                   self.cutout_const,
                                                   self.translate_const)
                branch_fns.append((
                    i,
                    # pylint:disable=g-long-lambda
                    lambda selected_func=func, selected_args=args:
                    selected_func(image, *selected_args)))
                # pylint:enable=g-long-lambda

            aug_image = tf.switch_case(branch_index=op_to_select,
                                       branch_fns=branch_fns,
                                       default=lambda: tf.identity(image))

            if self.prob_to_apply is not None:
                aug_image = tf.cond(
                    tf.random.uniform(shape=[],
                                      dtype=tf.float32) < self.prob_to_apply,
                    lambda: tf.identity(aug_image), lambda: tf.identity(image))
            image = aug_image

        image = tf.cast(image, dtype=input_image_type)
        return image
Exemplo n.º 20
0
def route_block(x, nexpert, nidx, route, expert):
    blocks = [x]
    for (c, f, p) in route:
        new_block = block(blocks[-1], c, f, p)
        blocks.append(new_block)

    flat  = tf.reshape(blocks[-1], [1, f])
    dense = tf.squeeze(dense_block(flat, [f, nidx * nexpert]))
    route = tf.reshape(dense, [nidx, nexpert])

    pis       = [None] * nidx
    train_idx = [None] * nidx
    for i in range(nidx):
        pis[i] = tf.distributions.Categorical(logits=route[p])
        train_idx[i] = tf.squeeze(pis[i].sample(1))

    test_idx = tf.cast(tf.argmax(route, axis=0), dtype=tf.int32)

    experts = []
    for e in range(nexpert):
        expert_blocks = [x]
        for (c, f, p) in expert:
            new_block = block(expert_blocks[-1], c, f, p)
            expert_blocks.append(new_block)

        pred = expert_blocks[-1]
        experts.append(pred)

    branch_fns = {}
    for e in range(nexpert):
        branch_fns[e] = lambda: experts[e]

    train_out = [None] * nidx
    test_out = [None] * nidx
    for i in range(nidx):
        train_out[i] = tf.switch_case(branch_index=train_idx[i], branch_fns=branch_fns)
        test_out[i] = tf.switch_case(branch_index=test_idx[i], branch_fns=branch_fns)

    train_out = tf.concat(train_out, axis=3)
    test_out = tf.concat(test_out, axis=3)

    entropy = 

    return train_out, test_out, train_idx, test_idx
Exemplo n.º 21
0
 def _apply_fn(image, ii):
   tmp_ii = function_order[ii]
   image = tf.switch_case(
       tmp_ii, {
           0: lambda: brightness_fn(image),
           1: lambda: contrast_fn(image),
           2: lambda: saturation_fn(image)
       })
   ii = ii + 1
   return image, ii
Exemplo n.º 22
0
def rot90(image, bboxes, k=(0, 1, 2, 3)):
    """
    Rotate image and bounding boxes counter-clockwise by random multiple of 90 degrees.

    :param image: 3-D Tensor of shape [height, width, channels]
    :param bboxes: 2-D Tensor of shape (box_number, 4) containing bounding boxes in format [ymin, xmin, ymin, xmax]
    :param k: array with multiples of 90 to choose from
    :return: (rotated image, rotated bounding boxes)
    """
    with tf.name_scope("rot90"):
        selected_k = tf.math.floormod(tf.random.shuffle(k)[0], 4)

        image = tf.image.rot90(image, k=selected_k)

        rotate_bboxes = [
            lambda: bboxes,
            lambda: tf.stack(
                [
                    tf.math.subtract(1.0, bboxes[:, 3]), bboxes[:, 0],
                    tf.math.subtract(1.0, bboxes[:, 1]), bboxes[:, 2]
                ],
                axis=1,
            ),
            lambda: tf.math.subtract(
                1.0,
                tf.stack(
                    [
                        bboxes[:, 2],
                        bboxes[:, 3],
                        bboxes[:, 0],
                        bboxes[:, 1],
                    ],
                    axis=1,
                ),
            ),
            lambda: tf.stack(
                [
                    bboxes[:, 1],
                    tf.math.subtract(1.0, bboxes[:, 2]), bboxes[:, 3],
                    tf.math.subtract(1.0, bboxes[:, 0])
                ],
                axis=1,
            ),
        ]

        bboxes = tf.cond(
            tf.greater(tf.shape(bboxes)[0], 0),
            lambda: tf.switch_case(selected_k, rotate_bboxes),
            lambda: bboxes,
        )

    return image, bboxes
Exemplo n.º 23
0
 def func(image):
     return tf.switch_case(i, [
         lambda: tf.clip_by_value(
             tf.image.random_brightness(image, brightness), 0, 1),
         lambda: tf.clip_by_value(
             tf.image.random_contrast(image, 1 - contrast, 1 + contrast
                                      ), 0, 1),
         lambda: tf.clip_by_value(
             tf.image.random_saturation(image, 1 - saturation, 1 +
                                        saturation), 0, 1),
         lambda: tf.clip_by_value(tf.image.random_hue(image, hue), 0, 1
                                  ),
     ])
Exemplo n.º 24
0
 def build(self, framework: str) -> None:
     if framework == 'tf':
         self.prob_fn = tfp.distributions.Uniform(low=0, high=len(self.ops))
         self.invoke_fn = lambda idx, data, state: tf.switch_case(
             idx, [lambda: op.forward(data, state) for op in self.ops])
     elif framework == 'torch':
         self.prob_fn = torch.distributions.uniform.Uniform(low=0,
                                                            high=len(
                                                                self.ops))
         self.invoke_fn = lambda idx, data, state: self.ops[idx].forward(
             data, state)
     else:
         raise ValueError("unrecognized framework: {}".format(framework))
Exemplo n.º 25
0
    def inference_eval(actor_ids, run_ids, env_outputs, raw_rewards):
        # Reset the actors that had their first run or crashed.
        previous_run_ids = actor_run_ids.read(actor_ids)
        actor_run_ids.replace(actor_ids, run_ids)
        reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
        actors_needing_reset = tf.gather(actor_ids, reset_indices)
        if tf.not_equal(tf.shape(actors_needing_reset)[0], 0):
            tf.print('Actor ids needing reset:', actors_needing_reset)

        initial_agent_states = agent.initial_state(
            tf.shape(actors_needing_reset)[0])
        agent_states.replace(actors_needing_reset, initial_agent_states)
        actions.reset(actors_needing_reset)

        # Inference.
        prev_actions = parametric_action_distribution.postprocess(
            actions.read(actor_ids))
        input_ = encode((prev_actions, env_outputs))
        prev_agent_states = agent_states.read(actor_ids)

        def make_inference_fn(inference_device):
            def device_specific_inference_fn():
                with tf.device(inference_device):

                    @tf.function
                    def agent_inference(*args):
                        return agent(*decode(args),
                                     is_training=False,
                                     postprocess_action=False)

                    return agent_inference(*input_, prev_agent_states)

            return device_specific_inference_fn

        # Distribute the inference calls among the inference cores.
        branch_index = inference_iteration_eval.assign_add(1) % len(
            inference_devices)
        agent_outputs, curr_agent_states = tf.switch_case(
            branch_index, {
                i: make_inference_fn(inference_device)
                for i, inference_device in enumerate(inference_devices)
            })

        # Update current state.
        agent_states.replace(actor_ids, curr_agent_states)
        actions.replace(actor_ids, agent_outputs.action)

        # Return environment actions to actors.
        return parametric_action_distribution.postprocess(agent_outputs.action)
Exemplo n.º 26
0
 def augment(batch):
     which_aug = tf.random.uniform(shape=(),
                                   dtype=tf.dtypes.int32,
                                   minval=0,
                                   maxval=8)
     aug_options = {
         0: lambda: batch,
         1: lambda: tf.transpose(batch, perm=[0, 2, 1, 3]),
         2: lambda: tf.image.flip_up_down(batch),
         3: lambda: tf.image.flip_left_right(tf.image.rot90(batch, k=1)),
         4: lambda: tf.image.flip_left_right(batch),
         5: lambda: tf.image.rot90(batch, k=1),
         6: lambda: tf.image.rot90(batch, k=2),
         7: lambda: tf.image.rot90(batch, k=3),
     }
     batch = tf.switch_case(which_aug, aug_options)
     return batch
Exemplo n.º 27
0
    def combined_labels(self, nb_lesions):
        def f1():
            return tf.constant(0)

        def f2():
            return tf.constant(1)

        def f3():
            return tf.constant(2)

        return tf.switch_case(nb_lesions,
                              branch_fns={
                                  0: f1,
                                  1: f2,
                                  2: f2,
                                  3: f2,
                                  4: f2
                              })
Exemplo n.º 28
0
def isup_to_smoothed_labels(label):
    label = tf.cast(label, dtype=tf.int32)
    # label smothing that accounts for order
    def case0(): return tf.constant([2/3,   2/9,    1/9,    0,      0,      0],dtype=tf.float32)
    def case1(): return tf.constant([1/6,   2/3,    1/9,    1/18,   0,      0],dtype=tf.float32)
    def case2(): return tf.constant([1/18,  1/9,    2/3,    1/9,    1/18,   0],dtype=tf.float32)
    def case3(): return tf.constant([0,     1/18,   1/9,    2/3,    1/9,    1/18],dtype=tf.float32)
    def case4(): return tf.constant([0,     0,      1/18,   1/9,    2/3,    1/6],dtype=tf.float32)
    def case5(): return tf.constant([0,     0,      0,      1/9,    2/9,    2/3],dtype=tf.float32)
    return tf.switch_case(
        label,
        {
            0: case0,
            1: case1,
            2: case2,
            3: case3,
            4: case4,
            5: case5
        })
Exemplo n.º 29
0
def apply_augmentation_op(image, op_index, op_level, prob_to_apply):
    """Applies one augmentation op to the image."""
    branch_fns = []
    for augment_op_name in IMAGENET_AUG_OPS:
        augment_fn = augment_ops.NAME_TO_FUNC[augment_op_name]
        level_to_args_fn = LEVEL_TO_ARG[augment_op_name]

        def _branch_fn(image=image,
                       augment_fn=augment_fn,
                       level_to_args_fn=level_to_args_fn):
            args = [image] + list(level_to_args_fn(op_level))
            return augment_fn(*args)

        branch_fns.append(_branch_fn)
    aug_image = tf.switch_case(op_index, branch_fns, default=lambda: image)
    if prob_to_apply is not None:
        return tf.cond(
            tf.random.uniform(shape=[], dtype=tf.float32) < prob_to_apply,
            lambda: aug_image, lambda: image)
    else:
        return aug_image
Exemplo n.º 30
0
def apply_augmentation_op(data, op_index, op_level, prob_to_apply):
    """Applies one augmentation op to the data."""
    branch_fns = []
    for augment_op_name in AUG_OPS:
        augment_fn = NAME_TO_FUNC[augment_op_name]
        level_to_args_fn = LEVEL_TO_ARG[augment_op_name]

        def _branch_fn(data=data,
                       augment_fn=augment_fn,
                       level_to_args_fn=level_to_args_fn):
            args = [data] + list(level_to_args_fn(op_level))
            return augment_fn(*args)

        branch_fns.append(_branch_fn)
    aug_data = tf.switch_case(op_index, branch_fns, default=lambda: data)
    if prob_to_apply is not None:
        return tf.cond(
            tf.random.uniform(shape=[], dtype=tf.float32) < prob_to_apply,
            lambda: aug_data, lambda: data)
    else:
        return aug_data