示例#1
0
def produce_plots(experiment_name:str, gan_samples):   

    epochs = pd.read_csv('_epoch_record.csv', header=None).values

    jsd_record = gan_error_all_species(gan_samples , ground_truth, "JSD", vector_solution=True)

    print(jsd_record.shape)

    # JSD boxplot per species
    ds  =jsd_record.flatten()

    plt.figure(figsize=(18.5, 10.5))
    sns.boxplot(ds)
    mess = "; epoch "+str(epochs[-1])
    plt.title("JSD All Species"+mess)

    # Computing the ouliers
    q1 = np.quantile(jsd_record, 0.25)
    q3 = np.quantile(jsd_record, 0.75)

    outliers = jsd_record > (q3 + 1.5*(q3-q1))

    filtered = jsd_record[outliers]  

    outliers_species= np.array(find_indexes(jsd_record, filtered)) + 1

    outlier_mess = "; Outlier Species: " + str(outliers_species)

    #
    plt.xlabel("STD "+str(np.round(ds.std(),2))+outlier_mess)
    plt.ylabel("Mean "+str(np.round(ds.mean(),2 )))

    plt.savefig('n_jsd_boxplot'+".png", dpi=300)
    plt.close()
def train_gan(train_set, indices: List, samples_per_N:int, repetition_n:int, identifier:str,experiment_name:str,  batch_size: int = 256, desired_epochs: int = 2000, use_bot = False):
    """
    The GAN is trained for 1000 epochs. If a a set of 60k samples is trained with a batchsize of 256,
    then a epoch equals 226 iterations. A budget of 100,000 iterations would equals to 426

    """
    assert train_set.shape[0] > len(indices)

    print(train_set.shape)
    print(len(indices))

    my_ds = DataSetManager(train_set[indices])


    # print("Set number of iterations to train\n")
    v5 = (desired_epochs*(train_set[indices].shape[0]))//batch_size +1

    print("ITERS "+str(v5))
    print("SIZE "+str(train_set[indices].shape))


    # print("Use pretrained model? (0 means No, some number different to 0 means yes)\n")
    decision_number = 0 #int( input() )

    # print("Type a name to save the model with?\n")
    model_tag = str(round(samples_per_N)) +'_'+ str(repetition_n)
    

    storing_path = 'data/'+ experiment_name + "/" + model_tag + '_data/'
    model_path = storing_path+ model_tag + '.ckpt'
    
    # Recall that os.mkdir isn't recursive, so it only makes on directoryt at a time
    try:
        # Create target Directory
        os.mkdir(storing_path)
        print("Directory " , storing_path ,  " Created ") 
    except FileExistsError:
        print("Directory " , storing_path ,  " already exists")

    # ===> Auxiliar functions <=== 
    """
    ----------------8<-------------[ cut here ]------------------

    ------------------------------------------------
    """
    def save_history(files_prefix, gen_loss_record,disc_loss_record, jsd_error, current_epoch, epoch_record,my_ds,iter_, epochs, global_iters, BATCH_SIZE, low_lr, high_lr ):
        # Save losses per epoch

        df = pd.DataFrame(np.array(gen_loss_record))
        with open(files_prefix+'_gen_loss.csv', 'w+') as f:
            df.to_csv(f, header=False, index=False)

        df = pd.DataFrame(np.array(disc_loss_record))
        with open(files_prefix+'_disc_loss.csv', 'w+') as f:
            df.to_csv(f, header=False, index=False)

        df = pd.DataFrame(np.array(epoch_record))
        with open(files_prefix+'_epoch_record.csv', 'w+') as f:
            df.to_csv(f, header=False, index=False)

        # Save current iter and epochs

        training_history = {'epochs': [epochs + my_ds.epochs_completed],
                            'iters':  [global_iters + iter_],
                            'Batch Size': [BATCH_SIZE],
                            'low LR': [low_lr],
                            'high LR': [high_lr]}
        df = pd.DataFrame(training_history) 

        with open(files_prefix+'_training.csv', 'w+') as f:
            df.to_csv(f,  index=False) #, header=False, index=False

        with open(files_prefix+'_jsd_error.csv', 'a') as csvFile:
            writer = csv.writer(csvFile)
            writer.writerow([current_epoch, jsd_error])         

    def send_bot_message(bot,my_ds, iter_, ITERS, identifier ):
        """ 
        Not quite straighforward since the critic draws many more samples.

        """

        message = "\nEpochs ["+str(my_ds.epochs_completed)+"] Iter: "+str(iter_)+";\t"+str(np.round(100* iter_/ITERS,2))+"% "
        message = message + identifier
        print(message)
        bot.set_status(message)
        # Send update message
        if bot.verbose:
            bot.send_message(message)                

        print("\n")

    def save_gen_samples(gen_op, disc_op, sess,path,  k, n = 4):
        """
        k: is the number of epochs used to trained the generator
        n: is the number of batches to draw samples
        """

        suffix = '_gen_samples_'+str(k)+'_epochs_'+'.csv'

        for k in range(n):

            samples = sess.run(gen_op)
            df = pd.DataFrame(np.array(samples))
            with open(path+suffix, 'a') as f:
                df.to_csv(f, header=False, index=False)

            # Score the samples using the critic
            scores = sess.run(disc_op)
            df = pd.DataFrame(np.array(scores))
            with open(path+'scores_'+suffix, 'a') as f:
                df.to_csv(f, header=False, index=False)

    # ===> Model Parameters <=== 
    """
    ----------------8<-------------[ cut here ]------------------

    ------------------------------------------------
    """

    DIM = 512  # model dimensionality
    GEN_DIM = 100  # output dimension of the generator
    DIS_DIM = 1  # outptu dimension fo the discriminator
    FIXED_GENERATOR = False  # wheter to hold the generator fixed at ral data plus Gaussian noise, as in the plots in the paper
    LAMBDA = .1  # smaller lambda makes things faster for toy tasks, but isn't necessary if you increase CRITIC_ITERS enough
    BATCH_SIZE = batch_size   # batch size
    ITERS = v5 #100000 # how many generator iterations to train for
    FREQ = 250  # sample frequency
    
    print("==>>Using batch size of "+str(BATCH_SIZE))
    CRITIC_ITERS = 5  # homw many critic iteractions per generator iteration


    def Generator_Softmax(n_samples,  name='gen'):

        with tf.variable_scope(name):
            noise = tf.random_normal([n_samples, GEN_DIM])
            output01 = tf_utils.linear(noise, 2*DIM, name='fc-1')
            output01 = tf_utils.relu(output01, name='relu-1')
            
            output02 = tf_utils.linear(output01, 2*DIM, name='fc-2')
            output02 = tf_utils.relu(output02, name='relu-2')
            
            output03 = tf_utils.linear(output02, 2*DIM, name='fc-3')
            output03 = tf_utils.relu(output03, name='relu-3')

            output04 = tf_utils.linear(output03, GEN_DIM, name='fc-4')

            # Reminder: a logit can be modeled as a linear function of the predictors
            output05 = tf.nn.softmax(output04, name = 'softmax-1')

            return output05
            

    def Discriminator(inputs, is_reuse=True, name='disc'):
        with tf.variable_scope(name, reuse=is_reuse):
            print('is_reuse: {}'.format(is_reuse))
            output01 = tf_utils.linear(inputs, 2*DIM, name='fc-1')
            output01 = tf_utils.relu(output01, name='relu-1')

            output02 = tf_utils.linear(output01, 2*DIM, name='fc-2')
            output02 = tf_utils.relu(output02, name='relu-2')

            output03 = tf_utils.linear(output02, 2*DIM, name='fc-3')
            output03 = tf_utils.relu(output03, name='relu-3')

            output04 = tf_utils.linear(output03, DIS_DIM, name='fc-4')
            
            return output04
        
    real_data = tf.placeholder(tf.float32, shape=[None, GEN_DIM])
    fake_data = Generator_Softmax(BATCH_SIZE)

    disc_real = Discriminator(real_data, is_reuse=False)
    disc_fake = Discriminator(fake_data)

    disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)
    gen_cost = - tf.reduce_mean(disc_fake)

    # WGAN gradient penalty parameters

    alpha = tf.random_uniform(shape=[BATCH_SIZE, 1], minval=0., maxval=1.)
    interpolates = alpha*real_data + (1.-alpha) * fake_data
    disc_interpolates = Discriminator(interpolates)
    gradients = tf.gradients(disc_interpolates, [interpolates][0])
    slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
    gradient_penalty = tf.reduce_mean((slopes - 1)**2)

    disc_cost += LAMBDA * gradient_penalty
        
    disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='disc')
    gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='gen')



    disc_lr = tf.placeholder(tf.float32, shape=()) # 1e-4
    gen_lr = tf.placeholder(tf.float32, shape=()) # 1e-4

    disc_train_op = tf.train.AdamOptimizer(learning_rate=disc_lr, beta1=0.5, beta2=0.9).minimize(disc_cost, var_list=disc_vars)

    if len(gen_vars) > 0:
        gen_train_op = tf.train.AdamOptimizer(learning_rate=gen_lr, beta1=0.5, beta2=0.9).minimize(gen_cost, var_list=gen_vars)
    else:
        gen_train_op = tf.no_op()


    """
    ----------------8<-------------[ cut here ]------------------

    ------------------------------------------------
    """
    # ===> Model Parameters <=== 
 

  
    
    session_saver = tf.train.Saver()

    # files_prefix = 'model/'+ model_tag 

    if decision_number == 0:
        pre_trained  = False


        gen_loss_record = []  # type: List[float]
        disc_loss_record = []  # type: List[float]
        epoch_record = []  # type: List[float]

        epochs = 0
        global_iters = 0

        df = pd.DataFrame(np.array(indices))
        with open(storing_path+'training_indices.csv', 'w+') as f:
            df.to_csv(f, header=False, index=False)


    else:
        pre_trained  = True



        print(storing_path)
        print(storing_path+'training_indices.csv')
        _indices = (pd.read_csv(storing_path+'training_indices.csv',header=None  ).values).tolist()

        
        print(len(_indices))
        print(train_set[indices].shape)
        print(train_set[_indices].squeeze().shape)
        assert train_set[_indices].squeeze().shape ==  train_set[indices].shape
        my_ds = DataSetManager(train_set[_indices].squeeze())

        temp = pd.read_csv(storing_path+'_training.csv',header=None  ).values
        
        epochs, global_iters = temp.flatten()

        my_ds.epochs_completed  = epochs

        gen_loss_record = (pd.read_csv(storing_path+'_gen_loss.csv',header=None  ).values).tolist()
        disc_loss_record = (pd.read_csv(storing_path+'_disc_loss.csv',header=None  ).values).tolist()
        epoch_record = (pd.read_csv(storing_path+'_epoch_record.csv',header=None  ).values).tolist()


        print("State has been restored")




    # Create a DLBot instance

    if use_bot:
        bot = DLBot(token=telegram_token, user_id=telegram_user_id)
        # Activate the bot
        bot.activate_bot()

    print("\nTelegram bot has been activated ")


    iters_per_epoch = my_ds.num_examples/BATCH_SIZE

    total_iters = int(np.ceil((desired_epochs*iters_per_epoch)/CRITIC_ITERS))

    critic_iters = np.round((5/6)*total_iters)
    gen_iters = np.round((1/6)*total_iters)

    
    ITERS = total_iters

    # Train loop
    with tf.Session() as sess:
        
        if pre_trained == False: # false by default:
            sess.run(tf.global_variables_initializer())
        if pre_trained == True:
            
            session_saver.restore(sess,model_path)
        #
        # DUCK TAPE SOLUTION
        iter_ = 0

        """
        while my_ds.epochs_completed < desired_epochs:
            iter_ +=1
        """
        # r=10**-4.72, max_lr=10**-3.72,
        lr_multiplier :int = 1
        low_lr =  10**-5
        high_lr = 10**-4

        lr1 = low_lr # lr_multiplier*low_lr
        lr2 = low_lr #lr_multiplier*high_lr

        gen_lr_ = low_lr # CyclicLR(base_lr= lr1, max_lr= lr2, step_size=gen_iters)
        disc_lr_ = low_lr # CyclicLR(base_lr= lr1, max_lr= lr2, step_size=critic_iters)

        for iter_ in range(ITERS):
            batch_data, disc_cost_ = None, None
            
            previous_epoch =  my_ds.epochs_completed 

            # train critic
            for i_ in range(CRITIC_ITERS):
                batch_data =  my_ds.next_batch(BATCH_SIZE) # data_gen.__next__()
                disc_cost_, _ =  sess.run([disc_cost, disc_train_op], feed_dict={real_data: batch_data, disc_lr:disc_lr_ }) # .clr()
                # disc_lr_.on_batch_end()

            # train generator
            sess.run(gen_train_op, feed_dict={gen_lr : gen_lr_})   #  gen_lr_.clr()
            # gen_lr_.on_batch_end()

            gen_cost2 = sess.run(gen_cost)   

            current_epoch =  my_ds.epochs_completed 

            condition2 = current_epoch % 5 == 0
            if current_epoch > previous_epoch and condition2:
                disc_loss_record.append(disc_cost_)
                gen_loss_record.append(gen_cost2)
                epoch_record.append(my_ds.epochs_completed ) 
                # print("Diff "+str(current_epoch - previous_epoch))

            if (np.mod(iter_, FREQ) == 0) or (iter_+1 == ITERS):
                
                """
                print("===> Debugging")
                print(disc_loss_record)
                print(gen_loss_record)
                """
                if use_bot:
                    bot.loss_hist.append(disc_cost_)

                fake_samples = sess.run(fake_data) # , feed_dict={real_data: batch_data}
                # print("\n==> Sum-Simplex condition: " +str(np.sum(fake_samples, axis=1))) 
                fake_population = np.array([ sess.run(fake_data) for k in range(40)]).reshape(40*batch_size,train_set.shape[1])

                print(fake_population.shape)
                jsd_error = gan_error_all_species(fake_population, k3_test_set)

                print("JSD Error "+str(jsd_error))

                message = "\nEpochs ["+str(my_ds.epochs_completed)+"] Iter: "+str(iter_)+";\t"+str(np.round(100* iter_/ITERS,2))+"% "
                message = message + identifier
                print(message)

                if use_bot:
                    send_bot_message(bot,my_ds, iter_, ITERS, identifier)


                current_epoch = my_ds.epochs_completed

                session_saver.save(sess, model_path)
                save_history(storing_path, gen_loss_record,disc_loss_record, jsd_error, current_epoch, epoch_record, my_ds,iter_, epochs, global_iters, BATCH_SIZE, low_lr, high_lr)

                
                # save_gen_samples(fake_data, disc_fake ,sess, storing_path, k) # fake_data = Generator_Softmax(BATCH_SIZE)
                

            utils.tick()  #  _iter[0] += 1

        if iter_ == ITERS:
            session_saver.save(sess, model_path)
        
        # Create gan samples
        n_samples = len(indices)

        k_iter = n_samples//BATCH_SIZE +1

        gan_samples_path = storing_path+"gan_samples_" +model_tag+'.csv'

        for k in range(k_iter):
            fake_samples = sess.run(fake_data)

            df = pd.DataFrame(fake_samples)
            with open(gan_samples_path, 'a') as f:
                df.to_csv(f, header=False, index=False)

    # Clear variables valuies

    tf.reset_default_graph()

    current_epoch = my_ds.epochs_completed
    save_history(storing_path, gen_loss_record,disc_loss_record, jsd_error, current_epoch, epoch_record, my_ds,iter_, epochs, global_iters, BATCH_SIZE, low_lr, high_lr)   
    if use_bot:
        bot.stop_bot()

    print("Training is done")

    # Duct tapping the size of gan sample set to avoid changing the TF Graph

    temp1 = pd.read_csv(gan_samples_path, header=None).values
    temp1 = temp1[0:n_samples]
    df = pd.DataFrame(temp1)

    with open(gan_samples_path, 'w+') as f:
        df.to_csv(f, header=False, index=False)


    print("Training is done")