Beispiel #1
0
def train(algorithm, learning_rate, clipping, momentum, layer_size, epochs,
          test_cost, experiment_path, initialization, init_width, weight_noise,
          z_prob, z_prob_states, z_prob_cells, drop_prob_igates,
          ogates_zoneout, batch_size, stoch_depth, share_mask, gaussian_drop,
          rnn_type, num_layers, norm_cost_coeff, penalty, testing, seq_len,
          decrease_lr_after_epoch, lr_decay, **kwargs):

    print '.. PTB experiment'
    print '.. arguments:', ' '.join(sys.argv)
    t0 = time.time()

    ###########################################
    #
    # LOAD DATA
    #
    ###########################################

    def onehot(x, numclasses=None):
        """ Convert integer encoding for class-labels (starting with 0 !)
            to one-hot encoding.
            The output is an array whose shape is the shape of the input array
            plus an extra dimension, containing the 'one-hot'-encoded labels.
        """
        if x.shape == ():
            x = x[None]
        if numclasses is None:
            numclasses = x.max() + 1
        result = numpy.zeros(list(x.shape) + [numclasses], dtype="int")
        z = numpy.zeros(x.shape, dtype="int")
        for c in range(numclasses):
            z *= 0
            z[numpy.where(x == c)] = 1
            result[..., c] += z
        return result.astype(theano.config.floatX)

    alphabetsize = 10000
    data = np.load('penntree_char_and_word.npz')
    trainset = data['train_words']
    validset = data['valid_words']
    testset = data['test_words']

    if testing:
        trainset = trainset[:3000]
        validset = validset[:3000]

    if share_mask:
        if not z_prob:
            raise ValueError('z_prob must be provided when using share_mask')
        if z_prob_cells or z_prob_states:
            raise ValueError(
                'z_prob_states and z_prob_cells must not be provided when using share_mask (use z_prob instead)'
            )
        z_prob_cells = z_prob
        # we don't want to actually use these masks, so this is to debug
        z_prob_states = None
    else:
        if z_prob:
            raise ValueError('z_prob is only used with share_mask')
        z_prob_cells = z_prob_cells or '1'
        z_prob_states = z_prob_states or '1'


#    rng = np.random.RandomState(seed)

