Ejemplo n.º 1
0
def train():
    rank_id = 0
    if args.run_distribute:
        context.set_auto_parallel_context(
            device_num=args.device_num,
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            parameter_broadcast=True)
        init()
        rank_id = get_rank()

    # dataset/network/criterion/optim
    ds = train_dataset_creator(args.device_id, args.device_num)
    step_size = ds.get_dataset_size()
    print('Create dataset done!')

    config.INFERENCE = False
    net = ETSNet(config)
    net = net.set_train()
    param_dict = load_checkpoint(args.pre_trained)
    load_param_into_net(net, param_dict)
    print('Load Pretrained parameters done!')

    criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE)

    lrs = lr_generator(start_lr=1e-3,
                       lr_scale=0.1,
                       total_iters=config.TRAIN_TOTAL_ITER)
    opt = nn.SGD(params=net.trainable_params(),
                 learning_rate=lrs,
                 momentum=0.99,
                 weight_decay=5e-4)

    # warp model
    net = WithLossCell(net, criterion)
    if args.run_distribute:
        net = TrainOneStepCell(net,
                               opt,
                               reduce_flag=True,
                               mean=True,
                               degree=args.device_num)
    else:
        net = TrainOneStepCell(net, opt)

    time_cb = TimeMonitor(data_size=step_size)
    loss_cb = LossCallBack(per_print_times=10)
    # set and apply parameters of check point config.TRAIN_MODEL_SAVE_PATH
    ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875,
                                  keep_checkpoint_max=2)
    ckpoint_cb = ModelCheckpoint(prefix="ETSNet",
                                 config=ckpoint_cf,
                                 directory="./ckpt_{}".format(rank_id))

    model = Model(net)
    model.train(config.TRAIN_REPEAT_NUM,
                ds,
                dataset_sink_mode=True,
                callbacks=[time_cb, loss_cb, ckpoint_cb])
Ejemplo n.º 2
0
    lr = Tensor(dynamic_lr(training_cfg, dataset_size), mstype.float32)
    opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,\
        weight_decay=config.weight_decay, loss_scale=config.loss_scale)
    net_with_loss = WithLossCell(net, loss)
    if args_opt.run_distribute:
        net = TrainOneStepCell(net_with_loss,
                               opt,
                               sens=config.loss_scale,
                               reduce_flag=True,
                               mean=True,
                               degree=device_num)
    else:
        net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)

    time_cb = TimeMonitor(data_size=dataset_size)
    loss_cb = LossCallBack(rank_id=rank)
    cb = [time_cb, loss_cb]
    if config.save_checkpoint:
        ckptconfig = CheckpointConfig(
            save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size,
            keep_checkpoint_max=config.keep_checkpoint_max)
        save_checkpoint_path = os.path.join(config.save_checkpoint_path,
                                            "ckpt_" + str(rank) + "/")
        ckpoint_cb = ModelCheckpoint(prefix='ctpn',
                                     directory=save_checkpoint_path,
                                     config=ckptconfig)
        cb += [ckpoint_cb]

    model = Model(net)
    model.train(training_cfg.total_epoch,
                dataset,
Ejemplo n.º 3
0
            if config.pretrain_epoch_size == 0:
                for item in list(param_dict.keys()):
                    if not (item.startswith('backbone') or item.startswith('rcnn_mask')):
                        param_dict.pop(item)
            load_param_into_net(net, param_dict)

        loss = LossNet()
        lr = Tensor(dynamic_lr(config, rank_size=device_num, start_steps=config.pretrain_epoch_size * dataset_size),
                    mstype.float32)
        opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
                  weight_decay=config.weight_decay, loss_scale=config.loss_scale)

        net_with_loss = WithLossCell(net, loss)
        if args_opt.run_distribute:
            net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True,
                                   mean=True, degree=device_num)
        else:
            net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale)

        time_cb = TimeMonitor(data_size=dataset_size)
        loss_cb = LossCallBack()
        cb = [time_cb, loss_cb]
        if config.save_checkpoint:
            ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size,
                                          keep_checkpoint_max=config.keep_checkpoint_max)
            ckpoint_cb = ModelCheckpoint(prefix='mask_rcnn', directory=config.save_checkpoint_path, config=ckptconfig)
            cb += [ckpoint_cb]

        model = Model(net)
        model.train(config.epoch_size, dataset, callbacks=cb)