def test_train_lenet_with_new_interface(num_classes=10, epoch=20, batch_size=32): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") network = LeNet5(num_classes) criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_with_criterion = WithLossCell(network, criterion) net_with_criterion.set_train() weights = ParameterTuple(network.trainable_params()) optimizer = nn.Momentum(weights, 0.1, 0.9) train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True) losses = [] for i in range(0, epoch): data = Tensor( np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01) label = Tensor(np.ones([batch_size]).astype(np.int32)) sens = Tensor(np.ones([1]).astype(np.float32)) loss, grads = train_network(data, label, sens) grads = F.identity(grads) optimizer(grads) losses.append(loss) assert losses[-1].asnumpy() < 0.01 assert losses[-1].asnumpy() > 0.001
def test_big_batchSize_with_new_interface(num_classes=10, epoch=8, batch_size=338): net = resnet50(num_classes) criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') net_with_criterion = WithLossCell(net, criterion) net_with_criterion.set_train() weights = ParameterTuple( filter(lambda x: x.requires_grad, net.get_parameters())) optimizer = Momentum(weights, 0.1, 0.9) train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True, sens=1.0) losses = [] for i in range(0, epoch): data = Tensor( np.ones([batch_size, 3, 224, 224]).astype(np.float32) * 0.01) label = Tensor(np.ones([batch_size]).astype(np.int32)) loss, grads = train_network(data, label) grads = F.identity(grads) optimizer(grads) losses.append(loss) assert (losses[-1].asnumpy() < 0.8)
def construct(self, x, clip_norm): mul_x = F.square(x) l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32) cond = self.greater_(l2sum, self.zero) ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum) intermediate = x * clip_norm max_norm = self.max_op(l2norm, clip_norm) values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1) values_clip = self.reshape(values_clip, self.shape(x)) values_clip = identity(values_clip) return values_clip
def construct(self, x, clip_norm): """add ms_function decorator for pynative mode""" mul_x = F.square(x) l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32) cond = self.greater_(l2sum, 0) ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum) _dtype_check(self.dtype(x)) if _is_equal_one(clip_norm): intermediate = x else: intermediate = x * clip_norm max_norm = self.max_op(l2norm, clip_norm) values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1) values_clip = self.reshape(values_clip, self.shape(x)) values_clip = identity(values_clip) return values_clip