Exemplo n.º 1
0
def _build(model, optimizer):
    # Infer blob devices by going through the net and param_init_net
    # ops and observing the device used to create or use the blob.
    param_to_device = core.InferBlobDevices(model.net)
    param_to_device.update(core.InferBlobDevices(model.param_init_net))

    for param_info in model.GetOptimizationParamInfo():
        param_name = str(param_info.blob)

        # We first check if parameter's device has been inferred. If not,
        # we check the gradient. This can happen if parameter is not output
        # by any blob but created by a FetchBlob.
        device = None
        if param_name in param_to_device:
            device = param_to_device[param_name]
        else:
            if isinstance(param_info.grad, core.GradientSlice):
                grad = param_info.grad
                if str(grad.values) in param_to_device:
                    device = param_to_device[str(grad.values)]
                elif str(grad.indices) in param_to_device:
                    device = param_to_device[str(grad.indices)]
            else:
                grad_name = str(param_info.grad)
                if grad_name in param_to_device:
                    device = param_to_device[grad_name]

        assert device is not None,\
            "Cannot infer device for {}: no op creates it".format(param_name)

        with core.DeviceScope(device):
            optimizer(model.net, model.param_init_net, param_info)
    return optimizer
Exemplo n.º 2
0
def _get_param_to_device(model):
    # Infer blob devices by going through the net and param_init_net
    # ops and observing the device used to create or use the blob.
    param_to_device = core.InferBlobDevices(model.net)
    param_to_device.update(core.InferBlobDevices(model.param_init_net))
    return param_to_device