示例#1
0
def test(cloud_args=None):
    """test"""
    args = parse_args(cloud_args)
    context.set_context(mode=context.GRAPH_MODE,
                        enable_auto_mixed_precision=True,
                        device_target=args.platform,
                        save_graphs=False)
    if os.getenv('DEVICE_ID', "not_set").isdigit():
        context.set_context(device_id=int(os.getenv('DEVICE_ID')))

    # init distributed
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=args.group_size,
                                          gradients_mean=True)

    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')
    if os.path.isdir(args.pretrained):
        models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt')))
        print(models)
        if args.graph_ckpt:
            f = lambda x: -1 * int(
                os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split(
                    '_')[0])
        else:
            f = lambda x: -1 * int(
                os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1])
        args.models = sorted(models, key=f)
    else:
        args.models = [
            args.pretrained,
        ]

    for model in args.models:
        de_dataset = classification_dataset(args.data_dir,
                                            image_size=args.image_size,
                                            per_batch_size=args.per_batch_size,
                                            max_epoch=1,
                                            rank=args.rank,
                                            group_size=args.group_size,
                                            mode='eval')
        eval_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
        network = get_network(args.backbone,
                              args.num_classes,
                              platform=args.platform)
        if network is None:
            raise NotImplementedError('not implement {}'.format(args.backbone))

        load_pretrain_model(model, network, args)

        img_tot = 0
        top1_correct = 0
        top5_correct = 0
        if args.platform == "Ascend":
            network.to_float(mstype.float16)
        else:
            auto_mixed_precision(network)
        network.set_train(False)
        t_end = time.time()
        it = 0
        for data, gt_classes in eval_dataloader:
            output = network(Tensor(data, mstype.float32))
            output = output.asnumpy()

            top1_output = np.argmax(output, (-1))
            top5_output = np.argsort(output)[:, -5:]

            t1_correct = np.equal(top1_output, gt_classes).sum()
            top1_correct += t1_correct
            top5_correct += get_top5_acc(top5_output, gt_classes)
            img_tot += args.per_batch_size

            if args.rank == 0 and it == 0:
                t_end = time.time()
                it = 1
        if args.rank == 0:
            time_used = time.time() - t_end
            fps = (img_tot - args.per_batch_size) * args.group_size / time_used
            args.logger.info(
                'Inference Performance: {:.2f} img/sec'.format(fps))
        results = get_result(args, model, top1_correct, top5_correct, img_tot)
        top1_correct = results[0, 0]
        top5_correct = results[1, 0]
        img_tot = results[2, 0]
        acc1 = 100.0 * top1_correct / img_tot
        acc5 = 100.0 * top5_correct / img_tot
        args.logger.info('after allreduce eval: top1_correct={}, tot={},'
                         'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot,
                                                    acc1))
        args.logger.info('after allreduce eval: top5_correct={}, tot={},'
                         'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot,
                                                    acc5))
    if args.is_distributed:
        release()
