Example #1
0
def def_grads(prims):
    """ Define gradient function for primitives """
    identity = lambda x: x
    # dot
    prims('dot').def_grad(lambda ans, a, b: lambda g: ndarray.dot(g, b.T))
    prims('dot').def_grad(
        lambda ans, a, b: lambda g: ndarray.dot(a.T, g), argnum=1)
    # non-linear
    prims('tanh').def_grad(lambda ans, x: lambda g: g * (1 - ans ** 2))
    prims('exp').def_grad(lambda ans, x: lambda g: g * ans)
    prims('log').def_grad(lambda ans, x: lambda g: g / x)
    # reduce
    prims('sum').def_grad(_sum_grad)
    # + - * /
    prims('multiply').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y))
    prims('multiply').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: x * g), argnum=1)
    prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, x, identity))
    prims('add').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, identity), argnum=1)
    prims('subtract').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, identity))
    prims('subtract').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, operator.neg), argnum=1)
    prims('divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y))
    prims('divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)),
        argnum=1)
    prims('true_divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y))
    prims('true_divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)),
        argnum=1)
    prims('maximum').def_grad(_maximum_grad_gen0)
    prims('maximum').def_grad(_maximum_grad_gen1, argnum=1)
    # TODO: minjie
    prims('max').def_grad_zero()
    # negate
    prims('negative').def_grad(lambda ans, x: operator.neg)
    prims('transpose').def_grad(lambda ans, x: mxnet.nd.transpose)
    prims('abs').def_grad(lambda ans, x: lambda g: mxnet.nd.sign(x) * g)
    prims('sign').def_grad_zero()
    prims('round').def_grad_zero()
    prims('ceil').def_grad_zero()
    prims('floor').def_grad_zero()
    prims('sqrt').def_grad(lambda ans, x: lambda g: g * 0.5 / mxnet.nd.sqrt(x))
    prims('sin').def_grad(lambda ans, x: lambda g: g * mxnet.nd.cos(x))
    prims('cos').def_grad(lambda ans, x: lambda g: -g * mxnet.nd.sin(x))
    prims('power').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y * mxnet.nd.power(x, y - 1))
    )
    prims('power').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: g * mxnet.nd.log(x) * ans),
        argnum=1)
    prims('reshape').def_grad(
        lambda _0, x, _1: lambda g: NDArray.reshape(g, x.shape))
    prims('expand_dims').def_grad(
        lambda ans, x, axis: lambda g: NDArray.reshape(g, x.shape))
