Ejemplo n.º 1
0
    def forward(self, input, deterministic=None):

        if deterministic is None:
            deterministic = self.deterministic
        dirac = T.cast(deterministic, 'float32')

        self.mean = T.mean(input, self.axis, keepdims=True)
        self.var = T.var(input, self.axis, keepdims=True)
        if len(self.updates.keys()) == 0:
            self.avgmean, upm, step = T.ExponentialMovingAverage(
                self.mean, self.beta1)
            self.avgvar, upv, step = T.ExponentialMovingAverage(
                self.var,
                self.beta2,
                step=step,
                init=numpy.ones(self.var.shape).astype('float32'))
            self.add_variable(self.avgmean)
            self.add_variable(self.avgvar)
            self.add_update(upm)
            self.add_update(upv)

        self.usemean = self.mean * (1 - dirac) + self.avgmean * dirac
        self.usevar = self.var * (1 - dirac) + self.avgvar * dirac
        return self.W * (input - self.usemean) / \
            (T.sqrt(self.usevar) + self.const) + self.b
Ejemplo n.º 2
0
    def forward(self, input, crop_shape, deterministic, padding=0, seed=None):

        self.crop_shape = crop_shape
        # if given only a scalar
        if not hasattr(padding, "__len__"):
            self.pad_shape = [(padding, padding)] * (input.shape - 1)
        # else
        else:
            self.pad_shape = [(pad,
                               pad) if not hasattr(pad, "__len__") else pad
                              for pad in padding]

        assert len(self.pad_shape) == len(self.crop_shape)
        assert len(self.pad_shape) == (len(input.shape) - 1)

        self.start_indices = list()
        self.fixed_indices = list()
        for i, (pad, dim, crop) in enumerate(
                zip(self.pad_shape, input.shape[1:], self.crop_shape)):
            maxval = pad[0] + pad[1] + dim - crop
            assert maxval >= 0
            self.start_indices.append(
                T.random.randint(
                    minval=0,
                    maxval=maxval,
                    shape=(input.shape[0], 1),
                    dtype="int32",
                    seed=seed + i if seed is not None else seed,
                ))

            self.fixed_indices.append(
                T.ones((input.shape[0], 1), "int32") * (maxval // 2))
        self.start_indices = T.concatenate(self.start_indices, 1)
        self.fixed_indices = T.concatenate(self.fixed_indices, 1)

        dirac = T.cast(deterministic, "float32")

        # pad the input
        pinput = T.pad(input, [(0, 0)] + self.pad_shape)

        routput = T.stack(
            [
                T.dynamic_slice(pinput[n], self.start_indices[n],
                                self.crop_shape) for n in range(input.shape[0])
            ],
            0,
        )
        doutput = T.stack(
            [
                T.dynamic_slice(pinput[n], self.fixed_indices[n],
                                self.crop_shape) for n in range(input.shape[0])
            ],
            0,
        )

        return doutput * dirac + (1 - dirac) * routput
Ejemplo n.º 3
0
    def forward(self, input, deterministic=None):
        if deterministic is None:
            deterministic = self.deterministic

        dirac = T.cast(deterministic, 'float32')

        flipped_input = self.flip * T.flip(input, self.axis)\
            + (1 - self.flip) * input

        return input * dirac + flipped_input * (1 - dirac)
Ejemplo n.º 4
0
    def __init__(self, input, p, axis, deterministic, seed=None):

        extra_dims = input.ndim - 1
        flip = T.random.bernoulli(
            shape=(input.shape[0], ) + (1, ) * extra_dims,
            p=p,
            seed=seed,
        )

        dirac = T.cast(deterministic, "float32")

        flipped_input = T.where(flip, T.flip(input, axis), input)

        return input * dirac + flipped_input * (1 - dirac)
Ejemplo n.º 5
0
    def __init__(self, input, crop_shape, deterministic, padding=0, seed=None):

        # if given only a scalar
        if not hasattr(padding, "__len__"):
            pad_shape = [(padding, padding)] * (input.ndim - 1)
        # else
        else:
            pad_shape = [(pad, pad) if not hasattr(pad, "__len__") else pad
                         for pad in padding]

        assert len(pad_shape) == len(crop_shape)
        assert len(pad_shape) == input.ndim - 1

        start_indices = list()
        fixed_indices = list()
        for i, (pad, dim,
                crop) in enumerate(zip(pad_shape, input.shape[1:],
                                       crop_shape)):
            maxval = pad[0] + pad[1] + dim - crop
            start_indices.append(
                T.random.randint(
                    minval=0,
                    maxval=maxval,
                    shape=(input.shape[0], 1),
                    dtype="int32",
                    seed=seed + i if seed is not None else seed,
                ))

            fixed_indices.append(
                T.ones((input.shape[0], 1), "int32") * (maxval // 2))
        start_indices = T.concatenate(start_indices, 1)
        fixed_indices = T.concatenate(fixed_indices, 1)

        dirac = T.cast(deterministic, "float32")

        # pad the input
        pinput = T.pad(input, [(0, 0)] + pad_shape)

        routput = T.map(
            lambda x, indices: T.dynamic_slice(x, indices, crop_shape),
            sequences=[pinput, start_indices],
        )
        doutput = T.map(
            lambda x, indices: T.dynamic_slice(x, indices, crop_shape),
            sequences=[pinput, fixed_indices],
        )

        return doutput * dirac + (1 - dirac) * routput
Ejemplo n.º 6
0
    def forward(self, input, deterministic=None):

        if deterministic is None:
            deterministic = self.deterministic
        dirac = T.cast(deterministic, 'float32')

        # pad the input
        pinput = T.pad(input, [(0, 0)] + self.pad_shape)

        routput = T.stack([
            T.dynamic_slice(pinput[n], self.start_indices[n], self.crop_shape)
            for n in range(self.input.shape[0])
        ], 0)
        doutput = T.stack([
            T.dynamic_slice(pinput[n], self.fixed_indices[n], self.crop_shape)
            for n in range(self.input.shape[0])
        ], 0)

        return doutput * dirac + (1 - dirac) * routput
Ejemplo n.º 7
0
    def __init__(
            self,
            *classargs,
            outputs=[],
            updates=None,  # noqa
            device=None,
            backend=None,
            default_value=None):
        """Initialize."""
        # check the given updates (if any) and ensure that they only
        # update Variable objects
        if updates is None:
            updates = {}

        for update in updates.keys():
            if not isinstance(update, t.Variable):
                raise RuntimeError(
                    "{} is not a Variable and cannot be updated".format(
                        update))

        # ensure that all inputs are actual placeholders or variables
        for arg in classargs:
            if not isinstance(arg, t.Tensor):
                raise RuntimeError(
                    "{} is not a Tensor type. Only tensor types can be" +
                    "function inputs".format(arg))

        # gather all roots, they need to be explicit as inputs of the
        # underlying functions otherwise they are treated as constants
        # and any change in their value will not appear when running the
        # function
        outs = list(updates.values())
        outs += [outputs] if isinstance(outputs, t.Tensor) else outputs
        self.all_roots = set(t.getroots(outs))
        self.classargs = classargs
        self.outputs = outputs

        items = list(updates.items())
        self.updates_keys = [item[0] for item in items]
        self.updates_values = [item[1] for item in items]
        for i in range(len(items)):
            if self.updates_keys[i].shape != self.updates_values[i].shape:
                warnings.warn(
                    'Variable and update {} {}'.format(
                        self.updates_keys[i], self.updates_values[i]) +
                    "are not the same shape... attempting to reshape")
                self.updates_values[i] = t.reshape(self.updates_values[i],
                                                   self.updates_keys[i].shape)
            if self.updates_keys[i].dtype != self.updates_values[i].dtype:
                warnings.warn('Variable and update {} {}'.format(
                    self.updates_keys[i], self.updates_values[i]) +
                              "are not the same dtype... attempting to cast")
                self.updates_values[i] = t.cast(self.updates_values[i],
                                                self.updates_keys[i].dtype)

        # check the function inputs, they must be at least contain all the
        # placeholders needed to compute the outputs values
        placeholders_in_root = filter(lambda x: isinstance(x, t.Placeholder),
                                      self.all_roots)

        # check for
        non_givens = set(placeholders_in_root) - set(self.classargs)
        if len(non_givens) > 0:
            raise RuntimeError(
                "Missing placeholders form the function inputs: {}".format(
                    non_givens))

        # the roots are made of variables, random tensors, placeholders. We
        # already ensured that all placeholders are given as inputs to the
        # function. Now we must ensure that the other ones will also be given
        # as inputs to not be treated as constants by jax.
        # we also remove update keys because we will expicitly feed them
        self.extra_inputs = set(self.all_roots)\
            - (set(self.classargs).union(self.updates_keys))
        self.extra_inputs = list(self.extra_inputs)

        def jitfn(*jitargs):
            allargs = list(self.classargs) + self.updates_keys +\
                self.extra_inputs
            return t.get([self.outputs, self.updates_values],
                         dict(zip(allargs, jitargs)))

        # we compile our underlying function using jit for performances
        self.jitfn = jax.jit(jitfn, device=device, backend=backend)

        # define the frontend function that takes as input the inputs variables
        # and internally compute and update the variables from updates if any
        def meta(*fnargs, rng):

            # ensure that the number of arguments is correct
            assert len(fnargs) == len(self.classargs)
            for fnarg, classarg in zip(fnargs, self.classargs):
                if hasattr(fnarg, 'shape'):
                    if fnarg.shape != classarg.shape:
                        raise RuntimeError(
                            "wrong input given for {}".format(classarg) +
                            ", given is {}".format(fnarg) +
                            ", shape={}".format(fnarg.shape))

            # retreive the function outputs, updated values and apply them
            jitoutputs, jitupdates = self.jitfn(
                *fnargs,
                *t.get(self.updates_keys + self.extra_inputs, {'rng': rng}))
            for key, update in zip(self.updates_keys, jitupdates):
                key.value = update
            if isinstance(jitoutputs, jax.interpreters.xla.DeviceArray):
                return jax.api.device_get(jitoutputs)
            else:
                npy_jitoutputs = [
                    jax.api.device_get(arr) if isinstance(
                        arr, jax.interpreters.xla.DeviceArray) else arr
                    for arr in jitoutputs
                ]
                return npy_jitoutputs

        self.meta = meta
Ejemplo n.º 8
0
 def forward(self, input, deterministic=None):
     if deterministic is None:
         deterministic = self.deterministic
     dirac = T.cast(deterministic, 'float32')
     return input * self.mask * (1 - dirac) + input * dirac