예제 #1
0
def modelarts_pre_process():
    '''modelarts pre process function.'''
    def unzip(zip_file, save_dir):
        import zipfile
        s_time = time.time()
        if not os.path.exists(
                os.path.join(save_dir, "face_recognition_dataset")):
            zip_isexist = zipfile.is_zipfile(zip_file)
            if zip_isexist:
                fz = zipfile.ZipFile(zip_file, 'r')
                data_num = len(fz.namelist())
                print("Extract Start...")
                print("unzip file num: {}".format(data_num))
                i = 0
                for file in fz.namelist():
                    if i % int(data_num / 100) == 0:
                        print("unzip percent: {}%".format(i /
                                                          int(data_num / 100)),
                              flush=True)
                    i += 1
                    fz.extract(file, save_dir)
                print("cost time: {}min:{}s.".format(
                    int((time.time() - s_time) / 60),
                    int(int(time.time() - s_time) % 60)))
                print("Extract Done.")
            else:
                print("This is not zip.")
        else:
            print("Zip has been extracted.")

    if config.need_modelarts_dataset_unzip:
        zip_file_1 = os.path.join(config.data_path,
                                  "face_recognition_dataset.zip")
        save_dir_1 = os.path.join(config.data_path)

        sync_lock = "/tmp/unzip_sync.lock"

        # Each server contains 8 devices as most.
        if get_device_id() % min(get_device_num(),
                                 8) == 0 and not os.path.exists(sync_lock):
            print("Zip file path: ", zip_file_1)
            print("Unzip file save dir: ", save_dir_1)
            unzip(zip_file_1, save_dir_1)
            print("===Finish extract data synchronization===")
            try:
                os.mknod(sync_lock)
            except IOError:
                pass

        while True:
            if os.path.exists(sync_lock):
                break
            time.sleep(1)

        print("Device: {}, Finish sync unzip data from {} to {}.".format(
            get_device_id(), zip_file_1, save_dir_1))

    config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()),
                                    config.ckpt_path)