Example #2
0
def def_grads(prims):
    """ Define gradient function for primitives """
    identity = lambda x: x
    # dot
    prims('dot').def_grad(lambda ans, a, b: lambda g: ndarray.dot(g, b.T))
    prims('dot').def_grad(lambda ans, a, b: lambda g: ndarray.dot(a.T, g),
                          argnum=1)
    # non-linear
    #prims.tanh.def_grad(lambda ans, x: lambda g: g / np.cosh(x) ** 2)
    prims('exp').def_grad(lambda ans, x: lambda g: g * ans)
    prims('log').def_grad(lambda ans, x: lambda g: g / x)
    # reduce
    prims('sum').def_grad(_sum_grad)
    # + - * /
    prims('multiply').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y))
    prims('multiply').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: x * g), argnum=1)
    prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, x, identity))
    prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, y, identity),
                          argnum=1)
    prims('subtract').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, identity))
    prims('subtract').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, operator.neg), argnum=1)
    prims('divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y))
    prims('divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)),
        argnum=1)
    prims('true_divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y))
    prims('true_divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)),
        argnum=1)
    prims('maximum').def_grad(_maximum_grad_gen0)
    prims('maximum').def_grad(_maximum_grad_gen1, argnum=1)
    # TODO: minjie
    prims('max').def_grad_zero()
    # negate
    prims('negative').def_grad(lambda ans, x: operator.neg)
    prims('transpose').def_grad(lambda ans, x: mxnet.nd.transpose)
    prims('abs').def_grad(lambda ans, x: lambda g: mxnet.nd.sign(x) * g)
    prims('sign').def_grad_zero()
    prims('round').def_grad_zero()
    prims('ceil').def_grad_zero()
    prims('floor').def_grad_zero()
    prims('sqrt').def_grad(lambda ans, x: lambda g: g * 0.5 / mxnet.nd.sqrt(x))
    prims('sin').def_grad(lambda ans, x: lambda g: g * mxnet.nd.cos(x))
    prims('cos').def_grad(lambda ans, x: lambda g: -g * mxnet.nd.sin(x))
    prims('power').def_grad(lambda ans, x, y: _unbroadcast(
        ans, x, lambda g: g * y * mxnet.nd.power(x, y - 1)))
    prims('power').def_grad(lambda ans, x, y: _unbroadcast(
        ans, y, lambda g: g * mxnet.nd.log(x) * ans),
                            argnum=1)
    prims('reshape').def_grad(
        lambda _0, x, _1: lambda g: NDArray.reshape(g, x.shape))
    prims('expand_dims').def_grad(
        lambda ans, x, axis: lambda g: NDArray.reshape(g, x.shape))
Example #3
0
    def forward(self,
                encoder_output: nd.NDArray,
                label=None,
                label_lengths=None):
        no_label = label is None or label_lengths is None

        encoder_output = nd.transpose(encoder_output, (0, 2, 3, 1))
        encoder_output = encoder_output.reshape(
            (encoder_output.shape[0], -1, encoder_output.shape[3]))
        batch_max_len = self.max_len if no_label else int(
            label_lengths.max().asscalar()) - 1

        # Initialize hidden states
        encoder_output_mean = encoder_output.mean(axis=1)
        h = self.init_h(encoder_output_mean)
        c = self.init_c(encoder_output_mean)

        # Two tensors to store outputs
        predictions = []
        alphas = []

        if not no_label:
            label_embedded = self.embedding(label)
        else:
            bs = encoder_output.shape[0]
            x_t = self.embedding(
                nd.zeros(shape=(bs, ), ctx=encoder_output.context))
        for t in range(batch_max_len):
            if not no_label:
                x_t = label_embedded[:, t]
            if self._use_current_state:
                _, [h, c] = self.lstm_cell(x_t, [h, c])
                if self._use_adaptive_attention:
                    atten_weights, alpha = self.attention(
                        encoder_output, h, x_t, c)
                else:
                    atten_weights, alpha = self.attention(encoder_output, h)
                atten_weights = self.f_beta(h).sigmoid() * atten_weights
                inputs = nd.concat(atten_weights, h, dim=1)
                preds = self.out(self.dropout(inputs))
            else:
                atten_weights, alpha = self.attention(encoder_output, h)
                atten_weights = nd.sigmoid(self.f_beta(h)) * atten_weights
                inputs = nd.concat(x_t, atten_weights, dim=1)
                _, [h, c] = self.lstm_cell(inputs, [h, c])
                preds = self.out(self.dropout(h))
            x_t = self.embedding(preds.argmax(axis=1))
            predictions.append(preds)
            alphas.append(alpha)
        predictions = nd.concat(*[x.expand_dims(axis=1) for x in predictions],
                                dim=1)
        alphas = nd.concat(*[x.expand_dims(axis=1) for x in alphas], dim=1)

        return predictions, alphas
Example #4
0
 def _batchify(data: nd.NDArray, batch_size):
     """
     Make a batch tensor out of a vector
     :param data: vector
     :param batch_size: NN
     :return: (IN,NN) tensor
     """
     # Work out how cleanly we can divide the dataset into bsz parts.
     nbatch = len(data) // batch_size
     # Trim off any extra elements that wouldn't cleanly fit (remainders).
     data = data[0:nbatch * batch_size]
     # Evenly divide the data across the bsz batches.
     data = data.reshape(batch_size, -1).transpose()
     # if torch.cuda.is_available():
     #     data = data.cuda()
     return data
Example #5
0
    def train(self, d: HybridSequential, x: NDArray, y: NDArray) -> float:
        #         with autograd.record():
        #             loss = (lambda y_hat: self.lossfun(1, d(concat(x, y_hat, dim=1)), y, y_hat))(self._network(x))
        with autograd.record():
            gen_out = self._network(x)

            y_hat = gen_out
            y = y

            #             x = x.repeat(int(y.shape[0]/x.shape[0]), 0)

            loss = self.lossfun(1, d(concat(x, y_hat, dim=1)),
                                y.reshape((-1, 3, 96, 96)),
                                y_hat.reshape((-1, 3, 96, 96)))

        loss.backward()
        self.trainer.step(1)

        return float(loss.asscalar())
Example #6
0
def def_grads(prims):
    """ Define gradient function for primitives """
    identity = lambda x: x
    # dot
    prims('dot').def_grad(
        lambda ans, a, b: lambda g: mx.nd.dot(g, b, transpose_b=True))
    prims('dot').def_grad(
        lambda ans, a, b: lambda g: mx.nd.dot(a, g, transpose_a=True),
        argnum=1)
    # non-linear
    prims('tanh').def_grad(lambda ans, x: lambda g: g * (1 - ans**2))
    prims('exp').def_grad(lambda ans, x: lambda g: g * ans)
    prims('log').def_grad(lambda ans, x: lambda g: g / x)
    # reduce
    prims('sum').def_grad(_reduce_sum_grad_gen)
    prims('max').def_grad(_reduce_select_grad_gen)
    prims('min').def_grad(_reduce_select_grad_gen)
    # + - * /
    prims('multiply').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y))
    prims('multiply').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: x * g), argnum=1)
    prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, x, identity))
    prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, y, identity),
                          argnum=1)
    prims('subtract').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, identity))
    prims('subtract').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, operator.neg), argnum=1)
    prims('divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y))
    prims('divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)),
        argnum=1)
    prims('true_divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y))
    prims('true_divide').def_grad(
        lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)),
        argnum=1)
    prims('maximum').def_grad(_selection_grad_gen0)
    prims('maximum').def_grad(_selection_grad_gen1, argnum=1)
    prims('minimum').def_grad(_selection_grad_gen0)
    prims('minimum').def_grad(_selection_grad_gen1, argnum=1)
    # negate
    prims('negative').def_grad(lambda ans, x: operator.neg)
    prims('transpose').def_grad(lambda ans, x: mx.nd.transpose)
    prims('abs').def_grad(lambda ans, x: lambda g: mx.nd.sign(x) * g)
    prims('sign').def_grad_zero()
    prims('round').def_grad_zero()
    prims('ceil').def_grad_zero()
    prims('floor').def_grad_zero()
    prims('sqrt').def_grad(lambda ans, x: lambda g: g * 0.5 / mx.nd.sqrt(x))
    prims('sin').def_grad(lambda ans, x: lambda g: g * mx.nd.cos(x))
    prims('cos').def_grad(lambda ans, x: lambda g: -g * mx.nd.sin(x))
    prims('power').def_grad(lambda ans, x, y: _unbroadcast(
        ans, x, lambda g: g * y * mx.nd.power(x, y - 1)))
    prims('power').def_grad(lambda ans, x, y: _unbroadcast(
        ans, y, lambda g: g * mx.nd.log(x) * ans),
                            argnum=1)
    prims('reshape').def_grad(
        lambda _0, x, _1: lambda g: NDArray.reshape(g, x.shape))
    prims('expand_dims').def_grad(
        lambda ans, x, axis: lambda g: NDArray.reshape(g, x.shape))
    prims('softmax_output').def_grad(_softmax_output_grad)
Example #7
0
def sq_loss(ey: nd.NDArray, y: nd.NDArray):
    ty = ey.reshape(shape=y.shape)
    loss = (ty - y)**2
    loss = loss.sum() / loss.size
    return loss