示例#1
0
 def valid(self,):
     print 'Performing validation.'
     model = RNN_RNADE(self.state['n_visible'],self.state['n_hidden'],self.state['n_recurrent'],self.state['n_components'],hidden_act=self.state['hidden_act'],
             l2=self.state['l2'],rec_mu=self.state['rec_mu'],rec_mix=self.state['rec_mix'],rec_sigma=self.state['rec_sigma'],load=False,load_dir=self.output_folder)
     #model.params = self.params
     model.load_model(self.output_folder,'best_params_train.pickle')
     num_test_sequences = 1
     batch_size = 100
     num_samples = 1
     error = []
     for i in xrange(num_test_sequences):
         seq = b.bounce_vec(15,n=3,T=128) 
         samples = model.sample_given_sequence(seq,num_samples)
         #make sure samples are between 0 and 1
         samples = numpy.minimum(samples,1.)
         samples = numpy.maximum(samples,0.)
         sq_diff = (samples - seq)**2
         sq_diff = sq_diff.mean(axis=0)
         sq_diff = sq_diff.sum(axis=1)
         seq_error = sq_diff.mean(axis=0)
         error.append(seq_error)
     total_error = numpy.mean(error)
     print 'Validation error: ',total_error
     self.valid_costs.append(total_error)
     if total_error < self.best_valid_cost:
         print 'Best validation params!'
         self.best_valid_cost = total_error
         self.save_model('best_params_valid.pickle')
示例#2
0
def generate_bouncing_ball_sample(batch_size, seq_length, shape, num_balls):
    """
    Creates sequences of bouncing balls.
    Args:
        batch_size: integer, number of sequences to generate
        seq_length: number of frames to generate for each sequence
        shape: [m, n, k] list where m and n are the frame height and width,
               and k is the number of channels ... *** note: m must = n!
        num_balls: number of balls to generate
    """
    dat = np.zeros((batch_size, seq_length, shape[0], shape[1], 3))
    for i in xrange(batch_size):
        dat[i, :, :, :, :] = b.bounce_vec(shape[0], num_balls, seq_length)
    if shape[2] == 1:
        rgb_dat = np.zeros((batch_size, seq_length, shape[0], shape[1], 1))
        rgb_dat[:, :, :, :, 0] = np.dot(dat[..., :3],
                                        [0.299, 0.587, 0.114])  #make grayscale
        return rgb_dat
    return dat
示例#3
0
 def valid(self, ):
     print 'Performing validation.'
     model = RNN_RNADE(self.state['n_visible'],
                       self.state['n_hidden'],
                       self.state['n_recurrent'],
                       self.state['n_components'],
                       hidden_act=self.state['hidden_act'],
                       l2=self.state['l2'],
                       rec_mu=self.state['rec_mu'],
                       rec_mix=self.state['rec_mix'],
                       rec_sigma=self.state['rec_sigma'],
                       load=False,
                       load_dir=self.output_folder)
     #model.params = self.params
     model.load_model(self.output_folder, 'best_params_train.pickle')
     num_test_sequences = 1
     batch_size = 100
     num_samples = 1
     error = []
     for i in xrange(num_test_sequences):
         seq = b.bounce_vec(15, n=3, T=128)
         samples = model.sample_given_sequence(seq, num_samples)
         #make sure samples are between 0 and 1
         samples = numpy.minimum(samples, 1.)
         samples = numpy.maximum(samples, 0.)
         sq_diff = (samples - seq)**2
         sq_diff = sq_diff.mean(axis=0)
         sq_diff = sq_diff.sum(axis=1)
         seq_error = sq_diff.mean(axis=0)
         error.append(seq_error)
     total_error = numpy.mean(error)
     print 'Validation error: ', total_error
     self.valid_costs.append(total_error)
     if total_error < self.best_valid_cost:
         print 'Best validation params!'
         self.best_valid_cost = total_error
         self.save_model('best_params_valid.pickle')
示例#4
0
def generate_bouncing_ball_sample(batch_size, seq_length, shape, num_balls):
    dat = np.zeros((batch_size, seq_length, shape, shape, 3))
    for i in range(batch_size):
        dat[i, :, :, :, :] = b.bounce_vec(32, num_balls, seq_length)
    return dat
