def optimizer_step(optimizer, func): optimizer.zero_grad() loss = func() loss.backward() optimizer.step() return smap(float, loss)
def call_torch_loader( module, dataloader, dtype=None, device=None, call_model=None, has_label=False ): """Call a torch function repeatedly with the result from a dataloader and return numpy objects. :param dtype: the datatype to apply to each batch :param device: the device to assign each item to :param has_label: If ``True``, the dataload should return a tuple for each batch. Only the first item in the tuple will be passed to the model, the remaining items will be returned as-is. If ``False``, the full batch of the dataloader will be passed to the model """ if call_model is None: call_model = default_call_model results = [] for batch in dataloader: batch = n2t(batch, dtype=dtype, device=device) if has_label: x, *y = batch result = (call_model(module, x), *y) else: result = call_model(module, batch) results += [t2n(result)] results = szip(results) return smap(lambda batch: np.concatenate(batch, axis=0), results)
def t2n(obj, dtype=None, sequences=default_sequences, mappings=default_mappings): """Torch to numpy.""" def scalar_t2n(obj, dtype): return np.asarray(obj.detach().cpu(), dtype=dtype) dtype = copy_structure(obj, dtype, sequences=sequences, mappings=mappings) return smap(scalar_t2n, obj, dtype, sequences=sequences, mappings=mappings)
def n2t(obj, dtype=None, device=None): """Numpy to torch.""" if isinstance(dtype, str): dtype = getattr(torch, dtype) if isinstance(device, str): device = torch.device(device) return smap(lambda obj: torch.as_tensor(obj, dtype=dtype, device=device), obj)
def n2t( obj, dtype=None, device=None, sequences=default_sequences, mappings=default_mappings ): """Numpy to torch.""" def scalar_n2t(obj, dtype, device): if isinstance(dtype, str): dtype = getattr(torch, dtype) if isinstance(device, str): device = torch.device(device) return torch.as_tensor(obj, dtype=dtype, device=device) dtype = copy_structure(obj, dtype, sequences=sequences, mappings=mappings) device = copy_structure(obj, device, sequences=sequences, mappings=mappings) return smap(scalar_n2t, obj, dtype, device, sequences=sequences, mappings=mappings)
def call_torch(func, arg, *args, dtype=None, device=None, batch_size=64): """Call a torch function with numpy arguments and numpy results.""" args = (arg, *args) index, values = flatten_with_index(args) result_batches = [] for start in it.count(0, batch_size): end = start + batch_size if start >= len(values[0]): break batch = unflatten(index, (val[start:end] for val in values)) batch = n2t(batch, dtype=dtype, device=device) result = func(*batch) result = t2n(result) result_batches.append(result) result, schema = szip(result_batches, return_schema=True) result = smap(lambda _, r: np.concatenate(r, axis=0), schema, result) return result
def t2n(obj, dtype=None): """Torch to numpy.""" return smap(lambda obj: np.asarray(obj.detach().cpu(), dtype=dtype), obj)