Beispiel #1
0
def main(_):
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    params = train_utils.parse_configuration(FLAGS)
    model_dir = FLAGS.model_dir
    if 'train' in FLAGS.mode:
        # Pure eval modes do not output yaml files. Otherwise continuous eval job
        # may race against the train job for writing the same file.
        train_utils.serialize_config(params, model_dir)

    # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
    # can have significant impact on model speeds by utilizing float16 in case of
    # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
    # dtype is float16
    if params.runtime.mixed_precision_dtype:
        performance.set_mixed_precision_policy(
            params.runtime.mixed_precision_dtype)
    distribution_strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu)

    with distribution_strategy.scope():
        tasks = multitask.MultiTask.from_config(params.task)
        model = multitask_model.SimCLRMTModel(params.task.model)

    train_lib.run_experiment(distribution_strategy=distribution_strategy,
                             task=tasks,
                             model=model,
                             mode=FLAGS.mode,
                             params=params,
                             model_dir=model_dir)

    train_utils.save_gin_config(FLAGS.mode, model_dir)
 def test_initialize_model_success(self):
   ckpt_dir = self.get_temp_dir()
   config = multitask_config.SimCLRMTModelConfig(
       input_size=[64, 64, 3],
       heads=(multitask_config.SimCLRMTHeadConfig(
           mode=simclr_model.PRETRAIN, task_name='pretrain_simclr'),
              multitask_config.SimCLRMTHeadConfig(
                  mode=simclr_model.FINETUNE, task_name='finetune_simclr')))
   model = multitask_model.SimCLRMTModel(config)
   self.assertIn('pretrain_simclr', model.sub_tasks)
   self.assertIn('finetune_simclr', model.sub_tasks)
   ckpt = tf.train.Checkpoint(backbone=model._backbone)
   ckpt.save(os.path.join(ckpt_dir, 'ckpt'))
   model.initialize()