def inv_gathernd(arrow: GatherNdArrow, port_attr: PortAttributes) -> Tuple[Arrow, PortMap]: if is_constant(arrow.out_ports()[0], port_attr): return GatherNdArrow(), {0: 0, 1: 1, 2: 2} tensor_shape = np.array(port_attr[arrow.in_ports()[0]]['shape']) index_list_value = port_attr[arrow.in_ports()[1]]['value'] index_list_compl = complement_bool(index_list_value, tensor_shape) # fixme: don't do this, complement could be huge source_compl = SourceArrow(np.array(index_list_compl, dtype=np.float32)) source_tensor_shape = SourceArrow(tensor_shape) snd = ScatterNdArrow() mul = MulArrow() add = AddArrow() edges = Bimap() edges.add(source_tensor_shape.out_port(0), snd.in_port(2)) edges.add(source_compl.out_port(0), mul.in_port(1)) edges.add(snd.out_port(0), add.in_port(0)) edges.add(mul.out_port(0), add.in_port(1)) # orig_out_port, params, inp_list in_ports = [snd.in_port(1), mul.in_port(0), snd.in_port(0)] out_ports = [add.out_port(0)] op = CompositeArrow(in_ports=in_ports, out_ports=out_ports, edges=edges, name="InvGatherNd") make_param_port(op.in_ports()[1]) return op, {0: 3, 1: 2, 2: 0}
def inv_gather(arrow: GatherArrow, port_attr: PortAttributes) -> Tuple[Arrow, PortMap]: if is_constant(arrow.out_ports()[0], port_attr): return GatherArrow(), {0: 0, 1: 1, 2: 2} tensor_shape = port_attr[arrow.in_ports()[0]]['shape'] if isinstance(tensor_shape, tuple): tensor_shape = list(tensor_shape) index_list_value = port_attr[arrow.in_ports()[1]]['value'] index_list_compl = complement(index_list_value, tensor_shape) std1 = SparseToDenseArrow() std2 = SparseToDenseArrow() dupl1 = DuplArrow() dupl2 = DuplArrow() # FIXME: don't do this, complement could be huge source_compl = SourceArrow(np.array(index_list_compl)) source_tensor_shape = SourceArrow(np.array(tensor_shape)) add = AddArrow() edges = Bimap() edges.add(source_compl.out_ports()[0], std1.in_ports()[0]) edges.add(source_tensor_shape.out_ports()[0], dupl1.in_ports()[0]) edges.add(dupl1.out_ports()[0], std1.in_ports()[1]) edges.add(dupl1.out_ports()[1], std2.in_ports()[1]) edges.add(std1.out_ports()[0], add.in_ports()[0]) edges.add(std2.out_ports()[0], add.in_ports()[1]) # orig_out_port, params, inp_list in_ports = [std2.in_ports()[2], std1.in_ports()[2], std2.in_ports()[0]] out_ports = [add.out_ports()[0]] op = CompositeArrow(in_ports=in_ports, out_ports=out_ports, edges=edges, name="InvGather") make_param_port(op.in_ports()[1]) return op, {0: 3, 1: 2, 2: 0}
def supervised_loss_arrow(arrow: Arrow, DiffArrow=SquaredDifference) -> CompositeArrow: """ Creates an arrow that computes |f(y) - x| Args: Arrow: f: Y -> X - The arrow to modify DiffArrow: d: X x X - R - Arrow for computing difference Returns: f: Y/Theta x .. Y/Theta x X -> |f^{-1}(y) - X| x X Arrow with same input and output as arrow except that it takes an addition input with label 'train_output' that should contain examples in Y, and it returns an additional error output labelled 'supervised_error' which is the |f(y) - x| """ c = CompositeArrow(name="%s_supervised" % arrow.name) # Pipe all inputs of composite to inputs of arrow # Make all in_ports of inverse inputs to composition for in_port in arrow.in_ports(): c_in_port = c.add_port() make_in_port(c_in_port) if is_param_port(in_port): make_param_port(c_in_port) c.add_edge(c_in_port, in_port) # find difference between inputs to inverse and outputs of fwd # make error port for each for i, out_port in enumerate(arrow.out_ports()): if is_error_port(out_port): # if its an error port just pass through error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) transfer_labels(out_port, error_port) c.add_edge(out_port, error_port) else: # If its normal outport then pass through c_out_port = c.add_port() make_out_port(c_out_port) c.add_edge(out_port, c_out_port) # And compute the error diff = DiffArrow() in_port = c.add_port() make_in_port(in_port) add_port_label(in_port, "train_output") c.add_edge(in_port, diff.in_port(0)) c.add_edge(out_port, diff.in_port(1)) error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "supervised_error") c.add_edge(diff.out_port(0), error_port) assert c.is_wired_correctly() return c
def gen_data(lmbda: float, size: int): forward_arrow = CompositeArrow(name="forward") in_port1 = forward_arrow.add_port() make_in_port(in_port1) in_port2 = forward_arrow.add_port() make_in_port(in_port2) out_port = forward_arrow.add_port() make_out_port(out_port) subtraction = SubArrow() absolute = AbsArrow() forward_arrow.add_edge(in_port1, subtraction.in_ports()[0]) forward_arrow.add_edge(in_port2, subtraction.in_ports()[1]) forward_arrow.add_edge(subtraction.out_ports()[0], absolute.in_ports()[0]) forward_arrow.add_edge(absolute.out_ports()[0], out_port) assert forward_arrow.is_wired_correctly() inverse_arrow = CompositeArrow(name="inverse") in_port = inverse_arrow.add_port() make_in_port(in_port) param_port_1 = inverse_arrow.add_port() make_in_port(param_port_1) make_param_port(param_port_1) param_port_2 = inverse_arrow.add_port() make_in_port(param_port_2) make_param_port(param_port_2) out_port_1 = inverse_arrow.add_port() make_out_port(out_port_1) out_port_2 = inverse_arrow.add_port() make_out_port(out_port_2) inv_sub = InvSubArrow() inv_abs = InvAbsArrow() inverse_arrow.add_edge(in_port, inv_abs.in_ports()[0]) inverse_arrow.add_edge(param_port_1, inv_abs.in_ports()[1]) inverse_arrow.add_edge(inv_abs.out_ports()[0], inv_sub.in_ports()[0]) inverse_arrow.add_edge(param_port_2, inv_sub.in_ports()[1]) inverse_arrow.add_edge(inv_sub.out_ports()[0], out_port_1) inverse_arrow.add_edge(inv_sub.out_ports()[1], out_port_2) assert inverse_arrow.is_wired_correctly() eps = 1e-5 data = [] dist = ExponentialRV(lmbda) for _ in range(size): x = tuple(dist.sample(shape=[2])) y = np.abs(x[0] - x[1]) theta = (np.sign(x[0] - x[1]), x[1]) y_ = apply(forward_arrow, list(x)) x_ = apply(inverse_arrow, [y, theta[0], theta[1]]) assert np.abs(y_[0] - y) < eps assert np.abs(x_[0] - x[0]) < eps assert np.abs(x_[1] - x[1]) < eps data.append((x, theta, y)) return data, forward_arrow, inverse_arrow
def inv_fwd_loss_arrow(arrow: Arrow, inverse: Arrow, DiffArrow=SquaredDifference) -> CompositeArrow: """ Arrow wihch computes |f(f^-1(y)) - y| Args: arrow: Forward function Returns: CompositeArrow """ c = CompositeArrow(name="%s_inv_fwd_loss" % arrow.name) # Make all in_ports of inverse inputs to composition for inv_in_port in inverse.in_ports(): in_port = c.add_port() make_in_port(in_port) if is_param_port(inv_in_port): make_param_port(in_port) c.add_edge(in_port, inv_in_port) # Connect all out_ports of inverse to in_ports of f for i, out_port in enumerate(inverse.out_ports()): if not is_error_port(out_port): c.add_edge(out_port, arrow.in_port(i)) c_out_port = c.add_port() # add edge from inverse output to composition output make_out_port(c_out_port) c.add_edge(out_port, c_out_port) # Pass errors (if any) of parametric inverse through as error_ports for i, out_port in enumerate(inverse.out_ports()): if is_error_port(out_port): error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "sub_arrow_error") c.add_edge(out_port, error_port) # find difference between inputs to inverse and outputs of fwd # make error port for each for i, out_port in enumerate(arrow.out_ports()): diff = DiffArrow() c.add_edge(c.in_port(i), diff.in_port(0)) c.add_edge(out_port, diff.in_port(1)) error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "inv_fwd_error") c.add_edge(diff.out_port(0), error_port) assert c.is_wired_correctly() return c