def __init__(self, num_steps, model_load_path, num_test_rec): """ Initializes the Adversarial Video Generation Runner. @param num_steps: The number of training steps to run. @param model_load_path: The path from which to load a previously-saved model. Default = None. @param num_test_rec: The number of recursive generations to produce when testing. Recursive generations use previous generations as input to predict further into the future. """ self.global_step = 0 self.num_steps = num_steps self.num_test_rec = num_test_rec self.sess = tf.Session() self.summary_writer = tf.summary.FileWriter(c.SUMMARY_SAVE_DIR, graph=self.sess.graph) if c.ADVERSARIAL: print 'Init discriminator...' self.d_model = DiscriminatorModel(self.sess, self.summary_writer, c.TRAIN_HEIGHT, c.TRAIN_WIDTH, c.SCALE_CONV_FMS_D, c.SCALE_KERNEL_SIZES_D, c.SCALE_FC_LAYER_SIZES_D) print 'Init generator...' c.FULL_HEIGHT = 120 c.FULL_WIDTH = 160 self.g_model = GeneratorModel(self.sess, self.summary_writer, c.TRAIN_HEIGHT, c.TRAIN_WIDTH, c.FULL_HEIGHT, c.FULL_WIDTH, c.SCALE_FMS_GE, c.SCALE_FMS_GD, c.SCALE_KERNEL_SIZES_GE, c.SCALE_KERNEL_SIZES_GD) print 'Init variables...' self.saver = tf.train.Saver(keep_checkpoint_every_n_hours=2) self.sess.run(tf.global_variables_initializer()) # if load path specified, load a saved model if model_load_path is not None: print('\n------------------------------------') print('loadpath: ' + model_load_path) print('------------------------------------\n') self.saver.restore(self.sess, model_load_path) print 'Model restored from ' + model_load_path
def __init__(self, num_steps, model_load_path): """ Initializes the Adversarial Video Generation Runner. @param num_steps: The number of training steps to run. @param model_load_path: The path from which to load a previously-saved model. Default = None. """ self.global_step = 0 self.num_steps = num_steps self.sess = tf.Session() #self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) self.summary_writer = tf.summary.FileWriter(c.SUMMARY_SAVE_DIR, graph=self.sess.graph) if c.ADVERSARIAL: print('Init discriminator...') self.d_model = DiscriminatorModel(self.sess, self.summary_writer, c.TRAIN_HEIGHT, c.TRAIN_WIDTH, c.SCALE_CONV_FMS_D, c.SCALE_KERNEL_SIZES_D, c.SCALE_FC_LAYER_SIZES_D) print('Init generator...') self.g_model = GeneratorModel(self.sess, self.summary_writer, c.TRAIN_HEIGHT, c.TRAIN_WIDTH, c.FULL_HEIGHT, c.FULL_WIDTH, c.SCALE_FMS_G, c.SCALE_KERNEL_SIZES_G) print('Init variables...') self.summary_writer.add_graph(self.sess.graph) self.saver = tf.train.Saver(keep_checkpoint_every_n_hours=2) self.sess.run(tf.global_variables_initializer()) # if load path specified, load a saved model if model_load_path is not None: self.saver.restore(self.sess, model_load_path) print('Model restored from ' + model_load_path)