Пример #1
0
def make_batch_generators(mode, X_train = None, X_test = None, batch_size=8, siamese = False ):
	batch_gen_test = None

	if not siamese:
		batch_gen_train = batch_utils.gen_batch( X_train, batch_size, augment=mode=='train', randomize=mode=='train' )  
	
		if X_test is not None:
			batch_gen_test = batch_utils.gen_batch( X_test, batch_size, augment=mode=='train', randomize = False )  
	
	else:
		batch_gen_train = batch_utils.gen_siamese_batch( X_train, batch_size, augment=mode=='train', write_examples = False, randomize = mode=='train' )
		if X_test is not None:
			batch_gen_test = batch_utils.gen_siamese_batch( X_test, batch_size, augment=mode=='train', write_examples = False, randomize = False )
	return batch_gen_train, batch_gen_test 
Пример #2
0
def train_model(training_data=None,
                n_levels=3,
                n_nodes=[10, 20, 40],
                input_dim=100,
                n_epochs=25,
                batch_size=64,
                n_batch_per_epoch=100,
                d_iters=20,
                lr_discriminator=0.005,
                lr_generator=0.00005,
                weight_constraint=[-0.01, 0.01],
                rule='mgd',
                verbose=True):
    """
    Train the hierarchical model.

    Progressively generate trees with
    more and more nodes.

    Parameters
    ----------
    training_data: dict of dicts
        each inner dict is an array
        'geometry': 3-d arrays (locations)
            n_samples x n_nodes - 1 x 3
        'morphology': 2-d arrays
            n_samples x n_nodes - 2 (prufer sequences)
        example: training_data['geometry']['n20'][0:10, :, :]
                 gives the geometry for the first 10 neurons
                 training_data['geometry']['n20'][0:10, :]
                 gives the prufer sequences for the first 10 neurons
                 here, 'n20' indexes a key corresponding to
                 20-node downsampled neurons.
    n_levels: int
        number of levels in the hierarchy
    n_nodes: list of length n_levels
        specifies the number of nodes for each level.
        should be consistent with training data.
    input_dim: int
        dimensionality of noise input
    n_epochs:
        number of epochs over training data
    batch_size:
        batch size
    n_batch_per_epoch: int
        number of batches per epoch
    d_iters: int
        number of iterations to train discriminator
    lr_discriminator: float
        learning rate for optimization of discriminator
    lr_generator: float
        learning rate for optimization of generator
    weight_constraint: array
        upper and lower bounds of weights (to clip)
    verbose: bool
        print relevant progress throughout training

    Returns
    -------
    geom_model: list of keras model objects
        geometry generators for each level
    cond_geom_model: list of keras model objects
        conditional geometry generators for each level
    morph_model: list of keras model objects
        morphology generators for each level
    cond_morph_model: list of keras model objects
        conditional morphology generators for each level
    disc_model: list of keras model objects
        discriminators for each level
    gan_model: list of keras model objects
        discriminators stacked on generators for each level
    """
    # ###################################
    # Initialize models at all levels
    # ###################################
    geom_model = list()
    cond_geom_model = list()
    morph_model = list()
    cond_morph_model = list()
    disc_model = list()
    gan_model = list()

    for level in range(n_levels):
        # Discriminator
        d_model = models.discriminator(n_nodes_in=n_nodes[level])

        # Generators and GANs
        # If we are in the first level, no context
        if level == 0:
            g_model, cg_model, m_model, cm_model = \
                models.generator(use_context=False,
                                 n_nodes_in=n_nodes[level-1],
                                 n_nodes_out=n_nodes[level])
            mgd_model = \
                models.discriminator_on_generators(g_model,
                                                   cg_model,
                                                   m_model,
                                                   cm_model,
                                                   d_model,
                                                   conditioning_rule=rule,
                                                   input_dim=input_dim,
                                                   n_nodes_in=n_nodes[level-1],
                                                   n_nodes_out=n_nodes[level],
                                                   use_context=False)
        # In subsequent levels, we need context
        else:
            g_model, cg_model, m_model, cm_model = \
                models.generator(use_context=True,
                                 n_nodes_in=n_nodes[level-1],
                                 n_nodes_out=n_nodes[level])
            mgd_model = \
                models.discriminator_on_generators(g_model,
                                                   cg_model,
                                                   m_model,
                                                   cm_model,
                                                   d_model,
                                                   conditioning_rule=rule,
                                                   input_dim=input_dim,
                                                   n_nodes_in=n_nodes[level-1],
                                                   n_nodes_out=n_nodes[level],
                                                   use_context=True)

        # Collect all models into a list
        disc_model.append(d_model)
        geom_model.append(g_model)
        cond_geom_model.append(cg_model)
        morph_model.append(m_model)
        cond_morph_model.append(cm_model)
        gan_model.append(mgd_model)

    # ###############
    # Optimizers
    # ###############
    optim_d = Adagrad()  # RMSprop(lr=lr_discriminator)
    optim_g = Adagrad()  # RMSprop(lr=lr_generator)

    # ##############
    # Train
    # ##############
    for level in range(n_levels):
        # ---------------
        # Compile models
        # ---------------
        g_model = geom_model[level]
        m_model = morph_model[level]
        d_model = disc_model[level]
        mgd_model = gan_model[level]

        g_model.compile(loss='mse', optimizer=optim_g)
        m_model.compile(loss='mse', optimizer=optim_g)

        d_model.trainable = False
        mgd_model.compile(loss=models.wasserstein_loss, optimizer=optim_g)

        d_model.trainable = True
        d_model.compile(loss=models.wasserstein_loss, optimizer=optim_d)

        if verbose:
            print("")
            print(20 * "=")
            print("Level #{0}".format(level))
            print(20 * "=")
        # -----------------
        # Loop over epochs
        # -----------------
        for e in range(n_epochs):
            batch_counter = 1
            g_iters = 0

            if verbose:
                print("")
                print("    Epoch #{0}".format(e))
                print("")

            while batch_counter < n_batch_per_epoch:
                list_d_loss = list()
                list_g_loss = list()
                # ----------------------------
                # Step 1: Train discriminator
                # ----------------------------
                for d_iter in range(d_iters):

                    # Clip discriminator weights
                    d_model = clip_weights(d_model, weight_constraint)

                    # Create a batch to feed the discriminator model
                    X_locations_real, X_prufer_real = \
                        batch_utils.get_batch(training_data=training_data,
                                              batch_size=batch_size,
                                              batch_counter=batch_counter,
                                              n_nodes=n_nodes[level])
                    y_real = -np.ones((X_locations_real.shape[0], 1, 1))

                    #print X_locations_real.shape, X_prufer_real.shape, y_real.shape

                    X_locations_gen, X_prufer_gen = \
                        batch_utils.gen_batch(batch_size=batch_size,
                                              n_nodes=n_nodes,
                                              level=level,
                                              input_dim=input_dim,
                                              geom_model=geom_model,
                                              cond_geom_model=cond_geom_model,
                                              morph_model=morph_model,
                                              cond_morph_model=cond_morph_model,
                                              conditioning_rule=rule)
                    y_gen = np.ones((X_locations_gen.shape[0], 1, 1))

                    #print X_locations_gen.shape, X_prufer_gen.shape, y_gen.shape

                    X_locations = np.concatenate(
                        (X_locations_real, X_locations_gen), axis=0)

                    X_prufer = np.concatenate((X_prufer_real, X_prufer_gen),
                                              axis=0)

                    y = np.concatenate((y_real, y_gen), axis=0)

                    # Update the discriminator
                    #d_model.summary()
                    disc_loss = \
                        d_model.train_on_batch([X_locations,
                                                X_prufer],
                                               y)

                    list_d_loss.append(disc_loss)

                if verbose:
                    print("    After {0} iterations".format(d_iters))
                    print("        Discriminator Loss \
                        = {0}".format(disc_loss))

                # -------------------------
                # Step 2: Train generators
                # -------------------------
                # Freeze the discriminator
                d_model.trainable = False

                if level > 0:
                    X_locations_prior_gen, X_prufer_prior_gen = \
                        batch_utils.gen_batch(batch_size=batch_size,
                                              n_nodes=n_nodes,
                                              level=level-1,
                                              input_dim=input_dim,
                                              geom_model=geom_model,
                                              cond_geom_model=cond_geom_model,
                                              morph_model=morph_model,
                                              cond_morph_model=cond_morph_model,
                                              conditioning_rule=rule)

                noise_input = np.random.randn(batch_size, 1, input_dim)

                if level == 0:
                    gen_loss = \
                        mgd_model.train_on_batch([noise_input],
                                                 y_real)
                else:
                    gen_loss = \
                        mgd_model.train_on_batch([X_locations_prior_gen,
                                                  X_prufer_prior_gen,
                                                  noise_input],
                                                 y_real)

                list_g_loss.append(gen_loss)
                if verbose:
                    print("")
                    print("    Generator_Loss: {0}".format(gen_loss))

                # Unfreeze the discriminator
                d_model.trainable = True

                # ---------------------
                # Step 3: Housekeeping
                # ---------------------
                g_iters += 1
                batch_counter += 1

                # Save model weights (few times per epoch)
                print(batch_counter)
                if batch_counter % 25 == 0:
                    #save_model_weights(g_model,
                    #                   m_model,
                    #                   level,
                    #                   epoch,
                    #                   batch_counter)
                    if verbose:
                        print("     Level #{0} Epoch #{1} Batch #{2}".format(
                            level, e, batch_counter))

                        neuron_object = \
                            plot_example_neuron(X_locations_gen[0, :, :],
                                                X_prufer_gen[0, :, :])
                        plt.show()
                # Display loss trace
                if 0:
                    plt.figure(figsize=(3, 2))
                    plt.plot(list_d_loss)
                    plt.show()

                # Save models
                geom_model[level] = g_model
                cond_geom_model[level] = cg_model
                morph_model[level] = m_model
                cond_morph_model[level] = cm_model
                disc_model[level] = d_model
                gan_model[level] = mgd_model

    return geom_model, \
        cond_geom_model, \
        morph_model, \
        cond_morph_model, \
        disc_model, \
        gan_model