def generate_bouncing_ball_sample(batch_size, seq_length, shape, num_balls):
  dat = np.zeros((batch_size, seq_length, shape, shape, 3))
  for i in range(batch_size):
    dat[i, :, :, :, :] = b.bounce_vec(64, num_balls, seq_length)
  return dat
示例#6
0
def train():
  """Train ring_net for a number of steps."""
  with tf.Graph().as_default():
    # make inputs
    x = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 1])

    # possible dropout inside (default is 1.0)
    keep_prob = tf.placeholder("float")

    # make model
    if FLAGS.model=="fully_connected":
      mean, stddev, y_sampled, x_prime = arc.fully_connected_model(x, keep_prob)
    elif FLAGS.model=="conv":
      mean, stddev, y_sampled, x_prime = arc.conv_model(x, keep_prob)
    elif FLAGS.model=="all_conv":
      mean, stddev, y_sampled, x_prime = arc.all_conv_model(x, keep_prob)
    else:
      print("model requested not found, now some error!")

    # calc loss stuff
    loss_vae, loss_reconstruction, loss, train_op = ls.loss(mean, stddev, x, x_prime)

    # List of all Variables
    variables = tf.all_variables()

    # Build a saver
    saver = tf.train.Saver(tf.all_variables())   

    # Summary op
    summary_op = tf.merge_all_summaries()
 
    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    sess = tf.Session()

    # init if this is the very time training
    print("init network from scratch")
    sess.run(init)

    # Summary op
    graph_def = sess.graph.as_graph_def(add_shapes=True)
    summary_writer = tf.train.SummaryWriter(train_dir_save, graph_def=graph_def)

    for step in xrange(FLAGS.max_step):
      dat = b.bounce_vec(32, FLAGS.num_balls, FLAGS.batch_size)
      t = time.time()
      _, loss_r = sess.run([train_op, loss],feed_dict={x:dat, keep_prob:FLAGS.keep_prob})
      elapsed = time.time() - t

      if step%2000 == 0:
        _ , loss_vae_r, loss_reconstruction_r, y_sampled_r, x_prime_r, stddev_r = sess.run([train_op, loss_vae, loss_reconstruction, y_sampled, x_prime, stddev],feed_dict={x:dat, keep_prob:FLAGS.keep_prob})
        summary_str = sess.run(summary_op, feed_dict={x:dat, keep_prob:FLAGS.keep_prob})
        summary_writer.add_summary(summary_str, step) 
        print("loss vae value at " + str(loss_vae_r))
        print("loss reconstruction value at " + str(loss_reconstruction_r))
        print("time per batch is " + str(elapsed))
        cv2.imwrite("real_balls.jpg", np.uint8(dat[0, :, :, :]*255))
        cv2.imwrite("generated_balls.jpg", np.uint8(x_prime_r[0, :, :, :]*255))
        stddev_r = np.sort(np.sum(stddev_r, axis=0))
        plt.plot(stddev_r/FLAGS.batch_size, label="step " + str(step))
        plt.legend()
        plt.savefig('stddev_num_balls_' + str(FLAGS.num_balls) + '_beta_' + str(FLAGS.beta) + '.png')
      
      assert not np.isnan(loss_r), 'Model diverged with loss = NaN'

      if step%1000 == 0:
        checkpoint_path = os.path.join(train_dir_save, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)  
        print("saved to " + train_dir_save)
        print("step " + str(step))
