def def_grads(prims): """ Define gradient function for primitives """ identity = lambda x: x # dot prims('dot').def_grad(lambda ans, a, b: lambda g: ndarray.dot(g, b.T)) prims('dot').def_grad( lambda ans, a, b: lambda g: ndarray.dot(a.T, g), argnum=1) # non-linear prims('tanh').def_grad(lambda ans, x: lambda g: g * (1 - ans ** 2)) prims('exp').def_grad(lambda ans, x: lambda g: g * ans) prims('log').def_grad(lambda ans, x: lambda g: g / x) # reduce prims('sum').def_grad(_sum_grad) # + - * / prims('multiply').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y)) prims('multiply').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: x * g), argnum=1) prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, x, identity)) prims('add').def_grad( lambda ans, x, y: _unbroadcast(ans, y, identity), argnum=1) prims('subtract').def_grad( lambda ans, x, y: _unbroadcast(ans, x, identity)) prims('subtract').def_grad( lambda ans, x, y: _unbroadcast(ans, y, operator.neg), argnum=1) prims('divide').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y)) prims('divide').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)), argnum=1) prims('true_divide').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y)) prims('true_divide').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)), argnum=1) prims('maximum').def_grad(_maximum_grad_gen0) prims('maximum').def_grad(_maximum_grad_gen1, argnum=1) # TODO: minjie prims('max').def_grad_zero() # negate prims('negative').def_grad(lambda ans, x: operator.neg) prims('transpose').def_grad(lambda ans, x: mxnet.nd.transpose) prims('abs').def_grad(lambda ans, x: lambda g: mxnet.nd.sign(x) * g) prims('sign').def_grad_zero() prims('round').def_grad_zero() prims('ceil').def_grad_zero() prims('floor').def_grad_zero() prims('sqrt').def_grad(lambda ans, x: lambda g: g * 0.5 / mxnet.nd.sqrt(x)) prims('sin').def_grad(lambda ans, x: lambda g: g * mxnet.nd.cos(x)) prims('cos').def_grad(lambda ans, x: lambda g: -g * mxnet.nd.sin(x)) prims('power').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y * mxnet.nd.power(x, y - 1)) ) prims('power').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: g * mxnet.nd.log(x) * ans), argnum=1) prims('reshape').def_grad( lambda _0, x, _1: lambda g: NDArray.reshape(g, x.shape)) prims('expand_dims').def_grad( lambda ans, x, axis: lambda g: NDArray.reshape(g, x.shape))
def def_grads(prims): """ Define gradient function for primitives """ identity = lambda x: x # dot prims('dot').def_grad(lambda ans, a, b: lambda g: ndarray.dot(g, b.T)) prims('dot').def_grad(lambda ans, a, b: lambda g: ndarray.dot(a.T, g), argnum=1) # non-linear #prims.tanh.def_grad(lambda ans, x: lambda g: g / np.cosh(x) ** 2) prims('exp').def_grad(lambda ans, x: lambda g: g * ans) prims('log').def_grad(lambda ans, x: lambda g: g / x) # reduce prims('sum').def_grad(_sum_grad) # + - * / prims('multiply').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y)) prims('multiply').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: x * g), argnum=1) prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, x, identity)) prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, y, identity), argnum=1) prims('subtract').def_grad( lambda ans, x, y: _unbroadcast(ans, x, identity)) prims('subtract').def_grad( lambda ans, x, y: _unbroadcast(ans, y, operator.neg), argnum=1) prims('divide').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y)) prims('divide').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)), argnum=1) prims('true_divide').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y)) prims('true_divide').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)), argnum=1) prims('maximum').def_grad(_maximum_grad_gen0) prims('maximum').def_grad(_maximum_grad_gen1, argnum=1) # TODO: minjie prims('max').def_grad_zero() # negate prims('negative').def_grad(lambda ans, x: operator.neg) prims('transpose').def_grad(lambda ans, x: mxnet.nd.transpose) prims('abs').def_grad(lambda ans, x: lambda g: mxnet.nd.sign(x) * g) prims('sign').def_grad_zero() prims('round').def_grad_zero() prims('ceil').def_grad_zero() prims('floor').def_grad_zero() prims('sqrt').def_grad(lambda ans, x: lambda g: g * 0.5 / mxnet.nd.sqrt(x)) prims('sin').def_grad(lambda ans, x: lambda g: g * mxnet.nd.cos(x)) prims('cos').def_grad(lambda ans, x: lambda g: -g * mxnet.nd.sin(x)) prims('power').def_grad(lambda ans, x, y: _unbroadcast( ans, x, lambda g: g * y * mxnet.nd.power(x, y - 1))) prims('power').def_grad(lambda ans, x, y: _unbroadcast( ans, y, lambda g: g * mxnet.nd.log(x) * ans), argnum=1) prims('reshape').def_grad( lambda _0, x, _1: lambda g: NDArray.reshape(g, x.shape)) prims('expand_dims').def_grad( lambda ans, x, axis: lambda g: NDArray.reshape(g, x.shape))
def forward(self, encoder_output: nd.NDArray, label=None, label_lengths=None): no_label = label is None or label_lengths is None encoder_output = nd.transpose(encoder_output, (0, 2, 3, 1)) encoder_output = encoder_output.reshape( (encoder_output.shape[0], -1, encoder_output.shape[3])) batch_max_len = self.max_len if no_label else int( label_lengths.max().asscalar()) - 1 # Initialize hidden states encoder_output_mean = encoder_output.mean(axis=1) h = self.init_h(encoder_output_mean) c = self.init_c(encoder_output_mean) # Two tensors to store outputs predictions = [] alphas = [] if not no_label: label_embedded = self.embedding(label) else: bs = encoder_output.shape[0] x_t = self.embedding( nd.zeros(shape=(bs, ), ctx=encoder_output.context)) for t in range(batch_max_len): if not no_label: x_t = label_embedded[:, t] if self._use_current_state: _, [h, c] = self.lstm_cell(x_t, [h, c]) if self._use_adaptive_attention: atten_weights, alpha = self.attention( encoder_output, h, x_t, c) else: atten_weights, alpha = self.attention(encoder_output, h) atten_weights = self.f_beta(h).sigmoid() * atten_weights inputs = nd.concat(atten_weights, h, dim=1) preds = self.out(self.dropout(inputs)) else: atten_weights, alpha = self.attention(encoder_output, h) atten_weights = nd.sigmoid(self.f_beta(h)) * atten_weights inputs = nd.concat(x_t, atten_weights, dim=1) _, [h, c] = self.lstm_cell(inputs, [h, c]) preds = self.out(self.dropout(h)) x_t = self.embedding(preds.argmax(axis=1)) predictions.append(preds) alphas.append(alpha) predictions = nd.concat(*[x.expand_dims(axis=1) for x in predictions], dim=1) alphas = nd.concat(*[x.expand_dims(axis=1) for x in alphas], dim=1) return predictions, alphas
def _batchify(data: nd.NDArray, batch_size): """ Make a batch tensor out of a vector :param data: vector :param batch_size: NN :return: (IN,NN) tensor """ # Work out how cleanly we can divide the dataset into bsz parts. nbatch = len(data) // batch_size # Trim off any extra elements that wouldn't cleanly fit (remainders). data = data[0:nbatch * batch_size] # Evenly divide the data across the bsz batches. data = data.reshape(batch_size, -1).transpose() # if torch.cuda.is_available(): # data = data.cuda() return data
def train(self, d: HybridSequential, x: NDArray, y: NDArray) -> float: # with autograd.record(): # loss = (lambda y_hat: self.lossfun(1, d(concat(x, y_hat, dim=1)), y, y_hat))(self._network(x)) with autograd.record(): gen_out = self._network(x) y_hat = gen_out y = y # x = x.repeat(int(y.shape[0]/x.shape[0]), 0) loss = self.lossfun(1, d(concat(x, y_hat, dim=1)), y.reshape((-1, 3, 96, 96)), y_hat.reshape((-1, 3, 96, 96))) loss.backward() self.trainer.step(1) return float(loss.asscalar())
def def_grads(prims): """ Define gradient function for primitives """ identity = lambda x: x # dot prims('dot').def_grad( lambda ans, a, b: lambda g: mx.nd.dot(g, b, transpose_b=True)) prims('dot').def_grad( lambda ans, a, b: lambda g: mx.nd.dot(a, g, transpose_a=True), argnum=1) # non-linear prims('tanh').def_grad(lambda ans, x: lambda g: g * (1 - ans**2)) prims('exp').def_grad(lambda ans, x: lambda g: g * ans) prims('log').def_grad(lambda ans, x: lambda g: g / x) # reduce prims('sum').def_grad(_reduce_sum_grad_gen) prims('max').def_grad(_reduce_select_grad_gen) prims('min').def_grad(_reduce_select_grad_gen) # + - * / prims('multiply').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y)) prims('multiply').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: x * g), argnum=1) prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, x, identity)) prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, y, identity), argnum=1) prims('subtract').def_grad( lambda ans, x, y: _unbroadcast(ans, x, identity)) prims('subtract').def_grad( lambda ans, x, y: _unbroadcast(ans, y, operator.neg), argnum=1) prims('divide').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y)) prims('divide').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)), argnum=1) prims('true_divide').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y)) prims('true_divide').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)), argnum=1) prims('maximum').def_grad(_selection_grad_gen0) prims('maximum').def_grad(_selection_grad_gen1, argnum=1) prims('minimum').def_grad(_selection_grad_gen0) prims('minimum').def_grad(_selection_grad_gen1, argnum=1) # negate prims('negative').def_grad(lambda ans, x: operator.neg) prims('transpose').def_grad(lambda ans, x: mx.nd.transpose) prims('abs').def_grad(lambda ans, x: lambda g: mx.nd.sign(x) * g) prims('sign').def_grad_zero() prims('round').def_grad_zero() prims('ceil').def_grad_zero() prims('floor').def_grad_zero() prims('sqrt').def_grad(lambda ans, x: lambda g: g * 0.5 / mx.nd.sqrt(x)) prims('sin').def_grad(lambda ans, x: lambda g: g * mx.nd.cos(x)) prims('cos').def_grad(lambda ans, x: lambda g: -g * mx.nd.sin(x)) prims('power').def_grad(lambda ans, x, y: _unbroadcast( ans, x, lambda g: g * y * mx.nd.power(x, y - 1))) prims('power').def_grad(lambda ans, x, y: _unbroadcast( ans, y, lambda g: g * mx.nd.log(x) * ans), argnum=1) prims('reshape').def_grad( lambda _0, x, _1: lambda g: NDArray.reshape(g, x.shape)) prims('expand_dims').def_grad( lambda ans, x, axis: lambda g: NDArray.reshape(g, x.shape)) prims('softmax_output').def_grad(_softmax_output_grad)
def sq_loss(ey: nd.NDArray, y: nd.NDArray): ty = ey.reshape(shape=y.shape) loss = (ty - y)**2 loss = loss.sum() / loss.size return loss