def resnet50_predict(args_opt):
    class_num = cfg.class_num
    local_data_path = '/cache/data'
    ckpt_file_slice = args_opt.checkpoint_path.split('/')
    ckpt_file = ckpt_file_slice[len(ckpt_file_slice) - 1]
    local_ckpt_path = '/cache/' + ckpt_file

    # set graph mode and parallel mode
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        save_graphs=False)

    # data download
    print('Download data.')
    mox.file.copy_parallel(src_url=args_opt.data_url, dst_url=local_data_path)
    mox.file.copy_parallel(src_url=args_opt.checkpoint_path,
                           dst_url=local_ckpt_path)

    # load checkpoint into net
    net = resnet50(class_num=class_num)
    param_dict = load_checkpoint(local_ckpt_path)
    load_param_into_net(net, param_dict)
    net.set_train(False)

    # preprocess the image
    images = os.listdir(local_data_path)
    for image in images:
        img = data_preprocess(os.path.join(local_data_path, image))
        # predict model
        res = net(Tensor(img.reshape((1, 3, 224, 224)),
                         mindspore.float32)).asnumpy()

        predict_label = label_list[res[0].argmax()]
        print("预测的蘑菇标签为:\n\t" + predict_label + "\n")
def create_network(name, **kwargs):
    if name == 'resnet50':
        return resnet50(*args, **kwargs)
    if name == 'resnet101':
        return resnet101(*args, **kwargs)
    if name == 'se_resnet50':
        return se_resnet50(*args, **kwargs)
    raise NotImplementedError(f"{name} is not implemented in the repo")
def resnet50_train(args_opt):
    epoch_size = args_opt.epoch_size
    batch_size = cfg.batch_size
    class_num = cfg.class_num
    loss_scale_num = cfg.loss_scale
    local_data_path = '/cache/data'
    local_ckpt_path = '/cache/ckpt_file'

    # set graph mode and parallel mode
    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)

    # data download
    print('Download data.')
    mox.file.copy_parallel(src_url=args_opt.data_url, dst_url=local_data_path)

    # create dataset
    print('Create train and evaluate dataset.')
    train_dataset = create_dataset(dataset_path=local_data_path, do_train=True,
                                   repeat_num=epoch_size, batch_size=batch_size)
    train_step_size = train_dataset.get_dataset_size()
    print('Create dataset success.')

    # create model
    net = resnet50(class_num=class_num)
    # reduction='mean' means that apply reduction of mean to loss
    loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    lr = Tensor(get_lr(global_step=0, total_epochs=epoch_size, steps_per_epoch=train_step_size))
    opt = Momentum(net.trainable_params(), lr, momentum=0.9, weight_decay=1e-4, loss_scale=loss_scale_num)
    loss_scale = FixedLossScaleManager(loss_scale_num, False)

    # amp_level="O2" means that the hybrid precision of O2 mode is used for training
    # the whole network except that batchnorm will be cast into float16 format and dynamic loss scale will be used
    # 'keep_batchnorm_fp32 = False' means that use the float16 format
    model = Model(net, amp_level="O2", keep_batchnorm_fp32=False, loss_fn=loss,
                  optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})

    # define performance callback to show ips and loss callback to show loss for every epoch
    time_cb = TimeMonitor(data_size=train_step_size)
    performance_cb = PerformanceCallback(batch_size)
    loss_cb = LossMonitor()
    cb = [time_cb, performance_cb, loss_cb]
    config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_epochs * train_step_size,
                                 keep_checkpoint_max=cfg.keep_checkpoint_max)
    ckpt_cb = ModelCheckpoint(prefix="resnet", directory=local_ckpt_path, config=config_ck)
    cb += [ckpt_cb]

    print(f'Start run training, total epoch: {epoch_size}.')
    model.train(epoch_size, train_dataset, callbacks=cb)

    # upload checkpoint files
    print('Upload checkpoint.')
    mox.file.copy_parallel(src_url=local_ckpt_path, dst_url=args_opt.train_url)
def resnet50_eval(args_opt):
    class_num = cfg.class_num
    local_data_path = '/cache/data'
    ckpt_file_slice = args_opt.checkpoint_path.split('/')
    ckpt_file = ckpt_file_slice[len(ckpt_file_slice) - 1]
    local_ckpt_path = '/cache/' + ckpt_file

    # set graph mode and parallel mode
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        save_graphs=False)

    # data download
    print('Download data.')
    mox.file.copy_parallel(src_url=args_opt.data_url, dst_url=local_data_path)
    mox.file.copy_parallel(src_url=args_opt.checkpoint_path,
                           dst_url=local_ckpt_path)

    # create dataset
    dataset = create_dataset(dataset_path=local_data_path,
                             do_train=False,
                             batch_size=cfg.batch_size)

    # load checkpoint into net
    net = resnet50(class_num=class_num)
    param_dict = load_checkpoint(local_ckpt_path)
    load_param_into_net(net, param_dict)
    net.set_train(False)

    # define loss and model
    if not cfg.use_label_smooth:
        cfg.label_smooth_factor = 0.0
    loss = CrossEntropySmooth(sparse=True,
                              reduction='mean',
                              smooth_factor=cfg.label_smooth_factor,
                              num_classes=cfg.class_num)
    model = Model(net,
                  loss_fn=loss,
                  metrics={'top_1_accuracy', 'top_5_accuracy'})

    # eval model
    res = model.eval(dataset)
    print("result:", res, "ckpt=", args_opt.checkpoint_path)