def train():
    """Train ring_net for a number of steps."""
    with tf.Graph().as_default():
        # make inputs
        x = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3])

        # possible dropout inside
        keep_prob = tf.placeholder("float")

        # create network
        # encodeing part first
        # conv1
        conv1 = ld.conv_layer(x, 3, 2, 8, "encode_1")
        # conv2
        conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2")
        # conv3
        conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3")
        # conv4
        conv4 = ld.conv_layer(conv3, 1, 1, 4, "encode_4")
        # fc5
        fc5 = ld.fc_layer(conv4, 128, "encode_5", True, False)
        # dropout maybe
        fc5_dropout = tf.nn.dropout(fc5, keep_prob)
        # y
        y = ld.fc_layer(fc5_dropout, (FLAGS.hidden_size) * 2, "encode_6",
                        False, True)
        mean, stddev = tf.split(1, 2, y)
        stddev = tf.sqrt(tf.exp(stddev))
        # now decoding part
        # sample distrobution
        epsilon = tf.random_normal(mean.get_shape())
        y_sampled = mean + epsilon * stddev
        # fc7
        fc7 = ld.fc_layer(y_sampled, 128, "decode_7", False, False)
        # fc8
        fc8 = ld.fc_layer(fc7, 4 * 8 * 8, "decode_8", False, False)
        conv9 = tf.reshape(fc8, [-1, 8, 8, 4])
        # conv10
        conv10 = ld.transpose_conv_layer(conv9, 1, 1, 8, "decode_9")
        # conv11
        conv11 = ld.transpose_conv_layer(conv10, 3, 2, 8, "decode_10")
        # conv12
        conv12 = ld.transpose_conv_layer(conv11, 3, 1, 8, "decode_11")
        # conv13
        conv13 = ld.transpose_conv_layer(conv12, 3, 2, 3, "decode_12", True)
        # x_prime
        x_prime = conv13
        x_prime = tf.nn.sigmoid(x_prime)

        # now calc loss
        epsilon = 1e-8
        # calc loss from vae
        kl_loss = 0.5 * (tf.square(mean) + tf.square(stddev) -
                         2.0 * tf.log(stddev + epsilon) - 1.0)
        loss_vae = FLAGS.beta * tf.reduce_sum(kl_loss)
        # log loss for reconstruction
        loss_reconstruction = tf.reduce_sum(-x * tf.log(x_prime + epsilon) -
                                            (1.0 - x) *
                                            tf.log(1.0 - x_prime + epsilon))
        # save for tensorboard
        tf.scalar_summary('loss_vae', loss_vae)
        tf.scalar_summary('loss_reconstruction', loss_reconstruction)
        # calc total loss
        loss = tf.reduce_sum(loss_vae + loss_reconstruction)

        # training
        train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss)

        # List of all Variables
        variables = tf.all_variables()

        # Build a saver
        saver = tf.train.Saver(tf.all_variables())

        # Summary op
        summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session()

        # init if this is the very time training
        print("init network from scratch")
        sess.run(init)

        # Summary op
        graph_def = sess.graph.as_graph_def(add_shapes=True)
        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                                graph_def=graph_def)

        for step in xrange(FLAGS.max_step):
            dat = b.bounce_vec(32, FLAGS.num_balls, FLAGS.batch_size)
            t = time.time()
            _, loss_r = sess.run([train_op, loss],
                                 feed_dict={
                                     x: dat,
                                     keep_prob: FLAGS.keep_prob
                                 })
            elapsed = time.time() - t
            #print(elapsed)

            if step % 500 == 0:
                _, loss_vae_r, loss_reconstruction_r, y_sampled_r, x_prime_r, kl_loss_dis, stddev_r = sess.run(
                    [
                        train_op, loss_vae, loss_reconstruction, y_sampled,
                        x_prime, kl_loss, stddev
                    ],
                    feed_dict={
                        x: dat,
                        keep_prob: FLAGS.keep_prob
                    })
                summary_str = sess.run(summary_op,
                                       feed_dict={
                                           x: dat,
                                           keep_prob: FLAGS.keep_prob
                                       })
                summary_writer.add_summary(summary_str, step)
                print("loss vae value at " + str(loss_vae_r))
                print("loss reconstruction value at " +
                      str(loss_reconstruction_r))
                print("min sampled vector " + str(np.min(y_sampled_r)))
                print("max sampled vector " + str(np.max(y_sampled_r)))
                print("time per batch is " + str(elapsed))
                cv2.imwrite("real_balls.jpg", np.uint8(dat[0, :, :, :] * 255))
                cv2.imwrite("generated_balls.jpg",
                            np.uint8(x_prime_r[0, :, :, :] * 255))
                kl_loss_dis = np.sort(np.sum(kl_loss_dis, axis=0))
                stddev_r = np.sort(np.sum(stddev_r, axis=0))
                #plt.plot(kl_loss_dis, label="step " + str(step))
                #plt.legend(loc = 'center left')
                #plt.savefig('kl_error_dis.png')
                plt.plot(stddev_r, label="step " + str(step))
                plt.legend(loc='center left')
                plt.savefig('stddev_r.png')

            assert not np.isnan(loss_r), 'Model diverged with loss = NaN'

            if step % 1000 == 0:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
                print("saved to " + FLAGS.train_dir)
                print("step " + str(step))
