예제 #1
0
def main(_):
  console.suppress_logging()
  FLAGS.train = TRAIN
  FLAGS.overwrite = OVERWRITE
  console.start('EXP OB 01')
  # Define system
  system = define_system()
  # Generate data
  training_set, validation_set, test_set = generate_data(system)
  if len(SYS_LOCK_ORDERS) == 1:
    homogeneous_check(system, SYS_LOCK_ORDERS[0], training_set.signls[0],
                      training_set.responses[0])
  # Identification
  # .. wiener
  wiener = Wiener(degree=WN_DEGREE, memory_depth=WN_MEN_DEPTH)
  if WIENER_ON: wiener.identify(training_set, validation_set)
  # .. vn
  homo_strs = NN_HOMO_STRS
  vns = collections.OrderedDict()
  for homo_str in homo_strs:
    console.show_status('Volterra Net h**o-strength = {:.2f}'.format(homo_str))
    vn = init_vn('vn_{:.2f}{}'.format(homo_str, POSTFIX), homo_str=homo_str)
    vns[homo_str] = vn
    if FLAGS.train:
      vn.identify(training_set, validation_set,
                  batch_size=50, print_cycle=100, epoch=EPOCH)
  # Verification
  verify(vns, wiener, system, test_set)
  # End
  console.end()
예제 #2
0
def main(_):
    console.suppress_logging()

    FLAGS.train = False
    FLAGS.overwrite = True
    show_false = True
    flatten = False

    # Start
    console.start('MNIST DEMO')

    # model = models.vanilla('003_post')
    model = models.deep_conv('dc_000')

    mnist = load_mnist('../../data/MNIST',
                       flatten=flatten,
                       validation_size=5000,
                       one_hot=True)
    # Train or test
    if FLAGS.train:
        model.train(training_set=mnist[pedia.training],
                    validation_set=mnist[pedia.validation],
                    epoch=30,
                    batch_size=100,
                    print_cycle=50)
    else:
        model.evaluate_model(mnist[pedia.test], with_false=show_false)

    # End
    console.end()
예제 #3
0
def main(_):
  console.suppress_logging()
  FLAGS.train = True
  FLAGS.overwrite = False


  # Start
  console.start("MNIST DCGAN DEMO")

  # Get model
  model = models.dcgan('dcgan_002')
  # model = models.dcgan_h3_rs_nbn()

  # Train or test
  if FLAGS.train:
    mnist = load_mnist('../../data/MNIST', flatten=False, validation_size=0,
                       one_hot=True)
    model.train(training_set=mnist[pedia.training], epoch=10, batch_size=128,
                print_cycle=20, snapshot_cycle=200, D_times=1, G_times=1)
  else:
    samples = model.generate(sample_num=16)
    console.show_status('{} samples generated'.format(samples.shape[0]))
    imtool.gan_grid_plot(samples, show=True)

  # End
  console.end()
예제 #4
0
def main(_):
    FLAGS.overwrite = False
    FLAGS.train = True
    play = True

    console.suppress_logging()
    console.start('TD Gomoku - vanilla')

    with tf.Graph().as_default():
        model = models.mlp00('mlp00_00')

    with tf.Graph().as_default():
        opponent = models.mlp00('mlp00_00')

    game = Game()
    if FLAGS.train:
        model.train(game,
                    episodes=500000,
                    print_cycle=20,
                    snapshot_cycle=300,
                    match_cycle=2000,
                    rounds=5,
                    rate_thresh=1.0,
                    shadow=opponent,
                    save_cycle=200,
                    snapshot_function=game.snapshot)
    else:
        if play:
            TkBoard(player=model).show()
        else:
            model.compete(game, rounds=100, opponent=opponent)

    console.end()
예제 #5
0
def main(_):
    console.suppress_logging()
    FLAGS.train = True
    FLAGS.overwrite = True

    # Start
    console.start('MNIST VANILLA VAE')

    # Get model
    model = models.vanilla('vanilla_00')

    if FLAGS.train:
        mnist = load_mnist('../../data/MNIST',
                           flatten=True,
                           validation_size=0,
                           one_hot=True)
        model.train(training_set=mnist[pedia.training],
                    epoch=1000,
                    batch_size=128,
                    print_cycle=50,
                    snapshot_cycle=200)
    else:
        samples = model.generate(sample_num=16)
        console.show_status('{} samples generated'.format(samples.shape[0]))
        imtool.gan_grid_plot(samples, show=True)

    # End
    console.end()
