Ejemplo n.º 1
0
def train():
    args = parse_args()

    # init multicards training
    if args.is_distributed:
        init()
        args.rank = get_rank()
        args.group_size = get_group_size()

        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          gradients_mean=True,
                                          device_num=args.group_size)

    # dataset
    dataset = data_generator.SegDataset(image_mean=args.image_mean,
                                        image_std=args.image_std,
                                        data_file=args.data_file,
                                        batch_size=args.batch_size,
                                        crop_size=args.crop_size,
                                        max_scale=args.max_scale,
                                        min_scale=args.min_scale,
                                        ignore_label=args.ignore_label,
                                        num_classes=args.num_classes,
                                        num_readers=2,
                                        num_parallel_calls=4,
                                        shard_id=args.rank,
                                        shard_num=args.group_size)
    dataset = dataset.get_dataset(repeat=1)

    # network
    if args.model == 'deeplab_v3_s16':
        network = net_factory.nets_map[args.model]('train', args.num_classes,
                                                   16, args.freeze_bn)
    elif args.model == 'deeplab_v3_s8':
        network = net_factory.nets_map[args.model]('train', args.num_classes,
                                                   8, args.freeze_bn)
    else:
        raise NotImplementedError('model [{:s}] not recognized'.format(
            args.model))

    # loss
    loss_ = loss.SoftmaxCrossEntropyLoss(args.num_classes, args.ignore_label)
    loss_.add_flags_recursive(fp32=True)
    train_net = BuildTrainNetwork(network, loss_)

    # load pretrained model
    if args.ckpt_pre_trained:
        param_dict = load_checkpoint(args.ckpt_pre_trained)
        load_param_into_net(train_net, param_dict)

    # optimizer
    iters_per_epoch = dataset.get_dataset_size()
    total_train_steps = iters_per_epoch * args.train_epochs
    if args.lr_type == 'cos':
        lr_iter = learning_rates.cosine_lr(args.base_lr, total_train_steps,
                                           total_train_steps)
    elif args.lr_type == 'poly':
        lr_iter = learning_rates.poly_lr(args.base_lr,
                                         total_train_steps,
                                         total_train_steps,
                                         end_lr=0.0,
                                         power=0.9)
    elif args.lr_type == 'exp':
        lr_iter = learning_rates.exponential_lr(args.base_lr,
                                                args.lr_decay_step,
                                                args.lr_decay_rate,
                                                total_train_steps,
                                                staircase=True)
    else:
        raise ValueError('unknown learning rate type')
    opt = nn.Momentum(params=train_net.trainable_params(),
                      learning_rate=lr_iter,
                      momentum=0.9,
                      weight_decay=0.0001,
                      loss_scale=args.loss_scale)

    # loss scale
    manager_loss_scale = FixedLossScaleManager(args.loss_scale,
                                               drop_overflow_update=False)
    model = Model(train_net,
                  optimizer=opt,
                  amp_level="O3",
                  loss_scale_manager=manager_loss_scale)

    # callback for saving ckpts
    time_cb = TimeMonitor(data_size=iters_per_epoch)
    loss_cb = LossMonitor()
    cbs = [time_cb, loss_cb]

    if args.rank == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args.save_steps,
            keep_checkpoint_max=args.keep_checkpoint_max)
        ckpoint_cb = ModelCheckpoint(prefix=args.model,
                                     directory=args.train_dir,
                                     config=config_ck)
        cbs.append(ckpoint_cb)

    model.train(args.train_epochs, dataset, callbacks=cbs)