###########################################
#
# MAKE STREAMS
#
###########################################

    def prep_dataset(dataset):
        dataset = dataset[:(len(dataset) - (len(dataset) %
                                            (seq_len * batch_size)))]
        dataset = dataset.reshape(batch_size, -1, seq_len).transpose((1, 0, 2))

        stream = DataStream(
            IndexableDataset(indexables=OrderedDict([('data', dataset)])),
            iteration_scheme=SequentialExampleScheme(dataset.shape[0]))
        stream = Transpose(stream, [(1, 0)])
        stream = SampleDropsNPWord(stream, z_prob_states, z_prob_cells,
                                   drop_prob_igates, layer_size, num_layers,
                                   False, stoch_depth, share_mask,
                                   gaussian_drop, alphabetsize)
        stream.sources = ('data', ) * 3 + stream.sources + (
            'zoneouts_states', 'zoneouts_cells', 'zoneouts_igates')
        return (stream, )

    train_stream, = prep_dataset(trainset)
    valid_stream, = prep_dataset(validset)
    test_stream, = prep_dataset(testset)

    ####################

    data = train_stream.get_epoch_iterator(as_dict=True).next()

    ####################

    ###########################################
    #
    # BUILD MODEL
    #
    ###########################################
    print '.. building model'

    x = T.tensor3('data')
    y = x
    zoneouts_states = T.tensor3('zoneouts_states')
    zoneouts_cells = T.tensor3('zoneouts_cells')
    zoneouts_igates = T.tensor3('zoneouts_igates')

    x.tag.test_value = data['data']
    zoneouts_states.tag.test_value = data['zoneouts_states']
    zoneouts_cells.tag.test_value = data['zoneouts_cells']
    zoneouts_igates.tag.test_value = data['zoneouts_igates']

    if init_width and not initialization == 'uniform':
        raise ValueError('Width is only for uniform init, whassup?')

    if initialization == 'glorot':
        weights_init = NormalizedInitialization()
    elif initialization == 'uniform':
        weights_init = Uniform(width=init_width)
    elif initialization == 'ortho':
        weights_init = OrthogonalInitialization()
    else:
        raise ValueError('No such initialization')

    if rnn_type.lower() == 'lstm':
        in_to_hids = [
            Linear(layer_size if l > 0 else alphabetsize,
                   layer_size * 4,
                   name='in_to_hid%d' % l,
                   weights_init=weights_init,
                   biases_init=Constant(0.0)) for l in range(num_layers)
        ]
        recurrent_layers = [
            DropLSTM(dim=layer_size,
                     weights_init=weights_init,
                     activation=Tanh(),
                     model_type=6,
                     name='rnn%d' % l,
                     ogates_zoneout=ogates_zoneout) for l in range(num_layers)
        ]
    elif rnn_type.lower() == 'gru':
        in_to_hids = [
            Linear(layer_size if l > 0 else alphabetsize,
                   layer_size * 3,
                   name='in_to_hid%d' % l,
                   weights_init=weights_init,
                   biases_init=Constant(0.0)) for l in range(num_layers)
        ]
        recurrent_layers = [
            DropGRU(dim=layer_size,
                    weights_init=weights_init,
                    activation=Tanh(),
                    name='rnn%d' % l) for l in range(num_layers)
        ]
    elif rnn_type.lower() == 'srnn':  # FIXME!!! make ReLU
        in_to_hids = [
            Linear(layer_size if l > 0 else alphabetsize,
                   layer_size,
                   name='in_to_hid%d' % l,
                   weights_init=weights_init,
                   biases_init=Constant(0.0)) for l in range(num_layers)
        ]
        recurrent_layers = [
            DropSimpleRecurrent(dim=layer_size,
                                weights_init=weights_init,
                                activation=Rectifier(),
                                name='rnn%d' % l) for l in range(num_layers)
        ]
    else:
        raise NotImplementedError

    hid_to_out = Linear(layer_size,
                        alphabetsize,
                        name='hid_to_out',
                        weights_init=weights_init,
                        biases_init=Constant(0.0))

    for layer in in_to_hids:
        layer.initialize()
    for layer in recurrent_layers:
        layer.initialize()
    hid_to_out.initialize()

    layer_input = x  #in_to_hid.apply(x)

    init_updates = OrderedDict()
    for l, (in_to_hid, layer) in enumerate(zip(in_to_hids, recurrent_layers)):
        rnn_embedding = in_to_hid.apply(layer_input)
        if rnn_type.lower() == 'lstm':
            states_init = theano.shared(
                np.zeros((batch_size, layer_size), dtype=floatX))
            cells_init = theano.shared(
                np.zeros((batch_size, layer_size), dtype=floatX))
            states_init.name, cells_init.name = "states_init", "cells_init"
            states, cells = layer.apply(
                rnn_embedding,
                zoneouts_states[:, :, l * layer_size:(l + 1) * layer_size],
                zoneouts_cells[:, :, l * layer_size:(l + 1) * layer_size],
                zoneouts_igates[:, :, l * layer_size:(l + 1) * layer_size],
                states_init, cells_init)
            init_updates.update([(states_init, states[-1]),
                                 (cells_init, cells[-1])])
        elif rnn_type.lower() in ['gru', 'srnn']:
            # untested!
            states_init = theano.shared(
                np.zeros((batch_size, layer_size), dtype=floatX))
            states_init.name = "states_init"
            states = layer.apply(rnn_embedding, zoneouts_states,
                                 zoneouts_igates, states_init)
            init_updates.update([(states_init, states[-1])])
        else:
            raise NotImplementedError
        layer_input = states

    y_hat_pre_softmax = hid_to_out.apply(T.join(0, [states_init], states[:-1]))
    shape_ = y_hat_pre_softmax.shape
    y_hat = Softmax().apply(y_hat_pre_softmax.reshape((-1, alphabetsize)))

    ####################

    ###########################################
    #
    # SET UP COSTS AND MONITORS
    #
    ###########################################

    cost = CategoricalCrossEntropy().apply(y.reshape((-1, alphabetsize)),
                                           y_hat).copy('cost')

    bpc = (cost / np.log(2.0)).copy(name='bpr')
    perp = T.exp(cost).copy(name='perp')

    cost_train = cost.copy(name='train_cost')
    cg_train = ComputationGraph([cost_train])

    ###########################################
    #
    # NORM STABILIZER
    #
    ###########################################
    norm_cost = 0.

    def _magnitude(x, axis=-1):
        return T.sqrt(
            T.maximum(T.sqr(x).sum(axis=axis),
                      numpy.finfo(x.dtype).tiny))

    if penalty == 'cells':
        assert VariableFilter(roles=[MEMORY_CELL])(cg_train.variables)
        for cell in VariableFilter(roles=[MEMORY_CELL])(cg_train.variables):
            norms = _magnitude(cell)
            norm_cost += T.mean(
                T.sum((norms[1:] - norms[:-1])**2, axis=0) / (seq_len - 1))
    elif penalty == 'hids':
        for l in range(num_layers):
            assert 'rnn%d_apply_states' % l in [
                o.name
                for o in VariableFilter(roles=[OUTPUT])(cg_train.variables)
            ]
        for output in VariableFilter(roles=[OUTPUT])(cg_train.variables):
            for l in range(num_layers):
                if output.name == 'rnn%d_apply_states' % l:
                    norms = _magnitude(output)
                    norm_cost += T.mean(
                        T.sum((norms[1:] - norms[:-1])**2, axis=0) /
                        (seq_len - 1))

    norm_cost.name = 'norm_cost'
    #cost_valid = cost_train
    cost_train += norm_cost_coeff * norm_cost
    cost_train = cost_train.copy(
        'cost_train')  #should this be cost_train.outputs[0]? no.

    cg_train = ComputationGraph([cost_train])

    ###########################################
    #
    # WEIGHT NOISE
    #
    ###########################################

    if weight_noise > 0:
        weights = VariableFilter(roles=[WEIGHT])(cg_train.variables)
        cg_train = apply_noise(cg_train, weights, weight_noise)
        cost_train = cg_train.outputs[0].copy(name='cost_train')

    model = Model(cost_train)

    learning_rate = float(learning_rate)
    clipping = StepClipping(threshold=np.cast[floatX](clipping))
    if algorithm == 'adam':
        adam = Adam(learning_rate=learning_rate)
        learning_rate = adam.learning_rate
        step_rule = CompositeRule([adam, clipping])
    elif algorithm == 'rms_prop':
        rms_prop = RMSProp(learning_rate=learning_rate)
        learning_rate = rms_prop.learning_rate
        step_rule = CompositeRule([clipping, rms_prop])
    elif algorithm == 'momentum':
        sgd_momentum = Momentum(learning_rate=learning_rate, momentum=momentum)
        learning_rate = sgd_momentum.learning_rate
        step_rule = CompositeRule([clipping, sgd_momentum])
    elif algorithm == 'sgd':
        sgd = Scale(learning_rate=learning_rate)
        learning_rate = sgd.learning_rate
        step_rule = CompositeRule([clipping, sgd])
    else:
        raise NotImplementedError
    algorithm = GradientDescent(step_rule=step_rule,
                                cost=cost_train,
                                parameters=cg_train.parameters)
    # theano_func_kwargs={"mode": theano.compile.MonitorMode(post_func=detect_nan)})

    algorithm.add_updates(init_updates)

    def cond_number(x):
        _, _, sing_vals = T.nlinalg.svd(x, True, True)
        sing_mags = abs(sing_vals)
        return T.max(sing_mags) / T.min(sing_mags)

    def rms(x):
        return (x * x).mean().sqrt()

    whysplode_cond = []
    whysplode_rms = []
    for i, p in enumerate(init_updates):
        v = p.get_value()
        if p.get_value().shape == 2:
            whysplode_cond.append(
                cond_number(p).copy(
                    'ini%d:%s_cond(%s)' %
                    (i, p.name, "x".join(map(str,
                                             p.get_value().shape)))))
        whysplode_rms.append(
            rms(p).copy('ini%d:%s_rms(%s)' %
                        (i, p.name, "x".join(map(str,
                                                 p.get_value().shape)))))
    for i, p in enumerate(cg_train.parameters):
        v = p.get_value()
        if p.get_value().shape == 2:
            whysplode_cond.append(
                cond_number(p).copy(
                    'ini%d:%s_cond(%s)' %
                    (i, p.name, "x".join(map(str,
                                             p.get_value().shape)))))
        whysplode_rms.append(
            rms(p).copy('ini%d:%s_rms(%s)' %
                        (i, p.name, "x".join(map(str,
                                                 p.get_value().shape)))))

    observed_vars = [
        cost_train, cost, bpc, perp, learning_rate,
        aggregation.mean(
            algorithm.total_gradient_norm).copy("gradient_norm_mean")
    ]  # + whysplode_rms

    parameters = model.get_parameter_dict()
    for name, param in parameters.iteritems():
        observed_vars.append(param.norm(2).copy(name=name + "_norm"))
        observed_vars.append(
            algorithm.gradients[param].norm(2).copy(name=name + "_grad_norm"))

    train_monitor = TrainingDataMonitoring(variables=observed_vars,
                                           prefix="train",
                                           after_epoch=True)

    dev_inits = [p.clone() for p in init_updates]
    cg_dev = ComputationGraph([cost, bpc, perp] +
                              init_updates.values()).replace(
                                  zip(init_updates.keys(), dev_inits))
    dev_cost, dev_bpc, dev_perp = cg_dev.outputs[:3]
    dev_init_updates = OrderedDict(zip(dev_inits, cg_dev.outputs[3:]))

    dev_monitor = DataStreamMonitoring(variables=[dev_cost, dev_bpc, dev_perp],
                                       data_stream=valid_stream,
                                       prefix="dev",
                                       updates=dev_init_updates)

    # noone does this
    if 'load_path' in kwargs:
        with open(kwargs['load_path']) as f:
            loaded = np.load(f)
            model = Model(cost_train)
            params_dicts = model.get_parameter_dict()
            params_names = params_dicts.keys()
            for param_name in params_names:
                param = params_dicts[param_name]
                # '/f_6_.W' --> 'f_6_.W'
                slash_index = param_name.find('/')
                param_name = param_name[slash_index + 1:]
                if param.get_value().shape == loaded[param_name].shape:
                    print 'Found: ' + param_name
                    param.set_value(loaded[param_name])
                else:
                    print 'Not found: ' + param_name

    extensions = []
    extensions.extend(
        [FinishAfter(after_n_epochs=epochs), train_monitor, dev_monitor])
    if test_cost:
        test_inits = [p.clone() for p in init_updates]
        cg_test = ComputationGraph([cost, bpc, perp] +
                                   init_updates.values()).replace(
                                       zip(init_updates.keys(), test_inits))
        test_cost, test_bpc, test_perp = cg_test.outputs[:3]
        test_init_updates = OrderedDict(zip(test_inits, cg_test.outputs[3:]))

        test_monitor = DataStreamMonitoring(
            variables=[test_cost, test_bpc, test_perp],
            data_stream=test_stream,
            prefix="test",
            updates=test_init_updates)
        extensions.extend([test_monitor])

    if not os.path.exists(experiment_path):
        os.makedirs(experiment_path)
    log_path = os.path.join(experiment_path, 'log.txt')
    fh = logging.FileHandler(filename=log_path)
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)

    extensions.append(
        SaveParams('dev_cost', model, experiment_path, every_n_epochs=1))
    extensions.append(SaveLog(every_n_epochs=1))
    extensions.append(ProgressBar())
    extensions.append(Printing())

    class RollsExtension(TrainingExtension):
        """ rolls the cell and state activations between epochs so that first batch gets correct initial activations """
        def __init__(self, shvars):
            self.shvars = shvars

        def before_epoch(self):
            for v in self.shvars:
                v.set_value(np.roll(v.get_value(), 1, 0))

    extensions.append(
        RollsExtension(init_updates.keys() + dev_init_updates.keys() +
                       (test_init_updates.keys() if test_cost else [])))

    class LearningRateSchedule(TrainingExtension):
        """ Lets you set a number to divide learning rate by each epoch + when to start doing that """
        def __init__(self):
            self.epoch_number = 0

        def after_epoch(self):
            self.epoch_number += 1
            if self.epoch_number > decrease_lr_after_epoch:
                learning_rate.set_value(learning_rate.get_value() / lr_decay)

    if bool(lr_decay) != bool(decrease_lr_after_epoch):
        raise ValueError(
            'Need to define both lr_decay and decrease_lr_after_epoch')
    if lr_decay and decrease_lr_after_epoch:
        extensions.append(LearningRateSchedule())

    main_loop = MainLoop(model=model,
                         data_stream=train_stream,
                         algorithm=algorithm,
                         extensions=extensions)
    t1 = time.time()
    print "Building time: %f" % (t1 - t0)

    main_loop.run()
    print "Execution time: %f" % (time.time() - t1)