示例#2
0
def train(cloud_args=None):
    """training process"""
    args = parse_args(cloud_args)
    context.set_context(mode=context.GRAPH_MODE,
                        enable_auto_mixed_precision=True,
                        device_target=args.platform,
                        save_graphs=False)
    if os.getenv('DEVICE_ID', "not_set").isdigit():
        context.set_context(device_id=int(os.getenv('DEVICE_ID')))

    # init distributed
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=args.group_size,
                                          gradients_mean=True)
    # dataloader
    de_dataset = classification_dataset(args.data_dir,
                                        args.image_size,
                                        args.per_batch_size,
                                        1,
                                        args.rank,
                                        args.group_size,
                                        num_parallel_workers=8)
    de_dataset.map_model = 4  # !!!important
    args.steps_per_epoch = de_dataset.get_dataset_size()

    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')
    # get network and init
    network = get_network(args.backbone,
                          num_classes=args.num_classes,
                          platform=args.platform)
    if network is None:
        raise NotImplementedError('not implement {}'.format(args.backbone))

    load_pretrain_model(args.pretrained, network, args)

    # lr scheduler
    lr = get_lr(args)

    # optimizer
    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)

    # loss
    if not args.label_smooth:
        args.label_smooth_factor = 0.0
    loss = CrossEntropy(smooth_factor=args.label_smooth_factor,
                        num_classes=args.num_classes)

    if args.is_dynamic_loss_scale == 1:
        loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536,
                                                     scale_factor=2,
                                                     scale_window=2000)
    else:
        loss_scale_manager = FixedLossScaleManager(args.loss_scale,
                                                   drop_overflow_update=False)

    model = Model(network,
                  loss_fn=loss,
                  optimizer=opt,
                  loss_scale_manager=loss_scale_manager,
                  metrics={'acc'},
                  amp_level="O3")

    # checkpoint save
    progress_cb = ProgressMonitor(args)
    callbacks = [
        progress_cb,
    ]
    if args.rank_save_ckpt_flag:
        ckpt_config = CheckpointConfig(
            save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
            keep_checkpoint_max=args.ckpt_save_max)
        save_ckpt_path = os.path.join(args.outputs_dir,
                                      'ckpt_' + str(args.rank) + '/')
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=save_ckpt_path,
                                  prefix='{}'.format(args.rank))
        callbacks.append(ckpt_cb)

    model.train(args.max_epoch,
                de_dataset,
                callbacks=callbacks,
                dataset_sink_mode=True)
示例#3
0
def train(cloud_args=None):
    """training process"""
    args = parse_args(cloud_args)

    # init distributed
    if args.is_distributed:
        init()
        args.rank = get_rank()
        args.group_size = get_group_size()

    if args.is_dynamic_loss_scale == 1:
        args.loss_scale = 1  # for dynamic loss scale can not set loss scale in momentum opt

    # select for master rank save ckpt or all rank save, compatiable for model parallel
    args.rank_save_ckpt_flag = 0
    if args.is_save_on_master:
        if args.rank == 0:
            args.rank_save_ckpt_flag = 1
    else:
        args.rank_save_ckpt_flag = 1

    # logger
    args.outputs_dir = os.path.join(args.ckpt_path,
                                    datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)

    # dataloader
    de_dataset = classification_dataset(args.data_dir, args.image_size,
                                        args.per_batch_size, args.max_epoch,
                                        args.rank, args.group_size)
    de_dataset.map_model = 4  # !!!important
    args.steps_per_epoch = de_dataset.get_dataset_size()

    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')
    # get network and init
    network = get_network(args.backbone, args.num_classes)
    if network is None:
        raise NotImplementedError('not implement {}'.format(args.backbone))
    network.add_flags_recursive(fp16=True)
    # loss
    if not args.label_smooth:
        args.label_smooth_factor = 0.0
    criterion = CrossEntropy(smooth_factor=args.label_smooth_factor,
                             num_classes=args.num_classes)

    # load pretrain model
    if os.path.isfile(args.pretrained):
        param_dict = load_checkpoint(args.pretrained)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values
        load_param_into_net(network, param_dict_new)
        args.logger.info('load model {} success'.format(args.pretrained))

    # lr scheduler
    if args.lr_scheduler == 'exponential':
        lr = warmup_step_lr(args.lr,
                            args.lr_epochs,
                            args.steps_per_epoch,
                            args.warmup_epochs,
                            args.max_epoch,
                            gamma=args.lr_gamma,
                            )
    elif args.lr_scheduler == 'cosine_annealing':
        lr = warmup_cosine_annealing_lr(args.lr,
                                        args.steps_per_epoch,
                                        args.warmup_epochs,
                                        args.max_epoch,
                                        args.T_max,
                                        args.eta_min)
    else:
        raise NotImplementedError(args.lr_scheduler)

    # optimizer
    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)


    criterion.add_flags_recursive(fp32=True)

    # package training process, adjust lr + forward + backward + optimizer
    train_net = BuildTrainNetwork(network, criterion)
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
    else:
        parallel_mode = ParallelMode.STAND_ALONE
    if args.is_dynamic_loss_scale == 1:
        loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
    else:
        loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)

    # Model api changed since TR5_branch 2020/03/09
    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
                                      parameter_broadcast=True, mirror_mean=True)
    model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager)

    # checkpoint save
    progress_cb = ProgressMonitor(args)
    callbacks = [progress_cb,]
    if args.rank_save_ckpt_flag:
        ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
        ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
                                       keep_checkpoint_max=ckpt_max_num)
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=args.outputs_dir,
                                  prefix='{}'.format(args.rank))
        callbacks.append(ckpt_cb)

    model.train(args.max_epoch, de_dataset, callbacks=callbacks, dataset_sink_mode=True)
