Exemplo n.º 1
0
            load_param_into_net(net, param_dict)

        loss = LossNet()
        lr = Tensor(dynamic_lr(config, rank_size=device_num, start_steps=config.pretrain_epoch_size * dataset_size),
                    mstype.float32)
        opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
                       weight_decay=config.weight_decay, loss_scale=config.loss_scale)

        net_with_loss = WithLossCell(net, loss)
        if args_opt.run_distribute:
            net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True,
                                   mean=True, degree=device_num)
        else:
            net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale)

        time_cb = TimeMonitor(data_size=dataset_size)
        loss_cb = LossCallBack(rank_id=rank)
        cb = [time_cb, loss_cb]
        if config.save_checkpoint:
            ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size,
                                          keep_checkpoint_max=config.keep_checkpoint_max)
            save_checkpoint_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
            ckpoint_cb = ModelCheckpoint(prefix='mask_rcnn', directory=save_checkpoint_path, config=ckptconfig)
            cb += [ckpoint_cb]

        model = Model(net)
        start_time = time.time()
        model.train(config.epoch_size, dataset, callbacks=cb)
        end_time = time.time()
        print("total cost time :",end_time-start_time)
Exemplo n.º 2
0
                              cfg.batch_size, cfg.epoch_size)
    step_size = ds_train.get_dataset_size()

    # define fusion network
    network = LeNet5Fusion(cfg.num_classes)
    # define network loss
    net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False,
                                                sparse=True,
                                                reduction="mean")
    # define network optimization
    net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)

    # call back and monitor
    time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
    config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size *
                                   step_size,
                                   keep_checkpoint_max=cfg.keep_checkpoint_max)
    ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet",
                                    directory=local_train_url,
                                    config=config_ckpt)

    # define model
    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

    print("============== Starting Training ==============")
    model.train(cfg['epoch_size'],
                ds_train,
                callbacks=[time_cb, ckpt_callback,
                           LossMonitor()],
                dataset_sink_mode=args.dataset_sink_mode)
    print("============== End Training ==============")
Exemplo n.º 3
0
                                                                cell.weight.default_input.shape(),
                                                                cell.weight.default_input.dtype())
        if isinstance(cell, nn.Dense):
            cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
                                                                cell.weight.default_input.shape(),
                                                                cell.weight.default_input.dtype())
    if not config.label_smooth:
        config.label_smooth_factor = 0.0
    loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
    if args_opt.do_train:
        dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
                                 repeat_num=epoch_size, batch_size=config.batch_size)
        step_size = dataset.get_dataset_size()
        loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)

        # learning rate strategy with cosine
        lr = Tensor(warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, config.epoch_size))
        opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
                       config.weight_decay, config.loss_scale)
        model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', keep_batchnorm_fp32=False,
                      loss_scale_manager=loss_scale, metrics={'acc'})
        time_cb = TimeMonitor(data_size=step_size)
        loss_cb = LossMonitor()
        cb = [time_cb, loss_cb]
        if config.save_checkpoint:
            config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size,
                                         keep_checkpoint_max=config.keep_checkpoint_max)
            ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck)
            cb += [ckpt_cb]
        model.train(epoch_size, dataset, callbacks=cb)
Exemplo n.º 4
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="true",
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="1",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--save_checkpoint_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--train_steps",
                        type=int,
                        default=-1,
                        help="Training Steps, default is -1, "
                        "meaning run all steps according to epoch number.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init('hccl')
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init('nccl')
            device_num = D.get_group_size()
            rank = D.get_rank()
            ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
                rank) + '/'

        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
        from mindspore.parallel._auto_parallel_context import auto_parallel_context
        if bert_net_cfg.num_hidden_layers == 12:
            if bert_net_cfg.use_relative_positions:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [29, 58, 87, 116, 145, 174, 203, 217])
            else:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [28, 55, 82, 109, 136, 163, 190, 205])
        elif bert_net_cfg.num_hidden_layers == 24:
            if bert_net_cfg.use_relative_positions:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [30, 90, 150, 210, 270, 330, 390, 421])
            else:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [38, 93, 148, 203, 258, 313, 368, 397])
    else:
        rank = 0
        device_num = 1

    if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
        logger.warning('Gpu only support fp32 temporarily, run with fp32.')
        bert_net_cfg.compute_type = mstype.float32

    ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num,
                                               rank, args_opt.do_shuffle,
                                               args_opt.enable_data_sink,
                                               args_opt.data_sink_steps,
                                               args_opt.data_dir,
                                               args_opt.schema_dir)
    if args_opt.train_steps > 0:
        new_repeat_count = min(
            new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
    netwithloss = BertNetworkWithLoss(bert_net_cfg, True)

    if cfg.optimizer == 'Lamb':
        optimizer = Lamb(netwithloss.trainable_params(),
                         decay_steps=ds.get_dataset_size() * new_repeat_count,
                         start_learning_rate=cfg.Lamb.start_learning_rate,
                         end_learning_rate=cfg.Lamb.end_learning_rate,
                         power=cfg.Lamb.power,
                         warmup_steps=cfg.Lamb.warmup_steps,
                         weight_decay=cfg.Lamb.weight_decay,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(netwithloss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecayDynamicLR':
        optimizer = AdamWeightDecayDynamicLR(
            netwithloss.trainable_params(),
            decay_steps=ds.get_dataset_size() * new_repeat_count,
            learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
            end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
            power=cfg.AdamWeightDecayDynamicLR.power,
            weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
            eps=cfg.AdamWeightDecayDynamicLR.eps,
            warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]"
            .format(cfg.optimizer))
    callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()]
    if args_opt.enable_save_ckpt == "true":
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
                                     directory=ckpt_save_dir,
                                     config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(netwithloss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        netwithgrads = BertTrainOneStepWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)

    model = Model(netwithgrads)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"))
Exemplo n.º 5
0
    # in this way by judging the mark of args, users will decide which function to use
    if not args_opt.do_eval and args_opt.run_distribute:
        context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
        auto_parallel_context().set_all_reduce_fusion_split_indices([140])
        init()

    epoch_size = args_opt.epoch_size
    net = resnet50(args_opt.batch_size, args_opt.num_classes)
    ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)

    model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})

    # as for train, users could use model.train
    if args_opt.do_train:
        dataset = create_dataset()
        batch_num = dataset.get_dataset_size()
        config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=35)
        ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck)
        loss_cb = LossMonitor()
        model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])

    # as for evaluation, users could use model.eval
    if args_opt.do_eval:
        if args_opt.checkpoint_path:
            param_dict = load_checkpoint(args_opt.checkpoint_path)
            load_param_into_net(net, param_dict)
        eval_dataset = create_dataset(training=False)
        res = model.eval(eval_dataset)
        print("result: ", res)
Exemplo n.º 6
0
def train():
    """Train function."""
    args = parse_args()
    profiler = network_init(args)

    loss_meter = AverageMeter('loss')
    parallel_init(args)

    network = YOLOV3DarkNet53(is_training=True)
    # default is kaiming-normal
    default_recurisive_init(network)
    load_yolov3_params(args, network)

    network = YoloWithLossCell(network)
    args.logger.info('finish get network')

    config = ConfigYOLOV3DarkNet53()
    config.label_smooth = args.label_smooth
    config.label_smooth_factor = args.label_smooth_factor

    if args.training_shape:
        config.multi_scale = [conver_training_shape(args)]
    if args.resize_rate:
        config.resize_rate = args.resize_rate

    ds, data_size = create_yolo_dataset(image_dir=args.data_root,
                                        anno_path=args.annFile,
                                        is_training=True,
                                        batch_size=args.per_batch_size,
                                        max_epoch=args.max_epoch,
                                        device_num=args.group_size,
                                        rank=args.rank,
                                        config=config)
    args.logger.info('Finish loading dataset')

    args.steps_per_epoch = int(data_size / args.per_batch_size /
                               args.group_size)

    if not args.ckpt_interval:
        args.ckpt_interval = args.steps_per_epoch

    lr = get_lr(args)
    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)
    is_gpu = context.get_context("device_target") == "GPU"
    if is_gpu:
        loss_scale_value = 1.0
        loss_scale = FixedLossScaleManager(loss_scale_value,
                                           drop_overflow_update=False)
        network = amp.build_train_network(network,
                                          optimizer=opt,
                                          loss_scale_manager=loss_scale,
                                          level="O2",
                                          keep_batchnorm_fp32=False)
        keep_loss_fp32(network)
    else:
        network = TrainingWrapper(network, opt, sens=args.loss_scale)
        network.set_train()

    if args.rank_save_ckpt_flag:
        # checkpoint save
        ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
        ckpt_config = CheckpointConfig(
            save_checkpoint_steps=args.ckpt_interval,
            keep_checkpoint_max=ckpt_max_num)
        save_ckpt_path = os.path.join(args.outputs_dir,
                                      'ckpt_' + str(args.rank) + '/')
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=save_ckpt_path,
                                  prefix='{}'.format(args.rank))
        cb_params = _InternalCallbackParam()
        cb_params.train_network = network
        cb_params.epoch_num = ckpt_max_num
        cb_params.cur_epoch_num = 1
        run_context = RunContext(cb_params)
        ckpt_cb.begin(run_context)

    old_progress = -1
    t_end = time.time()
    data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)

    for i, data in enumerate(data_loader):
        images = data["image"]
        input_shape = images.shape[2:4]
        args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))

        images = Tensor.from_numpy(images)

        batch_y_true_0 = Tensor.from_numpy(data['bbox1'])
        batch_y_true_1 = Tensor.from_numpy(data['bbox2'])
        batch_y_true_2 = Tensor.from_numpy(data['bbox3'])
        batch_gt_box0 = Tensor.from_numpy(data['gt_box1'])
        batch_gt_box1 = Tensor.from_numpy(data['gt_box2'])
        batch_gt_box2 = Tensor.from_numpy(data['gt_box3'])

        loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2,
                       batch_gt_box0, batch_gt_box1, batch_gt_box2)
        loss_meter.update(loss.asnumpy())

        if args.rank_save_ckpt_flag:
            # ckpt progress
            cb_params.cur_step_num = i + 1  # current step number
            cb_params.batch_num = i + 2
            ckpt_cb.step_end(run_context)

        if i % args.log_interval == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            per_step_time = time_used / args.log_interval
            fps = args.per_batch_size * (
                i - old_progress) * args.group_size / time_used
            if args.rank == 0:
                args.logger.info(
                    'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{},'
                    ' per_step_time:{}'.format(epoch, i, loss_meter, fps,
                                               lr[i], per_step_time))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
            cb_params.cur_epoch_num += 1

        if args.need_profiler:
            if i == 10:
                profiler.analyse()
                break
    args.logger.info('==========end training===============')
