def partition(comp_arrow: CompositeArrow) -> List[Set[Arrow]]: """Partitions the comp_arrow into sequential layers of its sub_arrows""" partition_arrows = [] arrow_colors = pqdict() for sub_arrow in comp_arrow.get_sub_arrows(): arrow_colors[sub_arrow] = sub_arrow.num_in_ports() for port in comp_arrow.in_ports(): in_ports = comp_arrow.neigh_in_ports(port) for in_port in in_ports: assert in_port.arrow in arrow_colors, "sub_arrow not in arrow_colors" arrow_colors[in_port.arrow] -= 1 while len(arrow_colors) > 0: arrow_layer = set() view_arrow, view_priority = arrow_colors.topitem() while view_priority == 0: sub_arrow, priority = arrow_colors.popitem() arrow_layer.add(sub_arrow) if len(arrow_colors) == 0: break view_arrow, view_priority = arrow_colors.topitem() partition_arrows.append(arrow_layer) for arrow in arrow_layer: for out_port in arrow.out_ports(): in_ports = comp_arrow.neigh_in_ports(out_port) for in_port in in_ports: if in_port.arrow != comp_arrow: assert in_port.arrow in arrow_colors, "sub_arrow not in arrow_colors" arrow_colors[in_port.arrow] -= 1 return partition_arrows
def invert(comp_arrow: CompositeArrow, dispatch: Dict[Arrow, Callable] = default_dispatch) -> Arrow: """Construct a parametric inverse of comp_arrow Args: comp_arrow: Arrow to invert dispatch: Dict mapping comp_arrow class to invert function Returns: A (approximate) parametric inverse of `comp_arrow`""" # Replace multiedges with dupls and propagate comp_arrow.duplify() port_attr = propagate(comp_arrow) return inner_invert(comp_arrow, port_attr, dispatch)[0]
def conv(a: CompositeArrow, args: TensorVarList, state) -> Sequence[Tensor]: assert len(args) == a.num_in_ports() with tf.name_scope(a.name): # import pdb; pdb.set_trace() # FIXME: A horrible horrible hack port_grab = state['port_grab'] return interpret(conv, a, args, state, port_grab)
def gen_arrow_inputs(comp_arrow: CompositeArrow, inputs: List, arrow_colors): # Store a map from an arrow to its inputs # Use a dict because no guarantee we'll create input tensors in order arrow_inputs = dict() # type: Dict[Arrow, MutableMapping[int, Any]] for sub_arrow in comp_arrow.get_all_arrows(): arrow_inputs[sub_arrow] = dict() # Decrement priority of every arrow connected to the input for i, input_value in enumerate(inputs): for in_port in comp_arrow.edges[comp_arrow.in_ports()[i]]: # in_port = comp_arrow.inner_in_ports()[i] sub_arrow = in_port.arrow arrow_colors[sub_arrow] = arrow_colors[sub_arrow] - 1 arrow_inputs[sub_arrow][in_port.index] = input_value return arrow_inputs
def filter_arrows(fil, arrow: CompositeArrow, deep=True): good_arrows = set() for sub_arrow in arrow.get_sub_arrows(): if fil(sub_arrow): good_arrows.add(sub_arrow) if deep and isinstance(sub_arrow, CompositeArrow): for sub_sub_arrow in filter_arrows(fil, sub_arrow, deep=deep): good_arrows.add(sub_sub_arrow) return good_arrows
def gen_arrow_colors(comp_arrow: CompositeArrow): """ Interpret a composite arrow on some inputs Args: comp_arrow: Composite Arrow Returns: arrow_colors: Priority Queue of arrows """ # priority is the number of inputs each arrrow has which have been 'seen' # seen inputs are inputs to the composition, or outputs of arrows that # have already been converted into arrow_colors = pqdict() # type: MutableMapping[Arrow, int] for sub_arrow in comp_arrow.get_sub_arrows(): arrow_colors[sub_arrow] = sub_arrow.num_in_ports() # TODO: Unify arrow_colors[comp_arrow] = comp_arrow.num_out_ports() return arrow_colors
def inner_interpret(conv: Callable, comp_arrow: CompositeArrow, inputs: List, arrow_colors: MutableMapping[Arrow, int], arrow_inputs: Sequence, state: Dict, port_grab: Dict[Port, Any]): """Convert an comp_arrow to a tensorflow graph and add to graph""" assert len(inputs) == comp_arrow.num_in_ports(), "wrong # inputs" emit_list = [] while len(arrow_colors) > 0: # print_arrow_colors(arrow_colors) # print("Converting ", sub_arrow.name) sub_arrow, priority = arrow_colors.popitem() if sub_arrow is not comp_arrow: assert priority == 0, "Must resolve {} more inputs to {} first".format( priority, sub_arrow) # inputs = [arrow_inputs[sub_arrow][i] for i in range(len(arrow_inputs[sub_arrow]))] inputs = [ arrow_inputs[sub_arrow][i] for i in sorted(arrow_inputs[sub_arrow].keys()) ] outputs = conv(sub_arrow, inputs, state) assert len(outputs) == len( sub_arrow.out_ports()), "diff num outputs" # Decrement the priority of each subarrow connected to this arrow # Unless of course it is connected to the outside word for i, out_port in enumerate(sub_arrow.out_ports()): neigh_in_ports = comp_arrow.neigh_in_ports(out_port) for neigh_in_port in neigh_in_ports: neigh_arrow = neigh_in_port.arrow arrow_colors[neigh_arrow] = arrow_colors[neigh_arrow] - 1 arrow_inputs[neigh_arrow][neigh_in_port.index] = outputs[i] # Extract some port, kind of a hack for port in port_grab: if port.arrow in arrow_inputs: if port.index in arrow_inputs[port.arrow]: port_grab[port] = arrow_inputs[port.arrow][port.index] outputs_dict = arrow_inputs[comp_arrow] out_port_indices = sorted(list(outputs_dict.keys())) return [outputs_dict[i] for i in out_port_indices]
def propagate(comp_arrow: CompositeArrow, port_attr: PortAttributes=None, state=None, already_prop=None, only_prop=None) -> PortAttributes: """ Propagate values around a composite arrow to determine knowns from unknowns The knowns should be determined by the knowns, otherwise an error throws Args: sub_propagate: an @overloaded function which propagates from each arrow sub_propagate(a: ArrowType, port_to_known:Dict[Port, T], state:Any) comp_arrow: Composite Arrow to propagate through port_attr: port->value map for inputs to composite arrow state: A value of any type that is passed around during propagation and can be updated by sub_propagate Returns: port->value map for all ports in composite arrow """ already_prop = set() if already_prop is None else already_prop print("Propagating") # Copy port_attr to avoid affecting input port_attr = {} if port_attr is None else port_attr _port_attr = defaultdict(lambda: dict()) for port, attr in port_attr.items(): for attr_key, attr_value in attr.items(): _port_attr[port][attr_key] = attr_value # if comp_arrow.parent is None: # comp_arrow.toposort() # update port_attr with values stored on port extract_port_attr(comp_arrow, _port_attr) updated = set(comp_arrow.get_sub_arrows_nested()) update_neigh(_port_attr, _port_attr, comp_arrow, comp_arrow, only_prop, updated) while len(updated) > 0: # print(len(updated), " arrows updating in proapgation iteration") sub_arrow = updated.pop() sub_port_attr = {port: _port_attr[port] for port in sub_arrow.ports() if port in _port_attr} pred_dispatches = sub_arrow.get_dispatches() for pred, dispatch in pred_dispatches.items(): if pred(sub_arrow, sub_port_attr) and (sub_arrow, dispatch) not in already_prop: new_sub_port_attr = dispatch(sub_arrow, sub_port_attr) update_neigh(new_sub_port_attr, _port_attr, sub_arrow.parent, comp_arrow, only_prop, updated) already_prop.add((sub_arrow, dispatch)) if isinstance(sub_arrow, CompositeArrow): # new_sub_port_attr = propagate(sub_arrow, sub_port_attr, state, already_prop) # update_neigh(new_sub_port_attr, _port_attr, comp_arrow, updated) update_neigh(sub_port_attr, _port_attr, sub_arrow, comp_arrow, only_prop, updated) update_neigh(sub_port_attr, _port_attr, sub_arrow.parent, comp_arrow, only_prop, updated) print("Done Propagating") return _port_attr
def eliminate(arrow: CompositeArrow): """Eliminates redundant parameter Args: a: Parametric Arrow prime for eliminate! Returns: New Parameteric Arrow with fewer parameters""" # Warning: This is a huge hack # Get the shapes of param ports port_attr = propagate(arrow) symbt_ports = {} for port in arrow.in_ports(): if is_param_port(port): shape = get_port_shape(port, port_attr) symbt_ports[port] = {} # Create a symbolic tensor for each param port st = SymbolicTensor(shape=shape, name="port%s" % port.index, port=port) symbt_ports[port]['symbolic_tensor'] = st # repropagate port_attr = propagate(arrow, symbt_ports) # as a hack, just look on ports of duples to find symbolic tensors which # should be equivalent dupls = filter_arrows(lambda a: a.name in dupl_names, arrow) dupl_to_equiv = {} # Not all ports contain ports with symbolic tensor constraints valid_ports = set() for dupl in dupls: equiv = [] for p in dupl.ports(): if 'symbolic_tensor' in port_attr[p]: valid_ports.add(port_attr[p]['symbolic_tensor'].port) equiv.append(port_attr[p]['symbolic_tensor']) dupl_to_equiv[dupl] = equiv equiv_thetas = find_equivalent_thetas(dupl_to_equiv) return create_arrow(arrow, equiv_thetas, port_attr, valid_ports, symbt_ports)
def gen_data(lmbda: float, size: int): forward_arrow = CompositeArrow(name="forward") in_port1 = forward_arrow.add_port() make_in_port(in_port1) in_port2 = forward_arrow.add_port() make_in_port(in_port2) out_port = forward_arrow.add_port() make_out_port(out_port) subtraction = SubArrow() absolute = AbsArrow() forward_arrow.add_edge(in_port1, subtraction.in_ports()[0]) forward_arrow.add_edge(in_port2, subtraction.in_ports()[1]) forward_arrow.add_edge(subtraction.out_ports()[0], absolute.in_ports()[0]) forward_arrow.add_edge(absolute.out_ports()[0], out_port) assert forward_arrow.is_wired_correctly() inverse_arrow = CompositeArrow(name="inverse") in_port = inverse_arrow.add_port() make_in_port(in_port) param_port_1 = inverse_arrow.add_port() make_in_port(param_port_1) make_param_port(param_port_1) param_port_2 = inverse_arrow.add_port() make_in_port(param_port_2) make_param_port(param_port_2) out_port_1 = inverse_arrow.add_port() make_out_port(out_port_1) out_port_2 = inverse_arrow.add_port() make_out_port(out_port_2) inv_sub = InvSubArrow() inv_abs = InvAbsArrow() inverse_arrow.add_edge(in_port, inv_abs.in_ports()[0]) inverse_arrow.add_edge(param_port_1, inv_abs.in_ports()[1]) inverse_arrow.add_edge(inv_abs.out_ports()[0], inv_sub.in_ports()[0]) inverse_arrow.add_edge(param_port_2, inv_sub.in_ports()[1]) inverse_arrow.add_edge(inv_sub.out_ports()[0], out_port_1) inverse_arrow.add_edge(inv_sub.out_ports()[1], out_port_2) assert inverse_arrow.is_wired_correctly() eps = 1e-5 data = [] dist = ExponentialRV(lmbda) for _ in range(size): x = tuple(dist.sample(shape=[2])) y = np.abs(x[0] - x[1]) theta = (np.sign(x[0] - x[1]), x[1]) y_ = apply(forward_arrow, list(x)) x_ = apply(inverse_arrow, [y, theta[0], theta[1]]) assert np.abs(y_[0] - y) < eps assert np.abs(x_[0] - x[0]) < eps assert np.abs(x_[1] - x[1]) < eps data.append((x, theta, y)) return data, forward_arrow, inverse_arrow
def inner_invert(comp_arrow: CompositeArrow, port_attr: PortAttributes, dispatch: Dict[Arrow, Callable]): """Construct a parametric inverse of arrow Args: arrow: Arrow to invert dispatch: Dict mapping arrow class to invert function Returns: A (approximate) parametric inverse of `arrow` The ith in_port of comp_arrow will be corresponding ith out_port error_ports and param_ports will follow""" # Empty compositon for inverse inv_comp_arrow = CompositeArrow(name="%s_inv" % comp_arrow.name) # import pdb; pdb.set_trace() # Add a port on inverse arrow for every port on arrow for port in comp_arrow.ports(): inv_port = inv_comp_arrow.add_port() if is_in_port(port): if is_constant(port, port_attr): make_in_port(inv_port) else: make_out_port(inv_port) elif is_out_port(port): if is_constant(port, port_attr): make_out_port(inv_port) else: make_in_port(inv_port) # Transfer port information # FIXME: What port_attr go transfered from port to inv_port if 'shape' not in port_attr[port]: print('WARNING: shape unknown for %s' % port) else: set_port_shape(inv_port, get_port_shape(port, port_attr)) # invert each sub_arrow arrow_to_inv = dict() arrow_to_port_map = dict() for sub_arrow in comp_arrow.get_sub_arrows(): inv_sub_arrow, port_map = invert_sub_arrow(sub_arrow, port_attr, dispatch) assert sub_arrow is not None assert inv_sub_arrow.parent is None arrow_to_port_map[sub_arrow] = port_map arrow_to_inv[sub_arrow] = inv_sub_arrow # Add comp_arrow to inv assert comp_arrow is not None arrow_to_inv[comp_arrow] = inv_comp_arrow comp_port_map = {i: i for i in range(comp_arrow.num_ports())} arrow_to_port_map[comp_arrow] = comp_port_map # Then, rewire up all the edges for out_port, in_port in comp_arrow.edges.items(): left_inv_port = get_inv_port(out_port, arrow_to_port_map, arrow_to_inv) right_inv_port = get_inv_port(in_port, arrow_to_port_map, arrow_to_inv) both = [left_inv_port, right_inv_port] projecting = list( filter(lambda x: would_project(x, inv_comp_arrow), both)) receiving = list( filter(lambda x: would_receive(x, inv_comp_arrow), both)) assert len(projecting) == 1, "Should be only 1 projecting" assert len(receiving) == 1, "Should be only 1 receiving" inv_comp_arrow.add_edge(projecting[0], receiving[0]) for transform in transforms: transform(inv_comp_arrow) # Craete new ports on inverse compositions for parametric and error ports for sub_arrow in inv_comp_arrow.get_sub_arrows(): for port in sub_arrow.ports(): if is_param_port(port): assert port not in inv_comp_arrow.edges.keys() assert port not in inv_comp_arrow.edges.values() param_port = inv_comp_arrow.add_port() inv_comp_arrow.add_edge(param_port, port) make_in_port(param_port) make_param_port(param_port) elif is_error_port(port): assert port not in inv_comp_arrow.edges.keys() assert port not in inv_comp_arrow.edges.values() error_port = inv_comp_arrow.add_port() inv_comp_arrow.add_edge(port, error_port) make_out_port(error_port) make_error_port(error_port) return inv_comp_arrow, comp_port_map
def reparam(comp_arrow: CompositeArrow, phi_shape: Tuple, nn_takes_input=True): """Reparameterize an arrow. All parametric inputs now function of phi Args: comp_arrow: Arrow to reparameterize phi_shape: Shape of parameter input """ reparam = CompositeArrow(name="%s_reparam" % comp_arrow.name) phi = reparam.add_port() set_port_shape(phi, phi_shape) make_in_port(phi) make_param_port(phi) n_in_ports = 1 if nn_takes_input: n_in_ports += comp_arrow.num_in_ports() - comp_arrow.num_param_ports() nn = TfArrow(n_in_ports=n_in_ports, n_out_ports=comp_arrow.num_param_ports()) reparam.add_edge(phi, nn.in_port(0)) i = 0 j = 1 for port in comp_arrow.ports(): if is_param_port(port): reparam.add_edge(nn.out_port(i), port) i += 1 else: re_port = reparam.add_port() if is_out_port(port): make_out_port(re_port) reparam.add_edge(port, re_port) if is_in_port(port): make_in_port(re_port) reparam.add_edge(re_port, port) if nn_takes_input: reparam.add_edge(re_port, nn.in_port(j)) j += 1 if is_error_port(port): make_error_port(re_port) for label in get_port_labels(port): add_port_label(re_port, label) assert reparam.is_wired_correctly() return reparam
def g_from_g_theta(inv: Arrow, g_theta: Arrow): """ Construct an amoritzed random variable g from a parametric inverse `inv` and function `g_theta` which constructs parameters for `inv` Args: g_theta: Y x Y ... Y x Z -> Theta Returns: Y x Y x .. x Z -> X x ... X x Error x ... Error """ c = CompositeArrow(name="%s_g_theta" % inv.name) inv_in_ports, inv_param_ports, inv_out_ports, inv_error_ports = split_ports( inv) g_theta_in_ports, g_theta_param_ports, g_theta_out_ports, g_theta_error_ports = split_ports( g_theta) assert len(g_theta_out_ports) == len(inv_param_ports) # Connect y to g_theta and f for i in range(len(inv_in_ports)): y_in_port = c.add_port() make_in_port(y_in_port) c.add_edge(y_in_port, g_theta.in_port(i)) c.add_edge(y_in_port, inv_in_ports[i]) # conect up noise input to g_theta z_in_port = c.add_port() make_in_port(z_in_port) c.add_edge(z_in_port, g_theta.in_port(len(inv_in_ports))) # connect g_theta to inv for i in range(len(g_theta_out_ports)): c.add_edge(g_theta.out_port(i), inv_param_ports[i]) for inv_out_port in inv_out_ports: out_port = c.add_port() make_out_port(out_port) c.add_edge(inv_out_port, out_port) for inv_error_port in inv_error_ports: error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) c.add_edge(inv_error_port, error_port) assert c.is_wired_correctly() return c
def ConcatShuffleArrow(n_inputs: int, ndims: int): """Concatenate n_inputs inputs and shuffle Arrow first n_inputs inputs are arrays to shuffle last input is permutation vector Args: n_inputs""" c = CompositeArrow(name="ConcatShuffle") stack = StackArrow(n_inputs, axis=0) for i in range(n_inputs): in_port = c.add_port() make_in_port(in_port) c.add_edge(in_port, stack.in_port(i)) # Permutation vector perm = c.add_port() make_in_port(perm) set_port_dtype(perm, 'int32') gather = GatherArrow() c.add_edge(stack.out_port(0), gather.in_port(0)) c.add_edge(perm, gather.in_port(1)) # Switch the first and last dimension # FIXME: I do this only because gather seems to affect only first dimension # There's a better way, probably using gather_nd a = [i for i in range(ndims + 1)] tp_perm = [a[i + 1] for i in range(len(a) - 1)] + [a[0]] transpose = TransposeArrow(tp_perm) c.add_edge(gather.out_port(0), transpose.in_port(0)) out = c.add_port() make_out_port(out) c.add_edge(transpose.out_port(0), out) assert c.is_wired_correctly() return c
def set_gan_arrow(arrow: Arrow, cond_gen: Arrow, disc: Arrow, n_fake_samples: int, ndims: int, x_shapes=None, z_shape=None) -> CompositeArrow: """ Arrow wihch computes loss for amortized random variable using set gan. Args: arrow: Forward function cond_gen: Y x Z -> X - Conditional Generators disc: X^n -> {0,1}^n n_fake_samples: n, number of samples seen by discriminator at once ndims: dimensionality of dims Returns: CompositeArrow: X x Z x ... Z x RAND_PERM -> d_Loss x g_Loss x Y x ... Y """ # TODO: Assumes that f has single in_port and single out_port, generalize c = CompositeArrow(name="%s_set_gan" % arrow.name) assert cond_gen.num_in_ports( ) == 2, "don't handle case of more than one Y input" cond_gens = [deepcopy(cond_gen) for i in range(n_fake_samples)] comp_in_ports = [] # Connect x to arrow in puts for i in range(arrow.num_in_ports()): in_port = c.add_port() make_in_port(in_port) comp_in_ports.append(in_port) c.add_edge(in_port, arrow.in_port(i)) if x_shapes is not None: set_port_shape(in_port, x_shapes[i]) # Connect f(x) to generator for i in range(n_fake_samples): for j in range(arrow.num_out_ports()): c.add_edge(arrow.out_port(j), cond_gens[i].in_port(j)) # Connect noise input to generator second inport for i in range(n_fake_samples): noise_in_port = c.add_port() make_in_port(noise_in_port) if z_shape is not None: set_port_shape(noise_in_port, z_shape) cg_noise_in_port_id = cond_gens[i].num_in_ports() - 1 c.add_edge(noise_in_port, cond_gens[i].in_port(cg_noise_in_port_id)) stack_shuffles = [] rand_perm_in_port = c.add_port() make_in_port(rand_perm_in_port) set_port_shape(rand_perm_in_port, (n_fake_samples + 1, )) set_port_dtype(rand_perm_in_port, 'int32') # For every output of g, i.e. x and y if f(x, y) = z # Stack all the Xs from the differet samples together and shuffle cond_gen_non_error_out_ports = cond_gen.num_out_ports( ) - cond_gen.num_error_ports() for i in range(cond_gen_non_error_out_ports): stack_shuffle = ConcatShuffleArrow(n_fake_samples + 1, ndims) stack_shuffles.append(stack_shuffle) # Add each output from generator to shuffle set for j in range(n_fake_samples): c.add_edge(cond_gens[j].out_port(i), stack_shuffle.in_port(j)) # Add the posterior sample x to the shuffle set c.add_edge(comp_in_ports[i], stack_shuffle.in_port(n_fake_samples)) c.add_edge(rand_perm_in_port, stack_shuffle.in_port(n_fake_samples + 1)) gan_loss_arrow = GanLossArrow(n_fake_samples + 1) # Connect output of each stack shuffle to discriminator for i in range(cond_gen_non_error_out_ports): c.add_edge(stack_shuffles[i].out_port(0), disc.in_port(i)) c.add_edge(disc.out_port(0), gan_loss_arrow.in_port(0)) c.add_edge(rand_perm_in_port, gan_loss_arrow.in_port(1)) # Add generator and discriminator loss ports loss_d_port = c.add_port() make_out_port(loss_d_port) make_error_port(loss_d_port) c.add_edge(gan_loss_arrow.out_port(0), loss_d_port) loss_g_port = c.add_port() make_out_port(loss_g_port) make_error_port(loss_g_port) c.add_edge(gan_loss_arrow.out_port(1), loss_g_port) # Connect fake samples to output of composition for i in range(n_fake_samples): for j in range(cond_gen_non_error_out_ports): sample = c.add_port() make_out_port(sample) c.add_edge(cond_gens[i].out_port(j), sample) # Pipe up error ports for cond_gen in cond_gens: for i in range(cond_gen_non_error_out_ports, cond_gen.num_out_ports()): error_port = c.add_port() make_out_port(error_port) c.add_edge(cond_gen.out_port(i), error_port) assert c.is_wired_correctly() return c
def unparam(arrow: Arrow, nnet: Arrow = None): """Unparameerize an arrow by sticking a tfArrow between its normal inputs, and any parametric inputs Args: arrow: Y x Theta -> X nnet: Y -> Theta Returns: Y -> X """ c = CompositeArrow(name="%s_unparam" % arrow.name) in_ports = [p for p in arrow.in_ports() if not is_param_port(p)] param_ports = [p for p in arrow.in_ports() if is_param_port(p)] if nnet is None: nnet = TfArrow(n_in_ports=len(in_ports), n_out_ports=len(param_ports)) for i, in_port in enumerate(in_ports): c_in_port = c.add_port() make_in_port(c_in_port) transfer_labels(in_port, c_in_port) c.add_edge(c_in_port, in_port) c.add_edge(c_in_port, nnet.in_port(i)) for i, param_port in enumerate(param_ports): c.add_edge(nnet.out_port(i), param_port) for out_port in arrow.out_ports(): c_out_port = c.add_port() make_out_port(c_out_port) if is_error_port(out_port): make_error_port(c_out_port) transfer_labels(out_port, c_out_port) c.add_edge(out_port, c_out_port) assert c.is_wired_correctly() return c
def create_arrow(arrow: CompositeArrow, equiv_thetas, port_attr, valid_ports, symbt_ports): # New parameter space should have nclasses elements nclasses = num_unique_elem(equiv_thetas) new_arrow = CompositeArrow(name="%s_elim" % arrow.name) for out_port in arrow.out_ports(): c_out_port = new_arrow.add_port() make_out_port(c_out_port) transfer_labels(out_port, c_out_port) if is_error_port(out_port): make_error_port(c_out_port) new_arrow.add_edge(out_port, c_out_port) flat_shape = SourceArrow(np.array([nclasses], dtype=np.int32)) flatten = ReshapeArrow() new_arrow.add_edge(flat_shape.out_port(0), flatten.in_port(1)) slim_param_flat = flatten.out_port(0) batch_size = None for in_port in arrow.in_ports(): if in_port in valid_ports: symbt = symbt_ports[in_port]['symbolic_tensor'] indices = [] for theta in symbt.symbols: setid = equiv_thetas[theta] indices.append(setid) shape = get_port_shape(in_port, port_attr) if len(shape) > 1: if batch_size is not None: assert shape[0] == batch_size batch_size = shape[0] gather = GatherArrow() src = SourceArrow(np.array(indices, dtype=np.int32)) shape_shape = SourceArrow(np.array(shape, dtype=np.int32)) reshape = ReshapeArrow() new_arrow.add_edge(slim_param_flat, gather.in_port(0)) new_arrow.add_edge(src.out_port(0), gather.in_port(1)) new_arrow.add_edge(gather.out_port(0), reshape.in_port(0)) new_arrow.add_edge(shape_shape.out_port(0), reshape.in_port(1)) new_arrow.add_edge(reshape.out_port(0), in_port) else: new_in_port = new_arrow.add_port() make_in_port(new_in_port) if is_param_port(in_port): make_param_port(new_in_port) transfer_labels(in_port, new_in_port) new_arrow.add_edge(new_in_port, in_port) assert nclasses % batch_size == 0 slim_param = new_arrow.add_port() make_in_port(slim_param) make_param_port(slim_param) new_arrow.add_edge(slim_param, flatten.in_port(0)) set_port_shape(slim_param, (batch_size, nclasses // batch_size)) assert new_arrow.is_wired_correctly() return new_arrow
def comp(fwd: Arrow, right_inv: Arrow, DiffArrow=SquaredDifference): """Compositon: Pipe output of forward model into input of right inverse Args: fwd: X -> Y right_inv: Y -> X x Error Returns: X -> X """ c = CompositeArrow(name="fwd_to_right_inv") # Connect left boundar to fwd for in_port in fwd.in_ports(): c_in_port = c.add_port() make_in_port(c_in_port) c.add_edge(c_in_port, in_port) transfer_labels(in_port, c_in_port) # Connect fwd to right_inv for i, out_port in enumerate(fwd.out_ports()): c.add_edge(out_port, right_inv.in_port(i)) # connect right_inv to right boundary for out_port in right_inv.out_ports(): c_out_port = c.add_port() make_out_port(c_out_port) if is_error_port(out_port): make_error_port(c_out_port) c.add_edge(out_port, c_out_port) transfer_labels(out_port, c_out_port) # Find difference between X and right_inv(f(x)) right_inv_out_ports = list(filter(lambda port: not is_error_port(port), right_inv.out_ports())) # len(X) assert len(right_inv_out_ports) == len(c.in_ports()) for i, in_port in enumerate(c.in_ports()): diff = DiffArrow() c.add_edge(in_port, diff.in_port(0)) c.add_edge(right_inv_out_ports[i], diff.in_port(1)) error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "supervised_error") c.add_edge(diff.out_port(0), error_port) assert c.is_wired_correctly() return c
def eliminate_gathernd(arrow: CompositeArrow): """Eliminates redundant parameters in GatherNd Args: a: Parametric Arrow prime for eliminate! Returns: New Parameteric Arrow with fewer parameters""" dupls = filter_arrows(lambda a: a.name in dupl_names, arrow) for dupl in dupls: slim_param_arrow = UpdateArrow() constraints = None free = None shape = None shape_source = None for p in dupl.in_ports(): inv = arrow.neigh_ports(p)[0].arrow if inv.name == 'InvGatherNd': out, theta, indices = inv.in_ports() make_not_param_port(theta) arrow.add_edge(slim_param_arrow.out_port(0), theta) indices_val = get_port_value(indices) if shape is not None: assert np.array_equal( shape, np.array(get_port_shape(inv.out_port(0)))) else: shape = np.array(get_port_shape(inv.out_port(0))) shape_source = SourceArrow(shape) out = arrow.neigh_ports(out)[0] unset, unique = complement_bool_list(indices_val, shape) free = unset if free is None else np.logical_and( unset, free) # FIXME: don't really need bool anymore unique_source = SourceArrow(unique) unique_inds = SourceArrow(indices_val[tuple( np.transpose(unique))]) # add in a constraint for these new known values unique_upds = GatherNdArrow() arrow.add_edge(out, unique_upds.in_port(0)) arrow.add_edge(unique_source.out_port(0), unique_upds.in_port(1)) tmp = UpdateArrow() if constraints is None: constraints = SourceArrow(np.zeros(shape, dtype=np.float32)) arrow.add_edge(constraints.out_port(0), tmp.in_port(0)) arrow.add_edge(unique_inds.out_port(0), tmp.in_port(1)) arrow.add_edge(unique_upds.out_port(0), tmp.in_port(2)) arrow.add_edge(shape_source.out_port(0), tmp.in_port(3)) constraints = tmp if constraints is not None: # put in knowns arrow.add_edge(constraints.out_port(0), slim_param_arrow.in_port(0)) # put in params make_param_port(slim_param_arrow.in_port(2)) arrow.add_edge(shape_source.out_port(0), slim_param_arrow.in_port(3)) inds = np.transpose(np.nonzero(free)) if (free == free[0]).all(): print("Assuming batched input") inds = np.array(np.split(inds, shape[0])) else: print( "WARNING: Unbatched input, haven't designed for this case") inds_source = SourceArrow(inds) arrow.add_edge(inds_source.out_port(0), slim_param_arrow.in_port(1))
def supervised_loss_arrow(arrow: Arrow, DiffArrow=SquaredDifference) -> CompositeArrow: """ Creates an arrow that computes |f(y) - x| Args: Arrow: f: Y -> X - The arrow to modify DiffArrow: d: X x X - R - Arrow for computing difference Returns: f: Y/Theta x .. Y/Theta x X -> |f^{-1}(y) - X| x X Arrow with same input and output as arrow except that it takes an addition input with label 'train_output' that should contain examples in Y, and it returns an additional error output labelled 'supervised_error' which is the |f(y) - x| """ c = CompositeArrow(name="%s_supervised" % arrow.name) # Pipe all inputs of composite to inputs of arrow # Make all in_ports of inverse inputs to composition for in_port in arrow.in_ports(): c_in_port = c.add_port() make_in_port(c_in_port) if is_param_port(in_port): make_param_port(c_in_port) c.add_edge(c_in_port, in_port) # find difference between inputs to inverse and outputs of fwd # make error port for each for i, out_port in enumerate(arrow.out_ports()): if is_error_port(out_port): # if its an error port just pass through error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) transfer_labels(out_port, error_port) c.add_edge(out_port, error_port) else: # If its normal outport then pass through c_out_port = c.add_port() make_out_port(c_out_port) c.add_edge(out_port, c_out_port) # And compute the error diff = DiffArrow() in_port = c.add_port() make_in_port(in_port) add_port_label(in_port, "train_output") c.add_edge(in_port, diff.in_port(0)) c.add_edge(out_port, diff.in_port(1)) error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "supervised_error") c.add_edge(diff.out_port(0), error_port) assert c.is_wired_correctly() return c
def __init__(self, lmbda: float) -> None: comp_arrow = CompositeArrow(name="ExponentialRVQuantile") in_port = comp_arrow.add_port() make_in_port(in_port) out_port = comp_arrow.add_port() make_out_port(out_port) lmbda_source = SourceArrow(lmbda) one_source = SourceArrow(1.0) one_minus_p = SubArrow() comp_arrow.add_edge(one_source.out_ports()[0], one_minus_p.in_ports()[0]) comp_arrow.add_edge(in_port, one_minus_p.in_ports()[1]) ln = LogArrow() comp_arrow.add_edge(one_minus_p.out_ports()[0], ln.in_ports()[0]) negate = NegArrow() comp_arrow.add_edge(ln.out_ports()[0], negate.in_ports()[0]) div_lmbda = DivArrow() comp_arrow.add_edge(negate.out_ports()[0], div_lmbda.in_ports()[0]) comp_arrow.add_edge(lmbda_source.out_ports()[0], div_lmbda.in_ports()[1]) comp_arrow.add_edge(div_lmbda.out_ports()[0], out_port) assert comp_arrow.is_wired_correctly() self.quantile = comp_arrow
def SumNArrow(ninputs: int): """ Create arrow f(x1, ..., xn) = sum(x1, ..., xn) Args: n: number of inputs Returns: Arrow of n inputs and one output """ assert ninputs > 1 c = CompositeArrow(name="SumNArrow") light_port = c.add_port() make_in_port(light_port) for _ in range(ninputs - 1): add = AddArrow() c.add_edge(light_port, add.in_port(0)) dark_port = c.add_port() make_in_port(dark_port) c.add_edge(dark_port, add.in_port(1)) light_port = add.out_port(0) out_port = c.add_port() make_out_port(out_port) c.add_edge(add.out_port(0), out_port) assert c.is_wired_correctly() assert c.num_in_ports() == ninputs return c
def inv_fwd_loss_arrow(arrow: Arrow, inverse: Arrow, DiffArrow=SquaredDifference) -> CompositeArrow: """ Arrow wihch computes |f(f^-1(y)) - y| Args: arrow: Forward function Returns: CompositeArrow """ c = CompositeArrow(name="%s_inv_fwd_loss" % arrow.name) # Make all in_ports of inverse inputs to composition for inv_in_port in inverse.in_ports(): in_port = c.add_port() make_in_port(in_port) if is_param_port(inv_in_port): make_param_port(in_port) c.add_edge(in_port, inv_in_port) # Connect all out_ports of inverse to in_ports of f for i, out_port in enumerate(inverse.out_ports()): if not is_error_port(out_port): c.add_edge(out_port, arrow.in_port(i)) c_out_port = c.add_port() # add edge from inverse output to composition output make_out_port(c_out_port) c.add_edge(out_port, c_out_port) # Pass errors (if any) of parametric inverse through as error_ports for i, out_port in enumerate(inverse.out_ports()): if is_error_port(out_port): error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "sub_arrow_error") c.add_edge(out_port, error_port) # find difference between inputs to inverse and outputs of fwd # make error port for each for i, out_port in enumerate(arrow.out_ports()): diff = DiffArrow() c.add_edge(c.in_port(i), diff.in_port(0)) c.add_edge(out_port, diff.in_port(1)) error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "inv_fwd_error") c.add_edge(diff.out_port(0), error_port) assert c.is_wired_correctly() return c