示例#1
0
def pi_supervised(options):
    """Neural network enhanced Parametric inverse! to do supervised learning"""
    tens = render_gen_graph(options)
    voxels, gdotl_cube, out_img = getn(tens, 'voxels', 'gdotl_cube', 'out_img')

    # Do the inversion
    data_right_inv = tensor_to_sup_right_inv([out_img], options)
    # data_right_inv = right_inv_nnet([out_img], options)

    callbacks = []
    tf.reset_default_graph()
    grabs = ({
        'input':
        lambda p: is_in_port(p) and not is_param_port(p) and
        not has_port_label(p, 'train_output'),
        'supervised_error':
        lambda p: has_port_label(p, 'supervised_error'),
        'sub_arrow_error':
        lambda p: has_port_label(p, 'sub_arrow_error'),
        'inv_fwd_error':
        lambda p: has_port_label(p, 'inv_fwd_error')
    })
    # Not all arrows will have these ports
    optional = ['sub_arrow_error', 'inv_fwd_error']
    tensors = extract_tensors(data_right_inv, grabs=grabs, optional=optional)
    train_voxel_data, test_voxel_data = train_test_model_net_40()
    batch_size = options['batch_size']
    train_generators = infinite_batches(train_voxel_data,
                                        batch_size=batch_size)
    test_generators = infinite_batches(test_voxel_data, batch_size=batch_size)

    def gen_gen(gen):
        while True:
            data = next(gen)
            data = np.reshape(data, (batch_size, -1))
            yield {tensors['input'][0]: data}

    sess = tf.Session()
    num_params = get_tf_num_params(data_right_inv)
    # to_min = ['sub_arrow_error', 'extra', 'supervised_error', 'input', 'inv_fwd_error
    to_min = ['supervised_error']
    losses = {a_min: accum(tensors[a_min]) for a_min in to_min}
    fetch = {'losses': losses}
    train_supervised(sess, losses, [gen_gen(train_generators)],
                     [gen_gen(test_generators)], callbacks, fetch, options)
    print("Number of params", num_params)
示例#2
0
文件: gan.py 项目: llwu/reverseflow
 def train_gen():
     """Generator for x, z and permutation"""
     from wacacore.util.generators import infinite_batches
     from voxel_helpers import model_net_40
     voxel_data = model_net_40()
     x_gen = infinite_batches(voxel_data, batch_size=batch_size)
     while True:
         x = next(x_gen)
         x = x.reshape(batch_size, -1)
         z = np.random.rand(batch_size, 1)
         perm = np.arange(n_samples)
         np.random.shuffle(perm)
         yield {x_ten: x, z_ten: z, perm_ten: perm}