Exemplo n.º 7
0
def train():
    '''
    finetune function
    '''
    # BertCLS train for classification
    # BertNER train for sequence labeling

    if cfg.task == 'NER':
        tag_to_index = None
        if cfg.use_crf:
            tag_to_index = json.loads(open(cfg.label2id_file).read())
            print(tag_to_index)
            max_val = len(tag_to_index)
            tag_to_index["<START>"] = max_val
            tag_to_index["<STOP>"] = max_val + 1
            number_labels = len(tag_to_index)
        else:
            number_labels = cfg.num_labels

        netwithloss = BertNER(bert_net_cfg,
                              cfg.batch_size,
                              True,
                              num_labels=number_labels,
                              use_crf=cfg.use_crf,
                              tag_to_index=tag_to_index,
                              dropout_prob=0.1)
    elif cfg.task == 'Classification':
        netwithloss = BertCLS(bert_net_cfg,
                              True,
                              num_labels=cfg.num_labels,
                              dropout_prob=0.1,
                              assessment_method=cfg.assessment_method)
    else:
        raise Exception("task error, NER or Classification is supported.")

    dataset = get_dataset(data_file=cfg.data_file, batch_size=cfg.batch_size)
    steps_per_epoch = dataset.get_dataset_size()
    print('steps_per_epoch:', steps_per_epoch)

    # optimizer
    steps_per_epoch = dataset.get_dataset_size()
    if cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(
            learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1),
            decay_steps=steps_per_epoch * cfg.epoch_num,
            power=optimizer_cfg.AdamWeightDecay.power)
        params = netwithloss.trainable_params()
        decay_params = list(
            filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(
            filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x),
                   params))
        group_params = [{
            'params':
            decay_params,
            'weight_decay':
            optimizer_cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }]
        optimizer = AdamWeightDecay(group_params,
                                    lr_schedule,
                                    eps=optimizer_cfg.AdamWeightDecay.eps)
    elif cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(
            learning_rate=optimizer_cfg.Lamb.learning_rate,
            end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
            warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1),
            decay_steps=steps_per_epoch * cfg.epoch_num,
            power=optimizer_cfg.Lamb.power)
        optimizer = Lamb(netwithloss.trainable_params(),
                         learning_rate=lr_schedule)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(
            netwithloss.trainable_params(),
            learning_rate=optimizer_cfg.Momentum.learning_rate,
            momentum=optimizer_cfg.Momentum.momentum)
    else:
        raise Exception("Optimizer not supported.")

    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
                                   keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix,
                                 directory=cfg.ckpt_dir,
                                 config=ckpt_config)
    param_dict = load_checkpoint(cfg.pre_training_ckpt)
    load_param_into_net(netwithloss, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32,
                                             scale_factor=2,
                                             scale_window=1000)
    netwithgrads = BertFinetuneCell(netwithloss,
                                    optimizer=optimizer,
                                    scale_update_cell=update_cell)
    model = Model(netwithgrads)
    callbacks = [
        TimeMonitor(dataset.get_dataset_size()),
        LossCallBack(dataset.get_dataset_size()), ckpoint_cb
    ]
    model.train(cfg.epoch_num,
                dataset,
                callbacks=callbacks,
                dataset_sink_mode=True)
Exemplo n.º 8
0
def main():
    cfg, args = init_argument()
    loss_meter = AverageMeter('loss')
    # dataloader
    cfg.logger.info('start create dataloader')
    de_dataset, steps_per_epoch, class_num = get_de_dataset(cfg)
    cfg.steps_per_epoch = steps_per_epoch
    cfg.logger.info('step per epoch: %s', cfg.steps_per_epoch)
    de_dataloader = de_dataset.create_tuple_iterator()
    cfg.logger.info('class num original: %s', class_num)
    if class_num % 16 != 0:
        class_num = (class_num // 16 + 1) * 16
    cfg.class_num = class_num
    cfg.logger.info('change the class num to: %s', cfg.class_num)
    cfg.logger.info('end create dataloader')

    # backbone and loss
    cfg.logger.important_info('start create network')
    create_network_start = time.time()

    network = SphereNet(num_layers=cfg.net_depth,
                        feature_dim=cfg.embedding_size,
                        shape=cfg.input_size)
    if args.device_target == 'CPU':
        head = CombineMarginFC(embbeding_size=cfg.embedding_size,
                               classnum=cfg.class_num)
    else:
        head = CombineMarginFCFp16(embbeding_size=cfg.embedding_size,
                                   classnum=cfg.class_num)
    criterion = CrossEntropy()

    # load the pretrained model
    if os.path.isfile(cfg.pretrained):
        param_dict = load_checkpoint(cfg.pretrained)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values
        load_param_into_net(network, param_dict_new)
        cfg.logger.info('load model %s success', cfg.pretrained)

    # mixed precision training
    if args.device_target == 'CPU':
        network.add_flags_recursive(fp32=True)
        head.add_flags_recursive(fp32=True)
    else:
        network.add_flags_recursive(fp16=True)
        head.add_flags_recursive(fp16=True)
    criterion.add_flags_recursive(fp32=True)

    train_net = BuildTrainNetworkWithHead(network, head, criterion)

    # optimizer and lr scheduler
    lr = step_lr(lr=cfg.lr,
                 epoch_size=cfg.epoch_size,
                 steps_per_epoch=cfg.steps_per_epoch,
                 max_epoch=cfg.max_epoch,
                 gamma=cfg.lr_gamma)
    opt = SGD(params=train_net.trainable_params(),
              learning_rate=lr,
              momentum=cfg.momentum,
              weight_decay=cfg.weight_decay,
              loss_scale=cfg.loss_scale)

    # package training process, adjust lr + forward + backward + optimizer
    train_net = TrainOneStepCell(train_net, opt, sens=cfg.loss_scale)

    # checkpoint save
    if cfg.local_rank == 0:
        ckpt_max_num = cfg.max_epoch * cfg.steps_per_epoch // cfg.ckpt_interval
        train_config = CheckpointConfig(
            save_checkpoint_steps=cfg.ckpt_interval,
            keep_checkpoint_max=ckpt_max_num)
        ckpt_cb = ModelCheckpoint(config=train_config,
                                  directory=cfg.outputs_dir,
                                  prefix='{}'.format(cfg.local_rank))
        cb_params = _InternalCallbackParam()
        cb_params.train_network = train_net
        cb_params.epoch_num = ckpt_max_num
        cb_params.cur_epoch_num = 1
        run_context = RunContext(cb_params)
        ckpt_cb.begin(run_context)

    train_net.set_train()
    t_end = time.time()
    t_epoch = time.time()
    old_progress = -1

    cfg.logger.important_info('====start train====')
    for i, total_data in enumerate(de_dataloader):
        data, gt = total_data
        data = Tensor(data)
        gt = Tensor(gt)

        loss = train_net(data, gt)
        loss_meter.update(loss.asnumpy())

        # ckpt
        if cfg.local_rank == 0:
            cb_params.cur_step_num = i + 1  # current step number
            cb_params.batch_num = i + 2
            ckpt_cb.step_end(run_context)

        # logging loss, fps, ...
        if i == 0:
            time_for_graph_compile = time.time() - create_network_start
            cfg.logger.important_info('{}, graph compile time={:.2f}s'.format(
                cfg.task, time_for_graph_compile))

        if i % cfg.log_interval == 0 and cfg.local_rank == 0:
            time_used = time.time() - t_end
            epoch = int(i / cfg.steps_per_epoch)
            fps = cfg.per_batch_size * (
                i - old_progress) * cfg.world_size / time_used
            cfg.logger.info(
                'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr={}'.format(
                    epoch, i, loss_meter, fps, lr[i]))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if i % cfg.steps_per_epoch == 0 and cfg.local_rank == 0:
            epoch_time_used = time.time() - t_epoch
            epoch = int(i / cfg.steps_per_epoch)
            fps = cfg.per_batch_size * cfg.world_size * cfg.steps_per_epoch / epoch_time_used
            cfg.logger.info(
                '=================================================')
            cfg.logger.info(
                'epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(
                    epoch, i, fps))
            cfg.logger.info(
                '=================================================')
            t_epoch = time.time()

    cfg.logger.important_info('====train end====')
