Ejemplo n.º 1
0
def train():
    rank_id = 0
    if args.run_distribute:
        context.set_auto_parallel_context(
            device_num=args.device_num,
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            parameter_broadcast=True)
        init()
        rank_id = get_rank()

    # dataset/network/criterion/optim
    ds = train_dataset_creator(args.device_id, args.device_num)
    step_size = ds.get_dataset_size()
    print('Create dataset done!')

    config.INFERENCE = False
    net = ETSNet(config)
    net = net.set_train()
    param_dict = load_checkpoint(args.pre_trained)
    load_param_into_net(net, param_dict)
    print('Load Pretrained parameters done!')

    criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE)

    lrs = lr_generator(start_lr=1e-3,
                       lr_scale=0.1,
                       total_iters=config.TRAIN_TOTAL_ITER)
    opt = nn.SGD(params=net.trainable_params(),
                 learning_rate=lrs,
                 momentum=0.99,
                 weight_decay=5e-4)

    # warp model
    net = WithLossCell(net, criterion)
    if args.run_distribute:
        net = TrainOneStepCell(net,
                               opt,
                               reduce_flag=True,
                               mean=True,
                               degree=args.device_num)
    else:
        net = TrainOneStepCell(net, opt)

    time_cb = TimeMonitor(data_size=step_size)
    loss_cb = LossCallBack(per_print_times=10)
    # set and apply parameters of check point config.TRAIN_MODEL_SAVE_PATH
    ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875,
                                  keep_checkpoint_max=2)
    ckpoint_cb = ModelCheckpoint(prefix="ETSNet",
                                 config=ckpoint_cf,
                                 directory="./ckpt_{}".format(rank_id))

    model = Model(net)
    model.train(config.TRAIN_REPEAT_NUM,
                ds,
                dataset_sink_mode=True,
                callbacks=[time_cb, loss_cb, ckpoint_cb])
Ejemplo n.º 2
0
                param_dict.pop(item)
        load_param_into_net(net, param_dict)
    else:
        if load_path != "":
            print("load pretrain ckpt {}".format(args_opt.pre_trained))
            param_dict = load_checkpoint(load_path)
            load_param_into_net(net, param_dict)
    loss = LossNet()
    lr = Tensor(dynamic_lr(training_cfg, dataset_size), mstype.float32)
    opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,\
        weight_decay=config.weight_decay, loss_scale=config.loss_scale)
    net_with_loss = WithLossCell(net, loss)
    if args_opt.run_distribute:
        net = TrainOneStepCell(net_with_loss,
                               opt,
                               sens=config.loss_scale,
                               reduce_flag=True,
                               mean=True,
                               degree=device_num)
    else:
        net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)

    time_cb = TimeMonitor(data_size=dataset_size)
    loss_cb = LossCallBack(rank_id=rank)
    cb = [time_cb, loss_cb]
    if config.save_checkpoint:
        ckptconfig = CheckpointConfig(
            save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size,
            keep_checkpoint_max=config.keep_checkpoint_max)
        save_checkpoint_path = os.path.join(config.save_checkpoint_path,
                                            "ckpt_" + str(rank) + "/")
        ckpoint_cb = ModelCheckpoint(prefix='ctpn',
Ejemplo n.º 3
0
                           total_epochs=config.epochs,
                           epoch_stage=config.epoch_stage,
                           steps_per_epoch=train_dataset_batch_num,
                           mode=config.lr_mode,
                           warmup_epoch=config.warmup_epoch), mstype.float32)

    opt = SGD(params=net_with_loss.trainable_params(),
              learning_rate=lr,
              momentum=config.momentum,
              weight_decay=config.weight_decay,
              loss_scale=config.loss_scale)

    if device_num > 1:
        net = TrainOneStepCell(net_with_loss,
                               opt,
                               reduce_flag=True,
                               mean=True,
                               degree=device_num)
    else:
        net = TrainOneStepCell(net_with_loss, opt)

    loss_cb = LossCallBack(data_size=train_dataset_batch_num, logger=logger)

    cb = [loss_cb]

    if config.save_checkpoint:
        ckptconfig = CheckpointConfig(
            save_checkpoint_steps=train_dataset_batch_num,
            keep_checkpoint_max=config.keep_checkpoint_max)
        ckpoint_cb = ModelCheckpoint(prefix='AVA',
                                     directory=save_checkpoint_path,