Ejemplo n.º 1
0
    loss_scale_manager = None
    metrics = None
    step_per_epoch = ds_train.get_dataset_size(
    ) if args.sink_size == -1 else args.sink_size
    if args.dataset_name == 'cifar10':
        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
        lr = Tensor(
            get_lr_cifar10(0, cfg.learning_rate, cfg.epoch_size,
                           step_per_epoch))
        opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum)
        metrics = {"Accuracy": Accuracy()}

    elif args.dataset_name == 'imagenet':
        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
        lr = Tensor(get_lr_imagenet(cfg, step_per_epoch))
        opt = nn.Momentum(params=get_param_groups(network),
                          learning_rate=lr,
                          momentum=cfg.momentum,
                          weight_decay=cfg.weight_decay,
                          loss_scale=cfg.loss_scale)

        from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
        if cfg.is_dynamic_loss_scale == 1:
            loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536,
                                                         scale_factor=2,
                                                         scale_window=2000)
        else:
            loss_scale_manager = FixedLossScaleManager(
                cfg.loss_scale, drop_overflow_update=False)
Ejemplo n.º 2
0
    network = AlexNet(cfg.num_classes)

    loss_scale_manager = None
    metrics = None
    if args.dataset_name == 'cifar10':
        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
        lr = Tensor(
            get_lr_cifar10(0, cfg.learning_rate, cfg.epoch_size,
                           ds_train.get_dataset_size()))
        opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum)
        metrics = {"Accuracy": Accuracy()}

    elif args.dataset_name == 'imagenet':
        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")

        lr = Tensor(get_lr_imagenet(cfg, ds_train.get_dataset_size()))

        opt = nn.Momentum(params=get_param_groups(network),
                          learning_rate=lr,
                          momentum=cfg.momentum,
                          weight_decay=cfg.weight_decay,
                          loss_scale=cfg.loss_scale)

        from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
        if cfg.is_dynamic_loss_scale == 1:
            loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536,
                                                         scale_factor=2,
                                                         scale_window=2000)
        else:
            loss_scale_manager = FixedLossScaleManager(
                cfg.loss_scale, drop_overflow_update=False)
Ejemplo n.º 3
0
    loss_scale_manager = None
    metrics = None
    step_per_epoch = ds_train.get_dataset_size(
    ) if args.sink_size == -1 else args.sink_size
    if args.dataset_name == 'cifar10':
        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
        lr = Tensor(
            get_lr_cifar10(0, cfg.learning_rate, cfg.epoch_size,
                           step_per_epoch))
        opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum)
        metrics = {"Accuracy": Accuracy()}

    elif args.dataset_name == 'imagenet':
        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
        lr = Tensor(
            get_lr_imagenet(cfg.learning_rate, cfg.epoch_size, step_per_epoch))
        opt = nn.Momentum(params=get_param_groups(network),
                          learning_rate=lr,
                          momentum=cfg.momentum,
                          weight_decay=cfg.weight_decay,
                          loss_scale=cfg.loss_scale)

        from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
        if cfg.is_dynamic_loss_scale == 1:
            loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536,
                                                         scale_factor=2,
                                                         scale_window=2000)
        else:
            loss_scale_manager = FixedLossScaleManager(
                cfg.loss_scale, drop_overflow_update=False)
Ejemplo n.º 4
0
def train_alexnet():
    print(config)
    print('device id:', get_device_id())
    print('device num:', get_device_num())
    print('rank id:', get_rank_id())
    print('job id:', get_job_id())

    device_target = config.device_target
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=config.device_target)
    context.set_context(save_graphs=False)

    device_num = get_device_num()
    if config.dataset_name == "cifar10":
        if device_num > 1:
            config.learning_rate = config.learning_rate * device_num
            config.epoch_size = config.epoch_size * 2
    elif config.dataset_name == "imagenet":
        pass
    else:
        raise ValueError("Unsupported dataset.")

    if device_num > 1:
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(device_num=device_num, \
            parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
        if device_target == "Ascend":
            context.set_context(device_id=get_device_id())
            init()
        elif device_target == "GPU":
            init()
    else:
        context.set_context(device_id=get_device_id())

    if config.dataset_name == "cifar10":
        ds_train = create_dataset_cifar10(config.data_path,
                                          config.batch_size,
                                          target=config.device_target)
    elif config.dataset_name == "imagenet":
        ds_train = create_dataset_imagenet(config.data_path, config.batch_size)
    else:
        raise ValueError("Unsupported dataset.")

    if ds_train.get_dataset_size() == 0:
        raise ValueError(
            "Please check dataset size > 0 and batch_size <= dataset size")

    network = AlexNet(config.num_classes, phase='train')

    loss_scale_manager = None
    metrics = None
    step_per_epoch = ds_train.get_dataset_size(
    ) if config.sink_size == -1 else config.sink_size
    if config.dataset_name == 'cifar10':
        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
        lr = Tensor(
            get_lr_cifar10(0, config.learning_rate, config.epoch_size,
                           step_per_epoch))
        opt = nn.Momentum(network.trainable_params(), lr, config.momentum)
        metrics = {"Accuracy": Accuracy()}

    elif config.dataset_name == 'imagenet':
        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
        lr = Tensor(
            get_lr_imagenet(config.learning_rate, config.epoch_size,
                            step_per_epoch))
        opt = nn.Momentum(params=get_param_groups(network),
                          learning_rate=lr,
                          momentum=config.momentum,
                          weight_decay=config.weight_decay,
                          loss_scale=config.loss_scale)

        from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
        if config.is_dynamic_loss_scale == 1:
            loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536,
                                                         scale_factor=2,
                                                         scale_window=2000)
        else:
            loss_scale_manager = FixedLossScaleManager(
                config.loss_scale, drop_overflow_update=False)

    else:
        raise ValueError("Unsupported dataset.")

    if device_target == "Ascend":
        model = Model(network,
                      loss_fn=loss,
                      optimizer=opt,
                      metrics=metrics,
                      amp_level="O2",
                      keep_batchnorm_fp32=False,
                      loss_scale_manager=loss_scale_manager)
    elif device_target == "GPU":
        model = Model(network,
                      loss_fn=loss,
                      optimizer=opt,
                      metrics=metrics,
                      loss_scale_manager=loss_scale_manager)
    else:
        raise ValueError("Unsupported platform.")

    if device_num > 1:
        ckpt_save_dir = os.path.join(config.checkpoint_path + "_" +
                                     str(get_rank()))
    else:
        ckpt_save_dir = config.checkpoint_path

    time_cb = TimeMonitor(data_size=step_per_epoch)
    config_ck = CheckpointConfig(
        save_checkpoint_steps=config.save_checkpoint_steps,
        keep_checkpoint_max=config.keep_checkpoint_max)
    ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet",
                                 directory=ckpt_save_dir,
                                 config=config_ck)

    print("============== Starting Training ==============")
    model.train(config.epoch_size,
                ds_train,
                callbacks=[time_cb, ckpoint_cb,
                           LossMonitor()],
                dataset_sink_mode=config.dataset_sink_mode,
                sink_size=config.sink_size)