def resnet50_predict(args_opt): class_num = cfg.class_num local_data_path = '/cache/data' ckpt_file_slice = args_opt.checkpoint_path.split('/') ckpt_file = ckpt_file_slice[len(ckpt_file_slice) - 1] local_ckpt_path = '/cache/' + ckpt_file # set graph mode and parallel mode context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) # data download print('Download data.') mox.file.copy_parallel(src_url=args_opt.data_url, dst_url=local_data_path) mox.file.copy_parallel(src_url=args_opt.checkpoint_path, dst_url=local_ckpt_path) # load checkpoint into net net = resnet50(class_num=class_num) param_dict = load_checkpoint(local_ckpt_path) load_param_into_net(net, param_dict) net.set_train(False) # preprocess the image images = os.listdir(local_data_path) for image in images: img = data_preprocess(os.path.join(local_data_path, image)) # predict model res = net(Tensor(img.reshape((1, 3, 224, 224)), mindspore.float32)).asnumpy() predict_label = label_list[res[0].argmax()] print("预测的蘑菇标签为:\n\t" + predict_label + "\n")
def create_network(name, **kwargs): if name == 'resnet50': return resnet50(*args, **kwargs) if name == 'resnet101': return resnet101(*args, **kwargs) if name == 'se_resnet50': return se_resnet50(*args, **kwargs) raise NotImplementedError(f"{name} is not implemented in the repo")
def resnet50_train(args_opt): epoch_size = args_opt.epoch_size batch_size = cfg.batch_size class_num = cfg.class_num loss_scale_num = cfg.loss_scale local_data_path = '/cache/data' local_ckpt_path = '/cache/ckpt_file' # set graph mode and parallel mode context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) # data download print('Download data.') mox.file.copy_parallel(src_url=args_opt.data_url, dst_url=local_data_path) # create dataset print('Create train and evaluate dataset.') train_dataset = create_dataset(dataset_path=local_data_path, do_train=True, repeat_num=epoch_size, batch_size=batch_size) train_step_size = train_dataset.get_dataset_size() print('Create dataset success.') # create model net = resnet50(class_num=class_num) # reduction='mean' means that apply reduction of mean to loss loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') lr = Tensor(get_lr(global_step=0, total_epochs=epoch_size, steps_per_epoch=train_step_size)) opt = Momentum(net.trainable_params(), lr, momentum=0.9, weight_decay=1e-4, loss_scale=loss_scale_num) loss_scale = FixedLossScaleManager(loss_scale_num, False) # amp_level="O2" means that the hybrid precision of O2 mode is used for training # the whole network except that batchnorm will be cast into float16 format and dynamic loss scale will be used # 'keep_batchnorm_fp32 = False' means that use the float16 format model = Model(net, amp_level="O2", keep_batchnorm_fp32=False, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}) # define performance callback to show ips and loss callback to show loss for every epoch time_cb = TimeMonitor(data_size=train_step_size) performance_cb = PerformanceCallback(batch_size) loss_cb = LossMonitor() cb = [time_cb, performance_cb, loss_cb] config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_epochs * train_step_size, keep_checkpoint_max=cfg.keep_checkpoint_max) ckpt_cb = ModelCheckpoint(prefix="resnet", directory=local_ckpt_path, config=config_ck) cb += [ckpt_cb] print(f'Start run training, total epoch: {epoch_size}.') model.train(epoch_size, train_dataset, callbacks=cb) # upload checkpoint files print('Upload checkpoint.') mox.file.copy_parallel(src_url=local_ckpt_path, dst_url=args_opt.train_url)
def resnet50_eval(args_opt): class_num = cfg.class_num local_data_path = '/cache/data' ckpt_file_slice = args_opt.checkpoint_path.split('/') ckpt_file = ckpt_file_slice[len(ckpt_file_slice) - 1] local_ckpt_path = '/cache/' + ckpt_file # set graph mode and parallel mode context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) # data download print('Download data.') mox.file.copy_parallel(src_url=args_opt.data_url, dst_url=local_data_path) mox.file.copy_parallel(src_url=args_opt.checkpoint_path, dst_url=local_ckpt_path) # create dataset dataset = create_dataset(dataset_path=local_data_path, do_train=False, batch_size=cfg.batch_size) # load checkpoint into net net = resnet50(class_num=class_num) param_dict = load_checkpoint(local_ckpt_path) load_param_into_net(net, param_dict) net.set_train(False) # define loss and model if not cfg.use_label_smooth: cfg.label_smooth_factor = 0.0 loss = CrossEntropySmooth(sparse=True, reduction='mean', smooth_factor=cfg.label_smooth_factor, num_classes=cfg.class_num) model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) # eval model res = model.eval(dataset) print("result:", res, "ckpt=", args_opt.checkpoint_path)