def create_network(name, *args, **kwargs):
    if name == "renext50":
        get_network("renext50", *args, **kwargs)
        return net
    raise NotImplementedError(f"{name} is not implemented in the repo")
示例#5
0
def test(cloud_args=None):
    """test"""
    args = parse_args(cloud_args)
    context.set_context(mode=context.GRAPH_MODE,
                        enable_auto_mixed_precision=True,
                        device_target=args.platform,
                        save_graphs=False)
    if os.getenv('DEVICE_ID', "not_set").isdigit():
        context.set_context(device_id=int(os.getenv('DEVICE_ID')))

    # init distributed
    if args.is_distributed:
        init()
        args.rank = get_rank()
        args.group_size = get_group_size()
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=args.group_size,
                                          parameter_broadcast=True,
                                          mirror_mean=True)
    else:
        args.rank = 0
        args.group_size = 1

    args.outputs_dir = os.path.join(
        args.log_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

    args.logger = get_logger(args.outputs_dir, args.rank)
    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')
    if os.path.isdir(args.pretrained):
        models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt')))
        print(models)
        if args.graph_ckpt:
            f = lambda x: -1 * int(
                os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split(
                    '_')[0])
        else:
            f = lambda x: -1 * int(
                os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1])
        args.models = sorted(models, key=f)
    else:
        args.models = [
            args.pretrained,
        ]

    for model in args.models:
        de_dataset = classification_dataset(args.data_dir,
                                            image_size=args.image_size,
                                            per_batch_size=args.per_batch_size,
                                            max_epoch=1,
                                            rank=args.rank,
                                            group_size=args.group_size,
                                            mode='eval')
        eval_dataloader = de_dataset.create_tuple_iterator()
        network = get_network(args.backbone,
                              args.num_classes,
                              platform=args.platform)
        if network is None:
            raise NotImplementedError('not implement {}'.format(args.backbone))

        param_dict = load_checkpoint(model)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values

        load_param_into_net(network, param_dict_new)
        args.logger.info('load model {} success'.format(model))

        img_tot = 0
        top1_correct = 0
        top5_correct = 0
        if args.platform == "Ascend":
            network.to_float(mstype.float16)
        else:
            auto_mixed_precision(network)
        network.set_train(False)
        t_end = time.time()
        it = 0
        for data, gt_classes in eval_dataloader:
            output = network(Tensor(data, mstype.float32))
            output = output.asnumpy()

            top1_output = np.argmax(output, (-1))
            top5_output = np.argsort(output)[:, -5:]

            t1_correct = np.equal(top1_output, gt_classes).sum()
            top1_correct += t1_correct
            top5_correct += get_top5_acc(top5_output, gt_classes)
            img_tot += args.per_batch_size

            if args.rank == 0 and it == 0:
                t_end = time.time()
                it = 1
        if args.rank == 0:
            time_used = time.time() - t_end
            fps = (img_tot - args.per_batch_size) * args.group_size / time_used
            args.logger.info(
                'Inference Performance: {:.2f} img/sec'.format(fps))
        results = [[top1_correct], [top5_correct], [img_tot]]
        args.logger.info('before results={}'.format(results))
        if args.is_distributed:
            model_md5 = model.replace('/', '')
            tmp_dir = '/cache'
            if not os.path.exists(tmp_dir):
                os.mkdir(tmp_dir)
            top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(
                args.rank, model_md5)
            top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(
                args.rank, model_md5)
            img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(
                args.rank, model_md5)
            np.save(top1_correct_npy, top1_correct)
            np.save(top5_correct_npy, top5_correct)
            np.save(img_tot_npy, img_tot)
            while True:
                rank_ok = True
                for other_rank in range(args.group_size):
                    top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(
                        other_rank, model_md5)
                    top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(
                        other_rank, model_md5)
                    img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(
                        other_rank, model_md5)
                    if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) or \
                       not os.path.exists(img_tot_npy):
                        rank_ok = False
                if rank_ok:
                    break

            top1_correct_all = 0
            top5_correct_all = 0
            img_tot_all = 0
            for other_rank in range(args.group_size):
                top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(
                    other_rank, model_md5)
                top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(
                    other_rank, model_md5)
                img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(
                    other_rank, model_md5)
                top1_correct_all += np.load(top1_correct_npy)
                top5_correct_all += np.load(top5_correct_npy)
                img_tot_all += np.load(img_tot_npy)
            results = [[top1_correct_all], [top5_correct_all], [img_tot_all]]
            results = np.array(results)
        else:
            results = np.array(results)

        args.logger.info('after results={}'.format(results))
        top1_correct = results[0, 0]
        top5_correct = results[1, 0]
        img_tot = results[2, 0]
        acc1 = 100.0 * top1_correct / img_tot
        acc5 = 100.0 * top5_correct / img_tot
        args.logger.info('after allreduce eval: top1_correct={}, tot={},'
                         'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot,
                                                    acc1))
        args.logger.info('after allreduce eval: top5_correct={}, tot={},'
                         'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot,
                                                    acc5))
    if args.is_distributed:
        release()