Exemplo n.º 9
0
    else:
        if not args.label_smooth:
            args.label_smooth_factor = 0.0
        loss = CrossEntropy(smooth_factor=args.label_smooth_factor,
                            num_classes=args.num_classes)

        loss_scale_manager = FixedLossScaleManager(args.loss_scale,
                                                   drop_overflow_update=False)
        model = Model(network,
                      loss_fn=loss,
                      optimizer=opt,
                      loss_scale_manager=loss_scale_manager,
                      amp_level="O2")

    # define callbacks
    time_cb = TimeMonitor(data_size=batch_num)
    loss_cb = LossMonitor(per_print_times=batch_num)
    callbacks = [time_cb, loss_cb]
    if args.rank_save_ckpt_flag:
        ckpt_config = CheckpointConfig(
            save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
            keep_checkpoint_max=args.ckpt_save_max)
        save_ckpt_path = os.path.join(args.outputs_dir,
                                      'ckpt_' + str(args.rank) + '/')
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=save_ckpt_path,
                                  prefix='{}'.format(args.rank))
        callbacks.append(ckpt_cb)

    model.train(args.max_epoch, dataset, callbacks=callbacks)
Exemplo n.º 10
0
        default="CPU",
        choices=['Ascend', 'GPU', 'CPU'],
        help='device where the code will be implemented (default: CPU)')
    args = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args.device_target)
    dataset_sink_mode = not args.device_target == "CPU"
    # download mnist dataset
    download_dataset()
    # learning rate setting
    lr = 0.01
    momentum = 0.9
    dataset_size = 1
    mnist_path = "./MNIST_Data"
    # define the loss function
    net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    train_epoch = 1
    # create the network
    net = LeNet5()
    # define the optimizer
    net_opt = nn.Momentum(net.trainable_params(), lr, momentum)
    config_ck = CheckpointConfig(save_checkpoint_steps=1875,
                                 keep_checkpoint_max=10)
    # save the network model and parameters for subsequence fine-tuning
    ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
    # group layers into an object with training and evaluation features
    model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

    train_net(model, train_epoch, mnist_path, dataset_size, ckpoint,
              dataset_sink_mode)
    test_net(net, model, mnist_path)
Exemplo n.º 11
0
    # define opt
    opt = Adam(params=net.trainable_params(), learning_rate=lr, eps=1e-07)

    # define loss, model
    loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="sum")

    model = Model(net,
                  loss_fn=loss,
                  optimizer=opt,
                  metrics={"Accuracy": Accuracy()})

    # define callbacks
    time_cb = TimeMonitor(data_size=step_size)
    loss_cb = LossMonitor()
    cb = [time_cb, loss_cb]
    if config.save_checkpoint:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=2500,
            keep_checkpoint_max=config.keep_checkpoint_max)
        ckpt_cb = ModelCheckpoint(prefix="cnn_direction_model",
                                  directory=ckpt_save_dir,
                                  config=config_ck)
        cb += [ckpt_cb]

    # train model
    model.train(config.epoch_size,
                dataset,
                callbacks=cb,
                dataset_sink_mode=False)
Exemplo n.º 12
0
def do_train(dataset=None,
             network=None,
             load_checkpoint_path="",
             save_checkpoint_path="",
             epoch_num=1):
    """ do train """
    if load_checkpoint_path == "":
        raise ValueError(
            "Pretrain model missed, finetune task must load pretrain model!")
    steps_per_epoch = dataset.get_dataset_size()
    # optimizer
    if optimizer_cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(
            learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
            decay_steps=steps_per_epoch * epoch_num,
            power=optimizer_cfg.AdamWeightDecay.power)
        params = network.trainable_params()
        decay_params = list(
            filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(
            filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x),
                   params))
        group_params = [{
            'params':
            decay_params,
            'weight_decay':
            optimizer_cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }]

        optimizer = AdamWeightDecay(group_params,
                                    lr_schedule,
                                    eps=optimizer_cfg.AdamWeightDecay.eps)
    elif optimizer_cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(
            learning_rate=optimizer_cfg.Lamb.learning_rate,
            end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
            warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
            decay_steps=steps_per_epoch * epoch_num,
            power=optimizer_cfg.Lamb.power)
        optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule)
    elif optimizer_cfg.optimizer == 'Momentum':
        optimizer = Momentum(
            network.trainable_params(),
            learning_rate=optimizer_cfg.Momentum.learning_rate,
            momentum=optimizer_cfg.Momentum.momentum)
    else:
        raise Exception(
            "Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]"
        )

    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
                                   keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(
        prefix="squad",
        directory=None if save_checkpoint_path == "" else save_checkpoint_path,
        config=ckpt_config)
    param_dict = load_checkpoint(load_checkpoint_path)
    load_param_into_net(network, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32,
                                             scale_factor=2,
                                             scale_window=1000)
    netwithgrads = BertSquadCell(network,
                                 optimizer=optimizer,
                                 scale_update_cell=update_cell)
    model = Model(netwithgrads)
    callbacks = [
        TimeMonitor(dataset.get_dataset_size()),
        LossCallBack(), ckpoint_cb
    ]
    model.train(epoch_num, dataset, callbacks=callbacks)
