def unparam(arrow: Arrow, nnet: Arrow = None): """Unparameerize an arrow by sticking a tfArrow between its normal inputs, and any parametric inputs Args: arrow: Y x Theta -> X nnet: Y -> Theta Returns: Y -> X """ c = CompositeArrow(name="%s_unparam" % arrow.name) in_ports = [p for p in arrow.in_ports() if not is_param_port(p)] param_ports = [p for p in arrow.in_ports() if is_param_port(p)] if nnet is None: nnet = TfArrow(n_in_ports=len(in_ports), n_out_ports=len(param_ports)) for i, in_port in enumerate(in_ports): c_in_port = c.add_port() make_in_port(c_in_port) transfer_labels(in_port, c_in_port) c.add_edge(c_in_port, in_port) c.add_edge(c_in_port, nnet.in_port(i)) for i, param_port in enumerate(param_ports): c.add_edge(nnet.out_port(i), param_port) for out_port in arrow.out_ports(): c_out_port = c.add_port() make_out_port(c_out_port) if is_error_port(out_port): make_error_port(c_out_port) transfer_labels(out_port, c_out_port) c.add_edge(out_port, c_out_port) assert c.is_wired_correctly() return c
def SumNArrow(ninputs: int): """ Create arrow f(x1, ..., xn) = sum(x1, ..., xn) Args: n: number of inputs Returns: Arrow of n inputs and one output """ assert ninputs > 1 c = CompositeArrow(name="SumNArrow") light_port = c.add_port() make_in_port(light_port) for _ in range(ninputs - 1): add = AddArrow() c.add_edge(light_port, add.in_port(0)) dark_port = c.add_port() make_in_port(dark_port) c.add_edge(dark_port, add.in_port(1)) light_port = add.out_port(0) out_port = c.add_port() make_out_port(out_port) c.add_edge(add.out_port(0), out_port) assert c.is_wired_correctly() assert c.num_in_ports() == ninputs return c
def __init__(self, lmbda: float) -> None: comp_arrow = CompositeArrow(name="ExponentialRVQuantile") in_port = comp_arrow.add_port() make_in_port(in_port) out_port = comp_arrow.add_port() make_out_port(out_port) lmbda_source = SourceArrow(lmbda) one_source = SourceArrow(1.0) one_minus_p = SubArrow() comp_arrow.add_edge(one_source.out_ports()[0], one_minus_p.in_ports()[0]) comp_arrow.add_edge(in_port, one_minus_p.in_ports()[1]) ln = LogArrow() comp_arrow.add_edge(one_minus_p.out_ports()[0], ln.in_ports()[0]) negate = NegArrow() comp_arrow.add_edge(ln.out_ports()[0], negate.in_ports()[0]) div_lmbda = DivArrow() comp_arrow.add_edge(negate.out_ports()[0], div_lmbda.in_ports()[0]) comp_arrow.add_edge(lmbda_source.out_ports()[0], div_lmbda.in_ports()[1]) comp_arrow.add_edge(div_lmbda.out_ports()[0], out_port) assert comp_arrow.is_wired_correctly() self.quantile = comp_arrow
def supervised_loss_arrow(arrow: Arrow, DiffArrow=SquaredDifference) -> CompositeArrow: """ Creates an arrow that computes |f(y) - x| Args: Arrow: f: Y -> X - The arrow to modify DiffArrow: d: X x X - R - Arrow for computing difference Returns: f: Y/Theta x .. Y/Theta x X -> |f^{-1}(y) - X| x X Arrow with same input and output as arrow except that it takes an addition input with label 'train_output' that should contain examples in Y, and it returns an additional error output labelled 'supervised_error' which is the |f(y) - x| """ c = CompositeArrow(name="%s_supervised" % arrow.name) # Pipe all inputs of composite to inputs of arrow # Make all in_ports of inverse inputs to composition for in_port in arrow.in_ports(): c_in_port = c.add_port() make_in_port(c_in_port) if is_param_port(in_port): make_param_port(c_in_port) c.add_edge(c_in_port, in_port) # find difference between inputs to inverse and outputs of fwd # make error port for each for i, out_port in enumerate(arrow.out_ports()): if is_error_port(out_port): # if its an error port just pass through error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) transfer_labels(out_port, error_port) c.add_edge(out_port, error_port) else: # If its normal outport then pass through c_out_port = c.add_port() make_out_port(c_out_port) c.add_edge(out_port, c_out_port) # And compute the error diff = DiffArrow() in_port = c.add_port() make_in_port(in_port) add_port_label(in_port, "train_output") c.add_edge(in_port, diff.in_port(0)) c.add_edge(out_port, diff.in_port(1)) error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "supervised_error") c.add_edge(diff.out_port(0), error_port) assert c.is_wired_correctly() return c
def create_arrow(arrow: CompositeArrow, equiv_thetas, port_attr, valid_ports, symbt_ports): # New parameter space should have nclasses elements nclasses = num_unique_elem(equiv_thetas) new_arrow = CompositeArrow(name="%s_elim" % arrow.name) for out_port in arrow.out_ports(): c_out_port = new_arrow.add_port() make_out_port(c_out_port) transfer_labels(out_port, c_out_port) if is_error_port(out_port): make_error_port(c_out_port) new_arrow.add_edge(out_port, c_out_port) flat_shape = SourceArrow(np.array([nclasses], dtype=np.int32)) flatten = ReshapeArrow() new_arrow.add_edge(flat_shape.out_port(0), flatten.in_port(1)) slim_param_flat = flatten.out_port(0) batch_size = None for in_port in arrow.in_ports(): if in_port in valid_ports: symbt = symbt_ports[in_port]['symbolic_tensor'] indices = [] for theta in symbt.symbols: setid = equiv_thetas[theta] indices.append(setid) shape = get_port_shape(in_port, port_attr) if len(shape) > 1: if batch_size is not None: assert shape[0] == batch_size batch_size = shape[0] gather = GatherArrow() src = SourceArrow(np.array(indices, dtype=np.int32)) shape_shape = SourceArrow(np.array(shape, dtype=np.int32)) reshape = ReshapeArrow() new_arrow.add_edge(slim_param_flat, gather.in_port(0)) new_arrow.add_edge(src.out_port(0), gather.in_port(1)) new_arrow.add_edge(gather.out_port(0), reshape.in_port(0)) new_arrow.add_edge(shape_shape.out_port(0), reshape.in_port(1)) new_arrow.add_edge(reshape.out_port(0), in_port) else: new_in_port = new_arrow.add_port() make_in_port(new_in_port) if is_param_port(in_port): make_param_port(new_in_port) transfer_labels(in_port, new_in_port) new_arrow.add_edge(new_in_port, in_port) assert nclasses % batch_size == 0 slim_param = new_arrow.add_port() make_in_port(slim_param) make_param_port(slim_param) new_arrow.add_edge(slim_param, flatten.in_port(0)) set_port_shape(slim_param, (batch_size, nclasses // batch_size)) assert new_arrow.is_wired_correctly() return new_arrow
def inv_fwd_loss_arrow(arrow: Arrow, inverse: Arrow, DiffArrow=SquaredDifference) -> CompositeArrow: """ Arrow wihch computes |f(f^-1(y)) - y| Args: arrow: Forward function Returns: CompositeArrow """ c = CompositeArrow(name="%s_inv_fwd_loss" % arrow.name) # Make all in_ports of inverse inputs to composition for inv_in_port in inverse.in_ports(): in_port = c.add_port() make_in_port(in_port) if is_param_port(inv_in_port): make_param_port(in_port) c.add_edge(in_port, inv_in_port) # Connect all out_ports of inverse to in_ports of f for i, out_port in enumerate(inverse.out_ports()): if not is_error_port(out_port): c.add_edge(out_port, arrow.in_port(i)) c_out_port = c.add_port() # add edge from inverse output to composition output make_out_port(c_out_port) c.add_edge(out_port, c_out_port) # Pass errors (if any) of parametric inverse through as error_ports for i, out_port in enumerate(inverse.out_ports()): if is_error_port(out_port): error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "sub_arrow_error") c.add_edge(out_port, error_port) # find difference between inputs to inverse and outputs of fwd # make error port for each for i, out_port in enumerate(arrow.out_ports()): diff = DiffArrow() c.add_edge(c.in_port(i), diff.in_port(0)) c.add_edge(out_port, diff.in_port(1)) error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "inv_fwd_error") c.add_edge(diff.out_port(0), error_port) assert c.is_wired_correctly() return c
def comp(fwd: Arrow, right_inv: Arrow, DiffArrow=SquaredDifference): """Compositon: Pipe output of forward model into input of right inverse Args: fwd: X -> Y right_inv: Y -> X x Error Returns: X -> X """ c = CompositeArrow(name="fwd_to_right_inv") # Connect left boundar to fwd for in_port in fwd.in_ports(): c_in_port = c.add_port() make_in_port(c_in_port) c.add_edge(c_in_port, in_port) transfer_labels(in_port, c_in_port) # Connect fwd to right_inv for i, out_port in enumerate(fwd.out_ports()): c.add_edge(out_port, right_inv.in_port(i)) # connect right_inv to right boundary for out_port in right_inv.out_ports(): c_out_port = c.add_port() make_out_port(c_out_port) if is_error_port(out_port): make_error_port(c_out_port) c.add_edge(out_port, c_out_port) transfer_labels(out_port, c_out_port) # Find difference between X and right_inv(f(x)) right_inv_out_ports = list(filter(lambda port: not is_error_port(port), right_inv.out_ports())) # len(X) assert len(right_inv_out_ports) == len(c.in_ports()) for i, in_port in enumerate(c.in_ports()): diff = DiffArrow() c.add_edge(in_port, diff.in_port(0)) c.add_edge(right_inv_out_ports[i], diff.in_port(1)) error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) add_port_label(error_port, "supervised_error") c.add_edge(diff.out_port(0), error_port) assert c.is_wired_correctly() return c
def g_from_g_theta(inv: Arrow, g_theta: Arrow): """ Construct an amoritzed random variable g from a parametric inverse `inv` and function `g_theta` which constructs parameters for `inv` Args: g_theta: Y x Y ... Y x Z -> Theta Returns: Y x Y x .. x Z -> X x ... X x Error x ... Error """ c = CompositeArrow(name="%s_g_theta" % inv.name) inv_in_ports, inv_param_ports, inv_out_ports, inv_error_ports = split_ports( inv) g_theta_in_ports, g_theta_param_ports, g_theta_out_ports, g_theta_error_ports = split_ports( g_theta) assert len(g_theta_out_ports) == len(inv_param_ports) # Connect y to g_theta and f for i in range(len(inv_in_ports)): y_in_port = c.add_port() make_in_port(y_in_port) c.add_edge(y_in_port, g_theta.in_port(i)) c.add_edge(y_in_port, inv_in_ports[i]) # conect up noise input to g_theta z_in_port = c.add_port() make_in_port(z_in_port) c.add_edge(z_in_port, g_theta.in_port(len(inv_in_ports))) # connect g_theta to inv for i in range(len(g_theta_out_ports)): c.add_edge(g_theta.out_port(i), inv_param_ports[i]) for inv_out_port in inv_out_ports: out_port = c.add_port() make_out_port(out_port) c.add_edge(inv_out_port, out_port) for inv_error_port in inv_error_ports: error_port = c.add_port() make_out_port(error_port) make_error_port(error_port) c.add_edge(inv_error_port, error_port) assert c.is_wired_correctly() return c
def gen_data(lmbda: float, size: int): forward_arrow = CompositeArrow(name="forward") in_port1 = forward_arrow.add_port() make_in_port(in_port1) in_port2 = forward_arrow.add_port() make_in_port(in_port2) out_port = forward_arrow.add_port() make_out_port(out_port) subtraction = SubArrow() absolute = AbsArrow() forward_arrow.add_edge(in_port1, subtraction.in_ports()[0]) forward_arrow.add_edge(in_port2, subtraction.in_ports()[1]) forward_arrow.add_edge(subtraction.out_ports()[0], absolute.in_ports()[0]) forward_arrow.add_edge(absolute.out_ports()[0], out_port) assert forward_arrow.is_wired_correctly() inverse_arrow = CompositeArrow(name="inverse") in_port = inverse_arrow.add_port() make_in_port(in_port) param_port_1 = inverse_arrow.add_port() make_in_port(param_port_1) make_param_port(param_port_1) param_port_2 = inverse_arrow.add_port() make_in_port(param_port_2) make_param_port(param_port_2) out_port_1 = inverse_arrow.add_port() make_out_port(out_port_1) out_port_2 = inverse_arrow.add_port() make_out_port(out_port_2) inv_sub = InvSubArrow() inv_abs = InvAbsArrow() inverse_arrow.add_edge(in_port, inv_abs.in_ports()[0]) inverse_arrow.add_edge(param_port_1, inv_abs.in_ports()[1]) inverse_arrow.add_edge(inv_abs.out_ports()[0], inv_sub.in_ports()[0]) inverse_arrow.add_edge(param_port_2, inv_sub.in_ports()[1]) inverse_arrow.add_edge(inv_sub.out_ports()[0], out_port_1) inverse_arrow.add_edge(inv_sub.out_ports()[1], out_port_2) assert inverse_arrow.is_wired_correctly() eps = 1e-5 data = [] dist = ExponentialRV(lmbda) for _ in range(size): x = tuple(dist.sample(shape=[2])) y = np.abs(x[0] - x[1]) theta = (np.sign(x[0] - x[1]), x[1]) y_ = apply(forward_arrow, list(x)) x_ = apply(inverse_arrow, [y, theta[0], theta[1]]) assert np.abs(y_[0] - y) < eps assert np.abs(x_[0] - x[0]) < eps assert np.abs(x_[1] - x[1]) < eps data.append((x, theta, y)) return data, forward_arrow, inverse_arrow
def reparam(comp_arrow: CompositeArrow, phi_shape: Tuple, nn_takes_input=True): """Reparameterize an arrow. All parametric inputs now function of phi Args: comp_arrow: Arrow to reparameterize phi_shape: Shape of parameter input """ reparam = CompositeArrow(name="%s_reparam" % comp_arrow.name) phi = reparam.add_port() set_port_shape(phi, phi_shape) make_in_port(phi) make_param_port(phi) n_in_ports = 1 if nn_takes_input: n_in_ports += comp_arrow.num_in_ports() - comp_arrow.num_param_ports() nn = TfArrow(n_in_ports=n_in_ports, n_out_ports=comp_arrow.num_param_ports()) reparam.add_edge(phi, nn.in_port(0)) i = 0 j = 1 for port in comp_arrow.ports(): if is_param_port(port): reparam.add_edge(nn.out_port(i), port) i += 1 else: re_port = reparam.add_port() if is_out_port(port): make_out_port(re_port) reparam.add_edge(port, re_port) if is_in_port(port): make_in_port(re_port) reparam.add_edge(re_port, port) if nn_takes_input: reparam.add_edge(re_port, nn.in_port(j)) j += 1 if is_error_port(port): make_error_port(re_port) for label in get_port_labels(port): add_port_label(re_port, label) assert reparam.is_wired_correctly() return reparam
def ConcatShuffleArrow(n_inputs: int, ndims: int): """Concatenate n_inputs inputs and shuffle Arrow first n_inputs inputs are arrays to shuffle last input is permutation vector Args: n_inputs""" c = CompositeArrow(name="ConcatShuffle") stack = StackArrow(n_inputs, axis=0) for i in range(n_inputs): in_port = c.add_port() make_in_port(in_port) c.add_edge(in_port, stack.in_port(i)) # Permutation vector perm = c.add_port() make_in_port(perm) set_port_dtype(perm, 'int32') gather = GatherArrow() c.add_edge(stack.out_port(0), gather.in_port(0)) c.add_edge(perm, gather.in_port(1)) # Switch the first and last dimension # FIXME: I do this only because gather seems to affect only first dimension # There's a better way, probably using gather_nd a = [i for i in range(ndims + 1)] tp_perm = [a[i + 1] for i in range(len(a) - 1)] + [a[0]] transpose = TransposeArrow(tp_perm) c.add_edge(gather.out_port(0), transpose.in_port(0)) out = c.add_port() make_out_port(out) c.add_edge(transpose.out_port(0), out) assert c.is_wired_correctly() return c
def set_gan_arrow(arrow: Arrow, cond_gen: Arrow, disc: Arrow, n_fake_samples: int, ndims: int, x_shapes=None, z_shape=None) -> CompositeArrow: """ Arrow wihch computes loss for amortized random variable using set gan. Args: arrow: Forward function cond_gen: Y x Z -> X - Conditional Generators disc: X^n -> {0,1}^n n_fake_samples: n, number of samples seen by discriminator at once ndims: dimensionality of dims Returns: CompositeArrow: X x Z x ... Z x RAND_PERM -> d_Loss x g_Loss x Y x ... Y """ # TODO: Assumes that f has single in_port and single out_port, generalize c = CompositeArrow(name="%s_set_gan" % arrow.name) assert cond_gen.num_in_ports( ) == 2, "don't handle case of more than one Y input" cond_gens = [deepcopy(cond_gen) for i in range(n_fake_samples)] comp_in_ports = [] # Connect x to arrow in puts for i in range(arrow.num_in_ports()): in_port = c.add_port() make_in_port(in_port) comp_in_ports.append(in_port) c.add_edge(in_port, arrow.in_port(i)) if x_shapes is not None: set_port_shape(in_port, x_shapes[i]) # Connect f(x) to generator for i in range(n_fake_samples): for j in range(arrow.num_out_ports()): c.add_edge(arrow.out_port(j), cond_gens[i].in_port(j)) # Connect noise input to generator second inport for i in range(n_fake_samples): noise_in_port = c.add_port() make_in_port(noise_in_port) if z_shape is not None: set_port_shape(noise_in_port, z_shape) cg_noise_in_port_id = cond_gens[i].num_in_ports() - 1 c.add_edge(noise_in_port, cond_gens[i].in_port(cg_noise_in_port_id)) stack_shuffles = [] rand_perm_in_port = c.add_port() make_in_port(rand_perm_in_port) set_port_shape(rand_perm_in_port, (n_fake_samples + 1, )) set_port_dtype(rand_perm_in_port, 'int32') # For every output of g, i.e. x and y if f(x, y) = z # Stack all the Xs from the differet samples together and shuffle cond_gen_non_error_out_ports = cond_gen.num_out_ports( ) - cond_gen.num_error_ports() for i in range(cond_gen_non_error_out_ports): stack_shuffle = ConcatShuffleArrow(n_fake_samples + 1, ndims) stack_shuffles.append(stack_shuffle) # Add each output from generator to shuffle set for j in range(n_fake_samples): c.add_edge(cond_gens[j].out_port(i), stack_shuffle.in_port(j)) # Add the posterior sample x to the shuffle set c.add_edge(comp_in_ports[i], stack_shuffle.in_port(n_fake_samples)) c.add_edge(rand_perm_in_port, stack_shuffle.in_port(n_fake_samples + 1)) gan_loss_arrow = GanLossArrow(n_fake_samples + 1) # Connect output of each stack shuffle to discriminator for i in range(cond_gen_non_error_out_ports): c.add_edge(stack_shuffles[i].out_port(0), disc.in_port(i)) c.add_edge(disc.out_port(0), gan_loss_arrow.in_port(0)) c.add_edge(rand_perm_in_port, gan_loss_arrow.in_port(1)) # Add generator and discriminator loss ports loss_d_port = c.add_port() make_out_port(loss_d_port) make_error_port(loss_d_port) c.add_edge(gan_loss_arrow.out_port(0), loss_d_port) loss_g_port = c.add_port() make_out_port(loss_g_port) make_error_port(loss_g_port) c.add_edge(gan_loss_arrow.out_port(1), loss_g_port) # Connect fake samples to output of composition for i in range(n_fake_samples): for j in range(cond_gen_non_error_out_ports): sample = c.add_port() make_out_port(sample) c.add_edge(cond_gens[i].out_port(j), sample) # Pipe up error ports for cond_gen in cond_gens: for i in range(cond_gen_non_error_out_ports, cond_gen.num_out_ports()): error_port = c.add_port() make_out_port(error_port) c.add_edge(cond_gen.out_port(i), error_port) assert c.is_wired_correctly() return c
def inner_invert(comp_arrow: CompositeArrow, port_attr: PortAttributes, dispatch: Dict[Arrow, Callable]): """Construct a parametric inverse of arrow Args: arrow: Arrow to invert dispatch: Dict mapping arrow class to invert function Returns: A (approximate) parametric inverse of `arrow` The ith in_port of comp_arrow will be corresponding ith out_port error_ports and param_ports will follow""" # Empty compositon for inverse inv_comp_arrow = CompositeArrow(name="%s_inv" % comp_arrow.name) # import pdb; pdb.set_trace() # Add a port on inverse arrow for every port on arrow for port in comp_arrow.ports(): inv_port = inv_comp_arrow.add_port() if is_in_port(port): if is_constant(port, port_attr): make_in_port(inv_port) else: make_out_port(inv_port) elif is_out_port(port): if is_constant(port, port_attr): make_out_port(inv_port) else: make_in_port(inv_port) # Transfer port information # FIXME: What port_attr go transfered from port to inv_port if 'shape' not in port_attr[port]: print('WARNING: shape unknown for %s' % port) else: set_port_shape(inv_port, get_port_shape(port, port_attr)) # invert each sub_arrow arrow_to_inv = dict() arrow_to_port_map = dict() for sub_arrow in comp_arrow.get_sub_arrows(): inv_sub_arrow, port_map = invert_sub_arrow(sub_arrow, port_attr, dispatch) assert sub_arrow is not None assert inv_sub_arrow.parent is None arrow_to_port_map[sub_arrow] = port_map arrow_to_inv[sub_arrow] = inv_sub_arrow # Add comp_arrow to inv assert comp_arrow is not None arrow_to_inv[comp_arrow] = inv_comp_arrow comp_port_map = {i: i for i in range(comp_arrow.num_ports())} arrow_to_port_map[comp_arrow] = comp_port_map # Then, rewire up all the edges for out_port, in_port in comp_arrow.edges.items(): left_inv_port = get_inv_port(out_port, arrow_to_port_map, arrow_to_inv) right_inv_port = get_inv_port(in_port, arrow_to_port_map, arrow_to_inv) both = [left_inv_port, right_inv_port] projecting = list( filter(lambda x: would_project(x, inv_comp_arrow), both)) receiving = list( filter(lambda x: would_receive(x, inv_comp_arrow), both)) assert len(projecting) == 1, "Should be only 1 projecting" assert len(receiving) == 1, "Should be only 1 receiving" inv_comp_arrow.add_edge(projecting[0], receiving[0]) for transform in transforms: transform(inv_comp_arrow) # Craete new ports on inverse compositions for parametric and error ports for sub_arrow in inv_comp_arrow.get_sub_arrows(): for port in sub_arrow.ports(): if is_param_port(port): assert port not in inv_comp_arrow.edges.keys() assert port not in inv_comp_arrow.edges.values() param_port = inv_comp_arrow.add_port() inv_comp_arrow.add_edge(param_port, port) make_in_port(param_port) make_param_port(param_port) elif is_error_port(port): assert port not in inv_comp_arrow.edges.keys() assert port not in inv_comp_arrow.edges.values() error_port = inv_comp_arrow.add_port() inv_comp_arrow.add_edge(port, error_port) make_out_port(error_port) make_error_port(error_port) return inv_comp_arrow, comp_port_map