예제 #1
0
    def __init__(self, num_steps, model_load_path, num_test_rec):
        """
        Initializes the Adversarial Video Generation Runner.

        @param num_steps: The number of training steps to run.
        @param model_load_path: The path from which to load a previously-saved model.
                                Default = None.
        @param num_test_rec: The number of recursive generations to produce when testing. Recursive
                             generations use previous generations as input to predict further into
                             the future.
        """

        self.global_step = 0
        self.num_steps = num_steps
        self.num_test_rec = num_test_rec

        self.sess = tf.Session()
        self.summary_writer = tf.summary.FileWriter(c.SUMMARY_SAVE_DIR, graph=self.sess.graph)

        if c.ADVERSARIAL:
            print 'Init discriminator...'
            self.d_model = DiscriminatorModel(self.sess,
                                              self.summary_writer,
                                              c.TRAIN_HEIGHT,
                                              c.TRAIN_WIDTH,
                                              c.SCALE_CONV_FMS_D,
                                              c.SCALE_KERNEL_SIZES_D,
                                              c.SCALE_FC_LAYER_SIZES_D)

        print 'Init generator...'
        c.FULL_HEIGHT = 120
        c.FULL_WIDTH = 160
        self.g_model = GeneratorModel(self.sess,
                                      self.summary_writer,
                                      c.TRAIN_HEIGHT,
                                      c.TRAIN_WIDTH,
                                      c.FULL_HEIGHT,
                                      c.FULL_WIDTH,
                                      c.SCALE_FMS_GE,
                                      c.SCALE_FMS_GD,
                                      c.SCALE_KERNEL_SIZES_GE,
                                      c.SCALE_KERNEL_SIZES_GD)

        print 'Init variables...'
        self.saver = tf.train.Saver(keep_checkpoint_every_n_hours=2)
        self.sess.run(tf.global_variables_initializer())

        # if load path specified, load a saved model
        if model_load_path is not None:
            print('\n------------------------------------')
            print('loadpath: ' + model_load_path)
            print('------------------------------------\n')
            self.saver.restore(self.sess, model_load_path)
            print 'Model restored from ' + model_load_path
예제 #2
0
파일: avg_runner.py 프로젝트: qui3n/Code
    def __init__(self, num_steps, model_load_path):
        """
        Initializes the Adversarial Video Generation Runner.

        @param num_steps: The number of training steps to run.
        @param model_load_path: The path from which to load a previously-saved model.
                                Default = None.
        """
        self.global_step = 0
        self.num_steps = num_steps

        self.sess = tf.Session()
        #self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))

        self.summary_writer = tf.summary.FileWriter(c.SUMMARY_SAVE_DIR,
                                                    graph=self.sess.graph)

        if c.ADVERSARIAL:
            print('Init discriminator...')
            self.d_model = DiscriminatorModel(self.sess, self.summary_writer,
                                              c.TRAIN_HEIGHT, c.TRAIN_WIDTH,
                                              c.SCALE_CONV_FMS_D,
                                              c.SCALE_KERNEL_SIZES_D,
                                              c.SCALE_FC_LAYER_SIZES_D)

        print('Init generator...')
        self.g_model = GeneratorModel(self.sess, self.summary_writer,
                                      c.TRAIN_HEIGHT, c.TRAIN_WIDTH,
                                      c.FULL_HEIGHT, c.FULL_WIDTH,
                                      c.SCALE_FMS_G, c.SCALE_KERNEL_SIZES_G)

        print('Init variables...')
        self.summary_writer.add_graph(self.sess.graph)
        self.saver = tf.train.Saver(keep_checkpoint_every_n_hours=2)
        self.sess.run(tf.global_variables_initializer())

        # if load path specified, load a saved model
        if model_load_path is not None:
            self.saver.restore(self.sess, model_load_path)
            print('Model restored from ' + model_load_path)