def test_stddev(network_type):
    # set parameters (quick and dirty code just to make graphs fast)
    if network_type in ("model_conv_num_balls_1_beta_0.1",
                        "model_conv_num_balls_1_beta_0.5",
                        "model_conv_num_balls_1_beta_1.0"):
        FLAGS.model = "conv"
        FLAGS.num_balls = 1
    if network_type in ("model_conv_num_balls_2_beta_0.1",
                        "model_conv_num_balls_2_beta_0.5",
                        "model_conv_num_balls_2_beta_1.0"):
        FLAGS.model = "conv"
        FLAGS.num_balls = 2
    elif network_type in ("model_fully_connected_num_balls_1_beta_0.1",
                          "model_fully_connected_num_balls_1_beta_0.5",
                          "model_fully_connected_num_balls_1_beta_1.0"):
        FLAGS.model = "fully_connected"
        FLAGS.num_balls = 1
        FLAGS.hidden_size = 10
    elif network_type in ("model_fully_connected_num_balls_2_beta_0.1",
                          "model_fully_connected_num_balls_2_beta_0.5",
                          "model_fully_connected_num_balls_2_beta_1.0"):
        FLAGS.model = "fully_connected"
        FLAGS.num_balls = 2
        FLAGS.hidden_size = 10
    elif network_type in ("model_all_conv_num_balls_1_beta_0.1",
                          "model_all_conv_num_balls_1_beta_0.5",
                          "model_all_conv_num_balls_1_beta_1.0"):
        FLAGS.model = "all_conv"
        FLAGS.num_balls = 1
        FLAGS.hidden_size = 1
    elif network_type in ("model_all_conv_num_balls_2_beta_0.1",
                          "model_all_conv_num_balls_2_beta_0.5",
                          "model_all_conv_num_balls_2_beta_1.0"):
        FLAGS.model = "all_conv"
        FLAGS.num_balls = 2
        FLAGS.hidden_size = 1
    """Eval net to get stddev."""
    with tf.Graph().as_default():
        # make inputs
        x = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 1])

        # no dropout on testing
        keep_prob = 1.0

        # make model
        if FLAGS.model == "fully_connected":
            mean, stddev, y_sampled, x_prime = arc.fully_connected_model(
                x, keep_prob)
        elif FLAGS.model == "conv":
            mean, stddev, y_sampled, x_prime = arc.conv_model(x, keep_prob)
        elif FLAGS.model == "all_conv":
            mean, stddev, y_sampled, x_prime = arc.all_conv_model(x, keep_prob)
        else:
            print("model requested not found, now some error!")

        # List of all Variables
        variables = tf.all_variables()

        # Load weights operator
        print('save file is ./checkpoints/train_store_' + network_type)
        ckpt = tf.train.get_checkpoint_state('./checkpoints/train_store_' +
                                             network_type)
        weight_saver = tf.train.Saver(variables)

        # Summary op
        summary_op = tf.merge_all_summaries()

        # Start running operations on the Graph.
        sess = tf.Session()

        # init if this is the very time training
        weight_saver.restore(sess, ckpt.model_checkpoint_path)
        print("restored from" + ckpt.model_checkpoint_path)

        # Summary op
        graph_def = sess.graph.as_graph_def(add_shapes=True)

        dat = b.bounce_vec(32, FLAGS.num_balls, FLAGS.batch_size)
        stddev_r = np.sort(
            np.sum(sess.run([stddev], feed_dict={x: dat})[0], axis=0))
        return stddev_r / FLAGS.batch_size