예제 #6
0
def main(_):
    console.suppress_logging()
    FLAGS.overwrite = True
    FLAGS.train = True

    # Start
    console.start()

    # Get or define model
    model = models.vanilla('vanilla_nov9_02_h2_c', bn=False)
    # model = models.dcgan('dcgan_c00')
    # model = models.vanilla_h3_rs_nbn('vanilla_nov9_01_h3_nbn_opdef')
    # return

    # Train or test
    if FLAGS.train:
        mnist = load_mnist('../../data/MNIST',
                           flatten=True,
                           validation_size=0,
                           one_hot=True)
        model.train(training_set=mnist[pedia.training],
                    epoch=1000,
                    batch_size=128,
                    print_cycle=20,
                    snapshot_cycle=150,
                    sample_num=25)
    else:
        samples = model.generate(sample_num=16)
        console.show_status('{} samples generated'.format(samples.shape[0]))
        imtool.gan_grid_plot(samples, show=True)

    # End
    console.end()
예제 #7
0
def main(_):
    console.suppress_logging()

    # Setting
    FLAGS.train = False
    FLAGS.overwrite = True
    # FLAGS.shuffle = True
    show_false_pred = True

    # Start
    console.start('CIFAR-10 CONV DEMO')

    # Get model
    # model = models.deep_conv('dper_do0p5_reg0p2')
    model = models.deep_conv('001_pre_bn')

    # Train or test
    cifar10 = load_cifar10('../../data/CIFAR-10',
                           flatten=False,
                           validation_size=5000,
                           one_hot=True)
    if FLAGS.train:
        model.train(training_set=cifar10[pedia.training],
                    validation_set=cifar10[pedia.validation],
                    epoch=120,
                    batch_size=64,
                    print_cycle=100)
    else:
        model.evaluate_model(cifar10[pedia.test], with_false=show_false_pred)

    # End
    console.end()
예제 #8
0
def main(_):
    console.suppress_logging()
    FLAGS.train = True
    FLAGS.overwrite = False

    # Start
    console.start('CIFAR-10 DCGAN')

    # Get model
    model = models.dcgan('dcgan_00')

    if FLAGS.train:
        cifar10 = load_cifar10('../../data/CIFAR-10',
                               validation_size=0,
                               one_hot=True)
        model.train(training_set=cifar10[pedia.training],
                    epoch=20000,
                    batch_size=128,
                    print_cycle=20,
                    snapshot_cycle=2000)
    else:
        samples = model.generate(sample_num=16)
        console.show_status('{} samples generated'.format(samples.shape[0]))
        imtool.gan_grid_plot(samples, show=True)

    # End
    console.end()
예제 #9
0
    def launch_model(self, overwrite=False):
        if hub.suppress_logging: console.suppress_logging()
        # Before launch session, do some cleaning work
        if overwrite and hub.overwrite:
            paths = []
            if hub.summary: paths.append(self.log_dir)
            if hub.save_model: paths.append(self.ckpt_dir)
            if hub.snapshot: paths.append(self.snapshot_dir)
            if hub.export_note: paths.append(self.note_dir)
            clear_paths(paths)
        if hub.summary: self._check_bash()

        # Launch session on self.graph
        console.show_status('Launching session ...')
        config = tf.ConfigProto()
        if not hub.allow_growth:
            value = hub.gpu_memory_fraction
            config.gpu_options.per_process_gpu_memory_fraction = value
        self._session = tf.Session(graph=self._graph, config=config)
        console.show_status('Session launched')
        # Prepare some tools
        self._saver = tf.train.Saver(var_list=self._model.variable_to_save)
        if hub.summary or hub.hp_tuning:
            self._summary_writer = tf.summary.FileWriter(self.log_dir)

        # Initialize all variables
        self._session.run(tf.global_variables_initializer())
        # Try to load exist model
        load_flag, self._model.counter = self.load()
        if not load_flag:
            assert self._model.counter == 0
            # Add graph
            if hub.summary: self._summary_writer.add_graph(self._session.graph)
            # Write model description to file
            if hub.snapshot:
                description_path = os.path.join(self.snapshot_dir,
                                                'description.txt')
                write_file(description_path, self._model.description)
            # Show status
            console.show_status('New model initiated')

        self._model.launched = True
        self.take_notes('Model launched')
        return load_flag
