예제 #1
0
파일: train.py 프로젝트: dm-mch/nnpix
def train(train_cfg, model_cfg, common_cfg, dataflow):
    epochs = train_cfg.epochs
    assert epochs, epochs
    epoch_size = train_cfg.epoch_size
    assert epoch_size, epoch_size
    config = TrainConfig(
        model=ModelWrapSingle(train_cfg, model_cfg, common_cfg),
        dataflow=dataflow,
        #data=my_inputsource, # alternatively, use a customized InputSource
        #callbacks=[...],    # some default callbacks are automatically applied
        # some default monitors are automatically applied
        steps_per_epoch=
        epoch_size,  # default to the size of your InputSource/DataFlow
        max_epoch=epochs)
    print("Create trainer")
    trainer = SimpleTrainer()
    print("Run train")
    launch_train_with_config(config, trainer)
예제 #2
0
    if save_dir is None:
        logger.auto_set_dir()
    else:
        logger.set_logger_dir(save_dir)

    dataset_train = get_data('train')
    dataset_test = get_data('test')

    config = TrainConfig(
        model=CifarResNet(n=NUM_UNITS,
                          mult_decay=mult_decay,
                          lr_init=lr_base * 0.1),
        dataflow=dataset_train,
        callbacks=[
            ModelSaver(),
            InferenceRunner(
                dataset_test,
                [ScalarStats('cost'),
                 ClassificationError('wrong_vector')]),
            ScheduledHyperParamSetter('learning_rate',
                                      [(1, lr_base), (82, lr_base * 0.1),
                                       (123, lr_base * 0.01),
                                       (164, lr_base * 0.002)])
        ],
        max_epoch=200,
        session_init=SmartInit(args.load),
    )
    num_gpu = max(get_num_gpu(), 1)
    launch_train_with_config(config,
                             SyncMultiGPUTrainerParameterServer(num_gpu))