Пример #1
0
def train_net(net,
              session_init,
              batch_size,
              num_epochs,
              train_dataflow,
              val_dataflow):

    num_towers = max(get_num_gpu(), 1)
    batch_per_tower = batch_size // num_towers
    logger.info("Running on {} towers. Batch size per tower: {}".format(num_towers, batch_per_tower))

    num_training_samples = 1281167
    step_size = num_training_samples // batch_size
    max_iter = (num_epochs - 1) * step_size
    callbacks = [
        ModelSaver(),
        ScheduledHyperParamSetter(
            'learning_rate',
            [(0, 0.5), (max_iter, 0)],
            interp='linear',
            step_based=True),
        EstimatedTimeLeft()]

    infs = [ClassificationError('wrong-top1', 'val-error-top1'),
            ClassificationError('wrong-top5', 'val-error-top5')]
    if num_towers == 1:
        # single-GPU inference with queue prefetch
        callbacks.append(InferenceRunner(
            input=QueueInput(val_dataflow),
            infs=infs))
    else:
        # multi-GPU inference (with mandatory queue prefetch)
        callbacks.append(DataParallelInferenceRunner(
            input=val_dataflow,
            infs=infs,
            gpus=list(range(num_towers))))

    config = TrainConfig(
        dataflow=train_dataflow,
        model=net,
        callbacks=callbacks,
        session_init=session_init,
        steps_per_epoch=step_size,
        max_epoch=num_epochs)

    launch_train_with_config(
        config=config,
        trainer=SyncMultiGPUTrainerParameterServer(num_towers))
Пример #2
0
def get_config(model, conf):
    nr_tower = max(get_nr_gpu(), 1)
    batch = conf.batch
    if conf.fake:
        logger.info("For benchmark, batch size is fixed to 64 per tower.")
        dataset_train = FakeData([[64, 224, 224, 3], [64]],
                                 1000,
                                 random=False,
                                 dtype='uint8')
        callbacks = []
    else:
        logger.info("Running on {} towers. Batch size per tower: {}".format(
            nr_tower, batch))
        dataset_train = get_data(conf.data_dir, 'train', batch)
        dataset_val = get_data(conf.data_dir, 'val', batch)
        callbacks = [
            ModelSaver(),
            ScheduledHyperParamSetter('learning_rate', [(45, 1e-2), (60, 1e-3),
                                                        (65, 1e-4), (70, 1e-5),
                                                        (75, 1e-6)]),
            HumanHyperParamSetter('learning_rate'),
        ]
        infs = [
            ClassificationError('wrong-top1', 'val-error-top1'),
            ClassificationError('wrong-top5', 'val-error-top5')
        ]
        if nr_tower == 1:
            # single-GPU inference with queue prefetch
            callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
        else:
            # multi-GPU inference (with mandatory queue prefetch)
            callbacks.append(
                DataParallelInferenceRunner(dataset_val, infs,
                                            list(range(nr_tower))))
    return TrainConfig(model=model,
                       dataflow=dataset_train,
                       callbacks=callbacks,
                       steps_per_epoch=5000,
                       max_epoch=80,
                       nr_tower=nr_tower)
Пример #3
0
def train(checkpoint_dir,
          model_name,
          dataset,
          num_epochs,
          quant_type,
          batch_size_per_gpu,
          lr=None,
          post_quantize_only=False):
    train_data, test_data, (img_shape,
                            label_shape) = datasets.DATASETS[dataset]()

    num_gpus = max(gpu.get_num_gpu(), 1)
    effective_batch_size = batch_size_per_gpu * num_gpus
    train_data = BatchData(train_data, batch_size_per_gpu)
    test_data = BatchData(test_data, batch_size_per_gpu, remainder=True)
    steps_per_epoch = len(train_data) // num_gpus

    if lr:
        if isinstance(lr, str):
            lr = ast.literal_eval(lr)
        if isinstance(lr, float):
            lr_schedule = [(0, lr)]
        else:
            lr_schedule = lr
    else:
        lr_schedule = [(0, 0.005), (8, 0.1), (25, 0.005), (30, 0)]

    if num_epochs is None:
        num_epochs = lr_schedule[-1][0]
    if post_quantize_only:
        start_quantising_at_epoch = 0
    else:
        start_quantising_at_epoch = lr_schedule[-2][0] if len(
            lr_schedule) > 1 else max(0, num_epochs - 5)

    logger.info(f"Training with LR schedule: {str(lr_schedule)}")
    logger.info(f"Quantising at epoch {start_quantising_at_epoch}")

    # train_data = FakeData([(batch_size_per_gpu,) + img_shape, (batch_size_per_gpu, ) + label_shape])

    model_func, input_spec, output_spec = get_model_func(
        "train",
        model_name,
        quant_type,
        img_shape,
        num_classes=label_shape[0],
        quant_delay=steps_per_epoch * start_quantising_at_epoch)
    target_spec = [
        tf.TensorSpec(t.shape, t.dtype, name=t.name.split("/")[-1] + "_target")
        for t in output_spec
    ]
    model = KerasModel(get_model=model_func,
                       input_signature=input_spec,
                       target_signature=target_spec,
                       input=train_data,
                       trainer=SyncMultiGPUTrainerParameterServer(
                           num_gpus, ps_device='gpu'))

    lr = tf.get_variable('learning_rate',
                         initializer=lr_schedule[0][1],
                         trainable=False)
    tf.summary.scalar('learning_rate-summary', lr)
    model.compile(optimizer=tf.train.MomentumOptimizer(learning_rate=lr,
                                                       momentum=0.9),
                  loss="categorical_crossentropy",
                  metrics=["categorical_accuracy"])

    model.fit(steps_per_epoch=steps_per_epoch,
              max_epoch=num_epochs,
              callbacks=[
                  ModelSaver(max_to_keep=1, checkpoint_dir=checkpoint_dir),
                  DataParallelInferenceRunner(
                      test_data, ScalarStats(model._stats_to_inference),
                      num_gpus),
                  ScheduledHyperParamSetter('learning_rate',
                                            lr_schedule,
                                            interp="linear"),
                  StatMonitorParamSetter('learning_rate',
                                         'validation_categorical_accuracy',
                                         lambda x: x / 2,
                                         threshold=0.001,
                                         last_k=10,
                                         reverse=True)
              ],
              session_init=SaverRestore(checkpoint_dir + "/checkpoint")
              if post_quantize_only else None)