Exemplo n.º 13
0
def train():
    """Train function."""
    args = parse_args()
    args.logger.save_args(args)

    if args.need_profiler:
        from mindspore.profiler.profiling import Profiler
        profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)

    loss_meter = AverageMeter('loss')

    context.reset_auto_parallel_context()
    parallel_mode = ParallelMode.STAND_ALONE
    degree = 1
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
        degree = get_group_size()
    context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)

    network = YOLOV3DarkNet53(is_training=True)
    # default is kaiming-normal
    default_recurisive_init(network)
    load_yolov3_quant_params(args, network)

    config = ConfigYOLOV3DarkNet53()
    # convert fusion network to quantization aware network
    if config.quantization_aware:
        network = build_quant_network(network)

    network = YoloWithLossCell(network)
    args.logger.info('finish get network')

    config.label_smooth = args.label_smooth
    config.label_smooth_factor = args.label_smooth_factor

    if args.training_shape:
        config.multi_scale = [conver_training_shape(args)]

    if args.resize_rate:
        config.resize_rate = args.resize_rate

    ds, data_size = create_yolo_dataset(image_dir=args.data_root, anno_path=args.annFile, is_training=True,
                                        batch_size=args.per_batch_size, max_epoch=args.max_epoch,
                                        device_num=args.group_size, rank=args.rank, config=config)
    args.logger.info('Finish loading dataset')

    args.steps_per_epoch = int(data_size / args.per_batch_size / args.group_size)

    if not args.ckpt_interval:
        args.ckpt_interval = args.steps_per_epoch

    lr = get_lr(args)

    opt = Momentum(params=get_param_groups(network), learning_rate=Tensor(lr), momentum=args.momentum,
                   weight_decay=args.weight_decay, loss_scale=args.loss_scale)

    network = TrainingWrapper(network, opt)
    network.set_train()

    if args.rank_save_ckpt_flag:
        # checkpoint save
        ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
        ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
                                       keep_checkpoint_max=ckpt_max_num)
        save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
        ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=save_ckpt_path, prefix='{}'.format(args.rank))
        cb_params = _InternalCallbackParam()
        cb_params.train_network = network
        cb_params.epoch_num = ckpt_max_num
        cb_params.cur_epoch_num = 1
        run_context = RunContext(cb_params)
        ckpt_cb.begin(run_context)

    old_progress = -1
    t_end = time.time()
    data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)

    shape_record = ShapeRecord()
    for i, data in enumerate(data_loader):
        images = data["image"]
        input_shape = images.shape[2:4]
        args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
        shape_record.set(input_shape)

        images = Tensor.from_numpy(images)
        annos = data["annotation"]
        if args.group_size == 1:
            batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
                batch_preprocess_true_box(annos, config, input_shape)
        else:
            batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
                batch_preprocess_true_box_single(annos, config, input_shape)

        batch_y_true_0 = Tensor.from_numpy(batch_y_true_0)
        batch_y_true_1 = Tensor.from_numpy(batch_y_true_1)
        batch_y_true_2 = Tensor.from_numpy(batch_y_true_2)
        batch_gt_box0 = Tensor.from_numpy(batch_gt_box0)
        batch_gt_box1 = Tensor.from_numpy(batch_gt_box1)
        batch_gt_box2 = Tensor.from_numpy(batch_gt_box2)

        input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
        loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,
                       batch_gt_box2, input_shape)
        loss_meter.update(loss.asnumpy())

        if args.rank_save_ckpt_flag:
            # ckpt progress
            cb_params.cur_step_num = i + 1  # current step number
            cb_params.batch_num = i + 2
            ckpt_cb.step_end(run_context)

        if i % args.log_interval == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used
            if args.rank == 0:
                args.logger.info(
                    'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i]))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
            cb_params.cur_epoch_num += 1

        if args.need_profiler:
            if i == 10:
                profiler.analyse()
                break

    args.logger.info('==========end training===============')
Exemplo n.º 14
0
def main():
    parser = argparse.ArgumentParser(description="YOLOv3 train")
    parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False,
                        help="If set it true, only create Mindrecord, default is False.")
    parser.add_argument("--distribute", type=ast.literal_eval, default=False, help="Run distribute, default is False.")
    parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
    parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate, default is 0.001.")
    parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink")
    parser.add_argument("--epoch_size", type=int, default=50, help="Epoch size, default is 50")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
    parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained checkpoint file path")
    parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size")
    parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.")
    parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
    parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train",
                        help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by "
                             "image_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir "
                             "rather than image_dir and anno_path. Default is ./Mindrecord_train")
    parser.add_argument("--image_dir", type=str, default="", help="Dataset directory, "
                                                                  "the absolute image path is joined by the image_dir "
                                                                  "and the relative path in anno_path")
    parser.add_argument("--anno_path", type=str, default="", help="Annotation path.")
    args_opt = parser.parse_args()

    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
    if args_opt.distribute:
        device_num = args_opt.device_num
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
                                          device_num=device_num)
        init()
        rank = args_opt.device_id % device_num
    else:
        rank = 0
        device_num = 1

    print("Start create dataset!")

    # It will generate mindrecord file in args_opt.mindrecord_dir,
    # and the file name is yolo.mindrecord0, 1, ... file_num.
    if not os.path.isdir(args_opt.mindrecord_dir):
        os.makedirs(args_opt.mindrecord_dir)

    prefix = "yolo.mindrecord"
    mindrecord_file = os.path.join(args_opt.mindrecord_dir, prefix + "0")
    if not os.path.exists(mindrecord_file):
        if os.path.isdir(args_opt.image_dir) and os.path.exists(args_opt.anno_path):
            print("Create Mindrecord.")
            data_to_mindrecord_byte_image(args_opt.image_dir,
                                          args_opt.anno_path,
                                          args_opt.mindrecord_dir,
                                          prefix,
                                          8)
            print("Create Mindrecord Done, at {}".format(args_opt.mindrecord_dir))
        else:
            raise ValueError('image_dir {} or anno_path {} does not exist'.format(\
                              args_opt.image_dir, args_opt.anno_path))

    if not args_opt.only_create_dataset:
        loss_scale = float(args_opt.loss_scale)

        # When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
        dataset = create_yolo_dataset(mindrecord_file,
                                      batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
        dataset_size = dataset.get_dataset_size()
        print("Create dataset done!")

        net = yolov3_resnet18(ConfigYOLOV3ResNet18())
        net = YoloWithLossCell(net, ConfigYOLOV3ResNet18())
        init_net_param(net, "XavierUniform")

        # checkpoint
        ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
        ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory='./ckpt_' + str(rank) + '/', config=ckpt_config)

        if args_opt.pre_trained:
            if args_opt.pre_trained_epoch_size <= 0:
                raise KeyError("pre_trained_epoch_size must be greater than 0.")
            param_dict = load_checkpoint(args_opt.pre_trained)
            load_param_into_net(net, param_dict)
        total_epoch_size = 60
        if args_opt.distribute:
            total_epoch_size = 160
        lr = Tensor(get_lr(learning_rate=args_opt.lr, start_step=args_opt.pre_trained_epoch_size * dataset_size,
                           global_step=total_epoch_size * dataset_size,
                           decay_step=1000, decay_rate=0.95, steps=True))
        opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale)
        net = TrainingWrapper(net, opt, loss_scale)

        callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]

        model = Model(net)
        dataset_sink_mode = False
        if args_opt.mode == "sink":
            print("In sink mode, one epoch return a loss.")
            dataset_sink_mode = True
        print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.")
        model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)
Exemplo n.º 15
0
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)

    train_net = BuildTrainNetwork(network, criterion)

    # mixed precision training
    criterion.add_flags_recursive(fp32=True)

    # package training process
    train_net = TrainOneStepCell(train_net, opt, sens=args.loss_scale)
    context.reset_auto_parallel_context()

    # checkpoint
    if args.local_rank == 0:
        ckpt_max_num = args.max_epoch
        train_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch, keep_checkpoint_max=ckpt_max_num)
        ckpt_cb = ModelCheckpoint(config=train_config, directory=args.outputs_dir, prefix='{}'.format(args.local_rank))
        cb_params = _InternalCallbackParam()
        cb_params.train_network = train_net
        cb_params.epoch_num = ckpt_max_num
        cb_params.cur_epoch_num = 0
        run_context = RunContext(cb_params)
        ckpt_cb.begin(run_context)

    train_net.set_train()
    t_end = time.time()
    t_epoch = time.time()
    old_progress = -1

    i = 0
    for _, (data, gt_classes) in enumerate(de_dataloader):
Exemplo n.º 16
0
        raise ValueError("Unsupported platform.")

    dataset = create_dataset(cfg.data_path, 1)
    batch_num = dataset.get_dataset_size()

    net = GoogleNet(num_classes=cfg.num_classes)
    # Continue training if set pre_trained to be True
    if cfg.pre_trained:
        param_dict = load_checkpoint(cfg.checkpoint_path)
        load_param_into_net(net, param_dict)
    lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum,
                   weight_decay=cfg.weight_decay)
    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)

    if device_target == "Ascend":
        model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
                      amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
        ckpt_save_dir = "./"
    else: # GPU
        model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
                      amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=None)
        ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/"

    config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max)
    time_cb = TimeMonitor(data_size=batch_num)
    ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory=ckpt_save_dir, config=config_ck)
    loss_cb = LossMonitor()
    model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
    print("train success")
