def test_siamese_model(): network_params_siam['load_saved_model'] = True test_data_x = None; test_data_y = None; batch_creator = None if network_params_siam['batch_params'] is not None: network_params_siam['batch_params']['data_file_indices'] = network_params_siam['batch_params']['test_data_file_indices'] batch_creator = BatchCreator(network_params_siam['batch_params']) else: test_data_x, test_data_y = get_data('test') siamese_model = SiamesePushModel(sess=get_session(), network_params=network_params_siam) siamese_model.init_model(epoch = 560) siamese_model.configure_data(data_x=test_data_x, data_y=test_data_y, batch_creator=batch_creator) # h = siamese_model.run_op('last_hidden',data_x) # plt.figure(figsize=(8, 8)) # plt.plot(h, data_y,'ro', alpha=0.3) # plt.show() print "Got the data, gonna train the model..." epochs = network_params_siam['epochs'] siamese_model.test(iterations=epochs)
def train_siamese_model(): network_params_siam['load_saved_model'] = False train_data_x = None; train_data_y = None; batch_creator = None if network_params_siam['batch_params'] is not None: batch_creator = BatchCreator(network_params_siam['batch_params']) else: train_data_x, train_data_y = get_data('train') siamese_model = SiamesePushModel(sess=get_session(), network_params=network_params_siam) siamese_model.init_model(epoch = 50) siamese_model.configure_data(data_x=train_data_x, data_y=train_data_y, batch_creator=batch_creator) # h = siamese_model.run_op('last_hidden',data_x) # plt.figure(figsize=(8, 8)) # plt.plot(h, data_y,'ro', alpha=0.3) # plt.show() print "Got the data, gonna train the model..." epochs = network_params_siam['epochs'] siamese_model.train2(iterations=epochs, chk_pnt_save_invl=network_params_siam['check_point_save_interval']) siamese_model.save_model('final') #providing subscript for the final model saved, so the model name is final_blah_lah if network_params_siam['write_summary']: logdir = siamese_model._tf_sumry_wrtr._summary_dir instruction = 'tensorboard --logdir=' + logdir os.system(instruction)
def main(): batch_params = { 'buffer_size':45, 'batch_size': 20, 'data_file_indices': range(1,10), 'model_type':'siam', 'use_random_batches':False, 'files_per_read':10, 'load_pre_processd_data':True} batch_creator = BatchCreator(batch_params) x_batch1 = None; y_batch1 = None x_batch2 = None; y_batch2 = None #while loops are given since if we make batch_creator multi threaded, this will be #useful while x_batch1 is None and y_batch1 is None: x_batch1, y_batch1, _ = batch_creator.get_batch() while x_batch2 is None and y_batch2 is None: x_batch2, y_batch2, _ = batch_creator.get_batch() x_check_list = [] y_check_list = [] for itm1, itm2, itm3, itm4 in zip(x_batch1, x_batch2, y_batch1, y_batch2): x_check_list.append((itm1==itm2).all()) y_check_list.append((itm3==itm4).all()) if any(x_check_list): print "Found a match in x_batch1 and x_batch2" else: print "No matches found in x_batch1 and x_batch2" if any(y_check_list): print "Found a match in y_batch1 and y_batch2" else: print "No matches found in y_batch1 and y_batch2"
def train_fwd_model(): if network_params_fwd['write_summary']: sess = tf.InteractiveSession() else: sess = tf.Session() train_data_x = None train_data_y = None batch_creator = None if network_params_fwd['batch_params'] is not None: batch_creator = BatchCreator(network_params_fwd['batch_params']) else: train_data_x, train_data_y = get_data('train') network_params_fwd['load_saved_model'] = False forward_model = MDNPushFwdModel(sess=sess, network_params=network_params_fwd) forward_model.init_model() forward_model.configure_data(data_x=train_data_x, data_y=train_data_y, batch_creator=batch_creator) # h = forward_model.run_op('last_hidden',data_x) # plt.figure(figsize=(8, 8)) # plt.plot(h, data_y,'ro', alpha=0.3) # plt.show() print "Got the data, gonna train the model..." epochs = 100 #10000 loss = forward_model.train(epochs=epochs) forward_model.save_model() if network_params_fwd['write_summary']: logdir = forward_model._tf_sumry_wrtr._summary_dir instruction = 'tensorboard --logdir=' + logdir os.system(instruction) else: plt.figure(figsize=(8, 8)) plt.plot(np.arange(100, epochs, 1), loss[100:], 'r-') plt.show()