Ejemplo n.º 1
0
            [107], "hccl_world_groupsum1")
        auto_parallel_context().set_all_reduce_fusion_split_indices(
            [27], "hccl_world_groupsum2")
        auto_parallel_context().set_all_reduce_fusion_split_indices(
            [27], "hccl_world_groupsum3")
        auto_parallel_context().set_all_reduce_fusion_split_indices(
            [27], "hccl_world_groupsum4")
        auto_parallel_context().set_all_reduce_fusion_split_indices(
            [27], "hccl_world_groupsum5")

        init()

    epoch_size = config.epoch_size
    damping = get_model_damping(0, 0.03, 0.87, 50, 5004)
    net = resnet50(class_num=config.class_num,
                   damping=damping,
                   loss_scale=config.loss_scale,
                   frequency=config.frequency)

    if not config.label_smooth:
        config.label_smooth_factor = 0.0
    loss = CrossEntropy(smooth_factor=config.label_smooth_factor,
                        num_classes=config.class_num)
    if args_opt.do_train:
        dataset = create_dataset(dataset_path=args_opt.dataset_path,
                                 do_train=True,
                                 repeat_num=epoch_size,
                                 batch_size=config.batch_size)
        step_size = dataset.get_dataset_size()

        loss_scale = FixedLossScaleManager(config.loss_scale,
                                           drop_overflow_update=False)
Ejemplo n.º 2
0
def create_network(name, *args, **kwargs):
    if name == 'resnet50_thor':
        return resnet50(*args, **kwargs)
    raise NotImplementedError(f"{name} is not implemented in the repo")