Ejemplo n.º 1
0
def optimizer_step(optimizer, func):
    optimizer.zero_grad()
    loss = func()
    loss.backward()
    optimizer.step()

    return smap(float, loss)
Ejemplo n.º 2
0
def call_torch_loader(
    module, dataloader, dtype=None, device=None, call_model=None, has_label=False
):
    """Call a torch function repeatedly with the result from a dataloader and return numpy objects.

    :param dtype:
        the datatype to apply to each batch
    :param device:
        the device to assign each item to
    :param has_label:
        If ``True``, the dataload should return a tuple for each batch.
        Only the first item in the tuple will be passed to the model,
        the remaining items will be returned as-is.
        If ``False``, the full batch of the dataloader will be passed to
        the model
    """
    if call_model is None:
        call_model = default_call_model

    results = []

    for batch in dataloader:
        batch = n2t(batch, dtype=dtype, device=device)

        if has_label:
            x, *y = batch
            result = (call_model(module, x), *y)

        else:
            result = call_model(module, batch)

        results += [t2n(result)]

    results = szip(results)
    return smap(lambda batch: np.concatenate(batch, axis=0), results)
Ejemplo n.º 3
0
def t2n(obj, dtype=None, sequences=default_sequences, mappings=default_mappings):
    """Torch to numpy."""

    def scalar_t2n(obj, dtype):
        return np.asarray(obj.detach().cpu(), dtype=dtype)

    dtype = copy_structure(obj, dtype, sequences=sequences, mappings=mappings)
    return smap(scalar_t2n, obj, dtype, sequences=sequences, mappings=mappings)
Ejemplo n.º 4
0
def n2t(obj, dtype=None, device=None):
    """Numpy to torch."""
    if isinstance(dtype, str):
        dtype = getattr(torch, dtype)

    if isinstance(device, str):
        device = torch.device(device)

    return smap(lambda obj: torch.as_tensor(obj, dtype=dtype, device=device),
                obj)
Ejemplo n.º 5
0
def n2t(
    obj, dtype=None, device=None, sequences=default_sequences, mappings=default_mappings
):
    """Numpy to torch."""

    def scalar_n2t(obj, dtype, device):
        if isinstance(dtype, str):
            dtype = getattr(torch, dtype)

        if isinstance(device, str):
            device = torch.device(device)

        return torch.as_tensor(obj, dtype=dtype, device=device)

    dtype = copy_structure(obj, dtype, sequences=sequences, mappings=mappings)
    device = copy_structure(obj, device, sequences=sequences, mappings=mappings)

    return smap(scalar_n2t, obj, dtype, device, sequences=sequences, mappings=mappings)
Ejemplo n.º 6
0
def call_torch(func, arg, *args, dtype=None, device=None, batch_size=64):
    """Call a torch function with numpy arguments and numpy results."""
    args = (arg, *args)
    index, values = flatten_with_index(args)
    result_batches = []

    for start in it.count(0, batch_size):
        end = start + batch_size

        if start >= len(values[0]):
            break

        batch = unflatten(index, (val[start:end] for val in values))
        batch = n2t(batch, dtype=dtype, device=device)
        result = func(*batch)
        result = t2n(result)

        result_batches.append(result)

    result, schema = szip(result_batches, return_schema=True)
    result = smap(lambda _, r: np.concatenate(r, axis=0), schema, result)
    return result
Ejemplo n.º 7
0
def t2n(obj, dtype=None):
    """Torch to numpy."""
    return smap(lambda obj: np.asarray(obj.detach().cpu(), dtype=dtype), obj)