Beispiel #2
0
def train(step_rule, layer_size, epochs, seed, experiment_path, initialization,
          weight_noise, to_watch, patience, z_prob, z_prob_states,
          z_prob_cells, drop_igates, ogates_zoneout, batch_size, stoch_depth,
          share_mask, gaussian_drop, rnn_type, num_layers, norm_cost_coeff,
          penalty, seq_len, input_drop, **kwargs):

    print '.. CharPTB experiment'
    print '.. arguments:', ' '.join(sys.argv)
    t0 = time.time()

    def numpy_rng(random_seed=None):
        if random_seed == None:
            random_seed = 1223
        return numpy.random.RandomState(random_seed)

    ###########################################
    #
    # MAKE STREAMS
    #
    ###########################################
    rng = np.random.RandomState(seed)
    stream_args = dict(rng=rng,
                       pool_size=pool_size,
                       maximum_frames=maximum_frames,
                       pretrain_alignment=pretrain_alignment,
                       uniform_alignment=uniform_alignment,
                       window_features=window_features)
    if share_mask:
        z_prob_cells = z_prob
        # we don't want to actually use these masks, so this is to debug
        z_prob_states = None

    print '.. initializing iterators'

    train_stream = get_ptb_stream('train', batch_size, seq_len, z_prob_states,
                                  z_prob_cells, z_prob_igates, layer_size,
                                  False)
    train_stream_evaluation = get_ptb_stream('train', batch_size, seq_len,
                                             z_prob_states, z_prob_cells,
                                             z_prob_igates, layer_size, True)
    dev_stream = get_ptb_stream('valid', batch_size, seq_len, z_prob_states,
                                z_prob_cells, z_prob_igates, layer_size, True)

    data = train_stream.get_epoch_iterator(as_dict=True).next()

    ###########################################
    #
    # BUILD MODEL
    #
    ###########################################

    print '.. building model'

    x = T.tensor3('features', dtype=floatX)
    x, y = x[:-1], x[1:]
    drops_states = T.tensor3('drops_states')
    drops_cells = T.tensor3('drops_cells')
    drops_igates = T.tensor3('drops_igates')

    x.tag.test_value = data['features']
    drops_states.tag.test_value = data['drops_states']
    drops_cells.tag.test_value = data['drops_cells']
    drops_igates.tag.test_value = data['drops_igates']

    if initialization == 'glorot':
        weights_init = NormalizedInitialization()
    elif initialization == 'uniform':
        weights_init = Uniform(width=.2)
    elif initialization == 'ortho':
        weights_init = OrthogonalInitialization()
    else:
        raise ValueError('No such initialization')

    if rnn_type.lower() == 'lstm':
        in_to_hid = Linear(50,
                           layer_size * 4,
                           name='in_to_hid',
                           weights_init=weights_init,
                           biases_init=Constant(0.0))
        recurrent_layer = ZoneoutLSTM(dim=layer_size,
                                      weights_init=weights_init,
                                      activation=Tanh(),
                                      model_type=6,
                                      name='rnn',
                                      ogates_zoneout=ogates_zoneout)
    elif rnn_type.lower() == 'gru':
        in_to_hid = Linear(50,
                           layer_size * 3,
                           name='in_to_hid',
                           weights_init=weights_init,
                           biases_init=Constant(0.0))
        recurrent_layer = ZoneoutGRU(dim=layer_size,
                                     weights_init=weights_init,
                                     activation=Tanh(),
                                     name='rnn')
    elif rnn_type.lower() == 'srnn':  #FIXME!!! make ReLU
        in_to_hid = Linear(50,
                           layer_size,
                           name='in_to_hid',
                           weights_init=weights_init,
                           biases_init=Constant(0.0))
        recurrent_layer = ZoneoutSimpleRecurrent(dim=layer_size,
                                                 weights_init=weights_init,
                                                 activation=Rectifier(),
                                                 name='rnn')
    else:
        raise NotImplementedError

    hid_to_out = Linear(layer_size,
                        50,
                        name='hid_to_out',
                        weights_init=weights_init,
                        biases_init=Constant(0.0))

    in_to_hid.initialize()
    recurrent_layer.initialize()
    hid_to_out.initialize()

    h = in_to_hid.apply(x)

    if rnn_type.lower() == 'lstm':
        yh = recurrent_layer.apply(h, drops_states, drops_cells,
                                   drops_igates)[0]
    else:
        yh = recurrent_layer.apply(h, drops_states, drops_cells, drops_igates)

    y_hat_pre_softmax = hid_to_out.apply(yh)
    shape_ = y_hat_pre_softmax.shape

    # y_hat = Softmax().apply(
    #     y_hat_pre_softmax.reshape((-1, shape_[-1])))# .reshape(shape_)

    ####################

    ###########################################
    #
    # SET UP COSTS AND MONITORS
    #
    ###########################################

    def crossentropy_lastaxes(yhat, y):
        # for sequence of distributions/targets
        return -(y * T.log(yhat)).sum(axis=yhat.ndim - 1)

    def softmax_lastaxis(x):
        # for sequence of distributions
        return T.nnet.softmax(x.reshape((-1, x.shape[-1]))).reshape(x.shape)

    yhat = softmax_lastaxis(y_hat_pre_softmax)
    cross_entropies = crossentropy_lastaxes(yhat, y)
    cross_entropy = cross_entropies.mean().copy(name="cross_entropy")
    cost = cross_entropy.copy(name="cost")

    batch_cost = cost.copy(name='batch_cost')
    nll_cost = cost.copy(name='nll_cost')
    bpc = (nll_cost / np.log(2.0)).copy(name='bpr')

    #nll_cost = aggregation.mean(batch_cost, batch_size).copy(name='nll_cost')

    cost_monitor = aggregation.mean(
        batch_cost, batch_size).copy(name='sequence_cost_monitor')
    cost_per_character = aggregation.mean(
        batch_cost, (seq_len - 1) * batch_size).copy(name='character_cost')
    cost_train = cost.copy(name='train_batch_cost')
    cost_train_monitor = cost_monitor.copy('train_batch_cost_monitor')
    cg_train = ComputationGraph([cost_train, cost_train_monitor])

    ###########################################
    #
    # NORM STABILIZER
    #
    ###########################################
    norm_cost = 0.

    def _magnitude(x, axis=-1):
        return T.sqrt(
            T.maximum(T.sqr(x).sum(axis=axis),
                      numpy.finfo(x.dtype).tiny))

    if penalty == 'cells':
        assert VariableFilter(roles=[MEMORY_CELL])(cg_train.variables)
        for cell in VariableFilter(roles=[MEMORY_CELL])(cg_train.variables):
            norms = _magnitude(cell)
            norm_cost += T.mean(
                T.sum((norms[1:] - norms[:-1])**2, axis=0) / (seq_len - 1))

    elif penalty == 'hids':
        assert 'rnn_apply_states' in [
            o.name for o in VariableFilter(roles=[OUTPUT])(cg_train.variables)
        ]
        for output in VariableFilter(roles=[OUTPUT])(cg_train.variables):
            if output.name == 'rnn_apply_states':
                norms = _magnitude(output)
                norm_cost += T.mean(
                    T.sum((norms[1:] - norms[:-1])**2, axis=0) / (seq_len - 1))

    norm_cost.name = 'norm_cost'
    #cost_valid = cost_train
    cost_train += norm_cost_coeff * norm_cost
    cost_train = cost_train.copy('cost_train')

    cg_train = ComputationGraph([cost_train,
                                 cost_train_monitor])  #, norm_cost])

    ###########################################
    #
    # WEIGHT NOISE
    #
    ###########################################

    if weight_noise > 0:
        weights = VariableFilter(roles=[WEIGHT])(cg_train.variables)
        cg_train = apply_noise(cg_train, weights, weight_noise)
        cost_train = cg_train.outputs[0].copy(name='cost_train')
        cost_train_monitor = cg_train.outputs[1].copy(
            'train_batch_cost_monitor')

    ###########################################
    #
    # MAKE MODEL
    #
    ###########################################

    model = Model(cost_train)
    train_cost_per_character = aggregation.mean(
        cost_train_monitor,
        (seq_len - 1) * batch_size).copy(name='train_character_cost')

    algorithm = GradientDescent(step_rule=step_rule,
                                cost=cost_train,
                                parameters=cg_train.parameters)

    observed_vars = [
        cost_train, cost_train_monitor, train_cost_per_character,
        aggregation.mean(algorithm.total_gradient_norm)
    ]
    train_monitor = TrainingDataMonitoring(variables=observed_vars,
                                           prefix="train",
                                           after_epoch=True)

    dev_monitor = DataStreamMonitoring(variables=[nll_cost, bpc],
                                       data_stream=dev_stream,
                                       prefix="dev")

    extensions = []

    ###########################################
    #
    # LOADING PRETRAINED MODELS (Mohammad Pezeshki)
    #
    ###########################################
    if 'load_path' in kwargs:
        with open(kwargs['load_path']) as f:
            loaded = np.load(f)
            model = Model(cost_train)
            params_dicts = model.get_parameter_dict()
            params_names = params_dicts.keys()
            for param_name in params_names:
                param = params_dicts[param_name]
                # '/f_6_.W' --> 'f_6_.W'
                slash_index = param_name.find('/')
                param_name = param_name[slash_index + 1:]
                if param.get_value().shape == loaded[param_name].shape:
                    print 'Found: ' + param_name
                    param.set_value(loaded[param_name])
                else:
                    print 'Not found: ' + param_name

    ###########################################
    #
    # MOAR EXTENSIONS
    #
    ###########################################
    extensions.extend(
        [FinishAfter(after_n_epochs=epochs), train_monitor, dev_monitor])
    #train_ctc_monitor,
    #dev_ctc_monitor])

    if test_cost:
        test_monitor = DataStreamMonitoring(
            variables=[cost_monitor, cost_per_character],
            data_stream=test_stream,
            prefix="test")
        extensions.append(test_monitor)

    if not os.path.exists(experiment_path):
        os.makedirs(experiment_path)
    log_path = os.path.join(experiment_path, 'log.txt')
    fh = logging.FileHandler(filename=log_path)
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)

    extensions.append(
        SaveParams('dev_nll_cost', model, experiment_path, every_n_epochs=1))
    extensions.append(SaveLog(every_n_epochs=1))
    extensions.append(ProgressBar())
    extensions.append(Printing())

    ###########################################
    #
    # MAIN LOOP
    #
    ###########################################
    main_loop = MainLoop(model=model,
                         data_stream=train_stream,
                         algorithm=algorithm,
                         extensions=extensions)
    t1 = time.time()
    print "Building time: %f" % (t1 - t0)

    main_loop.run()
    print "Execution time: %f" % (time.time() - t1)