Ejemplo n.º 2
0
def train():
    args = parse_args()
    cfg = FCN8s_VOC2012_cfg
    device_num = int(os.environ.get("DEVICE_NUM", 1))
    context.set_context(mode=context.GRAPH_MODE,
                        enable_auto_mixed_precision=True,
                        save_graphs=False,
                        device_target="Ascend",
                        device_id=args.device_id)
    # init multicards training
    args.rank = 0
    args.group_size = 1
    if device_num > 1:
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          gradients_mean=True,
                                          device_num=device_num)
        init()
        args.rank = get_rank()
        args.group_size = get_group_size()

    # dataset
    dataset = data_generator.SegDataset(image_mean=cfg.image_mean,
                                        image_std=cfg.image_std,
                                        data_file=cfg.data_file,
                                        batch_size=cfg.batch_size,
                                        crop_size=cfg.crop_size,
                                        max_scale=cfg.max_scale,
                                        min_scale=cfg.min_scale,
                                        ignore_label=cfg.ignore_label,
                                        num_classes=cfg.num_classes,
                                        num_readers=2,
                                        num_parallel_calls=4,
                                        shard_id=args.rank,
                                        shard_num=args.group_size)
    dataset = dataset.get_dataset(repeat=1)

    net = FCN8s(n_class=cfg.num_classes)
    loss_ = loss.SoftmaxCrossEntropyLoss(cfg.num_classes, cfg.ignore_label)

    # load pretrained vgg16 parameters to init FCN8s
    if cfg.ckpt_vgg16:
        param_vgg = load_checkpoint(cfg.ckpt_vgg16)
        param_dict = {}
        for layer_id in range(1, 6):
            sub_layer_num = 2 if layer_id < 3 else 3
            for sub_layer_id in range(sub_layer_num):
                # conv param
                y_weight = 'conv{}.{}.weight'.format(layer_id,
                                                     3 * sub_layer_id)
                x_weight = 'vgg16_feature_extractor.conv{}_{}.0.weight'.format(
                    layer_id, sub_layer_id + 1)
                param_dict[y_weight] = param_vgg[x_weight]
                # BatchNorm param
                y_gamma = 'conv{}.{}.gamma'.format(layer_id,
                                                   3 * sub_layer_id + 1)
                y_beta = 'conv{}.{}.beta'.format(layer_id,
                                                 3 * sub_layer_id + 1)
                x_gamma = 'vgg16_feature_extractor.conv{}_{}.1.gamma'.format(
                    layer_id, sub_layer_id + 1)
                x_beta = 'vgg16_feature_extractor.conv{}_{}.1.beta'.format(
                    layer_id, sub_layer_id + 1)
                param_dict[y_gamma] = param_vgg[x_gamma]
                param_dict[y_beta] = param_vgg[x_beta]
        load_param_into_net(net, param_dict)
    # load pretrained FCN8s
    elif cfg.ckpt_pre_trained:
        param_dict = load_checkpoint(cfg.ckpt_pre_trained)
        load_param_into_net(net, param_dict)

    # optimizer
    iters_per_epoch = dataset.get_dataset_size()

    lr_scheduler = CosineAnnealingLR(cfg.base_lr,
                                     cfg.train_epochs,
                                     iters_per_epoch,
                                     cfg.train_epochs,
                                     warmup_epochs=0,
                                     eta_min=0)
    lr = Tensor(lr_scheduler.get_lr())

    # loss scale
    manager_loss_scale = FixedLossScaleManager(cfg.loss_scale,
                                               drop_overflow_update=False)

    optimizer = nn.Momentum(params=net.trainable_params(),
                            learning_rate=lr,
                            momentum=0.9,
                            weight_decay=0.0001,
                            loss_scale=cfg.loss_scale)

    model = Model(net,
                  loss_fn=loss_,
                  loss_scale_manager=manager_loss_scale,
                  optimizer=optimizer,
                  amp_level="O3")

    # callback for saving ckpts
    time_cb = TimeMonitor(data_size=iters_per_epoch)
    loss_cb = LossMonitor()
    cbs = [time_cb, loss_cb]

    if args.rank == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=cfg.save_steps,
            keep_checkpoint_max=cfg.keep_checkpoint_max)
        ckpoint_cb = ModelCheckpoint(prefix=cfg.model,
                                     directory=cfg.ckpt_dir,
                                     config=config_ck)
        cbs.append(ckpoint_cb)

    model.train(cfg.train_epochs, dataset, callbacks=cbs)