Пример #3
0
def train_model(training_data=None,
                n_nodes=20,
                input_dim=100,
                n_epochs=25,
                batch_size=32,
                n_batch_per_epoch=100,
                d_iters=20,
                lr_discriminator=0.005,
                lr_generator=0.00005,
                d_weight_constraint=[-.03, .03],
                g_weight_constraint=[-.03, .03],
                m_weight_constraint=[-.03, .03],
                rule='none',
                train_loss='wasserstein_loss',
                verbose=True):
    """
    Train the hierarchical model.

    Progressively generate trees with
    more and more nodes.

    Parameters
    ----------
    training_data: dict of dicts
        each inner dict is an array
        'geometry': 3-d arrays (locations)
            n_samples x n_nodes - 1 x 3
        'morphology': 2-d arrays
            n_samples x n_nodes - 1 (parent sequences)
        example: training_data['geometry']['n20'][0:10, :, :]
                 gives the geometry for the first 10 neurons
                 training_data['geometry']['n20'][0:10, :]
                 gives the parent sequences for the first 10 neurons
                 here, 'n20' indexes a key corresponding to
                 20-node downsampled neurons.
    n_nodes: array
        specifies the number of nodes.
    input_dim: int
        dimensionality of noise input
    n_epochs:
        number of epochs over training data
    batch_size:
        batch size
    n_batch_per_epoch: int
        number of batches per epoch
    d_iters: int
        number of iterations to train discriminator
    lr_discriminator: float
        learning rate for optimization of discriminator
    lr_generator: float
        learning rate for optimization of generator
    weight_constraint: array
        upper and lower bounds of weights (to clip)
    verbose: bool
        print relevant progress throughout training

    Returns
    -------
    geom_model: list of keras model objects
        geometry generators
    morph_model: list of keras model objects
        morphology generators
    disc_model: list of keras model objects
        discriminators
    gan_model: list of keras model objects
        discriminators stacked on generators
    """
    # ###################################
    # Initialize models
    # ###################################
    geom_model = list()
    morph_model = list()
    disc_model = list()
    gan_model = list()

    # Discriminator
    d_model = models.discriminator(n_nodes=n_nodes,
                                   batch_size=batch_size,
                                   train_loss=train_loss)
    # Generators and GANs
    g_model, m_model = \
        models.generator(n_nodes=n_nodes,
                         batch_size=batch_size)
    stacked_model = \
        models.discriminator_on_generators(g_model,
                                           m_model,
                                           d_model,
                                           conditioning_rule=rule,
                                           input_dim=input_dim,
                                           n_nodes=n_nodes)

    # Collect all models into a list
    disc_model.append(d_model)
    geom_model.append(g_model)
    morph_model.append(m_model)
    gan_model.append(stacked_model)

    # ###############
    # Optimizers
    # ###############
    optim_d = Adagrad()  # RMSprop(lr=lr_discriminator)
    optim_g = Adagrad()  # RMSprop(lr=lr_generator)

    # ##############
    # Train
    # ##############
    # ---------------
    # Compile models
    # ---------------

    g_model.compile(loss='mse', optimizer=optim_g)
    m_model.compile(loss='mse', optimizer=optim_g)

    d_model.trainable = False
    if train_loss == 'wasserstein_loss':
        stacked_model.compile(loss=models.wasserstein_loss, optimizer=optim_g)
    else:
        stacked_model.compile(loss='binary_crossentropy', optimizer=optim_g)

    d_model.trainable = True

    if train_loss == 'wasserstein_loss':
        d_model.compile(loss=models.wasserstein_loss, optimizer=optim_d)
    else:
        d_model.compile(loss='binary_crossentropy', optimizer=optim_d)

    if verbose:
        print("")
        print(20 * "=")
    # -----------------
    # Loop over epochs
    # -----------------
    for e in range(n_epochs):
        batch_counter = 1
        g_iters = 0

        if verbose:
            print("")
            print("Epoch #{0}".format(e))
            print("")

        while batch_counter < n_batch_per_epoch:
            list_d_loss = list()
            list_g_loss = list()
            # ----------------------------
            # Step 1: Train discriminator
            # ----------------------------
            for d_iter in range(d_iters):

                # Clip discriminator weights
                d_model = clip_weights(d_model, d_weight_constraint)

                # Create a batch to feed the discriminator model
                select = range((batch_counter - 1) * batch_size,
                               batch_counter * batch_size)
                X_locations_real = \
                    training_data['geometry']['n'+str(n_nodes)][select, :, :]
                X_locations_real = np.reshape(X_locations_real,
                                              [batch_size, (n_nodes - 1), 3])
                X_parent_cut = \
                    np.reshape(training_data['morphology']['n'+str(n_nodes)][select, :],
                               [1, (n_nodes - 1) * batch_size])
                X_parent_real = \
                    batch_utils.get_batch(X_parent_cut=X_parent_cut,
                                          batch_size=batch_size,
                                          n_nodes=n_nodes)

                if train_loss == 'wasserstein_loss':
                    y_real = -np.ones((X_locations_real.shape[0], 1, 1))
                else:
                    y_real = np.ones((X_locations_real.shape[0], 1, 1))

                X_locations_gen, X_parent_gen = \
                    batch_utils.gen_batch(batch_size=batch_size,
                                           n_nodes=n_nodes,
                                           input_dim=input_dim,
                                           geom_model=g_model,
                                           morph_model=m_model,
                                           conditioning_rule=rule)

                if train_loss == 'wasserstein_loss':
                    y_gen = np.ones((X_locations_gen.shape[0], 1, 1))
                else:
                    y_gen = np.zeros((X_locations_gen.shape[0], 1, 1))
                # make data in half of real and generated
                cutting = int(batch_size / 2)
                X_locations_real_first_half = np.append(
                    X_locations_real[:cutting, :, :],
                    X_locations_gen[:cutting, :, :],
                    axis=0)
                X_parent_real_first_half = np.append(
                    X_parent_real[:cutting, :, :],
                    X_parent_gen[:cutting, :, :],
                    axis=0)
                y_real_first_half = np.append(y_real[:cutting, :, :],
                                              y_gen[:cutting, :, :],
                                              axis=0)

                X_locations_real_second_half = np.append(
                    X_locations_real[cutting:, :, :],
                    X_locations_gen[cutting:, :, :],
                    axis=0)
                X_parent_real_second_half = np.append(
                    X_parent_real[cutting:, :, :],
                    X_parent_real[cutting:, :, :],
                    axis=0)
                y_real_second_half = np.append(y_real[cutting:, :, :],
                                               y_gen[cutting:, :, :],
                                               axis=0)
                # Update the discriminator
                disc_loss = \
                    d_model.train_on_batch([X_locations_real_first_half,
                                            X_parent_real_first_half],
                                            y_real_first_half)
                list_d_loss.append(disc_loss)
                disc_loss = \
                    d_model.train_on_batch([X_locations_real_second_half,
                                            X_parent_real_second_half],
                                            y_real_second_half)
                list_d_loss.append(disc_loss)

            if verbose:
                print("    After {0} iterations".format(d_iters))
                print("        Discriminator Loss \
                    = {0}".format(disc_loss))

            # -------------------------------------
            # Step 2: Train generators alternately
            # -------------------------------------
            # Freeze the discriminator
            d_model.trainable = False

            noise_input = np.random.rand(batch_size, 1, input_dim)

            gen_loss = \
                stacked_model.train_on_batch([noise_input],
                                             y_real)
            # Clip generator weights
            g_model = clip_weights(g_model, g_weight_constraint)
            m_model = clip_weights(m_model, m_weight_constraint)

            list_g_loss.append(gen_loss)
            if verbose:
                print("")
                print("    Generator_Loss: {0}".format(gen_loss))

            # Unfreeze the discriminator
            d_model.trainable = True

            # ---------------------
            # Step 3: Housekeeping
            # ---------------------
            g_iters += 1
            batch_counter += 1

            # Save model weights (few times per epoch)
            print(batch_counter)
            if batch_counter % 2 == 0:
                #m_model = clip_weights(m_model, m_weight_constraint)
                #g_model = clip_weights(g_model, g_weight_constraint)
                if verbose:
                    print("     Level #{0} Epoch #{1} Batch #{2}".format(
                        1, e, batch_counter))

                    neuron_object = \
                        plot_utils.plot_example_neuron_from_parent(
                            X_locations_gen[0, :, :],
                            X_parent_gen[0, :, :])
                    plt.plot(np.squeeze(X_locations_gen[0, :, :]))

                    plot_utils.plot_adjacency(X_parent_real[0:2, :, :],
                                              X_parent_gen[0:2, :, :])

                    # Display loss trace
                    #if verbose:
                    plot_utils.plot_loss_trace(list_d_loss)

                    # save the models
                    save_model_weights(g_model,
                                       m_model,
                                       d_model,
                                       0,
                                       e,
                                       batch_counter,
                                       list_d_loss,
                                       model_path_root='../model_weights')

            #  Save models
            geom_model = g_model
            morph_model = m_model
            disc_model = d_model
            gan_model = stacked_model

    return geom_model, \
        morph_model, \
        disc_model, \
        gan_model