Exemplo n.º 17
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="1",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument(
        "--accumulation_steps",
        type=int,
        default="1",
        help=
        "Accumulating gradients N times before weight update, default is 1.")
    parser.add_argument("--save_checkpoint_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--train_steps",
                        type=int,
                        default=-1,
                        help="Training Steps, default is -1, "
                        "meaning run all steps according to epoch number.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")
    parser.add_argument("--enable_graph_kernel",
                        type=str,
                        default="auto",
                        choices=["auto", "true", "false"],
                        help="Accelerate by graph kernel, default is auto.")

    parser.add_argument("--optimizer",
                        type=str,
                        default="AdamWeightDecay",
                        choices=["AdamWeightDecay", "Lamb", "Momentum"],
                        help="Optimizer, default is AdamWeightDecay.")
    parser.add_argument(
        "--enable_global_norm",
        type=str,
        default="true",
        choices=["true", "false"],
        help="Enable gloabl norm for grad clip, default is true.")
    parser.add_argument("--batch_size",
                        type=int,
                        default=32,
                        help="Batch size, default is 32.")
    parser.add_argument("--dtype",
                        type=str,
                        default="fp32",
                        choices=["fp32", "fp16"],
                        help="dtype, default is fp32.")

    args_opt = parser.parse_args()
    cfg.optimizer = args_opt.optimizer
    cfg.batch_size = args_opt.batch_size
    cfg.enable_global_norm = True if args_opt.enable_global_norm == "true" else False
    bert_net_cfg.compute_type = mstype.float32 if args_opt.dtype == "fp32" else mstype.float16
    logger.warning("\nargs_opt: {}".format(args_opt))
    logger.warning("\ncfg: {}".format(cfg))

    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    is_auto_enable_graph_kernel = _auto_enable_graph_kernel(
        args_opt.device_target, args_opt.enable_graph_kernel)
    if args_opt.enable_graph_kernel == "true" or is_auto_enable_graph_kernel:
        context.set_context(enable_graph_kernel=True)
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
        ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
            get_rank()) + '/'

        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)
        _set_bert_all_reduce_split(args_opt.device_target,
                                   context.get_context('enable_graph_kernel'))
    else:
        rank = 0
        device_num = 1

    if args_opt.accumulation_steps > 1:
        logger.info("accumulation steps: {}".format(
            args_opt.accumulation_steps))
        logger.info("global batch size: {}".format(
            cfg.batch_size * args_opt.accumulation_steps))
        if args_opt.enable_data_sink == "true":
            args_opt.data_sink_steps *= args_opt.accumulation_steps
            logger.info("data sink steps: {}".format(args_opt.data_sink_steps))
        if args_opt.enable_save_ckpt == "true":
            args_opt.save_checkpoint_steps *= args_opt.accumulation_steps
            logger.info("save checkpoint steps: {}".format(
                args_opt.save_checkpoint_steps))

    ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle,
                             args_opt.data_dir, args_opt.schema_dir)
    net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)

    new_repeat_count = args_opt.epoch_size * ds.get_dataset_size(
    ) // args_opt.data_sink_steps
    if args_opt.train_steps > 0:
        train_steps = args_opt.train_steps * args_opt.accumulation_steps
        new_repeat_count = min(new_repeat_count,
                               train_steps // args_opt.data_sink_steps)
    else:
        args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size(
        ) // args_opt.accumulation_steps
        logger.info("train steps: {}".format(args_opt.train_steps))

    optimizer = _get_optimizer(args_opt, net_with_loss)
    callback = [
        TimeMonitor(args_opt.data_sink_steps),
        LossCallBack(ds.get_dataset_size())
    ]
    if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(
            8, device_num) == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(
            prefix='checkpoint_bert',
            directory=None if ckpt_save_dir == "" else ckpt_save_dir,
            config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(net_with_loss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)

        if args_opt.accumulation_steps <= 1:
            if cfg.optimizer == 'AdamWeightDecay' and args_opt.device_target == 'GPU':
                net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(
                    net_with_loss,
                    optimizer=optimizer,
                    scale_update_cell=update_cell)
            else:
                net_with_grads = BertTrainOneStepWithLossScaleCell(
                    net_with_loss,
                    optimizer=optimizer,
                    scale_update_cell=update_cell)
        else:
            accumulation_steps = args_opt.accumulation_steps
            net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(
                net_with_loss,
                optimizer=optimizer,
                scale_update_cell=update_cell,
                accumulation_steps=accumulation_steps,
                enable_global_norm=cfg.enable_global_norm)
    else:
        net_with_grads = BertTrainOneStepCell(net_with_loss,
                                              optimizer=optimizer)

    model = Model(net_with_grads)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
Exemplo n.º 18
0
def train_on_ascend():
    config = config_ascend_quant if args_opt.quantization_aware else config_ascend
    print("training args: {}".format(args_opt))
    print("training configure: {}".format(config))
    print("parallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
    epoch_size = config.epoch_size

    # distribute init
    if run_distribute:
        context.set_auto_parallel_context(device_num=rank_size,
                                          parallel_mode=ParallelMode.DATA_PARALLEL,
                                          parameter_broadcast=True,
                                          mirror_mean=True)
        init()

    # define network
    network = mobilenetV2(num_classes=config.num_classes)
    # define loss
    if config.label_smooth > 0:
        loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth, num_classes=config.num_classes)
    else:
        loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
    # define dataset
    dataset = create_dataset(dataset_path=args_opt.dataset_path,
                             do_train=True,
                             config=config,
                             device_target=args_opt.device_target,
                             repeat_num=1,
                             batch_size=config.batch_size)
    step_size = dataset.get_dataset_size()
    # load pre trained ckpt
    if args_opt.pre_trained:
        param_dict = load_checkpoint(args_opt.pre_trained)
        load_param_into_net(network, param_dict)

    # convert fusion network to quantization aware network
    if config.quantization_aware:
        network = quant.convert_quant_network(network,
                                              bn_fold=True,
                                              per_channel=[True, False],
                                              symmetric=[True, False])

    # get learning rate
    lr = Tensor(get_lr(global_step=config.start_epoch * step_size,
                       lr_init=0,
                       lr_end=0,
                       lr_max=config.lr,
                       warmup_epochs=config.warmup_epochs,
                       total_epochs=epoch_size + config.start_epoch,
                       steps_per_epoch=step_size))

    # define optimization
    opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), lr, config.momentum,
                      config.weight_decay)
    # define model
    model = Model(network, loss_fn=loss, optimizer=opt)

    print("============== Starting Training ==============")
    callback = None
    if rank_id == 0:
        callback = [Monitor(lr_init=lr.asnumpy())]
        if config.save_checkpoint:
            config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
                                         keep_checkpoint_max=config.keep_checkpoint_max)
            ckpt_cb = ModelCheckpoint(prefix="mobilenetV2",
                                      directory=config.save_checkpoint_path,
                                      config=config_ck)
            callback += [ckpt_cb]
    model.train(epoch_size, dataset, callbacks=callback)
    print("============== End Training ==============")
Exemplo n.º 19
0
def test_train():
    """train entry method"""
    if args.is_distributed:
        if args.device_target == "Ascend":
            init()
            context.set_context(device_id=args.device_id)
        elif args.device_target == "GPU":
            init()

        args.rank = get_rank()
        args.group_size = get_group_size()
        device_num = args.group_size
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            device_num=device_num,
            parallel_mode=ParallelMode.DATA_PARALLEL,
            parameter_broadcast=True,
            gradients_mean=True)
    else:
        context.set_context(device_id=args.device_id)
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args.device_target)

    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)

    layers = cfg.layers
    num_factors = cfg.num_factors
    epochs = args.train_epochs

    ds_train, num_train_users, num_train_items = create_dataset(
        test_train=True,
        data_dir=args.data_path,
        dataset=args.dataset,
        train_epochs=1,
        batch_size=args.batch_size,
        num_neg=args.num_neg)
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))

    ncf_net = NCFModel(num_users=num_train_users,
                       num_items=num_train_items,
                       num_factors=num_factors,
                       model_layers=layers,
                       mf_regularization=0,
                       mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
                       mf_dim=16)
    loss_net = NetWithLossClass(ncf_net)
    train_net = TrainStepWrap(loss_net)

    train_net.set_train()

    model = Model(train_net)
    callback = LossMonitor(per_print_times=ds_train.get_dataset_size())
    ckpt_config = CheckpointConfig(
        save_checkpoint_steps=(4970845 + args.batch_size - 1) //
        (args.batch_size),
        keep_checkpoint_max=100)
    ckpoint_cb = ModelCheckpoint(prefix='NCF',
                                 directory=args.checkpoint_path,
                                 config=ckpt_config)
    model.train(epochs,
                ds_train,
                callbacks=[
                    TimeMonitor(ds_train.get_dataset_size()), callback,
                    ckpoint_cb
                ],
                dataset_sink_mode=True)
Exemplo n.º 20
0
def train():
    args = parse_args()
    cfg = FCN8s_VOC2012_cfg
    device_num = int(os.environ.get("DEVICE_NUM", 1))
    context.set_context(mode=context.GRAPH_MODE,
                        enable_auto_mixed_precision=True,
                        save_graphs=False,
                        device_target="Ascend",
                        device_id=args.device_id)
    # init multicards training
    args.rank = 0
    args.group_size = 1
    if device_num > 1:
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          gradients_mean=True,
                                          device_num=device_num)
        init()
        args.rank = get_rank()
        args.group_size = get_group_size()

    # dataset
    dataset = data_generator.SegDataset(image_mean=cfg.image_mean,
                                        image_std=cfg.image_std,
                                        data_file=cfg.data_file,
                                        batch_size=cfg.batch_size,
                                        crop_size=cfg.crop_size,
                                        max_scale=cfg.max_scale,
                                        min_scale=cfg.min_scale,
                                        ignore_label=cfg.ignore_label,
                                        num_classes=cfg.num_classes,
                                        num_readers=2,
                                        num_parallel_calls=4,
                                        shard_id=args.rank,
                                        shard_num=args.group_size)
    dataset = dataset.get_dataset(repeat=1)

    net = FCN8s(n_class=cfg.num_classes)
    loss_ = loss.SoftmaxCrossEntropyLoss(cfg.num_classes, cfg.ignore_label)

    # load pretrained vgg16 parameters to init FCN8s
    if cfg.ckpt_vgg16:
        param_vgg = load_checkpoint(cfg.ckpt_vgg16)
        param_dict = {}
        for layer_id in range(1, 6):
            sub_layer_num = 2 if layer_id < 3 else 3
            for sub_layer_id in range(sub_layer_num):
                # conv param
                y_weight = 'conv{}.{}.weight'.format(layer_id,
                                                     3 * sub_layer_id)
                x_weight = 'vgg16_feature_extractor.conv{}_{}.0.weight'.format(
                    layer_id, sub_layer_id + 1)
                param_dict[y_weight] = param_vgg[x_weight]
                # BatchNorm param
                y_gamma = 'conv{}.{}.gamma'.format(layer_id,
                                                   3 * sub_layer_id + 1)
                y_beta = 'conv{}.{}.beta'.format(layer_id,
                                                 3 * sub_layer_id + 1)
                x_gamma = 'vgg16_feature_extractor.conv{}_{}.1.gamma'.format(
                    layer_id, sub_layer_id + 1)
                x_beta = 'vgg16_feature_extractor.conv{}_{}.1.beta'.format(
                    layer_id, sub_layer_id + 1)
                param_dict[y_gamma] = param_vgg[x_gamma]
                param_dict[y_beta] = param_vgg[x_beta]
        load_param_into_net(net, param_dict)
    # load pretrained FCN8s
    elif cfg.ckpt_pre_trained:
        param_dict = load_checkpoint(cfg.ckpt_pre_trained)
        load_param_into_net(net, param_dict)

    # optimizer
    iters_per_epoch = dataset.get_dataset_size()

    lr_scheduler = CosineAnnealingLR(cfg.base_lr,
                                     cfg.train_epochs,
                                     iters_per_epoch,
                                     cfg.train_epochs,
                                     warmup_epochs=0,
                                     eta_min=0)
    lr = Tensor(lr_scheduler.get_lr())

    # loss scale
    manager_loss_scale = FixedLossScaleManager(cfg.loss_scale,
                                               drop_overflow_update=False)

    optimizer = nn.Momentum(params=net.trainable_params(),
                            learning_rate=lr,
                            momentum=0.9,
                            weight_decay=0.0001,
                            loss_scale=cfg.loss_scale)

    model = Model(net,
                  loss_fn=loss_,
                  loss_scale_manager=manager_loss_scale,
                  optimizer=optimizer,
                  amp_level="O3")

    # callback for saving ckpts
    time_cb = TimeMonitor(data_size=iters_per_epoch)
    loss_cb = LossMonitor()
    cbs = [time_cb, loss_cb]

    if args.rank == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=cfg.save_steps,
            keep_checkpoint_max=cfg.keep_checkpoint_max)
        ckpoint_cb = ModelCheckpoint(prefix=cfg.model,
                                     directory=cfg.train_dir,
                                     config=config_ck)
        cbs.append(ckpoint_cb)

    model.train(cfg.train_epochs, dataset, callbacks=cbs)