示例#3
0
def gen_scalar_field_adt(train_data,
                         test_data,
                         options,
                         encode_args={'n_steps': 10},
                         field_shape=(8, 8, 8),
                         voxel_grid_shape=(32, 32, 32),
                         batch_size=64,
                         s_args={},
                         decode_args={}):

    extra_fetches = {}

    # Types
    # =====

    # Shape parameters
    sample_space_shape = (16, 16)
    rot_matrix_shape = (1, )

    Field = Type(field_shape, name="Field")
    Rotation = Type(rot_matrix_shape, name="Rotation")
    VoxelGrid = Type(voxel_grid_shape, name="VoxelGrid")
    # SampleSpace = Type(sample_space_shape, name="SampleSpace")
    # Bool = Type((1,), name="Bool")

    # Interfaces
    # ==========

    # A random variable over sample
    # generator = Interface([SampleSpace], [Field], 'generator', tf_interface=generator_net)
    # discriminator = Interface([Field], [Bool], 'discriminator', tf_interface=discriminator_net)
    rotate = Interface([Field, Rotation], [Field],
                       'rotate',
                       tf_interface=rotation_net)

    # Encode 2
    encode_interface = create_encode(field_shape)
    encode = Interface([VoxelGrid], [Field],
                       'encode',
                       tf_interface=encode_interface)

    decode_interface = create_decode(field_shape)
    decode = Interface([VoxelGrid], [Field],
                       'decode',
                       tf_interface=decode_interface)

    interfaces = [encode, decode, rotate]

    # Variables
    # =========
    voxel_grid = ForAllVar(VoxelGrid, "voxel_grid")
    rot_voxel_grid = ForAllVar(VoxelGrid, "rot_voxel_grid")
    rot_matrix = ForAllVar(Rotation, "rotation")
    # sample_space = ForAllVar(SampleSpace, "sample_space")
    add_summary("voxel_input", voxel_grid.input_var)

    forallvars = [
        voxel_grid,
        rot_matrix,
        # sample_space,
    ]

    # Train Generators
    # ================
    train_voxel_gen = infinite_batches(train_data, batch_size, shuffle=True)
    # sample_space_gen = infinite_samples(np.random.randn,
    #                                     (batch_size),
    #                                     sample_space_shape,
    #                                     add_batch=True)
    train_rot_gen = infinite_samples(lambda *shp: np.random.rand(*shp) * 360,
                                     batch_size,
                                     rot_matrix_shape,
                                     add_batch=True)

    def test_train_gen(voxel_gen, rot_gen):
        rot_vgrid = np.zeros((batch_size, 32, 32, 32))
        while True:
            sample_vgrids = next(voxel_gen)
            sample_rot_matrix = next(rot_gen)

            for i, vgrid in enumerate(sample_vgrids):
                rot_vgrid[i] = ndimage.interpolation.rotate(
                    sample_vgrids[i], sample_rot_matrix[i], reshape=False)

            vals = {
                voxel_grid.input_var: sample_vgrids,
                rot_matrix.input_var: sample_rot_matrix,
                rot_voxel_grid.input_var: rot_vgrid
            }

            yield vals

    train_generators = [test_train_gen(train_voxel_gen, train_rot_gen)]

    # Test Generators
    # ================
    test_generators = train_generators

    # Axioms
    # ======
    axioms = []

    # Encode Decode
    (encoded_field, ) = encode(voxel_grid.input_var)
    add_summary("encoded_field", encoded_field)
    extra_fetches["voxel_grid.input_var"] = voxel_grid.input_var

    (decoded_vox_grid, ) = decode(encoded_field)
    add_summary("decoded_vox_grid", decoded_vox_grid)
    extra_fetches["decoded_vox_grid"] = decoded_vox_grid

    axiom_enc_dec = Axiom((decoded_vox_grid, ), (voxel_grid.input_var, ),
                          'enc_dec')
    axioms.append(axiom_enc_dec)
    tf.summary.image("encoded_field", tf.reshape(encoded_field,
                                                 (-1, 16, 16, 1)))

    # rotation axioms
    (rotated, ) = rotate(encoded_field, rot_matrix.input_var)
    (dec_rotated_vgrid, ) = decode(rotated)
    axiom_rotate = Axiom((dec_rotated_vgrid, ), (rot_voxel_grid.input_var, ),
                         'rotate')
    axioms.append(axiom_rotate)

    tf.summary.image("rotate_field", tf.reshape(rotated, (-1, 16, 16, 1)))

    # Losses
    # ======
    #
    # # Other loss terms
    # data_sample = encoded_field
    # losses, adversarial_fetches = adversarial_losses(sample_space,
    #                             data_sample,
    #                             generator,
    #                             discriminator)
    #
    # # Make the encoder help the generator!!
    # losses[0].restrict_to.append(encode)
    # (generated_voxels, ) = decode(adversarial_fetches['generated_field'])
    # extra_fetches['generated_voxels'] = generated_voxels
    # extra_fetches.update(adversarial_fetches)
    #
    # add_summary("generated_field", adversarial_fetches['generated_field'])
    # add_summary("generated_voxels", generated_voxels)
    losses = []

    # Constants
    # =========
    consts = []

    # Data Types
    # ==========
    scalar_field_adt = AbstractDataType(interfaces=interfaces,
                                        consts=consts,
                                        forallvars=forallvars,
                                        axioms=axioms,
                                        losses=losses,
                                        name='scalar_field')

    scalar_field_pbt = ProbDataType(adt=scalar_field_adt,
                                    train_generators=train_generators,
                                    test_generators=test_generators,
                                    train_outs=[])
    return scalar_field_adt, scalar_field_pbt, extra_fetches