예제 #10
0
파일: agent.py 프로젝트: rscv5/tframe
  def launch_model(self, overwrite=False):
    if hub.suppress_logging: console.suppress_logging()
    # Before launch session, do some cleaning work
    if overwrite and hub.overwrite:
      paths = []
      if hub.summary: paths.append(self.log_dir)
      if hub.save_model: paths.append(self.ckpt_dir)
      if hub.snapshot: paths.append(self.snapshot_dir)
      if hub.export_note: paths.append(self.note_dir)
      clear_paths(paths)
    if hub.summary: self._check_bash()

    # Launch session on self.graph
    console.show_status('Launching session ...')
    config = tf.ConfigProto()
    if hub.visible_gpu_id is not None:
      gpu_id = hub.visible_gpu_id
      if isinstance(gpu_id, int): gpu_id = '{}'.format(gpu_id)
      elif not isinstance(gpu_id, str): raise TypeError(
        '!! Visible GPU id provided must be an integer or a string')
      os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
    if not hub.allow_growth:
      value = hub.gpu_memory_fraction
      config.gpu_options.per_process_gpu_memory_fraction = value
    self._session = tf.Session(graph=self._graph, config=config)
    console.show_status('Session launched')
    # Prepare some tools
    self.reset_saver()
    if hub.summary or hub.hp_tuning:
      self._summary_writer = tf.summary.FileWriter(self.log_dir)

    # Initialize all variables
    self._session.run(tf.global_variables_initializer())
    # Set init_val for pruner if necessary
    # .. if existed model is loaded, variables will be overwritten
    if hub.prune_on: context.pruner.set_init_val_lottery18()

    # Try to load exist model
    load_flag, self._model.counter, self._model.rounds = self.load()
    # Sanity check
    if hub.prune_on and hub.pruning_iterations > 0:
      if not load_flag: raise AssertionError(
        '!! Model {} should be initialized'.format(self._model.mark))

    if not load_flag:
      assert self._model.counter == 0
      # Add graph
      if hub.summary: self._summary_writer.add_graph(self._session.graph)
      # Write model description to file
      if hub.snapshot:
        description_path = os.path.join(self.snapshot_dir, 'description.txt')
        write_file(description_path, self._model.description)
      # Show status
      console.show_status('New model initiated')
    elif hub.branch_suffix not in [None, '']:
      hub.mark += hub.branch_suffix
      self._model.mark = hub.mark
      console.show_status('Checkpoint switched to branch `{}`'.format(hub.mark))

    self._model.launched = True
    self.take_notes('Model launched')

    # Handle structure detail here
    self._model.handle_structure_detail()

    return load_flag
예제 #11
0
    # 3. trainer setup
    # ---------------------------------------------------------------------------
    th.epoch = 1000
    th.batch_size = 64
    th.validation_per_round = 5

    th.optimizer = tf.train.AdamOptimizer
    th.learning_rate = 0.001

    th.patience = 5

    # ---------------------------------------------------------------------------
    # 4. summary and note setup
    # ---------------------------------------------------------------------------
    th.train = True
    th.save_model = True
    th.overwrite = True

    # ---------------------------------------------------------------------------
    # 5. other stuff and activate
    # ---------------------------------------------------------------------------
    tail = suffix
    th.mark = prefix + '{}({}){}'.format(model_name, th.num_layers, tail)
    th.gather_summ_name = prefix + summ_name + tail + '.sum'
    core.activate(True)


if __name__ == '__main__':
    console.suppress_logging()
    tf.app.run()