Exemplo n.º 21
0
def train(cfg):

    context.set_context(mode=context.GRAPH_MODE,
                        device_target='GPU',
                        save_graphs=False)
    if cfg['ngpu'] > 1:
        init("nccl")
        context.set_auto_parallel_context(
            device_num=get_group_size(),
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True)
        cfg['ckpt_path'] = cfg['ckpt_path'] + "ckpt_" + str(get_rank()) + "/"
    else:
        raise ValueError('cfg_num_gpu <= 1')

    batch_size = cfg['batch_size']
    max_epoch = cfg['epoch']

    momentum = cfg['momentum']
    weight_decay = cfg['weight_decay']
    initial_lr = cfg['initial_lr']
    gamma = cfg['gamma']
    training_dataset = cfg['training_dataset']
    num_classes = 2
    negative_ratio = 7
    stepvalues = (cfg['decay1'], cfg['decay2'])

    ds_train = create_dataset(training_dataset,
                              cfg,
                              batch_size,
                              multiprocessing=True,
                              num_worker=cfg['num_workers'])
    print('dataset size is : \n', ds_train.get_dataset_size())

    steps_per_epoch = math.ceil(ds_train.get_dataset_size())

    multibox_loss = MultiBoxLoss(num_classes, cfg['num_anchor'],
                                 negative_ratio, cfg['batch_size'])
    backbone = resnet50(1001)
    backbone.set_train(True)

    if cfg['pretrain'] and cfg['resume_net'] is None:
        pretrained_res50 = cfg['pretrain_path']
        param_dict_res50 = load_checkpoint(pretrained_res50)
        load_param_into_net(backbone, param_dict_res50)
        print('Load resnet50 from [{}] done.'.format(pretrained_res50))

    net = RetinaFace(phase='train', backbone=backbone)
    net.set_train(True)

    if cfg['resume_net'] is not None:
        pretrain_model_path = cfg['resume_net']
        param_dict_retinaface = load_checkpoint(pretrain_model_path)
        load_param_into_net(net, param_dict_retinaface)
        print('Resume Model from [{}] Done.'.format(cfg['resume_net']))

    net = RetinaFaceWithLossCell(net, multibox_loss, cfg)

    lr = adjust_learning_rate(initial_lr,
                              gamma,
                              stepvalues,
                              steps_per_epoch,
                              max_epoch,
                              warmup_epoch=cfg['warmup_epoch'])

    if cfg['optim'] == 'momentum':
        opt = mindspore.nn.Momentum(net.trainable_params(), lr, momentum)
    elif cfg['optim'] == 'sgd':
        opt = mindspore.nn.SGD(params=net.trainable_params(),
                               learning_rate=lr,
                               momentum=momentum,
                               weight_decay=weight_decay,
                               loss_scale=1)
    else:
        raise ValueError('optim is not define.')

    net = TrainingWrapper(net, opt)

    model = Model(net)

    config_ck = CheckpointConfig(
        save_checkpoint_steps=cfg['save_checkpoint_steps'],
        keep_checkpoint_max=cfg['keep_checkpoint_max'])
    ckpoint_cb = ModelCheckpoint(prefix="RetinaFace",
                                 directory=cfg['ckpt_path'],
                                 config=config_ck)

    time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
    callback_list = [LossMonitor(), time_cb, ckpoint_cb]

    print("============== Starting Training ==============")
    model.train(max_epoch,
                ds_train,
                callbacks=callback_list,
                dataset_sink_mode=True)
def train_and_eval(config):
    """
    test_train_eval
    """
    set_seed(1000)
    data_path = config.data_path
    batch_size = config.batch_size
    sparse = config.sparse
    epochs = config.epochs
    if config.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif config.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        dataset_type = DataType.H5
    print("epochs is {}".format(epochs))
    ds_train = create_dataset(data_path,
                              train_mode=True,
                              epochs=1,
                              batch_size=batch_size,
                              rank_id=get_rank(),
                              rank_size=get_group_size(),
                              data_type=dataset_type)
    ds_eval = create_dataset(data_path,
                             train_mode=False,
                             epochs=1,
                             batch_size=batch_size,
                             rank_id=get_rank(),
                             rank_size=get_group_size(),
                             data_type=dataset_type)
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net,
                  eval_network=eval_net,
                  metrics={"auc": auc_metric})

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

    callback = LossCallBack(config=config)
    ckptconfig = CheckpointConfig(
        save_checkpoint_steps=ds_train.get_dataset_size(),
        keep_checkpoint_max=5)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path + '/ckpt_' +
                                 str(get_rank()) + '/',
                                 config=ckptconfig)
    out = model.eval(ds_eval, dataset_sink_mode=(not sparse))
    print("=====" * 5 + "model.eval() initialized: {}".format(out))
    callback_list = [
        TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback
    ]
    if get_rank() == 0:
        callback_list.append(ckpoint_cb)
    model.train(epochs,
                ds_train,
                callbacks=callback_list,
                sink_size=ds_train.get_dataset_size(),
                dataset_sink_mode=(not sparse))
Exemplo n.º 23
0
def run_transformer_train():
    """
    Transformer training.
    """
    parser = argparse_init()
    args, _ = parser.parse_known_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        device_id=args.device_id)
    context.set_context(reserve_class_name_in_scope=False,
                        enable_auto_mixed_precision=False)

    if args.distribute == "true":
        device_num = args.device_num
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            parameter_broadcast=True,
            device_num=device_num)
        D.init()
        rank_id = args.device_id % device_num
    else:
        device_num = 1
        rank_id = 0
    dataset, repeat_count = create_transformer_dataset(
        epoch_count=args.epoch_size,
        rank_size=device_num,
        rank_id=rank_id,
        do_shuffle=args.do_shuffle,
        enable_data_sink=args.enable_data_sink,
        dataset_path=args.data_path)

    netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)

    if args.checkpoint_path:
        parameter_dict = load_checkpoint(args.checkpoint_path)
        load_param_into_net(netwithloss, parameter_dict)

    lr = Tensor(
        create_dynamic_lr(
            schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay",
            training_steps=dataset.get_dataset_size() * args.epoch_size,
            learning_rate=cfg.lr_schedule.learning_rate,
            warmup_steps=cfg.lr_schedule.warmup_steps,
            hidden_size=transformer_net_cfg.hidden_size,
            start_decay_step=cfg.lr_schedule.start_decay_step,
            min_lr=cfg.lr_schedule.min_lr), mstype.float32)
    optimizer = Adam(netwithloss.trainable_params(), lr)

    callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()]
    if args.enable_save_ckpt == "true":
        if device_num == 1 or (device_num > 1 and rank_id == 0):
            ckpt_config = CheckpointConfig(
                save_checkpoint_steps=args.save_checkpoint_steps,
                keep_checkpoint_max=args.save_checkpoint_num)
            ckpoint_cb = ModelCheckpoint(prefix='transformer',
                                         directory=args.save_checkpoint_path,
                                         config=ckpt_config)
            callbacks.append(ckpoint_cb)

    if args.enable_lossscale == "true":
        scale_manager = DynamicLossScaleManager(
            init_loss_scale=cfg.init_loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        update_cell = scale_manager.get_update_cell()
        netwithgrads = TransformerTrainOneStepWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = TransformerTrainOneStepCell(netwithloss,
                                                   optimizer=optimizer)

    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    model.train(repeat_count,
                dataset,
                callbacks=callbacks,
                dataset_sink_mode=(args.enable_data_sink == "true"))
Exemplo n.º 24
0
def train_and_eval(config):
    """
    test_train_eval
    """
    set_seed(1000)
    data_path = config.data_path
    batch_size = config.batch_size
    epochs = config.epochs
    if config.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif config.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        dataset_type = DataType.H5
    parameter_server = bool(config.parameter_server)
    if cache_enable:
        config.full_batch = True
    print("epochs is {}".format(epochs))
    if config.full_batch:
        context.set_auto_parallel_context(full_batch=True)
        ds.config.set_seed(1)
        ds_train = create_dataset(data_path,
                                  train_mode=True,
                                  epochs=1,
                                  batch_size=batch_size * get_group_size(),
                                  data_type=dataset_type)
        ds_eval = create_dataset(data_path,
                                 train_mode=False,
                                 epochs=1,
                                 batch_size=batch_size * get_group_size(),
                                 data_type=dataset_type)
    else:
        ds_train = create_dataset(data_path,
                                  train_mode=True,
                                  epochs=1,
                                  batch_size=batch_size,
                                  rank_id=get_rank(),
                                  rank_size=get_group_size(),
                                  data_type=dataset_type)
        ds_eval = create_dataset(data_path,
                                 train_mode=False,
                                 epochs=1,
                                 batch_size=batch_size,
                                 rank_id=get_rank(),
                                 rank_size=get_group_size(),
                                 data_type=dataset_type)
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net,
                  eval_network=eval_net,
                  metrics={"auc": auc_metric})

    if cache_enable:
        config.stra_ckpt = os.path.join(
            config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt")
        context.set_auto_parallel_context(
            strategy_ckpt_save_file=config.stra_ckpt)

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

    callback = LossCallBack(config=config)
    if _is_role_worker():
        if cache_enable:
            ckptconfig = CheckpointConfig(
                save_checkpoint_steps=ds_train.get_dataset_size() * epochs,
                keep_checkpoint_max=1,
                integrated_save=False)
        else:
            ckptconfig = CheckpointConfig(
                save_checkpoint_steps=ds_train.get_dataset_size(),
                keep_checkpoint_max=5)
    else:
        ckptconfig = CheckpointConfig(save_checkpoint_steps=1,
                                      keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path + '/ckpt_' +
                                 str(get_rank()) + '/',
                                 config=ckptconfig)
    callback_list = [
        TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback
    ]
    if get_rank() == 0:
        callback_list.append(ckpoint_cb)
    model.train(epochs,
                ds_train,
                callbacks=callback_list,
                dataset_sink_mode=(parameter_server and cache_enable))
Exemplo n.º 25
0
                      keep_batchnorm_fp32=False,
                      loss_scale_manager=loss_scale_manager)
    elif device_target == "GPU":
        model = Model(network,
                      loss_fn=loss,
                      optimizer=opt,
                      metrics=metrics,
                      loss_scale_manager=loss_scale_manager)
    else:
        raise ValueError("Unsupported platform.")

    if device_num > 1:
        ckpt_save_dir = os.path.join(args.ckpt_path + "_" + str(get_rank()))
    else:
        ckpt_save_dir = args.ckpt_path

    time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
    config_ck = CheckpointConfig(
        save_checkpoint_steps=ds_train.get_dataset_size(),
        keep_checkpoint_max=cfg.keep_checkpoint_max)
    ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet",
                                 directory=ckpt_save_dir,
                                 config=config_ck)

    print("============== Starting Training ==============")
    model.train(cfg.epoch_size,
                ds_train,
                callbacks=[time_cb, ckpoint_cb,
                           LossMonitor()],
                dataset_sink_mode=args.dataset_sink_mode)
