from train_utils import SaveInOut, TrainWrap
from official.cv.mobilenetv2.src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2
import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export

context.set_context(mode=context.PYNATIVE_MODE,
                    device_target="GPU",
                    save_graphs=False)
batch = 16

backbone_net = MobileNetV2Backbone()
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels,
                           num_classes=10)
n = mobilenet_v2(backbone_net, head_net)

loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
optimizer = nn.Momentum(n.trainable_params(), 0.01, 0.9, use_nesterov=False)
net = TrainWrap(n, loss_fn, optimizer)

x = Tensor(np.random.randn(batch, 3, 224, 224), mstype.float32)
label = Tensor(np.zeros([batch, 10]).astype(np.float32))
export(net,
       x,
       label,
       file_name="mindir/mobilenetv2_train",
       file_format='MINDIR')

if len(sys.argv) > 1:
    SaveInOut(sys.argv[1] + "mobilenetv2", x, label, n, net, sparse=False)
HEAD.bias.set_data(M.Tensor(np.zeros(HEAD.bias.data.shape, dtype="float32")))

n = TransferNet(BACKBONE, HEAD)

trainable_weights_list = []
trainable_weights_list.extend(n.head.trainable_params())
trainable_weights = ParameterTuple(trainable_weights_list)

M.context.set_context(mode=M.context.PYNATIVE_MODE,
                      device_target="GPU",
                      save_graphs=False)
BATCH_SIZE = 32
X = M.Tensor(np.ones((BATCH_SIZE, 3, 224, 224)), M.float32)
label = M.Tensor(np.zeros([BATCH_SIZE, 10]).astype(np.float32))

sgd = M.nn.SGD(trainable_weights,
               learning_rate=0.01,
               momentum=0.9,
               dampening=0.01,
               weight_decay=0.0,
               nesterov=False,
               loss_scale=1.0)
net = TrainWrap(n, optimizer=sgd, weights=trainable_weights)
export(net,
       X,
       label,
       file_name="transfer_learning_tod.mindir",
       file_format='MINDIR')

print("Exported")
Esempio n. 3
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lenet_export."""

import sys
from mindspore import context, Tensor
import mindspore.common.dtype as mstype
from mindspore.train.serialization import export
from lenet import LeNet5
import numpy as np
from train_utils import TrainWrap

sys.path.append('../../../cv/lenet/src/')

n = LeNet5()
n.set_train()
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)

batch_size = 32
x = Tensor(np.ones((batch_size, 1, 32, 32)), mstype.float32)
label = Tensor(np.zeros([batch_size, 10]).astype(np.float32))
net = TrainWrap(n)
export(net, x, label, file_name="lenet_tod.mindir", file_format='MINDIR')

print("finished exporting")
Esempio n. 4
0
BATCH_SIZE = 16
X = M.Tensor(np.ones((BATCH_SIZE, 3, 224, 224)), M.float32)
export(BACKBONE,
       X,
       file_name="transfer_learning_tod_backbone",
       file_format='MINDIR')

label = M.Tensor(np.zeros([BATCH_SIZE, 10]).astype(np.float32))
HEAD = M.nn.Dense(1000, 10)
HEAD.weight.set_data(
    M.Tensor(
        np.random.normal(0, 0.1, HEAD.weight.data.shape).astype("float32")))
HEAD.bias.set_data(M.Tensor(np.zeros(HEAD.bias.data.shape, dtype="float32")))

sgd = M.nn.SGD(HEAD.trainable_params(),
               learning_rate=0.01,
               momentum=0.9,
               dampening=0.01,
               weight_decay=0.0,
               nesterov=False,
               loss_scale=1.0)
net = TrainWrap(HEAD, optimizer=sgd)
backbone_out = M.Tensor(np.zeros([BATCH_SIZE, 1000]).astype(np.float32))
export(net,
       backbone_out,
       label,
       file_name="transfer_learning_tod_head",
       file_format='MINDIR')

print("Exported")