import bouncing_balls as b
import cv2
import numpy as np

res = 84 
n_balls = 4
T = 200
dat = b.bounce_vec(res,n_balls,T)
b.show_V(dat)




print(dat.shape)

dat = dat.reshape(T, res, res) 
dat = np.uint8(np.abs(dat * 255))

fourcc = cv2.cv.CV_FOURCC('m', 'p', '4', 'v') 
video = cv2.VideoWriter()
success = video.open("test.mov", fourcc, 4, (84, 84), True)


for i in xrange(T-3):
  frame = dat[i:i+3, :, :] 
  frame = np.transpose(frame, (1,2,0))
  print(frame.shape)
  video.write(frame)
示例#10
0
def generate_bouncing_ball_sample(batch_size, seq_length, shape, num_balls):
    dat = np.zeros((batch_size, seq_length, shape, shape, 3), dtype=np.float32)
    for i in xrange(batch_size):
        dat[i, :, :, :, :] = b.bounce_vec(32, num_balls, seq_length)
    return torch.from_numpy(dat).permute(0, 1, 4, 2, 3)
def create_image(network_type):
  # set parameters
  if network_type in ("model_conv_num_balls_1_beta_0.1", "model_conv_num_balls_1_beta_0.5", "model_conv_num_balls_1_beta_1.0"):
    FLAGS.model="conv"
    FLAGS.num_balls=1
  if network_type in ("model_conv_num_balls_2_beta_0.1", "model_conv_num_balls_2_beta_0.5", "model_conv_num_balls_2_beta_1.0"):
    FLAGS.model="conv"
    FLAGS.num_balls=2
  elif network_type in ("model_fully_connected_num_balls_1_beta_0.1", "model_fully_connected_num_balls_1_beta_0.5", "model_fully_connected_num_balls_1_beta_1.0"):
    FLAGS.model="fully_connected"
    FLAGS.num_balls=1
    FLAGS.hidden_size=10
  elif network_type in ("model_fully_connected_num_balls_2_beta_0.1", "model_fully_connected_num_balls_2_beta_0.5", "model_fully_connected_num_balls_2_beta_1.0"):
    FLAGS.model="fully_connected"
    FLAGS.num_balls=2
    FLAGS.hidden_size=10
  elif network_type in ("model_all_conv_num_balls_1_beta_0.1", "model_all_conv_num_balls_1_beta_0.5", "model_all_conv_num_balls_1_beta_1.0"):
    FLAGS.model="all_conv"
    FLAGS.num_balls=1
    FLAGS.hidden_size=1
  elif network_type in ("model_all_conv_num_balls_2_beta_0.1", "model_all_conv_num_balls_2_beta_0.5", "model_all_conv_num_balls_2_beta_1.0"):
    FLAGS.model="all_conv"
    FLAGS.num_balls=2
    FLAGS.hidden_size=1

  """Eval net to get stddev."""
  with tf.Graph().as_default():
    # make inputs
    x = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 1])
 
    # no dropout on testing
    keep_prob = 1.0

    # make model
    if FLAGS.model=="fully_connected":
      mean, stddev, y_sampled, x_prime = arc.fully_connected_model(x, keep_prob)
    elif FLAGS.model=="conv":
      mean, stddev, y_sampled, x_prime = arc.conv_model(x, keep_prob)
    elif FLAGS.model=="all_conv":
      mean, stddev, y_sampled, x_prime = arc.all_conv_model(x, keep_prob)
    else:
      print("model requested not found, now some errors!")

    # List of all Variables
    variables = tf.all_variables()

    # Load weights operator
    print('save file is ./checkpoints/train_store_' + network_type)
    ckpt = tf.train.get_checkpoint_state('./checkpoints/train_store_' + network_type)
    weight_saver = tf.train.Saver(variables)

    # Summary op
    summary_op = tf.merge_all_summaries()
 
    # Start running operations on the Graph.
    sess = tf.Session()

    # init if this is the very time training
    weight_saver.restore(sess, ckpt.model_checkpoint_path)
    print("restored from" + ckpt.model_checkpoint_path)

    # Summary op
    graph_def = sess.graph.as_graph_def(add_shapes=True)

    dat = b.bounce_vec(32, FLAGS.num_balls, FLAGS.batch_size)
    stddev_r = np.sum(sess.run([stddev],feed_dict={x:dat})[0], axis=0)/FLAGS.batch_size
    sample_y = sess.run([y_sampled],feed_dict={x:dat})[0]
    print(sample_y[0])

    # create grid
    y_p, x_p = np.mgrid[0:1:32j, 0:1:32j]
    
    for i in xrange(10):
      index = np.argmin(stddev_r)
      for j in xrange(5):
        z_f = np.copy(sample_y)
        z_f[0,index] = 1.5*j - 3.0
        print(sample_y[0])
        print(z_f[0])
        plt.subplot(10,5,j+5*i + 1)
        plt.pcolor(x_p, y_p, sess.run([x_prime],feed_dict={y_sampled:z_f})[0][0,:,:,0])
        if i == 9:
          plt.xlabel(str(1.5*j - 3.0))
        if j == 0:
          plt.ylabel("{0:.2f}".format(stddev_r[index]))
        plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off')
        #plt.axis(frameon=False)
      # just make it big to get out of the way
      stddev_r[index] = 2.0
    plt.savefig("figures/heat_map_" + network_type + ".png")
        pylab.axis('off')
    CreateMovie(filename, plotter, numframes, fps)


