def def_grads(prims): """ Define gradient function for primitives """ identity = lambda x: x # dot prims('dot').def_grad(lambda ans, a, b: lambda g: ndarray.dot(g, b.T)) prims('dot').def_grad( lambda ans, a, b: lambda g: ndarray.dot(a.T, g), argnum=1) # non-linear prims('tanh').def_grad(lambda ans, x: lambda g: g * (1 - ans ** 2)) prims('exp').def_grad(lambda ans, x: lambda g: g * ans) prims('log').def_grad(lambda ans, x: lambda g: g / x) # reduce prims('sum').def_grad(_sum_grad) # + - * / prims('multiply').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y)) prims('multiply').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: x * g), argnum=1) prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, x, identity)) prims('add').def_grad( lambda ans, x, y: _unbroadcast(ans, y, identity), argnum=1) prims('subtract').def_grad( lambda ans, x, y: _unbroadcast(ans, x, identity)) prims('subtract').def_grad( lambda ans, x, y: _unbroadcast(ans, y, operator.neg), argnum=1) prims('divide').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y)) prims('divide').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)), argnum=1) prims('true_divide').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y)) prims('true_divide').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)), argnum=1) prims('maximum').def_grad(_maximum_grad_gen0) prims('maximum').def_grad(_maximum_grad_gen1, argnum=1) # TODO: minjie prims('max').def_grad_zero() # negate prims('negative').def_grad(lambda ans, x: operator.neg) prims('transpose').def_grad(lambda ans, x: mxnet.nd.transpose) prims('abs').def_grad(lambda ans, x: lambda g: mxnet.nd.sign(x) * g) prims('sign').def_grad_zero() prims('round').def_grad_zero() prims('ceil').def_grad_zero() prims('floor').def_grad_zero() prims('sqrt').def_grad(lambda ans, x: lambda g: g * 0.5 / mxnet.nd.sqrt(x)) prims('sin').def_grad(lambda ans, x: lambda g: g * mxnet.nd.cos(x)) prims('cos').def_grad(lambda ans, x: lambda g: -g * mxnet.nd.sin(x)) prims('power').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y * mxnet.nd.power(x, y - 1)) ) prims('power').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: g * mxnet.nd.log(x) * ans), argnum=1) prims('reshape').def_grad( lambda _0, x, _1: lambda g: NDArray.reshape(g, x.shape)) prims('expand_dims').def_grad( lambda ans, x, axis: lambda g: NDArray.reshape(g, x.shape))
def def_grads(prims): """ Define gradient function for primitives """ identity = lambda x: x # dot prims('dot').def_grad(lambda ans, a, b: lambda g: ndarray.dot(g, b.T)) prims('dot').def_grad(lambda ans, a, b: lambda g: ndarray.dot(a.T, g), argnum=1) # non-linear #prims.tanh.def_grad(lambda ans, x: lambda g: g / np.cosh(x) ** 2) prims('exp').def_grad(lambda ans, x: lambda g: g * ans) prims('log').def_grad(lambda ans, x: lambda g: g / x) # reduce prims('sum').def_grad(_sum_grad) # + - * / prims('multiply').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y)) prims('multiply').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: x * g), argnum=1) prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, x, identity)) prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, y, identity), argnum=1) prims('subtract').def_grad( lambda ans, x, y: _unbroadcast(ans, x, identity)) prims('subtract').def_grad( lambda ans, x, y: _unbroadcast(ans, y, operator.neg), argnum=1) prims('divide').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y)) prims('divide').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)), argnum=1) prims('true_divide').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y)) prims('true_divide').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)), argnum=1) prims('maximum').def_grad(_maximum_grad_gen0) prims('maximum').def_grad(_maximum_grad_gen1, argnum=1) # TODO: minjie prims('max').def_grad_zero() # negate prims('negative').def_grad(lambda ans, x: operator.neg) prims('transpose').def_grad(lambda ans, x: mxnet.nd.transpose) prims('abs').def_grad(lambda ans, x: lambda g: mxnet.nd.sign(x) * g) prims('sign').def_grad_zero() prims('round').def_grad_zero() prims('ceil').def_grad_zero() prims('floor').def_grad_zero() prims('sqrt').def_grad(lambda ans, x: lambda g: g * 0.5 / mxnet.nd.sqrt(x)) prims('sin').def_grad(lambda ans, x: lambda g: g * mxnet.nd.cos(x)) prims('cos').def_grad(lambda ans, x: lambda g: -g * mxnet.nd.sin(x)) prims('power').def_grad(lambda ans, x, y: _unbroadcast( ans, x, lambda g: g * y * mxnet.nd.power(x, y - 1))) prims('power').def_grad(lambda ans, x, y: _unbroadcast( ans, y, lambda g: g * mxnet.nd.log(x) * ans), argnum=1) prims('reshape').def_grad( lambda _0, x, _1: lambda g: NDArray.reshape(g, x.shape)) prims('expand_dims').def_grad( lambda ans, x, axis: lambda g: NDArray.reshape(g, x.shape))
def forward(self, encoder_output: nd.NDArray, label=None, label_lengths=None): no_label = label is None or label_lengths is None encoder_output = nd.transpose(encoder_output, (0, 2, 3, 1)) encoder_output = encoder_output.reshape( (encoder_output.shape[0], -1, encoder_output.shape[3])) batch_max_len = self.max_len if no_label else int( label_lengths.max().asscalar()) - 1 # Initialize hidden states encoder_output_mean = encoder_output.mean(axis=1) h = self.init_h(encoder_output_mean) c = self.init_c(encoder_output_mean) # Two tensors to store outputs predictions = [] alphas = [] if not no_label: label_embedded = self.embedding(label) else: bs = encoder_output.shape[0] x_t = self.embedding( nd.zeros(shape=(bs, ), ctx=encoder_output.context)) for t in range(batch_max_len): if not no_label: x_t = label_embedded[:, t] if self._use_current_state: _, [h, c] = self.lstm_cell(x_t, [h, c]) if self._use_adaptive_attention: atten_weights, alpha = self.attention( encoder_output, h, x_t, c) else: atten_weights, alpha = self.attention(encoder_output, h) atten_weights = self.f_beta(h).sigmoid() * atten_weights inputs = nd.concat(atten_weights, h, dim=1) preds = self.out(self.dropout(inputs)) else: atten_weights, alpha = self.attention(encoder_output, h) atten_weights = nd.sigmoid(self.f_beta(h)) * atten_weights inputs = nd.concat(x_t, atten_weights, dim=1) _, [h, c] = self.lstm_cell(inputs, [h, c]) preds = self.out(self.dropout(h)) x_t = self.embedding(preds.argmax(axis=1)) predictions.append(preds) alphas.append(alpha) predictions = nd.concat(*[x.expand_dims(axis=1) for x in predictions], dim=1) alphas = nd.concat(*[x.expand_dims(axis=1) for x in alphas], dim=1) return predictions, alphas
def stat_helper(name, array): """wrapper for executor callback""" import ctypes from mxnet.ndarray import NDArray from mxnet.base import NDArrayHandle, py_str array = ctypes.cast(array, NDArrayHandle) array = NDArray(array, writable=False) array.wait_to_read() elapsed = float(time.time()-stat_helper.start_time)*1000 if elapsed>.01: print (name, array.shape, ('%.1fms' % (elapsed,))) stat_helper.start_time=time.time()
def stat_helper(name, array): """wrapper for executor callback""" import ctypes from mxnet.ndarray import NDArray from mxnet.base import NDArrayHandle, py_str array = ctypes.cast(array, NDArrayHandle) array = NDArray(array, writable=False) array.wait_to_read() elapsed = float(time.time()-stat_helper.start_time)*1000 # if elapsed>0.: # print (name, array.shape, ('%.1fms' % (elapsed,))) # stat_helper.start_time=time.time() array = array.asnumpy() print(name, array.shape, np.average(array), np.std(array), ('%.1fms' % (float(time.time()-stat_helper.start_time)*1000))) stat_helper.internal_shapes.append((name,array.shape))
def _partition_tensor(self, size): """Zero-copy implementation. Note: ndarray works for up to ~4 billion parameters. Below 2 lines are buggy with horovod -- causing bad performance. tmp = self._tensor.reshape(-1, 1) avatar = tmp[start:end] """ number = (self._tensor.size - 1) // size + 1 if number > self._tensor.shape[0]: self._logger.warning( "The number of tensor rows (with shape {}) is smaller than partition number {}." .format(self._tensor.shape, number)) number = self._tensor.shape[0] num_per_partition = self._tensor.shape[0] // number partitions_with_extra = self._tensor.shape[0] % number partitions = [] start = 0 end = num_per_partition for i in range(number): handle = NDArrayHandle() check_call( BYTESCHEDULER_LIB.bytescheduler_get_ndarray_avatar( self._tensor.handle, byref(handle))) avatar = NDArray(handle)[start:end] partitions.append(avatar) start = end end += num_per_partition if i >= number - partitions_with_extra - 1: end += 1 return partitions
def _prepare(self): """Post start barrier OP, start OP, comm OP, end OP and end barrier OP to MXNet engine. The function of each kind of OP is explained below. start barrier OP: barrier the start of a parent ByteTask, used to maintain original dependency. start OP: It notifies Core about task readiness. It is also used to delay the start of a child ByteTask. comm OP: the OP that does real communication, e.g., push, pull, allreduce. end OP: an OP that runs after a child ByteTask is finished. It notifies Core about the task completion. end barrier OP: an OP that runs after the parent ByteTask is finished, used to maintain original dependency. """ if self.parent is None: real = self._tensor.handle avatar = NDArrayHandle() check_call( BYTESCHEDULER_LIB.bytescheduler_get_ndarray_avatar( real, byref(avatar))) self._avatar = NDArray(avatar) avatar = self._avatar.handle else: real = self.parent._tensor.handle avatar = self._tensor.handle self._post_start_barrier(avatar, real) self._post_start_op(avatar) # Post real op if self.parent is None: self._post_communication(self._avatar) else: self._post_communication(self._tensor) self._post_end_op(avatar) self._post_end_barrier(avatar, real)
def _partition_single_tensor(self, tensor, size): """Only partition a single tensor. Arguments: size: An integer. After partitioning, each tensor partition size must be equal or smaller than `size`. Returns: A list of partitioned tensors. """ number = (tensor.size - 1) // size + 1 if number > tensor.shape[0]: self._logger.warning( "The number of tensor rows (with shape {}) is smaller than partition number {}." .format(tensor.shape, number)) number = tensor.shape[0] num_per_partition = tensor.shape[0] // number partitions_with_extra = tensor.shape[0] % number partitions = [] start = 0 end = num_per_partition for i in range(number): handle = NDArrayHandle() check_call( BYTESCHEDULER_LIB.bytescheduler_get_ndarray_avatar( tensor.handle, byref(handle))) avatar = NDArray(handle)[start:end] partitions.append(avatar) start = end end += num_per_partition if i >= number - partitions_with_extra - 1: end += 1 return partitions
def view_classify(img: NDArray, ps: NDArray, version="MNIST"): ''' Function for viewing an image and it's predicted classes. ''' ps = ps.asnumpy().squeeze() fig, (ax1, ax2) = plt.subplots(figsize=(6, 9), ncols=2) ax1.imshow(img.asnumpy().squeeze()) ax1.axis('off') ax2.barh(np.arange(10), ps) ax2.set_aspect(0.1) ax2.set_yticks(np.arange(10)) if version == "MNIST": ax2.set_yticklabels(np.arange(10)) ax2.set_title('Class Probability') ax2.set_xlim(0, 1.1) plt.tight_layout()
def view_recon(img: NDArray, recon): ''' Function for displaying an image (as a PyTorch Tensor) and its reconstruction also a PyTorch Tensor ''' fig, axes = plt.subplots(ncols=2, sharex=True, sharey=True) axes[0].imshow(img.asnumpy().squeeze()) axes[1].imshow(recon.asnumpy().squeeze()) for ax in axes: ax.axis('off') ax.set_adjustable('box-forced')
def collect(self, name, arr): """Callback function for collecting layer output NDArrays.""" name = py_str(name) if self.include_layer is not None and not self.include_layer(name): return handle = ctypes.cast(arr, NDArrayHandle) arr = NDArray(handle, writable=False).copyto(cpu()) if self.logger is not None: self.logger.info("Collecting layer %s output of shape %s" % (name, arr.shape)) if name in self.nd_dict: self.nd_dict[name].append(arr) else: self.nd_dict[name] = [arr]
def _batchify(data: nd.NDArray, batch_size): """ Make a batch tensor out of a vector :param data: vector :param batch_size: NN :return: (IN,NN) tensor """ # Work out how cleanly we can divide the dataset into bsz parts. nbatch = len(data) // batch_size # Trim off any extra elements that wouldn't cleanly fit (remainders). data = data[0:nbatch * batch_size] # Evenly divide the data across the bsz batches. data = data.reshape(batch_size, -1).transpose() # if torch.cuda.is_available(): # data = data.cuda() return data
def collect(self, name, arr): """Callback function for collecting min and max values from an NDArray.""" name = py_str(name) if self.include_layer is not None and not self.include_layer(name): return handle = ctypes.cast(arr, NDArrayHandle) arr = NDArray(handle, writable=False) min_range = ndarray.min(arr).asscalar() max_range = ndarray.max(arr).asscalar() if name in self.min_max_dict: cur_min_max = self.min_max_dict[name] self.min_max_dict[name] = (min(cur_min_max[0], min_range), max(cur_min_max[1], max_range)) else: self.min_max_dict[name] = (min_range, max_range) if self.logger is not None: self.logger.info("Collecting layer %s min_range=%f, max_range=%f" % (name, min_range, max_range))
def train(self, d: HybridSequential, x: NDArray, y: NDArray) -> float: # with autograd.record(): # loss = (lambda y_hat: self.lossfun(1, d(concat(x, y_hat, dim=1)), y, y_hat))(self._network(x)) with autograd.record(): gen_out = self._network(x) y_hat = gen_out y = y # x = x.repeat(int(y.shape[0]/x.shape[0]), 0) loss = self.lossfun(1, d(concat(x, y_hat, dim=1)), y.reshape((-1, 3, 96, 96)), y_hat.reshape((-1, 3, 96, 96))) loss.backward() self.trainer.step(1) return float(loss.asscalar())
def stat_helper(name, array): array = ctypes.cast(array, NDArrayHandle) array = NDArray(array, writable=False) if not self.activated or not self.re_prog.match(py_str(name)): return self.queue.append((self.step, py_str(name), stat(array)))
def sq_loss(ey: nd.NDArray, y: nd.NDArray): ty = ey.reshape(shape=y.shape) loss = (ty - y)**2 loss = loss.sum() / loss.size return loss
def _prepare(self): """Post start barrier OP, start OP, comm OP, end OP and end barrier OP to MXNet engine. The function of each kind of OP is explained below. start barrier OP: barrier the start of a parent ByteTask, used to maintain original dependency. start OP: It notifies Core about task readiness. It is also used to delay the start of a child ByteTask. comm OP: the OP that does real communication, e.g., push, pull, allreduce. end OP: an OP that runs after a child ByteTask is finished. It notifies Core about the task completion. end barrier OP: an OP that runs after the parent ByteTask is finished, used to maintain original dependency. Below are several key data structures. self._tensor: a list of NDArrays of the same key of all devices. If push_pull, self._tensor includes push list and pull list of NDArrays of all devices. real: the original handle list of self._tensor, used for keep dependency. avatar: a new handle list of self._tensor. """ if self.parent is None: if self.op == "push_pull": push_real = [t.handle for t in self._push_tensor] if isinstance( self._push_tensor, (tuple, list)) else [self._push_tensor.handle] pull_real = [t.handle for t in self._pull_tensor] if isinstance( self._pull_tensor, (tuple, list)) else [self._pull_tensor.handle] assert len(push_real) == len(pull_real) real = push_real + pull_real else: real = [t.handle for t in self._tensor] if isinstance( self._tensor, (tuple, list)) else [self._tensor.handle] avatar = [] for h in real: avatar_h = NDArrayHandle() check_call( BYTESCHEDULER_LIB.bytescheduler_get_ndarray_avatar( h, byref(avatar_h))) avatar.append(avatar_h) if self.op == "push_pull": # push avatar and pull avatar NDArrays self._avatar = [[ NDArray(_) for _ in avatar[:int(len(avatar) / 2)] ], [NDArray(_) for _ in avatar[int(len(avatar) / 2):]]] avatar = [_.handle for _ in self._avatar[0] ] + [_.handle for _ in self._avatar[1]] else: self._avatar = [NDArray(_) for _ in avatar] avatar = [_.handle for _ in self._avatar] else: if self.op == "push_pull": push_real = [ t.handle for t in self.parent._push_tensor ] if isinstance(self.parent._push_tensor, (tuple, list)) else [self.parent._push_tensor.handle] pull_real = [ t.handle for t in self.parent._pull_tensor ] if isinstance(self.parent._pull_tensor, (tuple, list)) else [self.parent._pull_tensor.handle] real = push_real + pull_real push_avatar = [t.handle for t in self._push_tensor] if isinstance( self._push_tensor, (tuple, list)) else [self._push_tensor.handle] pull_avatar = [t.handle for t in self._pull_tensor] if isinstance( self._pull_tensor, (tuple, list)) else [self._pull_tensor.handle] avatar = push_avatar + pull_avatar else: real = [t.handle for t in self.parent._tensor] if isinstance( self.parent._tensor, (tuple, list)) else [self.parent._tensor.handle] avatar = [t.handle for t in self._tensor] if isinstance( self._tensor, (tuple, list)) else [self._tensor.handle] self._post_start_barrier(avatar, real) self._post_start_op(avatar) self._post_push_pull_barrier(avatar) # post real op if self.parent is None: self._post_communication(self._avatar) else: self._post_communication(self._tensor) self._post_end_op(avatar) self._post_end_barrier(avatar, real)
def _repackage_hidden(h: nd.NDArray): """Wraps hidden states in new Variables, to detach them from their history.""" return h.detach()
def argmax_batch(vecs: nd.NDArray): # _, idx = torch.max(vecs, 1) # return idx return vecs.argmax(axis=1)
def argmax(vec: nd.NDArray): # _, idx = torch.max(vec, 1) # return to_scalar(idx) return vec.argmax(1)
def def_grads(prims): """ Define gradient function for primitives """ identity = lambda x: x # dot prims('dot').def_grad( lambda ans, a, b: lambda g: mx.nd.dot(g, b, transpose_b=True)) prims('dot').def_grad( lambda ans, a, b: lambda g: mx.nd.dot(a, g, transpose_a=True), argnum=1) # non-linear prims('tanh').def_grad(lambda ans, x: lambda g: g * (1 - ans**2)) prims('exp').def_grad(lambda ans, x: lambda g: g * ans) prims('log').def_grad(lambda ans, x: lambda g: g / x) # reduce prims('sum').def_grad(_reduce_sum_grad_gen) prims('max').def_grad(_reduce_select_grad_gen) prims('min').def_grad(_reduce_select_grad_gen) # + - * / prims('multiply').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g * y)) prims('multiply').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: x * g), argnum=1) prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, x, identity)) prims('add').def_grad(lambda ans, x, y: _unbroadcast(ans, y, identity), argnum=1) prims('subtract').def_grad( lambda ans, x, y: _unbroadcast(ans, x, identity)) prims('subtract').def_grad( lambda ans, x, y: _unbroadcast(ans, y, operator.neg), argnum=1) prims('divide').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y)) prims('divide').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)), argnum=1) prims('true_divide').def_grad( lambda ans, x, y: _unbroadcast(ans, x, lambda g: g / y)) prims('true_divide').def_grad( lambda ans, x, y: _unbroadcast(ans, y, lambda g: -g * x / (y * y)), argnum=1) prims('maximum').def_grad(_selection_grad_gen0) prims('maximum').def_grad(_selection_grad_gen1, argnum=1) prims('minimum').def_grad(_selection_grad_gen0) prims('minimum').def_grad(_selection_grad_gen1, argnum=1) # negate prims('negative').def_grad(lambda ans, x: operator.neg) prims('transpose').def_grad(lambda ans, x: mx.nd.transpose) prims('abs').def_grad(lambda ans, x: lambda g: mx.nd.sign(x) * g) prims('sign').def_grad_zero() prims('round').def_grad_zero() prims('ceil').def_grad_zero() prims('floor').def_grad_zero() prims('sqrt').def_grad(lambda ans, x: lambda g: g * 0.5 / mx.nd.sqrt(x)) prims('sin').def_grad(lambda ans, x: lambda g: g * mx.nd.cos(x)) prims('cos').def_grad(lambda ans, x: lambda g: -g * mx.nd.sin(x)) prims('power').def_grad(lambda ans, x, y: _unbroadcast( ans, x, lambda g: g * y * mx.nd.power(x, y - 1))) prims('power').def_grad(lambda ans, x, y: _unbroadcast( ans, y, lambda g: g * mx.nd.log(x) * ans), argnum=1) prims('reshape').def_grad( lambda _0, x, _1: lambda g: NDArray.reshape(g, x.shape)) prims('expand_dims').def_grad( lambda ans, x, axis: lambda g: NDArray.reshape(g, x.shape)) prims('softmax_output').def_grad(_softmax_output_grad)