示例#1
0
def test_siamese_model():

    network_params_siam['load_saved_model'] = True

    test_data_x = None; test_data_y = None; batch_creator = None   
    if network_params_siam['batch_params'] is not None:

        network_params_siam['batch_params']['data_file_indices'] = network_params_siam['batch_params']['test_data_file_indices']

        batch_creator = BatchCreator(network_params_siam['batch_params'])
    else:
        test_data_x, test_data_y = get_data('test')

    siamese_model = SiamesePushModel(sess=get_session(), network_params=network_params_siam)
    siamese_model.init_model(epoch = 560)
    siamese_model.configure_data(data_x=test_data_x, data_y=test_data_y, batch_creator=batch_creator)
    
    # h = siamese_model.run_op('last_hidden',data_x)
    # plt.figure(figsize=(8, 8))
    # plt.plot(h, data_y,'ro', alpha=0.3)
    # plt.show()

    print "Got the data, gonna train the model..."

    epochs = network_params_siam['epochs']
    siamese_model.test(iterations=epochs)
示例#2
0
def train_siamese_model():

    network_params_siam['load_saved_model'] = False

    train_data_x = None; train_data_y = None; batch_creator = None   
    if network_params_siam['batch_params'] is not None:
        batch_creator = BatchCreator(network_params_siam['batch_params'])
    else:
        train_data_x, train_data_y = get_data('train')

    siamese_model = SiamesePushModel(sess=get_session(), network_params=network_params_siam)
    siamese_model.init_model(epoch = 50)
    siamese_model.configure_data(data_x=train_data_x, data_y=train_data_y, batch_creator=batch_creator)
    
    # h = siamese_model.run_op('last_hidden',data_x)
    # plt.figure(figsize=(8, 8))
    # plt.plot(h, data_y,'ro', alpha=0.3)
    # plt.show()

    print "Got the data, gonna train the model..."

    epochs = network_params_siam['epochs']
    siamese_model.train2(iterations=epochs, chk_pnt_save_invl=network_params_siam['check_point_save_interval'])

    siamese_model.save_model('final') #providing subscript for the final model saved, so the model name is final_blah_lah

    if network_params_siam['write_summary']:
        logdir = siamese_model._tf_sumry_wrtr._summary_dir
        instruction = 'tensorboard --logdir=' + logdir
        os.system(instruction)
示例#3
0
def main():
    batch_params = {
                    'buffer_size':45, 
                    'batch_size': 20, 
                    'data_file_indices': range(1,10), 
                    'model_type':'siam', 
                    'use_random_batches':False,
                    'files_per_read':10,
                    'load_pre_processd_data':True}

    batch_creator = BatchCreator(batch_params)

    x_batch1 = None; y_batch1 = None
    x_batch2 = None; y_batch2 = None
    
    #while loops are given since if we make batch_creator multi threaded, this will be 
    #useful
    while x_batch1 is None and y_batch1 is None:
        x_batch1, y_batch1, _ =  batch_creator.get_batch()
    
    while x_batch2 is None and y_batch2 is None:
        x_batch2, y_batch2, _ =  batch_creator.get_batch()

    x_check_list = []
    y_check_list = []
    for itm1, itm2, itm3, itm4 in zip(x_batch1, x_batch2, y_batch1, y_batch2):
        x_check_list.append((itm1==itm2).all())
        y_check_list.append((itm3==itm4).all())

    if any(x_check_list):
        print "Found a match in x_batch1 and x_batch2"
    else:
        print "No matches found in x_batch1 and x_batch2"

    if any(y_check_list):
        print "Found a match in y_batch1 and y_batch2"
    else:
        print "No matches found in y_batch1 and y_batch2"
示例#4
0
def train_fwd_model():

    if network_params_fwd['write_summary']:
        sess = tf.InteractiveSession()
    else:
        sess = tf.Session()

    train_data_x = None
    train_data_y = None
    batch_creator = None
    if network_params_fwd['batch_params'] is not None:
        batch_creator = BatchCreator(network_params_fwd['batch_params'])
    else:
        train_data_x, train_data_y = get_data('train')

    network_params_fwd['load_saved_model'] = False

    forward_model = MDNPushFwdModel(sess=sess,
                                    network_params=network_params_fwd)
    forward_model.init_model()
    forward_model.configure_data(data_x=train_data_x,
                                 data_y=train_data_y,
                                 batch_creator=batch_creator)

    # h = forward_model.run_op('last_hidden',data_x)

    # plt.figure(figsize=(8, 8))
    # plt.plot(h, data_y,'ro', alpha=0.3)
    # plt.show()

    print "Got the data, gonna train the model..."

    epochs = 100  #10000
    loss = forward_model.train(epochs=epochs)
    forward_model.save_model()

    if network_params_fwd['write_summary']:
        logdir = forward_model._tf_sumry_wrtr._summary_dir
        instruction = 'tensorboard --logdir=' + logdir
        os.system(instruction)

    else:
        plt.figure(figsize=(8, 8))
        plt.plot(np.arange(100, epochs, 1), loss[100:], 'r-')
        plt.show()