Exemplo n.º 26
0
def run_pretrain():
    """pre-train ernie"""
    parser = argparse_init()
    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    is_auto_enable_graph_kernel = _auto_enable_graph_kernel(
        args_opt.device_target, args_opt.enable_graph_kernel)
    _set_graph_kernel_context(args_opt.device_target,
                              args_opt.enable_graph_kernel,
                              is_auto_enable_graph_kernel)
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
        ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
            get_rank()) + '/'

        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)
        _set_ernie_all_reduce_split()
    else:
        rank = 0
        device_num = 1

    _check_compute_type(args_opt, is_auto_enable_graph_kernel)

    if args_opt.accumulation_steps > 1:
        logger.info("accumulation steps: {}".format(
            args_opt.accumulation_steps))
        logger.info("global batch size: {}".format(
            cfg.batch_size * args_opt.accumulation_steps))
        if args_opt.enable_data_sink == "true":
            args_opt.data_sink_steps *= args_opt.accumulation_steps
            logger.info("data sink steps: {}".format(args_opt.data_sink_steps))
        if args_opt.enable_save_ckpt == "true":
            args_opt.save_checkpoint_steps *= args_opt.accumulation_steps
            logger.info("save checkpoint steps: {}".format(
                args_opt.save_checkpoint_steps))

    ds = create_ernie_dataset(device_num, rank, args_opt.do_shuffle,
                              args_opt.data_dir, args_opt.schema_dir)
    net_with_loss = ErnieNetworkWithLoss(ernie_net_cfg, True)

    new_repeat_count = args_opt.epoch_size * ds.get_dataset_size(
    ) // args_opt.data_sink_steps
    if args_opt.train_steps > 0:
        train_steps = args_opt.train_steps * args_opt.accumulation_steps
        new_repeat_count = min(new_repeat_count,
                               train_steps // args_opt.data_sink_steps)
    else:
        args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size(
        ) // args_opt.accumulation_steps
        logger.info("train steps: {}".format(args_opt.train_steps))

    optimizer = _get_optimizer(args_opt, net_with_loss)
    callback = [
        TimeMonitor(args_opt.data_sink_steps),
        LossCallBack(ds.get_dataset_size())
    ]
    if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(
            8, device_num) == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(
            prefix='checkpoint_ernie',
            directory=None if ckpt_save_dir == "" else ckpt_save_dir,
            config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(net_with_loss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        accumulation_steps = args_opt.accumulation_steps
        enable_global_norm = cfg.enable_global_norm
        if accumulation_steps <= 1:
            if cfg.optimizer == 'AdamWeightDecay' and args_opt.device_target == 'GPU':
                net_with_grads = ErnieTrainOneStepWithLossScaleCellForAdam(
                    net_with_loss,
                    optimizer=optimizer,
                    scale_update_cell=update_cell)
            else:
                net_with_grads = ErnieTrainOneStepWithLossScaleCell(
                    net_with_loss,
                    optimizer=optimizer,
                    scale_update_cell=update_cell)
        else:
            allreduce_post = args_opt.distribute == "false" or args_opt.allreduce_post_accumulation == "true"
            net_with_accumulation = (
                ErnieTrainAccumulationAllReducePostWithLossScaleCell
                if allreduce_post else
                ErnieTrainAccumulationAllReduceEachWithLossScaleCell)
            net_with_grads = net_with_accumulation(
                net_with_loss,
                optimizer=optimizer,
                scale_update_cell=update_cell,
                accumulation_steps=accumulation_steps,
                enable_global_norm=enable_global_norm)
    else:
        net_with_grads = ErnieTrainOneStepCell(net_with_loss,
                                               optimizer=optimizer,
                                               enable_clip_grad=True)

    model = Model(net_with_grads)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
Exemplo n.º 27
0
    ds1 = ds1.batch(batch_size, drop_remainder=True)
    ds1 = ds1.repeat(repeat_size)

    return ds1


if __name__ == "__main__":
    # This configure can run both in pynative mode and graph mode
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=cfg.device_target)
    network = LeNet5()
    net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False,
                                                sparse=True,
                                                reduction="mean")
    config_ck = CheckpointConfig(
        save_checkpoint_steps=cfg.save_checkpoint_steps,
        keep_checkpoint_max=cfg.keep_checkpoint_max)
    ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
                                 directory='./trained_ckpt_file/',
                                 config=config_ck)

    # get training dataset
    ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"),
                                      cfg.batch_size, cfg.epoch_size)

    if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0:
        raise ValueError(
            "Number of micro_batches should divide evenly batch_size")
    # Create a factory class of DP mechanisms, this method is adding noise in gradients while training.
    # Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which
    # means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise
