Ejemplo n.º 1
0
def jobman(state, channel):
    # load dataset
    _train_data = ListSequences(path=state['path'],
                                pca=state['pca'],
                                subset=state['subset'],
                                which='train',
                                one_hot=False,
                                nbits=32)
    train_data = _train_data.export_dense_format(
        sequence_length=state['seqlen'],
        overlap=state['overlap'])

    valid_data = ListSequences(
        path = state['path'],
        pca=state['pca'],
        subset=state['subset'],
        which='valid',
        one_hot=False,
        nbits=32)
    model = biRNN(
        nhids=state['nhids'],
        nouts=numpy.max(train_data.data_y)+1,
        nins=train_data.data_x.shape[-1],
        activ = TT.nnet.sigmoid,
        seed = state['seed'],
        bs = state['bs'],
        seqlen = state['seqlen'])

    algo = SGD(model, state, train_data)

    main = MainLoop(train_data,valid_data, None, model, algo, state, channel)
    main.main()
Ejemplo n.º 2
0
def jobman(state, channel):
    rng = numpy.random.RandomState(state['seed'])
    model = DBMinpainting(state)
    data = DataMNIST(state['path'], state['mbs'], state['bs'], rng,
                     same_batch=state['samebatch'],
                     callback=model.callback)

    algo = natSGD(model, state, data)
    main = MainLoop(data, model, algo, state, channel)
    main.main()
Ejemplo n.º 3
0
def jobman(state, channel):
    rng = numpy.random.RandomState(state['seed'])
    data = DataMNIST(state['path'], state['mbs'], state['bs'], rng,
                     state['unlabled'])
    model = convMat(state, data)
    if state['natSGD'] == 0:
        algo = SGD(model, state, data)
    else:
        algo = natSGD(model, state, data)

    main = MainLoop(data, model, algo, state, channel)
    main.main()
Ejemplo n.º 4
0
def jobman(state, channel):
    rng = numpy.random.RandomState(state['seed'])
    model = DBMinpainting(state)
    data = DataMNIST(state['path'],
                     state['mbs'],
                     state['bs'],
                     rng,
                     same_batch=state['samebatch'],
                     callback=model.callback)

    algo = natSGD(model, state, data)
    main = MainLoop(data, model, algo, state, channel)
    main.main()
Ejemplo n.º 5
0
def jobman(state, channel):
    rng = numpy.random.RandomState(state['seed'])
    data = DataMNIST(
        state['path'],
        state['mbs'],
        state['bs'],
        rng,
        state['unlabled'])
    model = mlp(state, data)
    if state['natSGD'] == 0:
        algo = SGD(model, state, data)
    else:
        algo = natSGD(model, state, data)
    main = MainLoop(data, model, algo, state, channel)
    main.main()
Ejemplo n.º 6
0
def main():
    args = parse_args()

    state = getattr(experiments.nmt, args.proto)()
    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, args.skip_init)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()

    logger.debug("Load data")
    train_data = get_batch_iterator(state)
    logger.debug("Compile trainer")
    algo = eval(state['algo'])(lm_model, state, train_data)
    logger.debug("Run training")
    main = MainLoop(train_data, None, None, lm_model, algo, state, None,
            reset=state['reset'],
            hooks=[RandomSamplePrinter(state, lm_model, train_data)]
                if state['hookFreq'] >= 0
                else None,
            valid=validate_translation)
    if state['reload']:
        main.load()
    if state['loopIters'] > 0:
        main.main()