def apply_backwards(arrow: Arrow, outputs: List[np.ndarray], port_attr=None) -> List[np.ndarray]: """ Takes out_port vals (excluding errors) and returns in_port vals (including params). FIXME: Mutates port_attr """ out_ports = [ out_port for out_port in arrow.out_ports() if not is_error_port(out_port) ] if port_attr is None: port_attr = propagate(arrow) for i, out_port in enumerate(out_ports): if out_port not in port_attr: port_attr[out_port] = {} port_attr[out_port]['value'] = outputs[i] for out_port in arrow.out_ports(): if is_error_port(out_port): if out_port not in port_attr: port_attr[out_port] = {} if 'shape' in port_attr[out_port]: port_attr[out_port]['value'] = np.zeros( port_attr[out_port]['shape']) else: # FIXME: there has to be a better way to do this print("WARNING: shape of error port unknown: %s" % (out_port)) port_attr[out_port]['value'] = 0 port_attr = propagate(arrow, port_attr) #, only_prop=set(['value'])) vals = extract_attribute('value', port_attr) in_vals = {port: vals[port] for port in arrow.in_ports() if port in vals} return in_vals
def inv_fwd_loss_arrow(arrow: Arrow, inverse: Arrow, DiffArrow=SquaredDifference) -> CompositeArrow: """ Arrow wihch computes |f(f^-1(y)) - y| Args: arrow: Forward function Returns: CompositeArrow """ c = CompositeArrow(name="%s_inv_fwd_loss" % arrow.name) # Make all in_ports of inverse inputs to composition for inv_in_port in inverse.in_ports(): in_port = c.add_port() make_in_port(in_port) if is_param_port(inv_in_port): make_param_port(in_port) c.add_edge(in_port, inv_in_port) # Connect all out_ports of inverse to in_ports of f for i, out_port in enumerate(inverse.out_ports()): if not is_error_port(out_port): c.add_edge(out_port, arrow.in_port(i)) c_out_port = c.add_port() # add edge from inverse output to composition output make_out_port(c_out_port) c.add_edge(out_port, c_out_port) # Pass errors (if any) of parametric inverse through as error_ports for i, out_port in enumerate(inverse.out_ports()): if is_error_port(out_port): error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "sub_arrow_error") c.add_edge(out_port, error_port) # find difference between inputs to inverse and outputs of fwd # make error port for each for i, out_port in enumerate(arrow.out_ports()): diff = DiffArrow() c.add_edge(c.in_port(i), diff.in_port(0)) c.add_edge(out_port, diff.in_port(1)) error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "inv_fwd_error") c.add_edge(diff.out_port(0), error_port) assert c.is_wired_correctly() return c
def unparam(arrow: Arrow, nnet: Arrow = None): """Unparameerize an arrow by sticking a tfArrow between its normal inputs, and any parametric inputs Args: arrow: Y x Theta -> X nnet: Y -> Theta Returns: Y -> X """ c = CompositeArrow(name="%s_unparam" % arrow.name) in_ports = [p for p in arrow.in_ports() if not is_param_port(p)] param_ports = [p for p in arrow.in_ports() if is_param_port(p)] if nnet is None: nnet = TfArrow(n_in_ports=len(in_ports), n_out_ports=len(param_ports)) for i, in_port in enumerate(in_ports): c_in_port = c.add_port() make_in_port(c_in_port) transfer_labels(in_port, c_in_port) c.add_edge(c_in_port, in_port) c.add_edge(c_in_port, nnet.in_port(i)) for i, param_port in enumerate(param_ports): c.add_edge(nnet.out_port(i), param_port) for out_port in arrow.out_ports(): c_out_port = c.add_port() make_out_port(c_out_port) if is_error_port(out_port): make_error_port(c_out_port) transfer_labels(out_port, c_out_port) c.add_edge(out_port, c_out_port) assert c.is_wired_correctly() return c
def get_param_pairs(inv, voxel_grids, batch_size, n, port_attr=None, pickle_to=None): """Pulls params from 'forward' runs. FIXME: mutates port_attr.""" if port_attr is None: port_attr = propagate(inv) shapes = [ port_attr[port]['shape'] for port in inv.out_ports() if not is_error_port(port) ] params = [] inputs = [] for i in range(n): rand_voxel_id = np.random.randint(0, voxel_grids.shape[0], size=batch_size) input_data = [ voxel_grids[rand_voxel_id].reshape(shape).astype(np.float32) for shape in shapes ] inputs.append(input_data) params_bwd = apply_backwards(inv, input_data, port_attr=None) params_list = [ params_bwd[port] for port in inv.in_ports() if is_param_port(port) ] params.append(params_list) if pickle_to is not None: with open(pickle_to, 'wb') as f: pickle.dump((inputs, params), f) return inputs, params
def error_ports(self): """ Get ErrorPorts of an Arrow. Returns: List of ErrorPorts """ return [port for port in self._ports if pa.is_error_port(port)]
def comp(fwd: Arrow, right_inv: Arrow, DiffArrow=SquaredDifference): """Compositon: Pipe output of forward model into input of right inverse Args: fwd: X -> Y right_inv: Y -> X x Error Returns: X -> X """ c = CompositeArrow(name="fwd_to_right_inv") # Connect left boundar to fwd for in_port in fwd.in_ports(): c_in_port = c.add_port() make_in_port(c_in_port) c.add_edge(c_in_port, in_port) transfer_labels(in_port, c_in_port) # Connect fwd to right_inv for i, out_port in enumerate(fwd.out_ports()): c.add_edge(out_port, right_inv.in_port(i)) # connect right_inv to right boundary for out_port in right_inv.out_ports(): c_out_port = c.add_port() make_out_port(c_out_port) if is_error_port(out_port): make_error_port(c_out_port) c.add_edge(out_port, c_out_port) transfer_labels(out_port, c_out_port) # Find difference between X and right_inv(f(x)) right_inv_out_ports = list(filter(lambda port: not is_error_port(port), right_inv.out_ports())) # len(X) assert len(right_inv_out_ports) == len(c.in_ports()) for i, in_port in enumerate(c.in_ports()): diff = DiffArrow() c.add_edge(in_port, diff.in_port(0)) c.add_edge(right_inv_out_ports[i], diff.in_port(1)) error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "supervised_error") c.add_edge(diff.out_port(0), error_port) assert c.is_wired_correctly() return c
def test_apply_backwards(): orig = test_twoxyplusx() arrow = invert(orig) outputs = [ np.random.randn(2, 2) for out_port in arrow.out_ports() if not is_error_port(out_port) ] return orig, arrow, outputs, apply_backwards(arrow, outputs)
def default_grans(): """Default tensors to grab""" def_grabs = { 'input': lambda p: is_in_port(p) and not is_param_port(p), 'param': lambda p: is_param_port(p), 'error': lambda p: is_error_port(p), 'output': lambda p: is_out_port(p) } return def_grabs
def supervised_loss_arrow(arrow: Arrow, DiffArrow=SquaredDifference) -> CompositeArrow: """ Creates an arrow that computes |f(y) - x| Args: Arrow: f: Y -> X - The arrow to modify DiffArrow: d: X x X - R - Arrow for computing difference Returns: f: Y/Theta x .. Y/Theta x X -> |f^{-1}(y) - X| x X Arrow with same input and output as arrow except that it takes an addition input with label 'train_output' that should contain examples in Y, and it returns an additional error output labelled 'supervised_error' which is the |f(y) - x| """ c = CompositeArrow(name="%s_supervised" % arrow.name) # Pipe all inputs of composite to inputs of arrow # Make all in_ports of inverse inputs to composition for in_port in arrow.in_ports(): c_in_port = c.add_port() make_in_port(c_in_port) if is_param_port(in_port): make_param_port(c_in_port) c.add_edge(c_in_port, in_port) # find difference between inputs to inverse and outputs of fwd # make error port for each for i, out_port in enumerate(arrow.out_ports()): if is_error_port(out_port): # if its an error port just pass through error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) transfer_labels(out_port, error_port) c.add_edge(out_port, error_port) else: # If its normal outport then pass through c_out_port = c.add_port() make_out_port(c_out_port) c.add_edge(out_port, c_out_port) # And compute the error diff = DiffArrow() in_port = c.add_port() make_in_port(in_port) add_port_label(in_port, "train_output") c.add_edge(in_port, diff.in_port(0)) c.add_edge(out_port, diff.in_port(1)) error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "supervised_error") c.add_edge(diff.out_port(0), error_port) assert c.is_wired_correctly() return c
def wrap(a: Arrow): """Wrap an arrow in a composite arrow""" c = CompositeArrow(name=a.name) for port in a.ports(): c_port = c.add_port() if is_in_port(port): make_in_port(c_port) c.add_edge(c_port, port) if is_param_port(port): make_param_port(c_port) if is_out_port(port): make_out_port(c_port) c.add_edge(port, c_port) if is_error_port(port): make_error_port(c_port) transfer_labels(port, c_port) assert c.is_wired_correctly() return c