Exemplo n.º 1
0
 def test_get_num_micro_batches_per_epoch(self):
     batch_config = BatchConfig(micro_batch_size=3,
                                num_replicas=2,
                                global_batch_size=10)
     assert batch_config.get_num_micro_batches_per_epoch(
         dataset_size=18) == 4
     assert batch_config.get_num_micro_batches_per_epoch(
         dataset_size=20) == 4
Exemplo n.º 2
0
 def test_round_gradient_accumulation_count(self):
     batch_config = BatchConfig(micro_batch_size=3,
                                num_replicas=2,
                                global_batch_size=10)
     assert batch_config.gradient_accumulation_count == 2
     assert batch_config.global_batch_size == 12
Exemplo n.º 3
0
 def test_calc_global_batch_size(self):
     batch_config = BatchConfig(micro_batch_size=3,
                                num_replicas=4,
                                gradient_accumulation_count=2)
     assert batch_config.global_batch_size == 24
Exemplo n.º 4
0
 def test_global_batch_size(self):
     batch_config = BatchConfig(micro_batch_size=1,
                                num_replicas=1,
                                global_batch_size=5)
     assert batch_config.global_batch_size == 5
Exemplo n.º 5
0
 def test_gradient_accumulation_count(self):
     batch_config = BatchConfig(micro_batch_size=1,
                                num_replicas=1,
                                gradient_accumulation_count=2)
     assert batch_config.gradient_accumulation_count == 2
Exemplo n.º 6
0
 def test_num_micro_batches_per_weight_update(self):
     batch_config = BatchConfig(micro_batch_size=1,
                                num_replicas=4,
                                gradient_accumulation_count=4)
     assert batch_config.num_micro_batches_per_weight_update == 4 * 4
Exemplo n.º 7
0
 def test_num_replicas(self):
     batch_config = BatchConfig(micro_batch_size=1,
                                num_replicas=4,
                                gradient_accumulation_count=1)
     assert batch_config.num_replicas == 4
Exemplo n.º 8
0
 def test_micro_batch_size(self):
     batch_config = BatchConfig(micro_batch_size=3,
                                num_replicas=1,
                                gradient_accumulation_count=1)
     assert batch_config.micro_batch_size == 3
Exemplo n.º 9
0
    if internal_exchange_optimization_target is not None:
        cfg.compilation_poplar_options[
            'opt.internalExchangeOptimisationTarget'] = internal_exchange_optimization_target

    if distributed_training:
        popdist.tensorflow.set_ipu_config(
            cfg, ipus_per_replica=num_ipus_per_replica, configure_device=True)
        hvd.init()
    else:
        cfg.auto_select_ipus = num_ipus_per_replica * num_replicas

    cfg.configure_ipu_system()

    set_seed(seed)

    batch_config = BatchConfig(micro_batch_size, num_replicas,
                               gradient_accumulation_count, global_batch_size)

    logging.info(f'micro batch size {batch_config.micro_batch_size}')
    logging.info(f'global batch size {batch_config.global_batch_size}')
    logging.info(
        f'gradient accumulation {batch_config.gradient_accumulation_count}')
    logging.info(f'num replicas {batch_config.num_replicas}')

    if validation:

        validation_num_replicas = validation_num_replicas or (
            num_replicas * num_ipus_per_replica)
        validation_batch_config = BatchConfig(
            micro_batch_size=validation_micro_batch_size,
            num_replicas=validation_num_replicas,
            gradient_accumulation_count=1,