def train(step_rule, state_dim, epochs, seed, experiment_path, initialization,
          to_watch, patience, static_mask, batch_size, rnn_type, num_layers,
          augment, seq_len, drop_prob, drop_prob_states, drop_prob_cells,
          drop_prob_igates, ogates_zoneout, stoch_depth, share_mask,
          gaussian_drop, weight_noise, norm_cost_coeff, penalty, input_drop,
          **kwargs):

    print '.. cPTB experiment'
    print '.. arguments:', ' '.join(sys.argv)
    t0 = time.time()

    def numpy_rng(random_seed=None):
        if random_seed == None:
            random_seed = 1223
        return numpy.random.RandomState(random_seed)

    ###########################################
    #
    # MAKE DATA STREAMS
    #
    ###########################################
    rng = np.random.RandomState(seed)

    if share_mask:
        drop_prob_cells = drop_prob
        # we don't want to actually use these masks, so this is to debug
        drop_prob_states = None

    print '.. initializing iterators'

    if static_mask:
        train_stream = get_static_mask_ptb_stream('train',
                                                  batch_size,
                                                  seq_len,
                                                  drop_prob_states,
                                                  drop_prob_cells,
                                                  drop_prob_igates,
                                                  state_dim,
                                                  False,
                                                  augment=augment)
        train_stream_evaluation = get_static_mask_ptb_stream('train',
                                                             batch_size,
                                                             seq_len,
                                                             drop_prob_states,
                                                             drop_prob_cells,
                                                             drop_prob_igates,
                                                             state_dim,
                                                             True,
                                                             augment=augment)
        dev_stream = get_static_mask_ptb_stream('valid',
                                                batch_size,
                                                seq_len,
                                                drop_prob_states,
                                                drop_prob_cells,
                                                drop_prob_igates,
                                                state_dim,
                                                True,
                                                augment=augment)
    else:
        train_stream = get_ptb_stream('train',
                                      batch_size,
                                      seq_len,
                                      drop_prob_states,
                                      drop_prob_cells,
                                      drop_prob_igates,
                                      state_dim,
                                      False,
                                      augment=augment)
        train_stream_evaluation = get_ptb_stream('train',
                                                 batch_size,
                                                 seq_len,
                                                 drop_prob_states,
                                                 drop_prob_cells,
                                                 drop_prob_igates,
                                                 state_dim,
                                                 True,
                                                 augment=augment)
        dev_stream = get_ptb_stream('valid',
                                    batch_size,
                                    seq_len,
                                    drop_prob_states,
                                    drop_prob_cells,
                                    drop_prob_igates,
                                    state_dim,
                                    True,
                                    augment=augment)

    data = train_stream.get_epoch_iterator(as_dict=True).next()
    #import ipdb; ipdb.set_trace()

    ###########################################
    #
    # BUILD MODEL
    #
    ###########################################

    print '.. building model'

    x = T.tensor3('features', dtype=floatX)
    x, y = x[:-1], x[1:]
    drops_states = T.tensor3('drops_states')
    drops_cells = T.tensor3('drops_cells')
    drops_igates = T.tensor3('drops_igates')

    x.tag.test_value = data['features']
    #y.tag.test_value = data['outputs']
    drops_states.tag.test_value = data['drops_states']
    drops_cells.tag.test_value = data['drops_cells']
    drops_igates.tag.test_value = data['drops_igates']

    if initialization == 'glorot':
        weights_init = NormalizedInitialization()
    elif initialization == 'uniform':
        weights_init = Uniform(width=.2)
    elif initialization == 'ortho':
        weights_init = OrthogonalInitialization()
    else:
        raise ValueError('No such initialization')

    if rnn_type.lower() == 'lstm':
        in_to_hid = Linear(50,
                           state_dim * 4,
                           name='in_to_hid',
                           weights_init=weights_init,
                           biases_init=Constant(0.0))
        recurrent_layer = ZoneoutLSTM(dim=state_dim,
                                      weights_init=weights_init,
                                      activation=Tanh(),
                                      model_type=6,
                                      name='rnn',
                                      ogates_zoneout=ogates_zoneout)
    elif rnn_type.lower() == 'gru':
        in_to_hid = Linear(50,
                           state_dim * 3,
                           name='in_to_hid',
                           weights_init=weights_init,
                           biases_init=Constant(0.0))
        recurrent_layer = ZoneoutGRU(dim=state_dim,
                                     weights_init=weights_init,
                                     activation=Tanh(),
                                     name='rnn')
    elif rnn_type.lower() == 'srnn':
        in_to_hid = Linear(50,
                           state_dim,
                           name='in_to_hid',
                           weights_init=weights_init,
                           biases_init=Constant(0.0))
        recurrent_layer = ZoneoutSimpleRecurrent(dim=state_dim,
                                                 weights_init=weights_init,
                                                 activation=Rectifier(),
                                                 name='rnn')
    else:
        raise NotImplementedError

    hid_to_out = Linear(state_dim,
                        50,
                        name='hid_to_out',
                        weights_init=weights_init,
                        biases_init=Constant(0.0))

    in_to_hid.initialize()
    recurrent_layer.initialize()
    hid_to_out.initialize()

    h = in_to_hid.apply(x)

    if rnn_type.lower() == 'lstm':
        yh = recurrent_layer.apply(h, drops_states, drops_cells,
                                   drops_igates)[0]
    else:
        yh = recurrent_layer.apply(h, drops_states, drops_cells, drops_igates)

    y_hat_pre_softmax = hid_to_out.apply(yh)
    shape_ = y_hat_pre_softmax.shape

    # y_hat = Softmax().apply(
    #     y_hat_pre_softmax.reshape((-1, shape_[-1])))# .reshape(shape_)

    ###########################################
    #
    # SET UP COSTS, MONITORS, and REGULARIZATION
    #
    ###########################################

    # cost = CategoricalCrossEntropy().apply(y.flatten().astype('int64'), y_hat)

    def crossentropy_lastaxes(yhat, y):
        # for sequence of distributions/targets
        return -(y * T.log(yhat)).sum(axis=yhat.ndim - 1)

    def softmax_lastaxis(x):
        # for sequence of distributions
        return T.nnet.softmax(x.reshape((-1, x.shape[-1]))).reshape(x.shape)

    yhat = softmax_lastaxis(y_hat_pre_softmax)
    cross_entropies = crossentropy_lastaxes(yhat, y)
    cross_entropy = cross_entropies.mean().copy(name="cross_entropy")
    cost = cross_entropy.copy(name="cost")

    batch_cost = cost.copy(name='batch_cost')
    nll_cost = cost.copy(name='nll_cost')
    bpc = (nll_cost / np.log(2.0)).copy(name='bpr')

    #nll_cost = aggregation.mean(batch_cost, batch_size).copy(name='nll_cost')

    cost_monitor = aggregation.mean(
        batch_cost, batch_size).copy(name='sequence_cost_monitor')
    cost_per_character = aggregation.mean(
        batch_cost, (seq_len - 1) * batch_size).copy(name='character_cost')
    cost_train = cost.copy(name='train_batch_cost')
    cost_train_monitor = cost_monitor.copy('train_batch_cost_monitor')
    cg_train = ComputationGraph([cost_train, cost_train_monitor])

    ##################
    # NORM STABILIZER
    ##################

    norm_cost = 0.

    def _magnitude(x, axis=-1):
        return T.sqrt(
            T.maximum(T.sqr(x).sum(axis=axis),
                      numpy.finfo(x.dtype).tiny))

    if penalty == 'cells':
        assert VariableFilter(roles=[MEMORY_CELL])(cg_train.variables)
        for cell in VariableFilter(roles=[MEMORY_CELL])(cg_train.variables):
            norms = _magnitude(cell)
            norm_cost += T.mean(
                T.sum((norms[1:] - norms[:-1])**2, axis=0) / (seq_len - 1))
            ## debugging nans stuff
            #gr = T.grad(norm_cost, cg_train.parameters, disconnected_inputs='ignore')
            #grf = theano.function([x, input_mask], gr)
            #grz = grf(x.tag.test_value, input_mask.tag.test_value)
            #params = cg_train.parameters
            #mynanz = [(pp, np.sum(gg)) for pp,gg in zip(params, grz) if np.isnan(np.sum(gg))]
            #for mm in mynanz: print mm
            ##import ipdb; ipdb.set_trace()
    elif penalty == 'hids':
        assert 'rnn_apply_states' in [
            o.name for o in VariableFilter(roles=[OUTPUT])(cg_train.variables)
        ]
        for output in VariableFilter(roles=[OUTPUT])(cg_train.variables):
            if output.name == 'rnn_apply_states':
                norms = _magnitude(output)
                norm_cost += T.mean(
                    T.sum((norms[1:] - norms[:-1])**2, axis=0) / (seq_len - 1))

    norm_cost.name = 'norm_cost'

    cost_train += norm_cost_coeff * norm_cost
    cost_train = cost_train.copy(
        'cost_train')  #should this be cost_train.outputs[0]?

    cg_train = ComputationGraph([cost_train,
                                 cost_train_monitor])  #, norm_cost])

    ##################
    # WEIGHT NOISE
    ##################

    if weight_noise > 0:
        weights = VariableFilter(roles=[WEIGHT])(cg_train.variables)
        cg_train = apply_noise(cg_train, weights, weight_noise)
        cost_train = cg_train.outputs[0].copy(name='cost_train')
        cost_train_monitor = cg_train.outputs[1].copy(
            'train_batch_cost_monitor')

    # if 'l2regularization' in kwargs:
    #     weights = VariableFilter(roles=[WEIGHT])(cg_train.variables)
    #     cost_train += kwargs['l2regularization'] * sum([
    #         (weight ** 2).sum() for weight in weights])
    #     cost_train.name = 'cost_train'
    #     cg_train = ComputationGraph(cost_train)

    model = Model(cost_train)
    train_cost_per_character = aggregation.mean(
        cost_train_monitor,
        (seq_len - 1) * batch_size).copy(name='train_character_cost')

    algorithm = GradientDescent(step_rule=step_rule,
                                cost=cost_train,
                                parameters=cg_train.parameters)

    observed_vars = [
        cost_train, cost_train_monitor, train_cost_per_character,
        aggregation.mean(algorithm.total_gradient_norm)
    ]
    # parameters = model.get_parameter_dict()
    # for name, param in parameters.iteritems():
    #     observed_vars.append(param.norm(2).copy(name=name + "_norm"))
    #     observed_vars.append(
    #         algorithm.gradients[param].norm(2).copy(name=name + "_grad_norm"))
    train_monitor = TrainingDataMonitoring(variables=observed_vars,
                                           prefix="train",
                                           after_epoch=True)

    dev_monitor = DataStreamMonitoring(variables=[nll_cost, bpc],
                                       data_stream=dev_stream,
                                       prefix="dev")

    extensions = []
    if 'load_path' in kwargs:
        with open(kwargs['load_path']) as f:
            loaded = np.load(f)
            model = Model(cost_train)
            params_dicts = model.get_parameter_dict()
            params_names = params_dicts.keys()
            for param_name in params_names:
                param = params_dicts[param_name]
                # '/f_6_.W' --> 'f_6_.W'
                slash_index = param_name.find('/')
                param_name = param_name[slash_index + 1:]
                if param.get_value().shape == loaded[param_name].shape:
                    print 'Found: ' + param_name
                    param.set_value(loaded[param_name])
                else:
                    print 'Not found: ' + param_name

    extensions.extend(
        [FinishAfter(after_n_epochs=epochs), train_monitor, dev_monitor])

    if not os.path.exists(experiment_path):
        os.makedirs(experiment_path)
    log_path = os.path.join(experiment_path, 'log.txt')
    fh = logging.FileHandler(filename=log_path)
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)

    extensions.append(
        SaveParams('dev_nll_cost', model, experiment_path, every_n_epochs=1))
    extensions.append(SaveLog(every_n_epochs=1))
    extensions.append(ProgressBar())
    extensions.append(Printing())

    ###########################################
    #
    # MAIN LOOOOOOOOOOOP
    #
    ###########################################

    main_loop = MainLoop(model=model,
                         data_stream=train_stream,
                         algorithm=algorithm,
                         extensions=extensions)
    t1 = time.time()
    print "Building time: %f" % (t1 - t0)
    # if write_predictions:
    #     with open('predicted.txt', 'w') as f_pred:
    #         with open('targets.txt', 'w') as f_targets:
    #             evaluator = CTCEvaluator(
    #                 eol_symbol, x, input_mask, y_hat, phoneme_dict, black_list)
    #             evaluator.evaluate(dev_stream, file_pred=f_pred,
    #                                file_targets=f_targets)
    #     return
    main_loop.run()
    print "Execution time: %f" % (time.time() - t1)
