def test_shift_gpu(self):
        model = self.create_model()
        data_parallel_model_utils.ShiftActivationDevices(
            model,
            activations=["fc4", "fc5"],
            shifts={
                0: 4,
                1: 4,
                2: 5,
                3: 5
            },
        )
        for op in model.param_init_net.Proto().op:
            for outp in op.output:
                prefix = outp.split("/")[0]
                if outp.split("/")[-1] in set(
                    ['fc4_w', 'fc5_w', 'fc4_b', 'fc5_b']):
                    if prefix == 'gpu_0' or prefix == 'gpu_1':
                        self.assertEqual(op.device_option.cuda_gpu_id, 4)
                    else:
                        self.assertEqual(op.device_option.cuda_gpu_id, 5)
                if outp.split("/")[-1] in set(
                    ['fc1_w', 'fc2_w', 'fc3_b', 'fc3_w']):
                    gpu_id = int(prefix.split("_")[-1])
                    self.assertEqual(gpu_id, op.device_option.cuda_gpu_id)

        # Test that we can run the net
        if workspace.NumCudaDevices() >= 6:
            workspace.RunNetOnce(model.param_init_net)
            workspace.CreateNet(model.net)
            workspace.RunNet(model.net.Proto().name)
Exemple #2
0
def Train(args):
    # Either use specified device list or generate one
    if args.gpus is not None:
        gpus = [int(x) for x in args.gpus.split(',')]
        num_gpus = len(gpus)
    else:
        gpus = list(range(args.num_gpus))
        num_gpus = args.num_gpus

    log.info("Running on GPUs: {}".format(gpus))

    # Verify valid batch size
    total_batch_size = args.batch_size
    batch_per_device = total_batch_size // num_gpus
    assert \
        total_batch_size % num_gpus == 0, \
        "Number of GPUs must divide batch size"

    # Round down epoch size to closest multiple of batch size across machines
    global_batch_size = total_batch_size * args.num_shards
    epoch_iters = int(args.epoch_size / global_batch_size)

    assert \
        epoch_iters > 0, \
        "Epoch size must be larger than batch size times shard count"

    args.epoch_size = epoch_iters * global_batch_size
    log.info("Using epoch size: {}".format(args.epoch_size))

    # Create ModelHelper object
    train_arg_scope = {
        'order': 'NCHW',
        'use_cudnn': True,
        'cudnn_exhaustive_search': True,
        'ws_nbytes_limit': (args.cudnn_workspace_limit_mb * 1024 * 1024),
    }
    train_model = model_helper.ModelHelper(
        name="ban-pc-resnet50", arg_scope=train_arg_scope
    )

    num_shards = args.num_shards
    shard_id = args.shard_id

    # Expect interfaces to be comma separated.
    # Use of multiple network interfaces is not yet complete,
    # so simply use the first one in the list.
    interfaces = args.distributed_interfaces.split(",")

    # Rendezvous using MPI when run with mpirun
    if os.getenv("OMPI_COMM_WORLD_SIZE") is not None:
        num_shards = int(os.getenv("OMPI_COMM_WORLD_SIZE", 1))
        shard_id = int(os.getenv("OMPI_COMM_WORLD_RANK", 0))
        if num_shards > 1:
            rendezvous = dict(
                kv_handler=None,
                num_shards=num_shards,
                shard_id=shard_id,
                engine="GLOO",
                transport=args.distributed_transport,
                interface=interfaces[0],
                mpi_rendezvous=True,
                exit_nets=None)

    elif num_shards > 1:
        # Create rendezvous for distributed computation
        store_handler = "store_handler"
        if args.redis_host is not None:
            # Use Redis for rendezvous if Redis host is specified
            workspace.RunOperatorOnce(
                core.CreateOperator(
                    "RedisStoreHandlerCreate", [], [store_handler],
                    host=args.redis_host,
                    port=args.redis_port,
                    prefix=args.run_id,
                )
            )
        else:
            # Use filesystem for rendezvous otherwise
            workspace.RunOperatorOnce(
                core.CreateOperator(
                    "FileStoreHandlerCreate", [], [store_handler],
                    path=args.file_store_path,
                    prefix=args.run_id,
                )
            )

        rendezvous = dict(
            kv_handler=store_handler,
            shard_id=shard_id,
            num_shards=num_shards,
            engine="GLOO",
            transport=args.distributed_transport,
            interface=interfaces[0],
            exit_nets=None)

    else:
        rendezvous = None

    # Model configs for constructing model
    with open(args.model_config) as f:
        model_config = yaml.load(f)

    # Model building functions
    def create_target_model_ops(model, loss_scale):
        initializer = (PseudoFP16Initializer if args.dtype == 'float16'
                       else Initializer)
        with brew.arg_scope([brew.conv, brew.fc],
                            WeightInitializer=initializer,
                            BiasInitializer=initializer,
                            enable_tensor_core=args.enable_tensor_core,
                            float16_compute=args.float16_compute):
            pred = add_se_model(model, model_config, "data", is_test=False)

        if args.dtype == 'float16':
            pred = model.net.HalfToFloat(pred, pred + '_fp32')

        loss = add_softmax_loss(model, pred, 'label')
        brew.accuracy(model, ['softmax', 'label'], 'accuracy')
        return [loss]

    def add_optimizer(model):
        '''
        stepsz = int(30 * args.epoch_size / total_batch_size / num_shards)
        optimizer.add_weight_decay(model, args.weight_decay)
        opt = optimizer.build_multi_precision_sgd(
            model,
            args.base_learning_rate,
            momentum=0.9,
            nesterov=1,
            policy="step",
            stepsize=stepsz,
            gamma=0.1
        )
        '''

        optimizer.add_weight_decay(model, args.weight_decay)
        opt = optimizer.build_multi_precision_sgd(
            model,
            base_learning_rate = args.base_learning_rate,
            momentum = model_config['solver']['momentum'],
            nesterov = model_config['solver']['nesterov'],
            policy = model_config['solver']['lr_policy'],
            power = model_config['solver']['power'],
            max_iter = model_config['solver']['max_iter'],
        )
        return opt

    # Define add_image_input function.
    # Depends on the "train_data" argument.
    # Note that the reader will be shared with between all GPUS.
    reader = train_model.CreateDB(
        "reader",
        db=args.train_data,
        db_type=args.db_type,
        num_shards=num_shards,
        shard_id=shard_id,
    )

    def add_image_input(model):
        AddImageInput(
            model,
            reader,
            batch_size=batch_per_device,
            img_size=args.image_size,
            dtype=args.dtype,
            is_test=False,
        )

    def add_post_sync_ops(model):
        """Add ops applied after initial parameter sync."""
        for param_info in model.GetOptimizationParamInfo(model.GetParams()):
            if param_info.blob_copy is not None:
                model.param_init_net.HalfToFloat(
                    param_info.blob,
                    param_info.blob_copy[core.DataType.FLOAT]
                )

    # Create parallelized model
    data_parallel_model.Parallelize(
        train_model,
        input_builder_fun=add_image_input,
        forward_pass_builder_fun=create_target_model_ops,
        optimizer_builder_fun=add_optimizer,
        post_sync_builder_fun=add_post_sync_ops,
        devices=gpus,
        rendezvous=rendezvous,
        optimize_gradient_memory=False,
        cpu_device=args.use_cpu,
        shared_model=args.use_cpu,
        combine_spatial_bn=args.use_cpu,
    )

    if args.model_parallel:
        # Shift half of the activations to another GPU
        assert workspace.NumCudaDevices() >= 2 * args.num_gpus
        activations = data_parallel_model_utils.GetActivationBlobs(train_model)
        data_parallel_model_utils.ShiftActivationDevices(
            train_model,
            activations=activations[len(activations) // 2:],
            shifts={g: args.num_gpus + g for g in range(args.num_gpus)},
        )

    data_parallel_model.OptimizeGradientMemory(train_model, {}, set(), False)

    workspace.RunNetOnce(train_model.param_init_net)
    workspace.CreateNet(train_model.net)

    # Add test model, if specified
    test_model = None
    if (args.test_data is not None):
        log.info("----- Create test net ----")
        test_arg_scope = {
            'order': "NCHW",
            'use_cudnn': True,
            'cudnn_exhaustive_search': True,
        }
        test_model = model_helper.ModelHelper(
            name="ban-pc-resnet50_test", arg_scope=test_arg_scope, init_params=False
        )

        test_reader = test_model.CreateDB(
            "test_reader",
            db=args.test_data,
            db_type=args.db_type,
        )

        def test_input_fn(model):
            AddImageInput(
                model,
                test_reader,
                batch_size=batch_per_device,
                img_size=args.image_size,
                dtype=args.dtype,
                is_test=True,
            )

        data_parallel_model.Parallelize(
            test_model,
            input_builder_fun=test_input_fn,
            forward_pass_builder_fun=create_target_model_ops,
            post_sync_builder_fun=add_post_sync_ops,
            param_update_builder_fun=None,
            devices=gpus,
            cpu_device=args.use_cpu,
        )
        workspace.RunNetOnce(test_model.param_init_net)
        workspace.CreateNet(test_model.net)

    epoch = 0
    # load the pre-trained model and reset epoch
    if args.load_model_path is not None:
        LoadModel(args.load_model_path, train_model)

        # Sync the model params
        data_parallel_model.FinalizeAfterCheckpoint(train_model)

        # reset epoch. load_model_path should end with *_X.mdl,
        # where X is the epoch number
        last_str = args.load_model_path.split('_')[-1]
        if last_str.endswith('.mdl'):
            epoch = int(last_str[:-4])
            log.info("Reset epoch to {}".format(epoch))
        else:
            log.warning("The format of load_model_path doesn't match!")

    expname = "log/{}/resnet50_gpu{}_b{}_L{}_lr{:.2f}_v2".format(
        args.dataset_name,
        args.num_gpus,
        total_batch_size,
        args.num_labels,
        args.base_learning_rate,
    )
    explog = experiment_util.ModelTrainerLog(expname, args)

    # Load pretrained param_init_net
    load_init_net_multigpu(args)

    # Run the training one epoch a time
    best_accuracy = 0
    while epoch < args.num_epochs:
        epoch, best_accuracy = RunEpoch(
            args,
            epoch,
            train_model,
            test_model,
            total_batch_size,
            num_shards,
            expname,
            explog,
            best_accuracy,
        )

        # Save the model for each epoch
        SaveModel(args, train_model, epoch)

        model_path = "%s/%s_" % (
            args.file_store_path,
            args.save_model_name
        )
        # remove the saved model from the previous epoch if it exists
        if os.path.isfile(model_path + str(epoch - 1) + ".mdl"):
            os.remove(model_path + str(epoch - 1) + ".mdl")
Exemple #3
0
def Train(args):
    # Either use specified device list or generate one
    if args.gpus is not None:
        gpus = [int(x) for x in args.gpus.split(',')]
        num_gpus = len(gpus)
    else:
        gpus = list(range(args.num_gpus))
        num_gpus = args.num_gpus

    log.info("Running on GPUs: {}".format(gpus))

    # Verify valid batch size
    total_batch_size = args.batch_size
    batch_per_device = total_batch_size // num_gpus

    global_batch_size = total_batch_size * args.num_shards
    epoch_iters = int(args.epoch_size / global_batch_size)
    args.epoch_size = epoch_iters * global_batch_size
    log.info("Using epoch size: {}".format(args.epoch_size))

    train_arg_scope = {
        'order': 'NCHW',
        'use_cudnn': True,
        'cudnn_exhaustive_search': True,
        'ws_nbytes_limit': (args.cudnn_workspace_limit_mb * 1024 * 1024),
    }
    train_model = model_helper.ModelHelper(name="resnet101",
                                           arg_scope=train_arg_scope)

    num_shards = args.num_shards
    shard_id = args.shard_id
    interfaces = args.distributed_interfaces.split(",")

    if os.getenv("OMPI_COMM_WORLD_SIZE") is not None:
        num_shards = int(os.getenv("OMPI_COMM_WORLD_SIZE", 1))
        shard_id = int(os.getenv("OMPI_COMM_WORLD_RANK", 0))
        if num_shards > 1:
            rendezvous = dict(kv_handler=None,
                              num_shards=num_shards,
                              shard_id=shard_id,
                              engine="GLOO",
                              transport=args.distributed_transport,
                              interface=interfaces[0],
                              mpi_rendezvous=True,
                              exit_nets=None)

    elif num_shards > 1:
        store_handler = "store_handler"
        if args.redis_host is not None:
            workspace.RunOperatorOnce(
                core.CreateOperator(
                    "RedisStoreHandlerCreate",
                    [],
                    [store_handler],
                    host=args.redis_host,
                    port=args.redis_port,
                    prefix=args.run_id,
                ))
        else:
            workspace.RunOperatorOnce(
                core.CreateOperator(
                    "FileStoreHandlerCreate",
                    [],
                    [store_handler],
                    path=args.file_store_path,
                    prefix=args.run_id,
                ))

        rendezvous = dict(kv_handler=store_handler,
                          shard_id=shard_id,
                          num_shards=num_shards,
                          engine="GLOO",
                          transport=args.distributed_transport,
                          interface=interfaces[0],
                          exit_nets=None)

    else:
        rendezvous = None

    def create_resnet101_model_ops(model, loss_scale):
        initializer = (pFP16Initializer
                       if args.dtype == 'float16' else Initializer)

        with brew.arg_scope([brew.conv, brew.fc],
                            WeightInitializer=initializer,
                            BiasInitializer=initializer,
                            enable_tensor_core=args.enable_tensor_core,
                            float16_compute=args.float16_compute):
            pred = resnet.create_resnet101(
                model,
                "data",
                num_input_channels=args.num_channels,
                num_labels=args.num_labels,
                no_bias=True,
                no_loss=True,
            )

        if args.dtype == 'float16':
            pred = model.net.HalfToFloat(pred, pred + '_fp32')

        softmax, loss = model.SoftmaxWithLoss([pred, 'label'],
                                              ['softmax', 'loss'])
        loss = model.Scale(loss, scale=loss_scale)
        brew.accuracy(model, [softmax, "label"], "accuracy")
        return [loss]

    def add_optimizer(model):
        stepsz = int(30 * args.epoch_size / total_batch_size / num_shards)

        if args.float16_compute:
            opt = optimizer.build_fp16_sgd(model,
                                           args.base_learning_rate,
                                           momentum=0.9,
                                           nesterov=1,
                                           weight_decay=args.weight_decay,
                                           policy="step",
                                           stepsize=stepsz,
                                           gamma=0.1)
        else:
            optimizer.add_weight_decay(model, args.weight_decay)
            opt = optimizer.build_multi_precision_sgd(model,
                                                      args.base_learning_rate,
                                                      momentum=0.9,
                                                      nesterov=1,
                                                      policy="step",
                                                      stepsize=stepsz,
                                                      gamma=0.1)
        return opt

    if args.train_data == "null":

        def add_image_input(model):
            AddNullInput(
                model,
                None,
                batch_size=batch_per_device,
                img_size=args.image_size,
                dtype=args.dtype,
            )
    else:
        reader = train_model.CreateDB(
            "reader",
            db=args.train_data,
            db_type=args.db_type,
            num_shards=num_shards,
            shard_id=shard_id,
        )

        def add_image_input(model):
            AddImageInput(
                model,
                reader,
                batch_size=batch_per_device,
                img_size=args.image_size,
                dtype=args.dtype,
                is_test=False,
            )

    def add_post_sync_ops(model):
        for param_info in model.GetOptimizationParamInfo(model.GetParams()):
            if param_info.blob_copy is not None:
                model.param_init_net.HalfToFloat(
                    param_info.blob, param_info.blob_copy[core.DataType.FLOAT])

    data_parallel_model.Parallelize(
        train_model,
        input_builder_fun=add_image_input,
        forward_pass_builder_fun=create_resnet101_model_ops,
        optimizer_builder_fun=add_optimizer,
        post_sync_builder_fun=add_post_sync_ops,
        devices=gpus,
        rendezvous=rendezvous,
        optimize_gradient_memory=False,
        cpu_device=args.use_cpu,
        shared_model=args.use_cpu,
    )

    if args.model_parallel:
        activations = data_parallel_model_utils.GetActivationBlobs(train_model)
        data_parallel_model_utils.ShiftActivationDevices(
            train_model,
            activations=activations[len(activations) // 2:],
            shifts={g: args.num_gpus + g
                    for g in range(args.num_gpus)},
        )

    data_parallel_model.OptimizeGradientMemory(train_model, {}, set(), False)

    workspace.RunNetOnce(train_model.param_init_net)
    workspace.CreateNet(train_model.net)

    test_model = None
    if (args.test_data is not None):
        log.info("----- Create test net ----")
        test_arg_scope = {
            'order': "NCHW",
            'use_cudnn': True,
            'cudnn_exhaustive_search': True,
        }
        test_model = model_helper.ModelHelper(name="resnet101_test",
                                              arg_scope=test_arg_scope,
                                              init_params=False)

        test_reader = test_model.CreateDB(
            "test_reader",
            db=args.test_data,
            db_type=args.db_type,
        )

        def test_input_fn(model):
            AddImageInput(
                model,
                test_reader,
                batch_size=batch_per_device,
                img_size=args.image_size,
                dtype=args.dtype,
                is_test=True,
            )

        data_parallel_model.Parallelize(
            test_model,
            input_builder_fun=test_input_fn,
            forward_pass_builder_fun=create_resnet101_model_ops,
            post_sync_builder_fun=add_post_sync_ops,
            param_update_builder_fun=None,
            devices=gpus,
            cpu_device=args.use_cpu,
        )
        workspace.RunNetOnce(test_model.param_init_net)
        workspace.CreateNet(test_model.net)

    epoch = 0
    if args.load_model_path is not None:
        LoadModel(args.load_model_path, train_model)
        data_parallel_model.FinalizeAfterCheckpoint(train_model)
        last_str = args.load_model_path.split('_')[-1]
        if last_str.endswith('.mdl'):
            epoch = int(last_str[:-4])
            log.info("Reset epoch to {}".format(epoch))
        else:
            log.warning("The format of load_model_path doesn't match!")

    expname = "resnet101_gpu%d_b%d_L%d_lr%.2f_v2" % (
        args.num_gpus,
        total_batch_size,
        args.num_labels,
        args.base_learning_rate,
    )
    explog = experiment_util.ModelTrainerLog(expname, args)

    while epoch < args.num_epochs:
        epoch = RunEpoch(args, epoch, train_model, test_model,
                         total_batch_size, num_shards, expname, explog)
    # final save
    SaveModel(workspace, train_model)
Exemple #4
0
def Train(args):
    # Either use specified device list or generate one
    if args.gpus is not None:
        gpus = [int(x) for x in args.gpus.split(',')]
        num_gpus = len(gpus)
    else:
        gpus = list(range(args.num_gpus))
        num_gpus = args.num_gpus

    log.info("Running on GPUs: {}".format(gpus))

    # Verify valid batch size
    total_batch_size = args.batch_size
    batch_per_device = total_batch_size // num_gpus
    assert \
        total_batch_size % num_gpus == 0, \
        "Number of GPUs must divide batch size"

    # Round down epoch size to closest multiple of batch size across machines
    global_batch_size = total_batch_size * args.num_shards
    epoch_iters = int(args.epoch_size / global_batch_size)

    assert \
        epoch_iters > 0, \
        "Epoch size must be larger than batch size times shard count"

    args.epoch_size = epoch_iters * global_batch_size
    log.info("Using epoch size: {}".format(args.epoch_size))

    # Create ModelHelper object
    train_arg_scope = {
        'order': 'NCHW',
        'use_cudnn': True,
        'cudnn_exhaustive_search': True,
        'ws_nbytes_limit': (args.cudnn_workspace_limit_mb * 1024 * 1024),
    }
    train_model = model_helper.ModelHelper(name="resnet50",
                                           arg_scope=train_arg_scope)

    num_shards = args.num_shards
    shard_id = args.shard_id

    # Expect interfaces to be comma separated.
    # Use of multiple network interfaces is not yet complete,
    # so simply use the first one in the list.
    interfaces = args.distributed_interfaces.split(",")

    # Rendezvous using MPI when run with mpirun
    if os.getenv("OMPI_COMM_WORLD_SIZE") is not None:
        num_shards = int(os.getenv("OMPI_COMM_WORLD_SIZE", 1))
        shard_id = int(os.getenv("OMPI_COMM_WORLD_RANK", 0))
        if num_shards > 1:
            rendezvous = dict(kv_handler=None,
                              num_shards=num_shards,
                              shard_id=shard_id,
                              engine="GLOO",
                              transport=args.distributed_transport,
                              interface=interfaces[0],
                              mpi_rendezvous=True,
                              exit_nets=None)

    elif num_shards > 1:
        # Create rendezvous for distributed computation
        store_handler = "store_handler"
        if args.redis_host is not None:
            # Use Redis for rendezvous if Redis host is specified
            workspace.RunOperatorOnce(
                core.CreateOperator(
                    "RedisStoreHandlerCreate",
                    [],
                    [store_handler],
                    host=args.redis_host,
                    port=args.redis_port,
                    prefix=args.run_id,
                ))
        else:
            # Use filesystem for rendezvous otherwise
            workspace.RunOperatorOnce(
                core.CreateOperator(
                    "FileStoreHandlerCreate",
                    [],
                    [store_handler],
                    path=args.file_store_path,
                    prefix=args.run_id,
                ))

        rendezvous = dict(kv_handler=store_handler,
                          shard_id=shard_id,
                          num_shards=num_shards,
                          engine="GLOO",
                          transport=args.distributed_transport,
                          interface=interfaces[0],
                          exit_nets=None)

    else:
        rendezvous = None

    # Model building functions
    # def create_resnet50_model_ops(model, loss_scale):
    #     initializer = (PseudoFP16Initializer if args.dtype == 'float16'
    #                    else Initializer)

    #     with brew.arg_scope([brew.conv, brew.fc],
    #                         WeightInitializer=initializer,
    #                         BiasInitializer=initializer,
    #                         enable_tensor_core=args.enable_tensor_core,
    #                         float16_compute=args.float16_compute):
    #         pred = resnet.create_resnet50(
    #             #args.layers,
    #             model,
    #             "data",
    #             num_input_channels=args.num_channels,
    #             num_labels=args.num_labels,
    #             no_bias=True,
    #             no_loss=True,
    #         )

    #     if args.dtype == 'float16':
    #         pred = model.net.HalfToFloat(pred, pred + '_fp32')

    #     softmax, loss = model.SoftmaxWithLoss([pred, 'label'],
    #                                           ['softmax', 'loss'])
    #     loss = model.Scale(loss, scale=loss_scale)
    #     brew.accuracy(model, [softmax, "label"], "accuracy")
    #     return [loss]

    def create_model_ops(model, loss_scale):
        return create_model_ops_testable(model, loss_scale, is_test=False)

    def create_model_ops_test(model, loss_scale):
        return create_model_ops_testable(model, loss_scale, is_test=True)

    # Model building functions
    def create_model_ops_testable(model, loss_scale, is_test=False):
        initializer = (PseudoFP16Initializer
                       if args.dtype == 'float16' else Initializer)

        with brew.arg_scope([brew.conv, brew.fc],
                            WeightInitializer=initializer,
                            BiasInitializer=initializer,
                            enable_tensor_core=args.enable_tensor_core,
                            float16_compute=args.float16_compute):

            if args.model == "cifar10":
                if args.image_size != 32:
                    log.warn("Cifar10 expects a 32x32 image.")
                pred = models.cifar10.create_cifar10(
                    model,
                    "data",
                    image_channels=args.num_channels,
                    num_classes=args.num_labels,
                    image_height=args.image_size,
                    image_width=args.image_size,
                )
            elif args.model == "resnet32x32":
                if args.image_size != 32:
                    log.warn("ResNet32x32 expects a 32x32 image.")
                pred = models.resnet.create_resnet32x32(
                    model,
                    "data",
                    num_layers=args.num_layers,
                    num_input_channels=args.num_channels,
                    num_labels=args.num_labels,
                    is_test=is_test)
            elif args.model == "resnet":
                if args.image_size != 224:
                    log.warn(
                        "ResNet expects a 224x224 image. input image = %d" %
                        args.image_size)
                pred = resnet.create_resnet50(
                    #args.layers,
                    model,
                    "data",
                    num_input_channels=args.num_channels,
                    num_labels=args.num_labels,
                    no_bias=True,
                    no_loss=True,
                )
            elif args.model == "vgg":
                if args.image_size != 224:
                    log.warn("VGG expects a 224x224 image.")
                pred = vgg.create_vgg(model,
                                      "data",
                                      num_input_channels=args.num_channels,
                                      num_labels=args.num_labels,
                                      num_layers=args.num_layers,
                                      is_test=is_test)
            elif args.model == "googlenet":
                if args.image_size != 224:
                    log.warn("GoogLeNet expects a 224x224 image.")
                pred = googlenet.create_googlenet(
                    model,
                    "data",
                    num_input_channels=args.num_channels,
                    num_labels=args.num_labels,
                    is_test=is_test)
            elif args.model == "alexnet":
                if args.image_size != 224:
                    log.warn("Alexnet expects a 224x224 image.")
                pred = alexnet.create_alexnet(
                    model,
                    "data",
                    num_input_channels=args.num_channels,
                    num_labels=args.num_labels,
                    is_test=is_test)
            elif args.model == "alexnetv0":
                if args.image_size != 224:
                    log.warn("Alexnet v0 expects a 224x224 image.")
                pred = alexnet.create_alexnetv0(
                    model,
                    "data",
                    num_input_channels=args.num_channels,
                    num_labels=args.num_labels,
                    is_test=is_test)
            else:
                raise NotImplementedError("Network {} not found.".format(
                    args.model))

        if args.dtype == 'float16':
            pred = model.net.HalfToFloat(pred, pred + '_fp32')

        softmax, loss = model.SoftmaxWithLoss([pred, 'label'],
                                              ['softmax', 'loss'])
        loss = model.Scale(loss, scale=loss_scale)
        brew.accuracy(model, [softmax, "label"], "accuracy")
        return [loss]

    def add_optimizer(model):
        stepsz = int(30 * args.epoch_size / total_batch_size / num_shards)

        if args.float16_compute:
            # TODO: merge with multi-prceision optimizer
            opt = optimizer.build_fp16_sgd(
                model,
                args.base_learning_rate,
                momentum=0.9,
                nesterov=1,
                weight_decay=args.weight_decay,  # weight decay included
                policy="step",
                stepsize=stepsz,
                gamma=0.1)
        else:
            optimizer.add_weight_decay(model, args.weight_decay)
            opt = optimizer.build_multi_precision_sgd(model,
                                                      args.base_learning_rate,
                                                      momentum=0.9,
                                                      nesterov=1,
                                                      policy="step",
                                                      stepsize=stepsz,
                                                      gamma=0.1)
            print("info:===============================" + str(opt))
        return opt

    # Define add_image_input function.
    # Depends on the "train_data" argument.
    # Note that the reader will be shared with between all GPUS.
    if args.train_data == "null":

        def add_image_input(model):
            AddNullInput(
                model,
                None,
                batch_size=batch_per_device,
                img_size=args.image_size,
                dtype=args.dtype,
            )
    else:
        reader = train_model.CreateDB(
            "reader",
            db=args.train_data,
            db_type=args.db_type,
            num_shards=num_shards,
            shard_id=shard_id,
        )

        def add_image_input(model):
            AddImageInput(
                model,
                reader,
                batch_size=batch_per_device,
                img_size=args.image_size,
                dtype=args.dtype,
                is_test=False,
            )

    def add_post_sync_ops(model):
        """Add ops applied after initial parameter sync."""
        for param_info in model.GetOptimizationParamInfo(model.GetParams()):
            if param_info.blob_copy is not None:
                model.param_init_net.HalfToFloat(
                    param_info.blob, param_info.blob_copy[core.DataType.FLOAT])

    # Create parallelized model
    data_parallel_model.Parallelize(train_model,
                                    input_builder_fun=add_image_input,
                                    forward_pass_builder_fun=create_model_ops,
                                    optimizer_builder_fun=add_optimizer,
                                    post_sync_builder_fun=add_post_sync_ops,
                                    devices=gpus,
                                    rendezvous=rendezvous,
                                    optimize_gradient_memory=False,
                                    cpu_device=args.use_cpu,
                                    shared_model=args.use_cpu,
                                    combine_spatial_bn=args.use_cpu,
                                    use_nccl=args.use_nccl)

    if args.model_parallel:
        # Shift half of the activations to another GPU
        assert workspace.NumCudaDevices() >= 2 * args.num_gpus
        activations = data_parallel_model_utils.GetActivationBlobs(train_model)
        data_parallel_model_utils.ShiftActivationDevices(
            train_model,
            activations=activations[len(activations) // 2:],
            shifts={g: args.num_gpus + g
                    for g in range(args.num_gpus)},
        )

    data_parallel_model.OptimizeGradientMemory(train_model, {}, set(), False)

    workspace.RunNetOnce(train_model.param_init_net)
    workspace.CreateNet(train_model.net)

    if "GLOO_ALGORITHM" in os.environ and os.environ[
            "GLOO_ALGORITHM"] == "PHUB":
        #i need to communicate to PHub about the elements that need aggregation,
        #as well as their sizes.
        #at this stage, all i need is the name of keys and my key ID.
        grad_names = list(reversed(train_model._grad_names))
        phubKeyNames = ["allreduce_{}_status".format(x) for x in grad_names]
        caffe2GradSizes = dict(
            zip([
                data_parallel_model.stripBlobName(name) + "_grad"
                for name in train_model._parameters_info.keys()
            ], [x.size for x in train_model._parameters_info.values()]))
        phubKeySizes = [str(caffe2GradSizes[x]) for x in grad_names]
        if rendezvous["shard_id"] == 0:
            #only id 0 needs to send to rendezvous.
            r = redis.StrictRedis()
            #foreach key, I need to assign an ID
            joinedStr = ",".join(phubKeyNames)
            r.set("[PLink]IntegrationKeys", joinedStr)
            joinedStr = ",".join(phubKeySizes)
            r.set("[PLink]IntegrationKeySizes", joinedStr)

    # Add test model, if specified
    test_model = None
    if (args.test_data is not None):
        log.info("----- Create test net ----")
        test_arg_scope = {
            'order': "NCHW",
            'use_cudnn': True,
            'cudnn_exhaustive_search': True,
        }
        test_model = model_helper.ModelHelper(name="resnet50_test",
                                              arg_scope=test_arg_scope,
                                              init_params=False)

        test_reader = test_model.CreateDB(
            "test_reader",
            db=args.test_data,
            db_type=args.db_type,
        )

        def test_input_fn(model):
            AddImageInput(
                model,
                test_reader,
                batch_size=batch_per_device,
                img_size=args.image_size,
                dtype=args.dtype,
                is_test=True,
            )

        data_parallel_model.Parallelize(
            test_model,
            input_builder_fun=test_input_fn,
            forward_pass_builder_fun=create_model_ops_test,
            post_sync_builder_fun=add_post_sync_ops,
            param_update_builder_fun=None,
            devices=gpus,
            cpu_device=args.use_cpu,
        )
        workspace.RunNetOnce(test_model.param_init_net)
        workspace.CreateNet(test_model.net)

    epoch = 0
    # load the pre-trained model and reset epoch
    if args.load_model_path is not None:
        LoadModel(args.load_model_path, train_model)

        # Sync the model params
        data_parallel_model.FinalizeAfterCheckpoint(train_model)

        # reset epoch. load_model_path should end with *_X.mdl,
        # where X is the epoch number
        last_str = args.load_model_path.split('_')[-1]
        if last_str.endswith('.mdl'):
            epoch = int(last_str[:-4])
            log.info("Reset epoch to {}".format(epoch))
        else:
            log.warning("The format of load_model_path doesn't match!")

    expname = "resnet50_gpu%d_b%d_L%d_lr%.2f_v2" % (
        args.num_gpus,
        total_batch_size,
        args.num_labels,
        args.base_learning_rate,
    )

    explog = experiment_util.ModelTrainerLog(expname, args)

    # Run the training one epoch a time
    while epoch < args.num_epochs:
        epoch = RunEpoch(args, epoch, train_model, test_model,
                         total_batch_size, num_shards, expname, explog)

        # Save the model for each epoch
        SaveModel(args, train_model, epoch)

        model_path = "%s/%s_" % (args.file_store_path, args.save_model_name)
        # remove the saved model from the previous epoch if it exists
        if os.path.isfile(model_path + str(epoch - 1) + ".mdl"):
            os.remove(model_path + str(epoch - 1) + ".mdl")