FLAGS = tf.app.flags.FLAGS


def launch_training():

    train.train_model()


# def launch_inference():

#     inference.infer()


def main(argv=None):

    # INFERENCE TODO
    assert FLAGS.run in ["train"], "Choose [train]"

    if FLAGS.run == 'train':
        launch_training()

    # if FLAGS.run == 'inference':
    #     launch_inference()


if __name__ == '__main__':

    flags.define_flags()
    tf.app.run()
예제 #2
0
파일: main.py 프로젝트: minoring/dcgan
                    z_dim=FLAGS.z_dim,
                    dataset_name=FLAGS.dataset,
                    input_fname_pattern=FLAGS.input_fname_pattern,
                    crop=FLAGS.crop,
                    checkpoint_dir=FLAGS.checkpoint_dir,
                    sample_dir=FLAGS.sample_dir,
                    data_dir=FLAGS.data_dir,
                    out_dir=FLAGS.out_dir,
                    max_to_keep=FLAGS.max_to_keep)


def main(_):
  set_default_flags(FLAGS)

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)

  # Save flag setting as json.
  with open(os.path.join(FLAGS.out_dir, 'FLAGS.json'), 'w') as f:
    flags_dict = {k: FLAGS[k].value for k in FLAGS}
    json.dump(flags_dict, f, indent=4, sort_keys=True, ensure_ascii=False)

  run(FLAGS)


if __name__ == '__main__':
  define_flags()
  app.run(main)
예제 #3
0
import os
# Disable Tensorflow's INFO and WARNING messages
# See http://stackoverflow.com/questions/35911252
if 'TF_CPP_MIN_LOG_LEVEL' not in os.environ:
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import flags
import tensorflow as tf
import train_wgan_GP

FLAGS = tf.app.flags.FLAGS


def launch_training():

    train_wgan_GP.train_model()


def main(argv=None):

    assert FLAGS.run in ["train", "inference"], "Choose [train|inference]"

    if FLAGS.run == 'train':
        launch_training()


if __name__ == '__main__':
    flags.define_flags()
    tf.app.run()
예제 #4
0
    def __init__(self, customized_params=None, log_to_tensorboard=True):
        # Parse parameters
        flags = define_flags()
        params = flags2params(flags, customized_params)

        self.logger = logging.getLogger(__name__)
        self.logger.info("Task: " + str(params.task))
        self.logger.info("Language: " + str(params.language))
        self.task = params.task
        self.language = params.language
        self.run_name = "%s_%s_%s_%s" % (
            params.tag, self.task.name, self.language.name,
            datetime.datetime.now().strftime('%b-%d_%H-%M-%S-%f'))
        assert not (params.log_dir / self.run_name).is_dir(
        ), "The run %s has existed in Log Path %s" % (self.run_name,
                                                      params.log_dir)
        self.checkpoint_dir = params.checkpoint_dir / self.run_name / "model"
        self.output_dir = params.output_dir

        # Loading dataset
        train_path, test_path, vocab = prepare_data_and_vocab(
            vocab=params.vocab,
            store_folder=params.embedding_dir,
            data_dir=params.data_dir,
            language=params.language)

        # split training set into train and dev sets
        self.raw_train, self.raw_dev = model_selection.train_test_split(
            json.load(train_path.open()),
            test_size=params.dev_ratio,
            random_state=params.random_seed)

        self.raw_test = json.load(test_path.open())

        train_dataset = process_raw_data(self.raw_train,
                                         vocab=vocab,
                                         max_len=params.max_len,
                                         cache_dir=params.cache_dir,
                                         is_train=True,
                                         name="train_%s" % params.language)

        dev_dataset = process_raw_data(self.raw_dev,
                                       vocab=vocab,
                                       max_len=params.max_len,
                                       cache_dir=params.cache_dir,
                                       is_train=False,
                                       name="dev_%s" % params.language)

        test_dataset = process_raw_data(self.raw_test,
                                        vocab=vocab,
                                        max_len=params.max_len,
                                        cache_dir=params.cache_dir,
                                        is_train=False,
                                        name="test_%s" % params.language)

        pad_idx = vocab.pad_idx
        self.train_iterator = build_dataset_op(train_dataset,
                                               pad_idx,
                                               params.batch_size,
                                               is_train=True)
        self.train_batch = self.train_iterator.get_next()
        self.dev_iterator = build_dataset_op(dev_dataset,
                                             pad_idx,
                                             params.batch_size,
                                             is_train=False)
        self.dev_batch = self.dev_iterator.get_next()
        self.test_iterator = build_dataset_op(test_dataset,
                                              pad_idx,
                                              params.batch_size,
                                              is_train=False)
        self.test_batch = self.test_iterator.get_next()

        config = tf.ConfigProto(allow_soft_placement=True)
        sess = tf.Session(config=config)

        self.model = model.Model(vocab.weight, self.task, params, session=sess)
        self.inference_mode = False
        self.num_epoch = params.num_epoch

        if params.resume_dir:
            self.model.load_model(params.resume_dir)
            if params.infer_test:
                self.inference_mode = True
            self.logger.info("Inference_mode: On")

        if log_to_tensorboard:
            self.log_to_tensorboard = log_to_tensorboard
            self.log_writer = tf.summary.FileWriter(str(params.log_dir /
                                                        self.run_name),
                                                    sess.graph,
                                                    flush_secs=20)
예제 #5
0
            trial_dir=trial_dir,
            train_steps=FLAGS.train_steps,
            mode=FLAGS.mode,
            strategy=strategy,
            metrics=metrics,
            hparams=hparams)


def main(program_flag_names):
    logging.info('Starting Uncertainty Baselines experiment %s',
                 FLAGS.experiment_name)
    if FLAGS.output_dir:
        logging.info(
            '\n\nRun the following command to view outputs in tensorboard.dev:\n\n'
            'tensorboard dev upload --logdir %s --plugins scalars,graphs,hparams\n\n',
            FLAGS.output_dir)

        # TODO(znado): when open sourced tuning is supported, change this to include
        # the trial number.
        trial_dir = os.path.join(FLAGS.output_dir, '0')
    else:
        trial_dir = None
    program_flags = {name: FLAGS[name].value for name in program_flag_names}
    flag_string = flags_lib.serialize_flags(program_flags)
    run(trial_dir, flag_string)


if __name__ == '__main__':
    defined_flag_names = flags_lib.define_flags()
    app.run(lambda _: main(defined_flag_names))