예제 #2
0
def run_train():
    '''run train function.'''
    config.local_rank = get_rank_id()
    config.world_size = get_device_num()
    log_path = os.path.join(config.ckpt_path, 'logs')
    config.logger = get_logger(log_path, config.local_rank)

    support_train_stage = ['base', 'beta']
    if config.train_stage.lower() not in support_train_stage:
        config.logger.info('your train stage is not support.')
        raise ValueError('train stage not support.')

    if not os.path.exists(config.data_dir):
        config.logger.info(
            'ERROR, data_dir is not exists, please set data_dir in config.py')
        raise ValueError(
            'ERROR, data_dir is not exists, please set data_dir in config.py')

    parallel_mode = ParallelMode.HYBRID_PARALLEL if config.is_distributed else ParallelMode.STAND_ALONE
    context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                      device_num=config.world_size,
                                      gradients_mean=True)
    if config.is_distributed:
        init()

    if config.local_rank % 8 == 0:
        if not os.path.exists(config.ckpt_path):
            os.makedirs(config.ckpt_path)

    de_dataset, steps_per_epoch, num_classes = get_de_dataset(config)
    config.logger.info('de_dataset: %d', de_dataset.get_dataset_size())

    config.steps_per_epoch = steps_per_epoch
    config.num_classes = num_classes
    config.lr_epochs = list(map(int, config.lr_epochs.split(',')))
    config.logger.info('config.num_classes: %d', config.num_classes)
    config.logger.info('config.world_size: %d', config.world_size)
    config.logger.info('config.local_rank: %d', config.local_rank)
    config.logger.info('config.lr: %f', config.lr)

    if config.nc_16 == 1:
        if config.model_parallel == 0:
            if config.num_classes % 16 == 0:
                config.logger.info('data parallel aleardy 16, nums: %d',
                                   config.num_classes)
            else:
                config.num_classes = (config.num_classes // 16 + 1) * 16
        else:
            if config.num_classes % (config.world_size * 16) == 0:
                config.logger.info('model parallel aleardy 16, nums: %d',
                                   config.num_classes)
            else:
                config.num_classes = (config.num_classes //
                                      (config.world_size * 16) +
                                      1) * config.world_size * 16

    config.logger.info('for D, loaded, class nums: %d', config.num_classes)
    config.logger.info('steps_per_epoch: %d', config.steps_per_epoch)
    config.logger.info('img_total_num: %d',
                       config.steps_per_epoch * config.per_batch_size)

    config.logger.info('get_backbone----in----')
    _backbone = get_backbone(config)
    config.logger.info('get_backbone----out----')
    config.logger.info('get_metric_fc----in----')
    margin_fc_1 = get_metric_fc(config)
    config.logger.info('get_metric_fc----out----')
    config.logger.info('DistributedHelper----in----')
    network_1 = DistributedHelper(_backbone, margin_fc_1)
    config.logger.info('DistributedHelper----out----')
    config.logger.info('network fp16----in----')
    if config.fp16 == 1:
        network_1.add_flags_recursive(fp16=True)
    config.logger.info('network fp16----out----')

    criterion_1 = get_loss(config)
    if config.fp16 == 1 and config.model_parallel == 0:
        criterion_1.add_flags_recursive(fp32=True)

    network_1 = load_pretrain(config, network_1)
    train_net = BuildTrainNetwork(network_1, criterion_1, config)

    # call warmup_step should behind the config steps_per_epoch
    config.lrs = warmup_step_list(config, gamma=0.1)
    lrs_gen = list_to_gen(config.lrs)
    opt = Momentum(params=train_net.trainable_params(),
                   learning_rate=lrs_gen,
                   momentum=config.momentum,
                   weight_decay=config.weight_decay)
    scale_manager = DynamicLossScaleManager(
        init_loss_scale=config.dynamic_init_loss_scale,
        scale_factor=2,
        scale_window=2000)
    model = Model(train_net,
                  optimizer=opt,
                  metrics=None,
                  loss_scale_manager=scale_manager)

    save_checkpoint_steps = config.ckpt_steps
    config.logger.info('save_checkpoint_steps: %d', save_checkpoint_steps)
    if config.max_ckpts == -1:
        keep_checkpoint_max = int(config.steps_per_epoch * config.max_epoch /
                                  save_checkpoint_steps) + 5
    else:
        keep_checkpoint_max = config.max_ckpts
    config.logger.info('keep_checkpoint_max: %d', keep_checkpoint_max)

    ckpt_config = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps,
                                   keep_checkpoint_max=keep_checkpoint_max)
    config.logger.info('max_epoch_train: %d', config.max_epoch)
    ckpt_cb = ModelCheckpoint(config=ckpt_config,
                              directory=config.ckpt_path,
                              prefix='{}'.format(config.local_rank))
    config.epoch_cnt = 0
    progress_cb = ProgressMonitor(config)
    new_epoch_train = config.max_epoch * steps_per_epoch // config.log_interval
    model.train(new_epoch_train,
                de_dataset,
                callbacks=[progress_cb, ckpt_cb],
                sink_size=config.log_interval)
예제 #3
0
from src.dataset import create_dataset
from src.lenet import LeNet5
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.common import set_seed

set_seed(1)

if os.path.exists(config.data_path_local):
    config.data_path = config.data_path_local
    config.checkpoint_path = os.path.join(config.checkpoint_path,
                                          str(get_rank_id()))
else:
    config.checkpoint_path = os.path.join(config.output_path,
                                          config.checkpoint_path,
                                          str(get_rank_id()))


def modelarts_pre_process():
    pass


@moxing_wrapper(pre_process=modelarts_pre_process)
def train_lenet():

    context.set_context(mode=context.GRAPH_MODE,
                        device_target=config.device_target)
예제 #4
0
def modelarts_pre_process():
    config.checkpoint_path = os.path.join(config.output_path,
                                          str(get_rank_id()),
                                          config.checkpoint_path)
