예제 #1
0
def train():
    rank_id = 0
    if args.run_distribute:
        context.set_auto_parallel_context(
            device_num=args.device_num,
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            parameter_broadcast=True)
        init()
        rank_id = get_rank()

    # dataset/network/criterion/optim
    ds = train_dataset_creator(args.device_id, args.device_num)
    step_size = ds.get_dataset_size()
    print('Create dataset done!')

    config.INFERENCE = False
    net = ETSNet(config)
    net = net.set_train()
    param_dict = load_checkpoint(args.pre_trained)
    load_param_into_net(net, param_dict)
    print('Load Pretrained parameters done!')

    criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE)

    lrs = lr_generator(start_lr=1e-3,
                       lr_scale=0.1,
                       total_iters=config.TRAIN_TOTAL_ITER)
    opt = nn.SGD(params=net.trainable_params(),
                 learning_rate=lrs,
                 momentum=0.99,
                 weight_decay=5e-4)

    # warp model
    net = WithLossCell(net, criterion)
    if args.run_distribute:
        net = TrainOneStepCell(net,
                               opt,
                               reduce_flag=True,
                               mean=True,
                               degree=args.device_num)
    else:
        net = TrainOneStepCell(net, opt)

    time_cb = TimeMonitor(data_size=step_size)
    loss_cb = LossCallBack(per_print_times=10)
    # set and apply parameters of check point config.TRAIN_MODEL_SAVE_PATH
    ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875,
                                  keep_checkpoint_max=2)
    ckpoint_cb = ModelCheckpoint(prefix="ETSNet",
                                 config=ckpoint_cf,
                                 directory="./ckpt_{}".format(rank_id))

    model = Model(net)
    model.train(config.TRAIN_REPEAT_NUM,
                ds,
                dataset_sink_mode=True,
                callbacks=[time_cb, loss_cb, ckpoint_cb])
예제 #2
0
        print("load backbone vgg16 ckpt {}".format(args_opt.pre_trained))
        param_dict = load_checkpoint(load_path)
        for item in list(param_dict.keys()):
            if not item.startswith('vgg16_feature_extractor'):
                param_dict.pop(item)
        load_param_into_net(net, param_dict)
    else:
        if load_path != "":
            print("load pretrain ckpt {}".format(args_opt.pre_trained))
            param_dict = load_checkpoint(load_path)
            load_param_into_net(net, param_dict)
    loss = LossNet()
    lr = Tensor(dynamic_lr(training_cfg, dataset_size), mstype.float32)
    opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,\
        weight_decay=config.weight_decay, loss_scale=config.loss_scale)
    net_with_loss = WithLossCell(net, loss)
    if args_opt.run_distribute:
        net = TrainOneStepCell(net_with_loss,
                               opt,
                               sens=config.loss_scale,
                               reduce_flag=True,
                               mean=True,
                               degree=device_num)
    else:
        net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)

    time_cb = TimeMonitor(data_size=dataset_size)
    loss_cb = LossCallBack(rank_id=rank)
    cb = [time_cb, loss_cb]
    if config.save_checkpoint:
        ckptconfig = CheckpointConfig(
예제 #3
0
                          training_mode=True,
                          use_MLP=config.use_MLP)
    elif config.net_work == 'resnet50':
        resnet = resnet50(low_dims=config.low_dims,
                          training_mode=True,
                          use_MLP=config.use_MLP)
    elif config.net_work == 'resnet101':
        resnet = resnet101(low_dims=config.low_dims,
                           training_mode=True,
                           use_MLP=config.use_MLP)
    else:
        raise ("net work config error!!!")

    loss = LossNet(temp=config.sigma)

    net_with_loss = WithLossCell(resnet, loss)

    if config.lr_schedule == "cosine_lr":
        lr = Tensor(
            cosine_lr(init_lr=config.base_lr,
                      total_epochs=config.epochs,
                      steps_per_epoch=train_dataset_batch_num,
                      mode=config.lr_mode,
                      warmup_epoch=config.warmup_epoch), mstype.float32)
    else:
        lr = Tensor(
            step_cosine_lr(init_lr=config.base_lr,
                           total_epochs=config.epochs,
                           epoch_stage=config.epoch_stage,
                           steps_per_epoch=train_dataset_batch_num,
                           mode=config.lr_mode,
예제 #4
0
파일: train.py 프로젝트: yrpang/mindspore
def main():
    # load parse and config
    print("loading parse...")
    args = parse_args()
    if args.batch_size:
        config.TRAIN.BATCH_SIZE = args.batch_size
    print('batch size :{}'.format(config.TRAIN.BATCH_SIZE))

    # distribution and context
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        save_graphs=False,
                        device_id=device_id)

    if args.run_distribute:
        init()
        rank = get_rank()
        device_num = get_group_size()
        context.set_auto_parallel_context(
            device_num=device_num,
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True)
    else:
        rank = 0
        device_num = 1

    # only rank = 0 can write
    rank_save_flag = False
    if rank == 0 or device_num == 1:
        rank_save_flag = True

        # create dataset
    dataset, _ = keypoint_dataset(config,
                                  rank=rank,
                                  group_size=device_num,
                                  train_mode=True,
                                  num_parallel_workers=8)

    # network
    net = get_pose_net(config, True, ckpt_path=config.MODEL.PRETRAINED)
    loss = JointsMSELoss(use_target_weight=True)
    net_with_loss = WithLossCell(net, loss)

    # lr schedule and optim
    dataset_size = dataset.get_dataset_size()
    lr = Tensor(
        get_lr(config.TRAIN.BEGIN_EPOCH,
               config.TRAIN.END_EPOCH,
               dataset_size,
               lr_init=config.TRAIN.LR,
               factor=config.TRAIN.LR_FACTOR,
               epoch_number_to_drop=config.TRAIN.LR_STEP))
    opt = Adam(net.trainable_params(), learning_rate=lr)

    # callback
    time_cb = TimeMonitor(data_size=dataset_size)
    loss_cb = LossMonitor()
    cb = [time_cb, loss_cb]
    if args.ckpt_path and rank_save_flag:
        config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size,
                                     keep_checkpoint_max=20)
        ckpoint_cb = ModelCheckpoint(prefix="simplepose",
                                     directory=args.ckpt_path,
                                     config=config_ck)
        cb.append(ckpoint_cb)
        # train model
    model = Model(net_with_loss, loss_fn=None, optimizer=opt, amp_level="O2")
    epoch_size = config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH
    print('start training, epoch size = %d' % epoch_size)
    model.train(epoch_size, dataset, callbacks=cb)