print 'making data...'

numballs = 1
numframes = 20
patchsize = 28


numcases = 70000
trainmovies = numpy.empty((numcases,numframes,patchsize**2), dtype=numpy.float32)
for i in range(numcases):
    #print i
    trainmovies[i] = bouncing_balls.bounce_vec(patchsize, numballs, numframes)

trainmovies = trainmovies.reshape(-1, numframes, patchsize**2)

np.save('/data/lisatmp2/kruegerd/bouncing_balls/bouncing_ball', trainmovies)
#trainmovies = numpy.load("100000bouncingballmovies16x16.npy") 

#numtest = 10000
#testmovies = numpy.empty((numtest,numframes,patchsize**2), dtype=numpy.float32)
#for i in range(numtest):
    #print i
#    testmovies[i] = bouncing_balls.bounce_vec(patchsize, numballs, numframes)

#testmovies = testmovies.reshape(-1, numframes, patchsize**2)

print '... done'
示例#13
0
    def train(self,
              valid_set=False,
              learning_rate=0.1,
              num_updates=500,
              save=False,
              output_folder=None,
              lr_update=None,
              mom_rate=0.9,
              update_type='linear',
              start=2,
              batch_size=100,
              filename=None):
        self.best_train_cost = numpy.inf
        self.best_valid_cost = numpy.inf
        self.init_lr = learning_rate
        self.lr = numpy.array(learning_rate)
        self.mom_rate = mom_rate
        self.output_folder = output_folder
        self.valid_set = valid_set
        self.save = save
        self.lr_update = lr_update
        self.stop_train = False
        self.train_costs = []
        self.valid_costs = []
        self.num_updates = num_updates
        self.batch_size = batch_size
        self.update_type = update_type
        self.start = start
        self.filename = filename
        self.valid_frequency = 1000
        #self.save_model() #saving model after pre-training
        try:
            cost = []
            for u in xrange(num_updates):
                #pdb.set_trace()
                if u % 10 == 0:
                    self.valid()
                else:
                    batch_data = b.bounce_vec(
                        15, n=3,
                        T=128)  #Ensure this is a list in the desired form.
                    fixed_array = numpy.zeros(batch_data.shape)
                    fixed_array[:] = batch_data
                    inputs = [fixed_array] + [self.lr]
                    if self.momentum:
                        inputs = inputs + [self.mom_rate]
                    no_update_cost = self.calc_cost(fixed_array)[0]
                    if numpy.isnan(no_update_cost):
                        print 'Training cost is NaN.'
                        print 'Breaking from training early, the last saved set of parameters is still usable!'
                        print 'Saving broken model for analysis.'
                        self.save_model('params_NaN.pickle')
                        print 'Saving input sequence'
                        path = os.path.join(self.output_folder, 'nan_seq.pkl')
                        pickle.dump(batch_data, open(path, 'w'))
                        break
                    else:
                        cost.append(self.f(*inputs))
                        mean_costs = numpy.mean(cost, axis=0)

                    print '  Update %i   ' % (u + 1)
                    print '***Train Results***'
                    for i in xrange(self.num_costs):
                        print "Cost %i: %f" % (i, mean_costs[i])
                    self.train_costs.append(mean_costs)
                    this_cost = numpy.mean(cost, axis=0)[0]
                    if u > 0:  #Because cost is not stable at the start of joint training of RNN_RNADE
                        if this_cost < self.best_train_cost:
                            self.best_train_cost = this_cost
                            print 'Best Params!'
                            if save:
                                self.save_model('best_params_train.pickle')
                        sys.stdout.flush()

                    if self.stop_train:
                        print 'Stopping training early.'
                        break

                    if lr_update:
                        self.update_lr(u + 1,
                                       update_type=self.update_type,
                                       start=self.start,
                                       num_iterations=self.num_updates)
            print 'Training completed!'

        except KeyboardInterrupt:
            print 'Training interrupted.'