예제 #5
0
def train_alexnet():
    print(config)
    print('device id:', get_device_id())
    print('device num:', get_device_num())
    print('rank id:', get_rank_id())
    print('job id:', get_job_id())

    device_target = config.device_target
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=config.device_target)
    context.set_context(save_graphs=False)

    device_num = get_device_num()
    if config.dataset_name == "cifar10":
        if device_num > 1:
            config.learning_rate = config.learning_rate * device_num
            config.epoch_size = config.epoch_size * 2
    elif config.dataset_name == "imagenet":
        pass
    else:
        raise ValueError("Unsupported dataset.")

    if device_num > 1:
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(device_num=device_num, \
            parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
        if device_target == "Ascend":
            context.set_context(device_id=get_device_id())
            init()
        elif device_target == "GPU":
            init()
    else:
        context.set_context(device_id=get_device_id())

    if config.dataset_name == "cifar10":
        ds_train = create_dataset_cifar10(config.data_path,
                                          config.batch_size,
                                          target=config.device_target)
    elif config.dataset_name == "imagenet":
        ds_train = create_dataset_imagenet(config.data_path, config.batch_size)
    else:
        raise ValueError("Unsupported dataset.")

    if ds_train.get_dataset_size() == 0:
        raise ValueError(
            "Please check dataset size > 0 and batch_size <= dataset size")

    network = AlexNet(config.num_classes, phase='train')

    loss_scale_manager = None
    metrics = None
    step_per_epoch = ds_train.get_dataset_size(
    ) if config.sink_size == -1 else config.sink_size
    if config.dataset_name == 'cifar10':
        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
        lr = Tensor(
            get_lr_cifar10(0, config.learning_rate, config.epoch_size,
                           step_per_epoch))
        opt = nn.Momentum(network.trainable_params(), lr, config.momentum)
        metrics = {"Accuracy": Accuracy()}

    elif config.dataset_name == 'imagenet':
        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
        lr = Tensor(
            get_lr_imagenet(config.learning_rate, config.epoch_size,
                            step_per_epoch))
        opt = nn.Momentum(params=get_param_groups(network),
                          learning_rate=lr,
                          momentum=config.momentum,
                          weight_decay=config.weight_decay,
                          loss_scale=config.loss_scale)

        from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
        if config.is_dynamic_loss_scale == 1:
            loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536,
                                                         scale_factor=2,
                                                         scale_window=2000)
        else:
            loss_scale_manager = FixedLossScaleManager(
                config.loss_scale, drop_overflow_update=False)

    else:
        raise ValueError("Unsupported dataset.")

    if device_target == "Ascend":
        model = Model(network,
                      loss_fn=loss,
                      optimizer=opt,
                      metrics=metrics,
                      amp_level="O2",
                      keep_batchnorm_fp32=False,
                      loss_scale_manager=loss_scale_manager)
    elif device_target == "GPU":
        model = Model(network,
                      loss_fn=loss,
                      optimizer=opt,
                      metrics=metrics,
                      loss_scale_manager=loss_scale_manager)
    else:
        raise ValueError("Unsupported platform.")

    if device_num > 1:
        ckpt_save_dir = os.path.join(config.checkpoint_path + "_" +
                                     str(get_rank()))
    else:
        ckpt_save_dir = config.checkpoint_path

    time_cb = TimeMonitor(data_size=step_per_epoch)
    config_ck = CheckpointConfig(
        save_checkpoint_steps=config.save_checkpoint_steps,
        keep_checkpoint_max=config.keep_checkpoint_max)
    ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet",
                                 directory=ckpt_save_dir,
                                 config=config_ck)

    print("============== Starting Training ==============")
    model.train(config.epoch_size,
                ds_train,
                callbacks=[time_cb, ckpoint_cb,
                           LossMonitor()],
                dataset_sink_mode=config.dataset_sink_mode,
                sink_size=config.sink_size)