Beispiel #4
0
def train(step_rule, input_dim, state_dim, label_dim, layers, epochs, seed,
          pretrain_alignment, uniform_alignment, dropout, beam_search,
          test_cost, experiment_path, window_features, features, pool_size,
          maximum_frames, initialization, weight_noise, to_watch, patience,
          plot, write_predictions, static_mask, drop_prob, drop_prob_states,
          drop_prob_cells, drop_prob_igates, ogates_zoneout, batch_size,
          stoch_depth, share_mask, gaussian_drop, rnn_type, num_layers,
          norm_cost_coeff, penalty, seq_len, input_drop, augment, **kwargs):

    print '.. PTB experiment'
    print '.. arguments:', ' '.join(sys.argv)
    t0 = time.time()

    ###########################################
    #
    # LOAD DATA
    #
    ###########################################

    def numpy_rng(random_seed=None):
        if random_seed == None:
            random_seed = 1223
        return numpy.random.RandomState(random_seed)

    #from utilities import onehot, unhot, vec2chars
    # from http://www.iro.umontreal.ca/~memisevr/code/logreg.py
    #def onehot(x,numclasses=None):
    #""" Convert integer encoding for class-labels (starting with 0 !)
    #to one-hot encoding.
    #The output is an array who's shape is the shape of the input array plus
    #an extra dimension, containing the 'one-hot'-encoded labels.
    #"""
    #if x.shape==():
    #x = x[None]
    #if numclasses is None:
    #numclasses = x.max() + 1
    #result = numpy.zeros(list(x.shape) + [numclasses], dtype="int")
    #z = numpy.zeros(x.shape, dtype="int")
    #for c in range(numclasses):
    #z *= 0
    #z[numpy.where(x==c)] = 1
    #result[...,c] += z
    #return result.astype(theano.config.floatX)

    #framelen = 1
    #50 = 50
    ##data = np.load(os.path.join(os.environ['FUEL_DATA_PATH'], 'PennTreebankCorpus/char_level_penntree.npz'))#pentree_char_and_word.npz')
    #data = np.load('char_level_penntree.npz')
    #trainset = data['train']
    #validset = data['valid']

    #allletters = " etanoisrhludcmfpkgybw<>\nvN.'xj$-qz&0193#285\\764/*"
    #dictionary = dict(zip(list(set(allletters)), range(50)))
    #invdict = {v: k for k, v in dictionary.items()}

    #numtrain = len(trainset) / seq_len * seq_len
    #numvalid = len(validset) / seq_len * seq_len
    #trainset = trainset[:numtrain]
    #validset = validset[:numvalid]
    ##if testing:
    ##    train_features_numpy = train_features_numpy[:32 * 5]
    ##    valid_features_numpy = valid_features_numpy[:100]
    #train_targets = trainset.reshape(-1, seq_len*framelen)[:,1:]
    #valid_targets = validset.reshape(-1, seq_len*framelen)[:,1:]
    ## still only 2d (b, t*n)
    #train_features_numpy = onehot(trainset).reshape(-1, 50*seq_len*framelen)[:,:-50]
    #valid_features_numpy = onehot(validset).reshape(-1, 50*seq_len*framelen)[:,:-50]
    #del trainset, validset
    #data_loaded = True
    #print '... done'
    #test_value = train_features_numpy[:32]

    ####################

    ###########################################
    #
    # MAKE STREAMS
    #
    ###########################################
    rng = np.random.RandomState(seed)
    stream_args = dict(rng=rng,
                       pool_size=pool_size,
                       maximum_frames=maximum_frames,
                       pretrain_alignment=pretrain_alignment,
                       uniform_alignment=uniform_alignment,
                       window_features=window_features)
    if share_mask:
        drop_prob_cells = drop_prob
        # we don't want to actually use these masks, so this is to debug
        drop_prob_states = None

    # the threes in here are because the number of layers is hardcoded to 3 atm. NIPS!
    print '.. initializing iterators'

    # train_stream, valid_stream = get_seq_mnist_streams(
    #    h_dim, batch_size, update_prob)
    if static_mask:
        train_stream = get_static_mask_ptb_stream('train',
                                                  batch_size,
                                                  seq_len,
                                                  drop_prob_states,
                                                  drop_prob_cells,
                                                  drop_prob_igates,
                                                  state_dim,
                                                  False,
                                                  augment=augment)
        train_stream_evaluation = get_static_mask_ptb_stream('train',
                                                             batch_size,
                                                             seq_len,
                                                             drop_prob_states,
                                                             drop_prob_cells,
                                                             drop_prob_igates,
                                                             state_dim,
                                                             True,
                                                             augment=augment)
        dev_stream = get_static_mask_ptb_stream('valid',
                                                batch_size,
                                                seq_len,
                                                drop_prob_states,
                                                drop_prob_cells,
                                                drop_prob_igates,
                                                state_dim,
                                                True,
                                                augment=augment)
    else:
        train_stream = get_ptb_stream('train',
                                      batch_size,
                                      seq_len,
                                      drop_prob_states,
                                      drop_prob_cells,
                                      drop_prob_igates,
                                      state_dim,
                                      False,
                                      augment=augment)
        train_stream_evaluation = get_ptb_stream('train',
                                                 batch_size,
                                                 seq_len,
                                                 drop_prob_states,
                                                 drop_prob_cells,
                                                 drop_prob_igates,
                                                 state_dim,
                                                 True,
                                                 augment=augment)
        dev_stream = get_ptb_stream('valid',
                                    batch_size,
                                    seq_len,
                                    drop_prob_states,
                                    drop_prob_cells,
                                    drop_prob_igates,
                                    state_dim,
                                    True,
                                    augment=augment)

    #train_dataset = Timit('train', features=features)
    # assert (train_features_numpy[:,-50:].sum(axis=-2)==1).all()
    #train_features_numpy = train_features_numpy.reshape(-1, seq_len-1, 50)#BTN for shuffled dataset?
    #train_dataset = IndexableDataset(indexables=OrderedDict(
    #[('features', train_features_numpy),
    #('outputs', train_targets)]))

    #train_stream = construct_stream_np(train_dataset, state_dim, batch_size, len(train_targets),
    #drop_prob_states, drop_prob_cells, drop_prob_igates,
    #num_layers=num_layers,
    #is_for_test=False, stoch_depth=stoch_depth, share_mask=share_mask,
    #gaussian_drop=gaussian_drop, input_drop=input_drop, **stream_args)
    ##dev_dataset = Timit('dev', features=features)
    #valid_features_numpy = valid_features_numpy.reshape(-1, seq_len-1,  50)
    #dev_dataset = IndexableDataset(indexables=OrderedDict(
    #[('features', valid_features_numpy),
    #('outputs', valid_targets)]))
    #dev_stream = construct_stream_np(dev_dataset, state_dim, batch_size, len(valid_targets),
    #drop_prob_states, drop_prob_cells, drop_prob_igates,
    #num_layers=num_layers,
    #is_for_test=True, stoch_depth=stoch_depth, share_mask=share_mask,
    #gaussian_drop=gaussian_drop, input_drop=input_drop, **stream_args)
    ##test_dataset = Timit('test', features=features)
    ##test_stream = construct_stream(test_dataset, state_dim, drop_prob_states, drop_prob_cells, drop_prob_igates,  3,
    ##                               is_for_test=True, stoch_depth=stoch_depth, share_mask=share_mask,
    ##                               gaussian_drop=gaussian_drop, **stream_args)
    data = train_stream.get_epoch_iterator(as_dict=True).next()
    #import ipdb; ipdb.set_trace()

    #phone_dict = train_dataset.get_phoneme_dict()
    #phoneme_dict = {k: phone_to_phoneme_dict[v]
    #                if v in phone_to_phoneme_dict else v
    #                for k, v in phone_dict.iteritems()}
    #ind_to_phoneme = {v: k for k, v in phoneme_dict.iteritems()}
    #eol_symbol = ind_to_phoneme['<STOP>']

    ####################

    ###########################################
    #
    # BUILD MODEL
    #
    ###########################################

    print '.. building model'

    x = T.tensor3('features', dtype=floatX)
    x, y = x[:-1], x[1:]  #T.lmatrix('outputs')# phonemes')
    drops_states = T.tensor3('drops_states')
    drops_cells = T.tensor3('drops_cells')
    drops_igates = T.tensor3('drops_igates')

    x.tag.test_value = data['features']
    #y.tag.test_value = data['outputs']
    drops_states.tag.test_value = data['drops_states']
    drops_cells.tag.test_value = data['drops_cells']
    drops_igates.tag.test_value = data['drops_igates']

    if initialization == 'glorot':
        weights_init = NormalizedInitialization()
    elif initialization == 'uniform':
        weights_init = Uniform(width=.2)
    elif initialization == 'ortho':
        weights_init = OrthogonalInitialization()
    else:
        raise ValueError('No such initialization')

    if rnn_type.lower() == 'lstm':
        in_to_hid = Linear(50,
                           state_dim * 4,
                           name='in_to_hid',
                           weights_init=weights_init,
                           biases_init=Constant(0.0))
        recurrent_layer = DropLSTM(dim=state_dim,
                                   weights_init=weights_init,
                                   activation=Tanh(),
                                   model_type=6,
                                   name='rnn',
                                   ogates_zoneout=ogates_zoneout)
    elif rnn_type.lower() == 'gru':
        in_to_hid = Linear(50,
                           state_dim * 3,
                           name='in_to_hid',
                           weights_init=weights_init,
                           biases_init=Constant(0.0))
        recurrent_layer = DropGRU(dim=state_dim,
                                  weights_init=weights_init,
                                  activation=Tanh(),
                                  name='rnn')
    elif rnn_type.lower() == 'srnn':  #FIXME!!! make ReLU
        in_to_hid = Linear(50,
                           state_dim,
                           name='in_to_hid',
                           weights_init=weights_init,
                           biases_init=Constant(0.0))
        recurrent_layer = DropSimpleRecurrent(dim=state_dim,
                                              weights_init=weights_init,
                                              activation=Rectifier(),
                                              name='rnn')
    else:
        raise NotImplementedError

    #lstm2 = DropLSTM(dim=state_dim, activation=Tanh(), model_type=6)

    #lstm3 = DropLSTM(dim=state_dim, activation=Tanh(), model_type=6)

    #encoder = DropMultiLayerEncoder(weights_init=weights_init,
    #biases_init=Constant(.0),
    #networks=[lstm1, lstm2, bidir3],
    #dims=[input_dim * window_features,
    #state_dim,
    #state_dim,
    #state_dim,
    #label_dim + 1])
    #encoder.initialize()
    #drops_states = [drops_forw_states, drops_back_states]
    #drops_cells = [drops_forw_cells, drops_back_cells]
    #drops_igates = [drops_forw_igates, drops_back_igates]
    hid_to_out = Linear(state_dim,
                        50,
                        name='hid_to_out',
                        weights_init=weights_init,
                        biases_init=Constant(0.0))

    in_to_hid.initialize()
    recurrent_layer.initialize()
    hid_to_out.initialize()

    h = in_to_hid.apply(x)

    if rnn_type.lower() == 'lstm':
        yh = recurrent_layer.apply(h, drops_states, drops_cells,
                                   drops_igates)[0]
    else:
        yh = recurrent_layer.apply(h, drops_states, drops_cells, drops_igates)

    y_hat_pre_softmax = hid_to_out.apply(yh)
    shape_ = y_hat_pre_softmax.shape

    # y_hat = Softmax().apply(
    #     y_hat_pre_softmax.reshape((-1, shape_[-1])))# .reshape(shape_)

    ####################

    ###########################################
    #
    # SET UP COSTS AND MONITORS
    #
    ###########################################

    # cost = CategoricalCrossEntropy().apply(y.flatten().astype('int64'), y_hat)

    def crossentropy_lastaxes(yhat, y):
        # for sequence of distributions/targets
        return -(y * T.log(yhat)).sum(axis=yhat.ndim - 1)

    def softmax_lastaxis(x):
        # for sequence of distributions
        return T.nnet.softmax(x.reshape((-1, x.shape[-1]))).reshape(x.shape)

    yhat = softmax_lastaxis(y_hat_pre_softmax)
    cross_entropies = crossentropy_lastaxes(yhat, y)
    cross_entropy = cross_entropies.mean().copy(name="cross_entropy")
    cost = cross_entropy.copy(name="cost")

    batch_cost = cost.copy(name='batch_cost')
    nll_cost = cost.copy(name='nll_cost')
    bpc = (nll_cost / np.log(2.0)).copy(name='bpr')

    #nll_cost = aggregation.mean(batch_cost, batch_size).copy(name='nll_cost')

    cost_monitor = aggregation.mean(
        batch_cost, batch_size).copy(name='sequence_cost_monitor')
    cost_per_character = aggregation.mean(
        batch_cost, (seq_len - 1) * batch_size).copy(name='character_cost')
    cost_train = cost.copy(name='train_batch_cost')
    cost_train_monitor = cost_monitor.copy('train_batch_cost_monitor')
    cg_train = ComputationGraph([cost_train, cost_train_monitor])

    ##################### DK ADD COST ########################
    ##################### DK ADD COST ########################
    ##################### DK ADD COST ########################
    ##################### DK ADD COST ########################
    ##################### DK ADD COST ########################
    ##################### DK ADD COST ########################
    ##################### DK ADD COST ########################
    ##################### DK ADD COST ########################
    norm_cost = 0.

    def _magnitude(x, axis=-1):
        return T.sqrt(
            T.maximum(T.sqr(x).sum(axis=axis),
                      numpy.finfo(x.dtype).tiny))

    if penalty == 'cells':
        assert VariableFilter(roles=[MEMORY_CELL])(cg_train.variables)
        for cell in VariableFilter(roles=[MEMORY_CELL])(cg_train.variables):
            norms = _magnitude(cell)
            norm_cost += T.mean(
                T.sum((norms[1:] - norms[:-1])**2, axis=0) / (seq_len - 1))
            ## debugging nans stuff
            #gr = T.grad(norm_cost, cg_train.parameters, disconnected_inputs='ignore')
            #grf = theano.function([x, input_mask], gr)
            #grz = grf(x.tag.test_value, input_mask.tag.test_value)
            #params = cg_train.parameters
            #mynanz = [(pp, np.sum(gg)) for pp,gg in zip(params, grz) if np.isnan(np.sum(gg))]
            #for mm in mynanz: print mm
            ##import ipdb; ipdb.set_trace()
    elif penalty == 'hids':
        assert 'rnn_apply_states' in [
            o.name for o in VariableFilter(roles=[OUTPUT])(cg_train.variables)
        ]
        for output in VariableFilter(roles=[OUTPUT])(cg_train.variables):
            if output.name == 'rnn_apply_states':
                norms = _magnitude(output)
                norm_cost += T.mean(
                    T.sum((norms[1:] - norms[:-1])**2, axis=0) / (seq_len - 1))
                ## debugging nans stuff
                #gr = T.grad(norm_cost, cg_train.parameters, disconnected_inputs='ignore')
                #grf = theano.function([x, input_mask], gr)
                #grz = grf(x.tag.test_value, input_mask.tag.test_value)
                #params = cg_train.parameters
                #mynanz = [(pp, np.sum(gg)) for pp,gg in zip(params, grz) if np.isnan(np.sum(gg))]
                #for mm in mynanz: print mm
                ##import ipdb; ipdb.set_trace()

    norm_cost.name = 'norm_cost'
    #cost_valid = cost_train
    cost_train += norm_cost_coeff * norm_cost
    cost_train = cost_train.copy(
        'cost_train')  #should this be cost_train.outputs[0]?

    cg_train = ComputationGraph([cost_train,
                                 cost_train_monitor])  #, norm_cost])

    ##################### DK ADD COST ########################
    ##################### DK ADD COST ########################
    ##################### DK ADD COST ########################
    ##################### DK ADD COST ########################

    if weight_noise > 0:
        weights = VariableFilter(roles=[WEIGHT])(cg_train.variables)
        cg_train = apply_noise(cg_train, weights, weight_noise)
        cost_train = cg_train.outputs[0].copy(name='cost_train')
        cost_train_monitor = cg_train.outputs[1].copy(
            'train_batch_cost_monitor')

    # if 'l2regularization' in kwargs:
    #     weights = VariableFilter(roles=[WEIGHT])(cg_train.variables)
    #     cost_train += kwargs['l2regularization'] * sum([
    #         (weight ** 2).sum() for weight in weights])
    #     cost_train.name = 'cost_train'
    #     cg_train = ComputationGraph(cost_train)

    model = Model(cost_train)
    train_cost_per_character = aggregation.mean(
        cost_train_monitor,
        (seq_len - 1) * batch_size).copy(name='train_character_cost')

    algorithm = GradientDescent(step_rule=step_rule,
                                cost=cost_train,
                                parameters=cg_train.parameters)

    observed_vars = [
        cost_train, cost_train_monitor, train_cost_per_character,
        aggregation.mean(algorithm.total_gradient_norm)
    ]
    # parameters = model.get_parameter_dict()
    # for name, param in parameters.iteritems():
    #     observed_vars.append(param.norm(2).copy(name=name + "_norm"))
    #     observed_vars.append(
    #         algorithm.gradients[param].norm(2).copy(name=name + "_grad_norm"))
    train_monitor = TrainingDataMonitoring(variables=observed_vars,
                                           prefix="train",
                                           after_epoch=True)

    dev_monitor = DataStreamMonitoring(variables=[nll_cost, bpc],
                                       data_stream=dev_stream,
                                       prefix="dev")
    #train_ctc_monitor = CTCMonitoring(
    #x, input_mask,
    #drops_forw_states, drops_forw_cells, drops_forw_igates,
    #drops_back_states, drops_back_cells, drops_back_igates,
    #y_hat, eol_symbol, train_stream,
    #prefix='train', every_n_epochs=1,
    #before_training=True,
    #phoneme_dict=phoneme_dict,
    #black_list=black_list, train=True)
    #dev_ctc_monitor = CTCMonitoring(
    #x, input_mask,
    #drops_forw_states, drops_forw_cells, drops_forw_igates,
    #drops_back_states, drops_back_cells, drops_back_igates,
    #y_hat, eol_symbol, dev_stream,
    #prefix='dev', every_n_epochs=1,
    #phoneme_dict=phoneme_dict,
    #black_list=black_list)

    extensions = []
    # /u/pezeshki/speech_project/five_layer_timit/trained_params_best.npz
    if 'load_path' in kwargs:
        with open(kwargs['load_path']) as f:
            loaded = np.load(f)
            model = Model(cost_train)
            params_dicts = model.get_parameter_dict()
            params_names = params_dicts.keys()
            for param_name in params_names:
                param = params_dicts[param_name]
                # '/f_6_.W' --> 'f_6_.W'
                slash_index = param_name.find('/')
                param_name = param_name[slash_index + 1:]
                if param.get_value().shape == loaded[param_name].shape:
                    print 'Found: ' + param_name
                    param.set_value(loaded[param_name])
                else:
                    print 'Not found: ' + param_name

        #_evaluator = CTCEvaluator(eol_symbol, x, input_mask, y_hat,
        #phoneme_dict=phoneme_dict,
        #black_list=black_list)

        #logger.info("CTC monitoring on TEST data started")
        #value_dict = _evaluator.evaluate(test_stream, False)
        #print value_dict.items()
        #logger.info("CTC monitoring on TEST data finished")

        #logger.info("CTC monitoring on TRAIN data started")
        #value_dict = _evaluator.evaluate(train_stream, True)
        #print value_dict.items()
        #logger.info("CTC monitoring on TRAIN data finished")

        #logger.info("CTC monitoring on DEV data started")
        #value_dict = _evaluator.evaluate(dev_stream, False)
        #print value_dict.items()
        #logger.info("CTC monitoring on DEV data finished")

    extensions.extend(
        [FinishAfter(after_n_epochs=epochs), train_monitor, dev_monitor])
    #train_ctc_monitor,
    #dev_ctc_monitor])

    if test_cost:
        test_monitor = DataStreamMonitoring(
            variables=[cost_monitor, cost_per_character],
            data_stream=test_stream,
            prefix="test")
        extensions.append(test_monitor)

    if not os.path.exists(experiment_path):
        os.makedirs(experiment_path)
    log_path = os.path.join(experiment_path, 'log.txt')
    fh = logging.FileHandler(filename=log_path)
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)

    extensions.append(
        SaveParams('dev_nll_cost', model, experiment_path, every_n_epochs=1))
    extensions.append(SaveLog(every_n_epochs=1))
    extensions.append(ProgressBar())
    extensions.append(Printing())

    main_loop = MainLoop(model=model,
                         data_stream=train_stream,
                         algorithm=algorithm,
                         extensions=extensions)
    t1 = time.time()
    print "Building time: %f" % (t1 - t0)
    # if write_predictions:
    #     with open('predicted.txt', 'w') as f_pred:
    #         with open('targets.txt', 'w') as f_targets:
    #             evaluator = CTCEvaluator(
    #                 eol_symbol, x, input_mask, y_hat, phoneme_dict, black_list)
    #             evaluator.evaluate(dev_stream, file_pred=f_pred,
    #                                file_targets=f_targets)
    #     return
    main_loop.run()
    print "Execution time: %f" % (time.time() - t1)