示例#14
0
    def train(self,valid_set=False,learning_rate=0.1,num_updates=500,save=False,output_folder=None,lr_update=None,
              mom_rate=0.9,update_type='linear',start=2,batch_size=100,filename=None):
        self.best_train_cost = numpy.inf
        self.best_valid_cost = numpy.inf
        self.init_lr = learning_rate
        self.lr = numpy.array(learning_rate)
        self.mom_rate = mom_rate
        self.output_folder = output_folder
        self.valid_set = valid_set
        self.save = save
        self.lr_update = lr_update
        self.stop_train = False
        self.train_costs = []
        self.valid_costs = []
        self.num_updates = num_updates
        self.batch_size = batch_size
        self.update_type = update_type
        self.start = start
        self.filename = filename
        self.valid_frequency = 1000
        #self.save_model() #saving model after pre-training
        try:
            cost = []
            for u in xrange(num_updates):
                #pdb.set_trace()
                if u%10 == 0:
                    self.valid()
                else:
                    batch_data = b.bounce_vec(15,n=3,T=128) #Ensure this is a list in the desired form. 
                    fixed_array = numpy.zeros(batch_data.shape)
                    fixed_array[:] = batch_data
                    inputs = [fixed_array] + [self.lr]
                    if self.momentum:
                        inputs = inputs + [self.mom_rate]
                    no_update_cost = self.calc_cost(fixed_array)[0]
                    if numpy.isnan(no_update_cost):
                        print 'Training cost is NaN.'
                        print 'Breaking from training early, the last saved set of parameters is still usable!'
                        print 'Saving broken model for analysis.'
                        self.save_model('params_NaN.pickle')
                        print 'Saving input sequence'
                        path = os.path.join(self.output_folder,'nan_seq.pkl')
                        pickle.dump(batch_data,open(path,'w'))
                        break
                    else:
                        cost.append(self.f(*inputs))
                        mean_costs = numpy.mean(cost,axis=0)
                        
                    print '  Update %i   ' %(u+1)
                    print '***Train Results***'
                    for i in xrange(self.num_costs):
                        print "Cost %i: %f"%(i,mean_costs[i])
                    self.train_costs.append(mean_costs)
                    this_cost = numpy.mean(cost, axis=0)[0]
                    if u > 0: #Because cost is not stable at the start of joint training of RNN_RNADE
                        if this_cost < self.best_train_cost:
                            self.best_train_cost = this_cost
                            print 'Best Params!'
                            if save:
                                self.save_model('best_params_train.pickle')
                        sys.stdout.flush()     
             
                    if self.stop_train:
                        print 'Stopping training early.'
                        break

                    if lr_update:
                        self.update_lr(u+1,update_type=self.update_type,start=self.start,num_iterations=self.num_updates)
            print 'Training completed!'

        except KeyboardInterrupt: 
            print 'Training interrupted.'