Exemplo n.º 28
0
def train_net(args_opt,
              cross_valid_ind=1,
              epochs=400,
              batch_size=16,
              lr=0.0001,
              cfg=None):
    rank = 0
    group_size = 1
    data_dir = args_opt.data_url
    run_distribute = args_opt.run_distribute
    if run_distribute:
        init()
        group_size = get_group_size()
        rank = get_rank()
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=group_size,
                                          gradients_mean=False)
    need_slice = False
    if cfg['model'] == 'unet_medical':
        net = UNetMedical(n_channels=cfg['num_channels'],
                          n_classes=cfg['num_classes'])
    elif cfg['model'] == 'unet_nested':
        net = NestedUNet(in_channel=cfg['num_channels'],
                         n_class=cfg['num_classes'],
                         use_deconv=cfg['use_deconv'],
                         use_bn=cfg['use_bn'],
                         use_ds=cfg['use_ds'])
        need_slice = cfg['use_ds']
    elif cfg['model'] == 'unet_simple':
        net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
    else:
        raise ValueError("Unsupported model: {}".format(cfg['model']))

    if cfg['resume']:
        param_dict = load_checkpoint(cfg['resume_ckpt'])
        if cfg['transfer_training']:
            filter_checkpoint_parameter_by_list(param_dict,
                                                cfg['filter_weight'])
        load_param_into_net(net, param_dict)

    if 'use_ds' in cfg and cfg['use_ds']:
        criterion = MultiCrossEntropyWithLogits()
    else:
        criterion = CrossEntropyWithLogits()
    if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
        repeat = cfg['repeat']
        dataset_sink_mode = True
        per_print_times = 0
        train_dataset = create_cell_nuclei_dataset(data_dir,
                                                   cfg['img_size'],
                                                   repeat,
                                                   batch_size,
                                                   is_train=True,
                                                   augment=True,
                                                   split=0.8,
                                                   rank=rank,
                                                   group_size=group_size)
        valid_dataset = create_cell_nuclei_dataset(
            data_dir,
            cfg['img_size'],
            1,
            1,
            is_train=False,
            eval_resize=cfg["eval_resize"],
            split=0.8,
            python_multiprocessing=False)
    else:
        repeat = cfg['repeat']
        dataset_sink_mode = False
        per_print_times = 1
        train_dataset, valid_dataset = create_dataset(
            data_dir, repeat, batch_size, True, cross_valid_ind,
            run_distribute, cfg["crop"], cfg['img_size'])
    train_data_size = train_dataset.get_dataset_size()
    print("dataset length is:", train_data_size)
    ckpt_config = CheckpointConfig(
        save_checkpoint_steps=train_data_size,
        keep_checkpoint_max=cfg['keep_checkpoint_max'])
    ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(cfg['model']),
                                 directory='./ckpt_{}/'.format(device_id),
                                 config=ckpt_config)

    optimizer = nn.Adam(params=net.trainable_params(),
                        learning_rate=lr,
                        weight_decay=cfg['weight_decay'],
                        loss_scale=cfg['loss_scale'])

    loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(
        cfg['FixedLossScaleManager'], False)

    model = Model(net,
                  loss_fn=criterion,
                  loss_scale_manager=loss_scale_manager,
                  optimizer=optimizer,
                  amp_level="O3")

    print("============== Starting Training ==============")
    callbacks = [
        StepLossTimeMonitor(batch_size=batch_size,
                            per_print_times=per_print_times), ckpoint_cb
    ]
    if args_opt.run_eval:
        eval_model = Model(UnetEval(net, need_slice=need_slice),
                           loss_fn=TempLoss(),
                           metrics={"dice_coeff": dice_coeff(cfg_unet, False)})
        eval_param_dict = {
            "model": eval_model,
            "dataset": valid_dataset,
            "metrics_name": args_opt.eval_metrics
        }
        eval_cb = EvalCallBack(apply_eval,
                               eval_param_dict,
                               interval=args_opt.eval_interval,
                               eval_start_epoch=args_opt.eval_start_epoch,
                               save_best_ckpt=True,
                               ckpt_directory='./ckpt_{}/'.format(device_id),
                               besk_ckpt_name="best.ckpt",
                               metrics_name=args_opt.eval_metrics)
        callbacks.append(eval_cb)
    model.train(int(epochs / repeat),
                train_dataset,
                callbacks=callbacks,
                dataset_sink_mode=dataset_sink_mode)
    print("============== End Training ==============")
Exemplo n.º 29
0
def train():
    args = parse_args()

    if args.device_target == "CPU":
        context.set_context(mode=context.GRAPH_MODE,
                            save_graphs=False,
                            device_target="CPU")
    else:
        context.set_context(mode=context.GRAPH_MODE,
                            enable_auto_mixed_precision=True,
                            save_graphs=False,
                            device_target="Ascend",
                            device_id=int(os.getenv('DEVICE_ID')))

    # init multicards training
    if args.is_distributed:
        init()
        args.rank = get_rank()
        args.group_size = get_group_size()

        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          gradients_mean=True,
                                          device_num=args.group_size)

    # dataset
    dataset = data_generator.SegDataset(image_mean=args.image_mean,
                                        image_std=args.image_std,
                                        data_file=args.data_file,
                                        batch_size=args.batch_size,
                                        crop_size=args.crop_size,
                                        max_scale=args.max_scale,
                                        min_scale=args.min_scale,
                                        ignore_label=args.ignore_label,
                                        num_classes=args.num_classes,
                                        num_readers=2,
                                        num_parallel_calls=4,
                                        shard_id=args.rank,
                                        shard_num=args.group_size)
    dataset = dataset.get_dataset(repeat=1)

    # network
    if args.model == 'deeplab_v3_s16':
        network = net_factory.nets_map[args.model]('train', args.num_classes,
                                                   16, args.freeze_bn)
    elif args.model == 'deeplab_v3_s8':
        network = net_factory.nets_map[args.model]('train', args.num_classes,
                                                   8, args.freeze_bn)
    else:
        raise NotImplementedError('model [{:s}] not recognized'.format(
            args.model))

    # loss
    loss_ = loss.SoftmaxCrossEntropyLoss(args.num_classes, args.ignore_label)
    loss_.add_flags_recursive(fp32=True)
    train_net = BuildTrainNetwork(network, loss_)

    # load pretrained model
    if args.ckpt_pre_trained:
        param_dict = load_checkpoint(args.ckpt_pre_trained)
        if args.filter_weight:
            for key in list(param_dict.keys()):
                if key in [
                        "network.aspp.conv2.weight", "network.aspp.conv2.bias"
                ]:
                    print('filter {}'.format(key))
                    del param_dict[key]
        load_param_into_net(train_net, param_dict)
        print('load_model {} success'.format(args.ckpt_pre_trained))

    # optimizer
    iters_per_epoch = dataset.get_dataset_size()
    total_train_steps = iters_per_epoch * args.train_epochs
    if args.lr_type == 'cos':
        lr_iter = learning_rates.cosine_lr(args.base_lr, total_train_steps,
                                           total_train_steps)
    elif args.lr_type == 'poly':
        lr_iter = learning_rates.poly_lr(args.base_lr,
                                         total_train_steps,
                                         total_train_steps,
                                         end_lr=0.0,
                                         power=0.9)
    elif args.lr_type == 'exp':
        lr_iter = learning_rates.exponential_lr(args.base_lr,
                                                args.lr_decay_step,
                                                args.lr_decay_rate,
                                                total_train_steps,
                                                staircase=True)
    else:
        raise ValueError('unknown learning rate type')
    opt = nn.Momentum(params=train_net.trainable_params(),
                      learning_rate=lr_iter,
                      momentum=0.9,
                      weight_decay=0.0001,
                      loss_scale=args.loss_scale)

    # loss scale
    manager_loss_scale = FixedLossScaleManager(args.loss_scale,
                                               drop_overflow_update=False)
    amp_level = "O0" if args.device_target == "CPU" else "O3"
    model = Model(train_net,
                  optimizer=opt,
                  amp_level=amp_level,
                  loss_scale_manager=manager_loss_scale)

    # callback for saving ckpts
    time_cb = TimeMonitor(data_size=iters_per_epoch)
    loss_cb = LossMonitor()
    cbs = [time_cb, loss_cb]

    if args.rank == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args.save_steps,
            keep_checkpoint_max=args.keep_checkpoint_max)
        ckpoint_cb = ModelCheckpoint(prefix=args.model,
                                     directory=args.train_dir,
                                     config=config_ck)
        cbs.append(ckpoint_cb)

    model.train(args.train_epochs,
                dataset,
                callbacks=cbs,
                dataset_sink_mode=(args.device_target != "CPU"))
Exemplo n.º 30
0
def train(cloud_args=None):
    """training process"""
    args = parse_args(cloud_args)
    context.set_context(mode=context.GRAPH_MODE,
                        enable_auto_mixed_precision=True,
                        device_target=args.platform,
                        save_graphs=False)
    if os.getenv('DEVICE_ID', "not_set").isdigit():
        context.set_context(device_id=int(os.getenv('DEVICE_ID')))

    # init distributed
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=args.group_size,
                                          gradients_mean=True)
    # dataloader
    de_dataset = classification_dataset(args.data_dir,
                                        args.image_size,
                                        args.per_batch_size,
                                        1,
                                        args.rank,
                                        args.group_size,
                                        num_parallel_workers=8)
    de_dataset.map_model = 4  # !!!important
    args.steps_per_epoch = de_dataset.get_dataset_size()

    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')
    # get network and init
    network = get_network(args.backbone,
                          args.num_classes,
                          platform=args.platform)
    if network is None:
        raise NotImplementedError('not implement {}'.format(args.backbone))

    load_pretrain_model(args.pretrained, network, args)

    # lr scheduler
    lr = get_lr(args)

    # optimizer
    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)

    # loss
    if not args.label_smooth:
        args.label_smooth_factor = 0.0
    loss = CrossEntropy(smooth_factor=args.label_smooth_factor,
                        num_classes=args.num_classes)

    if args.is_dynamic_loss_scale == 1:
        loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536,
                                                     scale_factor=2,
                                                     scale_window=2000)
    else:
        loss_scale_manager = FixedLossScaleManager(args.loss_scale,
                                                   drop_overflow_update=False)

    if args.platform == "Ascend":
        model = Model(network,
                      loss_fn=loss,
                      optimizer=opt,
                      loss_scale_manager=loss_scale_manager,
                      metrics={'acc'},
                      amp_level="O3")
    else:
        model = Model(network,
                      loss_fn=loss,
                      optimizer=opt,
                      loss_scale_manager=loss_scale_manager,
                      metrics={'acc'},
                      amp_level="O2")

    # checkpoint save
    progress_cb = ProgressMonitor(args)
    callbacks = [
        progress_cb,
    ]
    if args.rank_save_ckpt_flag:
        ckpt_config = CheckpointConfig(
            save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
            keep_checkpoint_max=args.ckpt_save_max)
        save_ckpt_path = os.path.join(args.outputs_dir,
                                      'ckpt_' + str(args.rank) + '/')
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=save_ckpt_path,
                                  prefix='{}'.format(args.rank))
        callbacks.append(ckpt_cb)

    model.train(args.max_epoch,
                de_dataset,
                callbacks=callbacks,
                dataset_sink_mode=True)