示例#6
0
                    help="output file name.")
parser.add_argument("--file_format",
                    type=str,
                    choices=["AIR", "ONNX", "MINDIR"],
                    default="AIR",
                    help="file format")
parser.add_argument("--device_target",
                    type=str,
                    default="Ascend",
                    choices=["Ascend", "GPU", "CPU"],
                    help="device target (default: Ascend)")
args = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE,
                    device_target=args.device_target,
                    device_id=args.device_id)

if __name__ == '__main__':
    net = get_network(num_classes=config.num_classes,
                      platform=args.device_target)

    param_dict = load_checkpoint(args.ckpt_file)
    load_param_into_net(net, param_dict)
    input_shp = [args.batch_size, 3, args.height, args.width]
    input_array = Tensor(
        np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
    export(net,
           input_array,
           file_name=args.file_name,
           file_format=args.file_format)
示例#7
0
    args.image_size = config.image_size
    args.num_classes = config.num_classes
    args.backbone = config.backbone

    args.image_size = list(map(int, config.image_size.split(',')))
    args.image_height = args.image_size[0]
    args.image_width = args.image_size[1]
    args.export_format = config.export_format
    args.export_file = config.export_file
    return args


if __name__ == '__main__':
    args_export = parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_export.platform)

    net = get_network(args_export.backbone,
                      num_classes=args_export.num_classes,
                      platform=args_export.platform)

    param_dict = load_checkpoint(args_export.pretrained)
    load_param_into_net(net, param_dict)
    input_shp = [1, 3, args_export.image_height, args_export.image_width]
    input_array = Tensor(
        np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
    export(net,
           input_array,
           file_name=args_export.export_file,
           file_format=args_export.export_format)
示例#8
0
def train(cloud_args=None):
    """training process"""
    args = parse_args(cloud_args)
    context.set_context(mode=context.GRAPH_MODE,
                        enable_auto_mixed_precision=True,
                        device_target=args.platform,
                        save_graphs=False)
    if os.getenv('DEVICE_ID', "not_set").isdigit():
        context.set_context(device_id=int(os.getenv('DEVICE_ID')))

    # init distributed
    if args.is_distributed:
        if args.platform == "Ascend":
            init()
        else:
            init("nccl")
        args.rank = get_rank()
        args.group_size = get_group_size()
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=args.group_size,
                                          parameter_broadcast=True,
                                          mirror_mean=True)
    else:
        args.rank = 0
        args.group_size = 1

    if args.is_dynamic_loss_scale == 1:
        args.loss_scale = 1  # for dynamic loss scale can not set loss scale in momentum opt

    # select for master rank save ckpt or all rank save, compatiable for model parallel
    args.rank_save_ckpt_flag = 0
    if args.is_save_on_master:
        if args.rank == 0:
            args.rank_save_ckpt_flag = 1
    else:
        args.rank_save_ckpt_flag = 1

    # logger
    args.outputs_dir = os.path.join(
        args.ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)

    # dataloader
    de_dataset = classification_dataset(args.data_dir,
                                        args.image_size,
                                        args.per_batch_size,
                                        1,
                                        args.rank,
                                        args.group_size,
                                        num_parallel_workers=8)
    de_dataset.map_model = 4  # !!!important
    args.steps_per_epoch = de_dataset.get_dataset_size()

    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')
    # get network and init
    network = get_network(args.backbone,
                          args.num_classes,
                          platform=args.platform)
    if network is None:
        raise NotImplementedError('not implement {}'.format(args.backbone))

    # load pretrain model
    if os.path.isfile(args.pretrained):
        param_dict = load_checkpoint(args.pretrained)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values
        load_param_into_net(network, param_dict_new)
        args.logger.info('load model {} success'.format(args.pretrained))

    # lr scheduler
    if args.lr_scheduler == 'exponential':
        lr = warmup_step_lr(
            args.lr,
            args.lr_epochs,
            args.steps_per_epoch,
            args.warmup_epochs,
            args.max_epoch,
            gamma=args.lr_gamma,
        )
    elif args.lr_scheduler == 'cosine_annealing':
        lr = warmup_cosine_annealing_lr(args.lr, args.steps_per_epoch,
                                        args.warmup_epochs, args.max_epoch,
                                        args.T_max, args.eta_min)
    else:
        raise NotImplementedError(args.lr_scheduler)

    # optimizer
    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)

    # loss
    if not args.label_smooth:
        args.label_smooth_factor = 0.0
    loss = CrossEntropy(smooth_factor=args.label_smooth_factor,
                        num_classes=args.num_classes)

    if args.is_dynamic_loss_scale == 1:
        loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536,
                                                     scale_factor=2,
                                                     scale_window=2000)
    else:
        loss_scale_manager = FixedLossScaleManager(args.loss_scale,
                                                   drop_overflow_update=False)

    if args.platform == "Ascend":
        model = Model(network,
                      loss_fn=loss,
                      optimizer=opt,
                      loss_scale_manager=loss_scale_manager,
                      metrics={'acc'},
                      amp_level="O3")
    else:
        auto_mixed_precision(network)
        model = Model(network,
                      loss_fn=loss,
                      optimizer=opt,
                      loss_scale_manager=loss_scale_manager,
                      metrics={'acc'})

    # checkpoint save
    progress_cb = ProgressMonitor(args)
    callbacks = [
        progress_cb,
    ]
    if args.rank_save_ckpt_flag:
        ckpt_config = CheckpointConfig(
            save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
            keep_checkpoint_max=args.ckpt_save_max)
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=args.outputs_dir,
                                  prefix='{}'.format(args.rank))
        callbacks.append(ckpt_cb)

    model.train(args.max_epoch,
                de_dataset,
                callbacks=callbacks,
                dataset_sink_mode=True)
示例#9
0
    args, _ = parser.parse_known_args()
    args.image_size = config.image_size
    args.num_classes = config.num_classes

    args.image_size = list(map(int, config.image_size.split(',')))
    args.image_height = args.image_size[0]
    args.image_width = args.image_size[1]
    args.export_format = config.export_format
    args.export_file = config.export_file
    return args


if __name__ == '__main__':
    args_export = parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_export.platform)

    net = get_network(num_classes=args_export.num_classes,
                      platform=args_export.platform)

    param_dict = load_checkpoint(args_export.pretrained)
    load_param_into_net(net, param_dict)
    input_shp = [1, 3, args_export.image_height, args_export.image_width]
    input_array = Tensor(
        np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
    export(net,
           input_array,
           file_name=args_export.export_file,
           file_format=args_export.export_format)