示例#4
0
def eqqueue_adt(train_data,
                options,
                eqqueue_shape=(28, 28, 1),
                push_args={},
                pop_args={},
                empty_eqqueue_args={},
                item_shape=(28, 28, 1),
                batch_size=512,
                nitems=3):
    """Construct a eqqueue abstract data type"""
    # Types - a Eqqueue of Item
    # Eqqueue = Type(eqqueue_shape, 'Eqqueue')
    Eqqueue = Type(eqqueue_shape, 'Eqqueue')
    Item = Type(item_shape, 'Item')

    # Interface

    # Push an Item onto a eqqueue to create a new eqqueue
    push = Interface([Eqqueue, Item], [Eqqueue], 'push', **push_args)
    #push = Interface([Eqqueue, Item], [Eqqueue], 'push', **push_args)

    # Pop an Item from a eqqueue, returning a new eqqueue and the item
    pop = Interface([Eqqueue], [Eqqueue, Item], 'pop', **pop_args)
    interfaces = [push, pop]

    # train_outs
    train_outs = []
    gen_to_inputs = identity

    # Consts
    # The empty eqqueue is the eqqueue with no items
    empty_eqqueue = Const(Eqqueue, 'empty_eqqueue', batch_size,
                          **empty_eqqueue_args)
    consts = [empty_eqqueue]

    # Vars
    # eqqueue1 = ForAllVar(Eqqueue)
    items = [ForAllVar(Item, str(i)) for i in range(nitems)]
    forallvars = items

    generators = [
        infinite_batches(train_data, batch_size, shuffle=True)
        for i in range(nitems)
    ]

    # Axioms
    '''
    When push N items onto eqqueue, then pop N item off a eqqueue,
        want to get the N items in the same order that you pushed them.

    '''
    axioms = []
    eqqueue = empty_eqqueue.batch_input_var
    eqqueues = [eqqueue]
    for i in range(nitems):
        orig_eqqueue = eqqueue
        (push_eqqueue, ) = push(
            orig_eqqueue,
            items[i].input_var)  # pushed the item onto the eqqueue
        eqqueues.append(push_eqqueue)
        pop_eqqueue = push_eqqueue

        for j in range(i + 1):
            # Item equivalence
            (pop_eqqueue,
             pop_item) = pop(pop_eqqueue)  # when you pop item from queue
            axiom = Axiom((pop_item, ), (items[j].input_var, ),
                          'item-eq%s-%s' % (i, j))
            axioms.append(axiom)

            # (pop_eqqueue, pop_item) = pop(pop_eqqueue) # when you pop item from eqqueue
            # axiom = Axiom((pop_item,), (items[i].input_var,), 'item-eq%s-%s' %(i, i))
            # axioms.append(axiom)

            # Eqqueue equivalence, Case 1: Orig queue was empty
            if i == j:
                axiom = Axiom((pop_eqqueue, ),
                              (empty_eqqueue.batch_input_var, ),
                              'eqqueue-eq%s-%s' % (i, j))
                axioms.append(axiom)

            # Eqqueue equivalence, Case 2: Orig queue had items
            else:
                (test_pop_eqqueue, test_pop_item) = pop(orig_eqqueue)
                (test_push_eqqueue, ) = push(test_pop_eqqueue,
                                             items[i].input_var)
                axiom = Axiom(
                    (pop_eqqueue, ), (test_push_eqqueue, ),
                    'eqqueue-eq%s-%s' % (i, j)
                )  #queue.push(i)[0].pop()[0] == queue.pop()[0].push(i)[0]
                axioms.append(axiom)

        # Set next queue to support one more item
        eqqueue = push_eqqueue

    #FIXME: Remove train_fn and call_fns from datastructure
    train_fn, call_fns = None, None
    eqqueue_adt = AbstractDataType(interfaces,
                                   consts,
                                   forallvars,
                                   axioms,
                                   name='eqqueue')
    eqqueue_pdt = ProbDataType(eqqueue_adt, train_fn, call_fns, generators,
                               gen_to_inputs, train_outs)
    return eqqueue_adt, eqqueue_pdt
示例#5
0
def gen_queue_adt(train_data,
                  options,
                  queue_shape=(28, 28, 1),
                  enqueue_args={},
                  dequeue_args={},
                  empty_queue_args={},
                  item_shape=(28, 28, 1),
                  batch_size=512,
                  nitems=3):
    """Construct a queue abstract data type"""
    # Types - a Queue of Item
    Queue = Type(queue_shape, 'Queue')
    Item = Type(item_shape, 'Item')

    # Interface

    # Push an Item onto a Queue to create a new queue
    enqueue = Interface([Queue, Item], [Queue], 'enqueue', **enqueue_args)
    # Pop an Item from a queue, returning a new queue and the item
    dequeue = Interface([Queue], [Queue, Item], 'dequeue', **dequeue_args)
    interfaces = [enqueue, dequeue]

    # train_outs
    train_outs = []
    gen_to_inputs = identity

    # Consts
    # The empty queue is the queue with no items
    empty_queue = Const(Queue, 'empty_queue', batch_size, **empty_queue_args)
    consts = [empty_queue]

    # Vars
    # queue1 = ForAllVar(Queue)
    items = [ForAllVar(Item, str(i)) for i in range(nitems)]
    forallvars = items

    generators = [
        infinite_batches(train_data, batch_size, shuffle=True)
        for i in range(nitems)
    ]

    # Axioms
    axioms = []
    queue = empty_queue.batch_input_var
    for i in range(nitems):
        (queue, ) = enqueue(queue, items[i].input_var)
        dequeue_queue = queue
        for j in range(i + 1):
            (dequeue_queue, dequeue_item) = dequeue(dequeue_queue)
            axiom = Axiom((dequeue_item, ), (items[j].input_var, ))
            axioms.append(axiom)
    train_fn, call_fns = compile_fns(interfaces, consts, forallvars, axioms,
                                     train_outs, options)
    queue_adt = AbstractDataType(interfaces,
                                 consts,
                                 forallvars,
                                 axioms,
                                 name='queue')
    queue_pdt = ProbDataType(queue_adt, train_fn, call_fns, generators,
                             gen_to_inputs, train_outs)
    return queue_adt, queue_pdt