示例#1
0
def main():
    args, _ = parser.parse_known_args()
    rank_id, rank_size = 0, 1

    context.set_context(mode=context.GRAPH_MODE)

    if args.distributed:
        if args.GPU:
            init("nccl")
            context.set_context(device_target='GPU')
        else:
            raise ValueError("Only supported GPU training.")
        context.reset_auto_parallel_context()
        rank_id = get_rank()
        rank_size = get_group_size()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=rank_size)
    else:
        if args.GPU:
            context.set_context(device_target='GPU')
        else:
            raise ValueError("Only supported GPU training.")

    net = efficientnet_b0(
        num_classes=cfg.num_classes,
        drop_rate=cfg.drop,
        drop_connect_rate=cfg.drop_connect,
        global_pool=cfg.gp,
        bn_tf=cfg.bn_tf,
    )

    train_data_url = args.data_path
    train_dataset = create_dataset(cfg.batch_size,
                                   train_data_url,
                                   workers=cfg.workers,
                                   distributed=args.distributed)
    batches_per_epoch = train_dataset.get_dataset_size()

    loss_cb = LossMonitor(per_print_times=batches_per_epoch)
    loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing)
    time_cb = TimeMonitor(data_size=batches_per_epoch)
    loss_scale_manager = FixedLossScaleManager(cfg.loss_scale,
                                               drop_overflow_update=False)

    callbacks = [time_cb, loss_cb]

    if cfg.save_checkpoint:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=batches_per_epoch,
            keep_checkpoint_max=cfg.keep_checkpoint_max)
        ckpoint_cb = ModelCheckpoint(prefix=cfg.model,
                                     directory='./ckpt_' + str(rank_id) + '/',
                                     config=config_ck)
        callbacks += [ckpoint_cb]

    lr = Tensor(
        get_lr(base_lr=cfg.lr,
               total_epochs=cfg.epochs,
               steps_per_epoch=batches_per_epoch,
               decay_steps=cfg.decay_epochs,
               decay_rate=cfg.decay_rate,
               warmup_steps=cfg.warmup_epochs,
               warmup_lr_init=cfg.warmup_lr_init,
               global_epoch=cfg.resume_start_epoch))
    if cfg.opt == 'sgd':
        optimizer = SGD(net.trainable_params(),
                        learning_rate=lr,
                        momentum=cfg.momentum,
                        weight_decay=cfg.weight_decay,
                        loss_scale=cfg.loss_scale)
    elif cfg.opt == 'rmsprop':
        optimizer = RMSProp(net.trainable_params(),
                            learning_rate=lr,
                            decay=0.9,
                            weight_decay=cfg.weight_decay,
                            momentum=cfg.momentum,
                            epsilon=cfg.opt_eps,
                            loss_scale=cfg.loss_scale)

    loss.add_flags_recursive(fp32=True, fp16=False)

    if args.resume:
        ckpt = load_checkpoint(args.resume)
        load_param_into_net(net, ckpt)

    model = Model(net,
                  loss,
                  optimizer,
                  loss_scale_manager=loss_scale_manager,
                  amp_level=cfg.amp_level)

    #    callbacks = callbacks if is_master else []

    if args.resume:
        real_epoch = cfg.epochs - cfg.resume_start_epoch
        model.train(real_epoch,
                    train_dataset,
                    callbacks=callbacks,
                    dataset_sink_mode=True)
    else:
        model.train(cfg.epochs,
                    train_dataset,
                    callbacks=callbacks,
                    dataset_sink_mode=True)
示例#2
0
def efficinetnet(*args, **kwargs):
    return efficientnet_b0(*args, **kwargs)
示例#3
0
def create_network(name, *args, **kwargs):
    if name == "efficinetnet":
        return efficientnet_b0(*args, **kwargs)
    raise NotImplementedError(f"{name} is not implemented in the repo")
示例#4
0
parser.add_argument("--device_target",
                    type=str,
                    choices=["Ascend", "GPU", "CPU"],
                    default="GPU",
                    help="device target")
args = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE,
                    device_target=args.device_target,
                    device_id=args.device_id)

if __name__ == "__main__":
    if args.device_target != "GPU":
        raise ValueError("Only supported GPU now.")

    net = efficientnet_b0(
        num_classes=cfg.num_classes,
        drop_rate=cfg.drop,
        drop_connect_rate=cfg.drop_connect,
        global_pool=cfg.gp,
        bn_tf=cfg.bn_tf,
    )

    ckpt = load_checkpoint(args.ckpt_file)
    load_param_into_net(net, ckpt)
    net.set_train(False)

    image = Tensor(
        np.ones([cfg.batch_size, 3, args.height, args.width], np.float32))
    export(net, image, file_name=args.file_name, file_format=args.file_format)