예제 #1
0
파일: dl-conv.py 프로젝트: Raysor/deepnl
def create_trainer(args, converter, labels):
    """
    Creates or loads a neural network according to the specified args.
    :param labels: dict of labels.
    """

    logger = logging.getLogger("Logger")

    if args.load:
        logger.info("Loading provided network...")
        trainer = ConvTrainer.load(args.load)
        trainer.learning_rate = args.learning_rate
        trainer.threads = args.threads
    else:
        logger.info('Creating new network...')
        trainer = ConvTrainer(converter, args.learning_rate,
                              args.window/2, args.window/2,
                              args.hidden, labels, args.verbose)

    trainer.saver = saver(args.model, args.output)

    logger.info("... with the following parameters:")
    logger.info(trainer.nn.description())
    
    return trainer
예제 #2
0
파일: dl-conv.py 프로젝트: cerisara/deepnl
def create_trainer(args, converter, labels):
    """
    Creates or loads a neural network according to the specified args.
    :param labels: list of labels.
    """

    logger = logging.getLogger("Logger")

    if args.load:
        logger.info("Loading provided network...")
        trainer = ConvTrainer.load(args.load)
        # change learning rate
        trainer.learning_rate = args.learning_rate
        trainer.threads = args.threads
    else:
        logger.info("Creating new network...")
        # sum the number of features in all extractors' tables
        feat_size = converter.size()
        pool_size = args.window
        nn = ConvolutionalNetwork(feat_size * pool_size, args.hidden, args.hidden2, len(labels), pool_size)
        options = {
            "learning_rate": args.learning_rate,
            "verbose": args.verbose,
            "left_context": args.window / 2,
            "right_context": args.window / 2,
        }
        trainer = ConvTrainer(nn, converter, labels, options)

    trainer.saver = saver(args.model, args.vectors, args.variant)

    logger.info("... with the following parameters:")
    logger.info(trainer.nn.description())

    return trainer
예제 #3
0
def create_trainer(args, converter, labels):
    """
    Creates or loads a neural network according to the specified args.
    :param labels: dict of labels.
    """

    logger = logging.getLogger("Logger")

    if args.load:
        logger.info("Loading provided network...")
        trainer = ConvTrainer.load(args.load)
        trainer.learning_rate = args.learning_rate
        trainer.threads = args.threads
    else:
        logger.info('Creating new network...')
        trainer = ConvTrainer(converter, args.learning_rate, args.window / 2,
                              args.window / 2, args.hidden, labels,
                              args.verbose)

    trainer.saver = saver(args.model, args.output)

    logger.info("... with the following parameters:")
    logger.info(trainer.nn.description())

    return trainer
예제 #4
0
def create_trainer(args, converter, labels):
    """
    Creates or loads a neural network according to the specified args.
    :param labels: list of labels.
    """

    logger = logging.getLogger("Logger")

    if args.load:
        logger.info("Loading provided network...")
        trainer = ConvTrainer.load(args.load)
        # change learning rate
        trainer.learning_rate = args.learning_rate
        trainer.threads = args.threads
    else:
        logger.info('Creating new network...')
        # sum the number of features in all extractors' tables
        feat_size = converter.size()
        pool_size = args.window * 2 + 1
        nn = ConvolutionalNetwork(feat_size * pool_size, args.hidden,
                                  args.hidden2, len(labels), pool_size)
        options = {
            'learning_rate': args.learning_rate,
            'eps': args.eps,
            'verbose': args.verbose,
            'left_context': args.window,
            'right_context': args.window
        }
        trainer = ConvTrainer(nn, converter, labels, options)

    trainer.saver = saver(args.model, args.vectors, args.variant)

    logger.info("... with the following parameters:")
    logger.info(trainer.nn.description())

    return trainer