Пример #1
0
    def realize(self, values: typing.Dict[str, typing.Any], **sensitives):
        train_pwl, test_pwl = self.pwl_func()

        network = NaturalRNN.create(
            str(values['nonlinearity']), test_pwl.input_dim, int(values['hidden_size']), test_pwl.output_dim,
            input_weights=wi.OrthogonalWeightInitializer(float(values['inp_stddev']), 0),
            input_biases=wi.ZerosWeightInitializer(),
            hidden_weights=wi.SompolinskySmoothedFixedGainWeightInitializer(
                float(values['dt']), float(values['g'])),
            hidden_biases=wi.GaussianWeightInitializer(
                mean=0, vari=float(values['hidden_bias_vari']), normalize_dim=0),
            output_weights=wi.GaussianWeightInitializer(
                mean=0, vari=float(values['output_weight_vari']), normalize_dim=0),
            output_biases=wi.ZerosWeightInitializer()
        )

        trainer = tnr.GenericTrainer(
            train_pwl=train_pwl,
            test_pwl=test_pwl,
            teacher=RNNTeacher(recurrent_times=int(values['recurrent_times']), input_times=1),
            batch_size=int(sensitives['batch_size']),
            learning_rate=float(sensitives['learning_rate']),
            optimizer=torch.optim.RMSprop(
                [p for p in network.parameters() if p.requires_grad],
                lr=0.001, alpha=float(values['alpha'])),
            criterion=torch.nn.CrossEntropyLoss()
        )

        (trainer
         .reg(tnr.EpochsTracker())
         .reg(tnr.EpochsStopper(150))
         .reg(tnr.DecayTracker())
         .reg(tnr.DecayStopper(8))
         .reg(tnr.LRMultiplicativeDecayer(factor=values['lr_factor']))
         .reg(tnr.DecayOnPlateau())
         .reg(tnr.AccuracyTracker(5, 1000, True))
        )

        return trainer, network
Пример #2
0
def main():
    """Entry point"""
    pwl = GaussianSpheresPWLP(epoch_size=1000,
                              input_dim=2,
                              output_dim=2,
                              clusters=[
                                  PointWithLabel(point=torch.tensor(
                                      (-1, 0), dtype=torch.double),
                                                 label=0),
                                  PointWithLabel(point=torch.tensor(
                                      (1, 0), dtype=torch.double),
                                                 label=1)
                              ],
                              std_dev=0.4,
                              mean=0)

    layers = [(50, True, False)]
    layer_names = ['input', 'hidden', 'output']

    network = FeedforwardLarge.create(input_dim=2,
                                      output_dim=2,
                                      weights=wi.GaussianWeightInitializer(
                                          mean=0, vari=0.1, normalize_dim=1),
                                      biases=wi.ZerosWeightInitializer(),
                                      layer_sizes=layers,
                                      nonlinearity='linear',
                                      train_readout_weights=False,
                                      train_readout_bias=False)

    trainer = tnr.GenericTrainer(
        train_pwl=pwl,
        test_pwl=pwl,
        teacher=FFTeacher(),
        batch_size=1,
        learning_rate=0.003,
        optimizer=torch.optim.Adam(
            [p for p in network.parameters() if p.requires_grad], lr=0.003),
        criterion=mycrits.create_meansqerr_regul(
            noise_strength=0.5)  #torch.nn.CrossEntropyLoss()
    )

    pca3d_throughtrain.FRAMES_PER_TRAIN = 1
    pca3d_throughtrain.SKIP_TRAINS = 0
    pca3d_throughtrain.NUM_FRAME_WORKERS = 4

    dig = npmp.NPDigestor('train_one', 5)
    #pca_3d.plot_ff(pca_ff.find_trajectory(network, pwl, 3), os.path.join(SAVEDIR, 'pca_3d_start'), True,
    #               digestor=dig, frame_time=FRAME_TIME, layer_names=layer_names)
    dtt_training_dir = os.path.join(SAVEDIR, 'dtt')
    pca_training_dir = os.path.join(SAVEDIR, 'pca')
    pr_training_dir = os.path.join(SAVEDIR, 'pr')
    svm_training_dir = os.path.join(SAVEDIR, 'svm')
    satur_training_dir = os.path.join(SAVEDIR, 'saturation')
    pca_throughtrain_dir = os.path.join(SAVEDIR, 'pca_throughtrain')
    (trainer.reg(tnr.EpochsTracker()).reg(tnr.EpochsStopper(10)).reg(
        tnr.DecayTracker())
     #.reg(tnr.DecayStopper(8))
     #.reg(tnr.LRMultiplicativeDecayer())
     .reg(tnr.DecayOnPlateau()).reg(tnr.AccuracyTracker(5, 1000, True))
     #.reg(tnr.WeightNoiser(
     #    wi.GaussianWeightInitializer(mean=0, vari=0.02, normalize_dim=None),
     #    lambda ctxt: ctxt.model.layers[-1].weight.data))
     .reg(
         tnr.OnEpochCaller.create_every(
             satur.during_training(satur_training_dir, True, dig),
             skip=10)).reg(
                 tnr.OnEpochCaller.create_every(
                     dtt.during_training_ff(dtt_training_dir, True, dig),
                     skip=10)).reg(
                         tnr.OnEpochCaller.create_every(
                             pca_ff.during_training(pca_training_dir,
                                                    True,
                                                    dig,
                                                    alpha=0.8),
                             skip=10)).reg(
                                 tnr.OnEpochCaller.create_every(
                                     pr.during_training_ff(
                                         pr_training_dir, True, dig),
                                     skip=1000)).reg(
                                         tnr.OnEpochCaller.create_every(
                                             svm.during_training_ff(
                                                 svm_training_dir, True, dig),
                                             skip=1000))
     #.reg(pca3d_throughtrain.PCAThroughTrain(pca_throughtrain_dir, layer_names, True))
     .reg(tnr.OnFinishCaller(lambda *args, **kwargs: dig.join())).reg(
         tnr.ZipDirOnFinish(dtt_training_dir)).reg(
             tnr.ZipDirOnFinish(pca_training_dir)).reg(
                 tnr.ZipDirOnFinish(pr_training_dir)).reg(
                     tnr.ZipDirOnFinish(svm_training_dir)).reg(
                         tnr.ZipDirOnFinish(satur_training_dir)))
    trainer.train(network)
    #pca_3d.plot_ff(pca_ff.find_trajectory(network, pwl, 3), os.path.join(SAVEDIR, 'pca_3d_end'), True,
    #               digestor=dig, frame_time=FRAME_TIME, layer_names=layer_names)
    dig.archive_raw_inputs(os.path.join(SAVEDIR, 'raw_digestor.zip'))
def train_with_noise(vari, rep, ignoreme):  # pylint: disable=unused-argument
    """Entry point"""
    train_pwl = MNISTData.load_train().to_pwl().restrict_to(set(
        range(10))).rescale()
    test_pwl = MNISTData.load_test().to_pwl().restrict_to(set(
        range(10))).rescale()

    layers_and_nonlins = (
        (90, 'tanh'),
        (90, 'tanh'),
        (90, 'tanh'),
        (90, 'tanh'),
        (90, 'tanh'),
    )

    layers = [lyr[0] for lyr in layers_and_nonlins]
    nonlins = [lyr[1] for lyr in layers_and_nonlins]
    nonlins.append('tanh')  # output
    #layer_names = [f'{lyr[1]} (layer {idx})' for idx, lyr in enumerate(layers_and_nonlins)]
    layer_names = [
        f'Layer {idx+1}' for idx, lyr in enumerate(layers_and_nonlins)
    ]
    layer_names.insert(0, 'Input')
    layer_names.append('Output')

    network = FeedforwardLarge.create(input_dim=train_pwl.input_dim,
                                      output_dim=train_pwl.output_dim,
                                      weights=wi.GaussianWeightInitializer(
                                          mean=0, vari=0.3, normalize_dim=0),
                                      biases=wi.ZerosWeightInitializer(),
                                      layer_sizes=layers,
                                      nonlinearity=nonlins
                                      #layer_sizes=[500, 200]
                                      )

    _lr = 0.1
    trainer = tnr.GenericTrainer(
        train_pwl=train_pwl,
        test_pwl=test_pwl,
        teacher=FFTeacher(),
        batch_size=30,
        learning_rate=_lr,
        optimizer=torch.optim.SGD(
            [p for p in network.parameters() if p.requires_grad], lr=_lr
        ),  #torch.optim.Adam([p for p in network.parameters() if p.requires_grad], lr=0.003),
        criterion=mycrits.meansqerr  #torch.nn.CrossEntropyLoss()#
    )

    #pca3d_throughtrain.FRAMES_PER_TRAIN = 4
    #pca3d_throughtrain.SKIP_TRAINS = 0
    #pca3d_throughtrain.NUM_FRAME_WORKERS = 6

    dig = npmp.NPDigestor(f'TRMCN_{rep}_{vari}', 8)

    savedir = os.path.join(SAVEDIR, f'variance_{vari}', f'repeat_{rep}')

    dtt_training_dir = os.path.join(savedir, 'dtt')
    pca_training_dir = os.path.join(savedir, 'pca')
    pca3d_training_dir = os.path.join(savedir, 'pca3d')
    pr_training_dir = os.path.join(savedir, 'pr')
    svm_training_dir = os.path.join(savedir, 'svm')
    satur_training_dir = os.path.join(savedir, 'saturation')
    trained_net_dir = os.path.join(savedir, 'trained_model')
    pca_throughtrain_dir = os.path.join(savedir, 'pca_throughtrain')
    logpath = os.path.join(savedir, 'log.txt')
    (trainer.reg(tnr.EpochsTracker()).reg(tnr.EpochsStopper(0.2)).reg(
        tnr.DecayTracker()).reg(tnr.DecayStopper(5)).reg(
            tnr.LRMultiplicativeDecayer())
     #.reg(tnr.DecayOnPlateau())
     #.reg(tnr.DecayEvery(5))
     .reg(tnr.AccuracyTracker(1, 1000, True)).reg(
         tnr.WeightNoiser(
             wi.GaussianWeightInitializer(mean=0, vari=vari),
             (lambda ctx: ctx.model.layers[-1].weight.data.detach()), 'scale',
             (lambda noise: wi.GaussianWeightInitializer(0, noise.vari * 0.5)
              )))
     #.reg(tnr.OnEpochCaller.create_every(dtt.during_training_ff(dtt_training_dir, True, dig), skip=100))
     #.reg(tnr.OnEpochCaller.create_every(pca_3d.during_training(pca3d_training_dir, True, dig, plot_kwargs={'layer_names': layer_names}), start=500, skip=100))
     #.reg(tnr.OnEpochCaller.create_every(pca_ff.during_training(pca_training_dir, True, dig), skip=100))
     .reg(
         tnr.OnEpochCaller.create_every(pr.during_training_ff(
             pr_training_dir, True, dig),
                                        skip=1))
     #.reg(tnr.OnEpochCaller.create_every(svm.during_training_ff(svm_training_dir, True, dig), skip=100))
     #.reg(tnr.OnEpochCaller.create_every(satur.during_training(satur_training_dir, True, dig), skip=100))
     .reg(
         tnr.OnEpochCaller.create_every(tnr.save_model(trained_net_dir),
                                        skip=100))
     #.reg(pca3d_throughtrain.PCAThroughTrain(pca_throughtrain_dir, layer_names, True))
     .reg(tnr.OnFinishCaller(lambda *args, **kwargs: dig.join())).reg(
         tnr.CopyLogOnFinish(logpath)).reg(
             tnr.ZipDirOnFinish(dtt_training_dir)).reg(
                 tnr.ZipDirOnFinish(pca_training_dir)).reg(
                     tnr.ZipDirOnFinish(pca3d_training_dir)).reg(
                         tnr.ZipDirOnFinish(pr_training_dir)).reg(
                             tnr.ZipDirOnFinish(svm_training_dir)).reg(
                                 tnr.ZipDirOnFinish(satur_training_dir)).reg(
                                     tnr.ZipDirOnFinish(trained_net_dir)))

    trainer.train(network)
    dig.archive_raw_inputs(os.path.join(savedir, 'digestor_raw.zip'))
Пример #4
0
def main():
    """Entry point"""
    pwl = GaussianSpheresPWLP.create(epoch_size=2700,
                                     input_dim=INPUT_DIM,
                                     output_dim=OUTPUT_DIM,
                                     cube_half_side_len=2,
                                     num_clusters=10,
                                     std_dev=0.5,
                                     mean=0,
                                     min_sep=1,
                                     force_split=True)

    layers_and_nonlins = (
        (100, 'tanh'),
        #(100, 'linear'),
        #(25, 'linear'),
        #(90, 'tanh'),
        #(90, 'tanh'),
        #(90, 'linear'),
        #(25, 'linear'),
    )
    layers = [lyr[0] for lyr in layers_and_nonlins]
    nonlins = [lyr[1] for lyr in layers_and_nonlins]
    nonlins.append('tanh')  # output
    layer_names = [
        f'{lyr[1]} ({idx})' for idx, lyr in enumerate(layers_and_nonlins)
    ]
    layer_names.insert(0, 'input')
    layer_names.append('output')

    network = FeedforwardLarge.create(input_dim=INPUT_DIM,
                                      output_dim=OUTPUT_DIM,
                                      weights=wi.GaussianWeightInitializer(
                                          mean=0, vari=0.3, normalize_dim=1),
                                      biases=wi.ZerosWeightInitializer(),
                                      layer_sizes=layers,
                                      nonlinearity=nonlins)

    trainer = tnr.GenericTrainer(
        train_pwl=pwl,
        test_pwl=pwl,
        teacher=FFTeacher(),
        batch_size=20,
        learning_rate=0.001,
        optimizer=torch.optim.Adam(
            [p for p in network.parameters() if p.requires_grad], lr=0.001),
        criterion=mycrits.meansqerr  #torch.nn.CrossEntropyLoss()
    )

    pca3d_throughtrain.FRAMES_PER_TRAIN = 1
    pca3d_throughtrain.SKIP_TRAINS = 4
    pca3d_throughtrain.NUM_FRAME_WORKERS = 6

    dig = npmp.NPDigestor('train_one', 35)
    #pca_3d.plot_ff(pca_ff.find_trajectory(network, pwl, 3), os.path.join(SAVEDIR, 'pca_3d_start'), True,
    #               digestor=dig, frame_time=FRAME_TIME, layer_names=layer_names)
    dtt_training_dir = os.path.join(SAVEDIR, 'dtt')
    pca_training_dir = os.path.join(SAVEDIR, 'pca')
    pr_training_dir = os.path.join(SAVEDIR, 'pr')
    svm_training_dir = os.path.join(SAVEDIR, 'svm')
    satur_training_dir = os.path.join(SAVEDIR, 'saturation')
    pca_throughtrain_dir = os.path.join(SAVEDIR, 'pca_throughtrain')
    (trainer.reg(tnr.EpochsTracker()).reg(tnr.EpochsStopper(100)).reg(
        tnr.InfOrNANDetecter()).reg(tnr.DecayTracker()).reg(
            tnr.DecayStopper(8)).reg(tnr.LRMultiplicativeDecayer()).reg(
                tnr.DecayOnPlateau()).reg(tnr.AccuracyTracker(5, 1000, True))
     #.reg(tnr.WeightNoiser(
     #    wi.GaussianWeightInitializer(mean=0, vari=0.02, normalize_dim=None),
     #    lambda ctxt: ctxt.model.layers[-1].weight.data))
     #.reg(tnr.OnEpochCaller.create_every(satur.during_training(satur_training_dir, True, dig), skip=1000))
     #.reg(tnr.OnEpochCaller.create_every(dtt.during_training_ff(dtt_training_dir, True, dig), skip=1000))
     .reg(
         tnr.OnEpochCaller.create_every(pca_ff.during_training(
             pca_training_dir, True, dig),
                                        skip=1000))
     #.reg(tnr.OnEpochCaller.create_every(pr.during_training_ff(pr_training_dir, True, dig), skip=1000))
     #.reg(tnr.OnEpochCaller.create_every(svm.during_training_ff(svm_training_dir, True, dig), skip=1000))
     #.reg(pca3d_throughtrain.PCAThroughTrain(pca_throughtrain_dir, layer_names, True))
     .reg(tnr.OnFinishCaller(lambda *args, **kwargs: dig.join())).reg(
         tnr.ZipDirOnFinish(dtt_training_dir)).reg(
             tnr.ZipDirOnFinish(pca_training_dir)).reg(
                 tnr.ZipDirOnFinish(pr_training_dir)).reg(
                     tnr.ZipDirOnFinish(svm_training_dir)).reg(
                         tnr.ZipDirOnFinish(satur_training_dir)))
    trainer.train(network)
    #pca_3d.plot_ff(pca_ff.find_trajectory(network, pwl, 3), os.path.join(SAVEDIR, 'pca_3d_end'), True,
    #               digestor=dig, frame_time=FRAME_TIME, layer_names=layer_names)
    dig.archive_raw_inputs(os.path.join(SAVEDIR, 'raw_digestor.zip'))
def main():
    """Entry point"""
    train_pwl = MNISTData.load_train().to_pwl().restrict_to(set(range(10))).rescale()
    test_pwl = MNISTData.load_test().to_pwl().restrict_to(set(range(10))).rescale()
    network = NaturalRNN.create(
        'tanh', train_pwl.input_dim, 200, train_pwl.output_dim,
        input_weights=wi.OrthogonalWeightInitializer(0.03, 0),
        input_biases=wi.ZerosWeightInitializer(), #
        hidden_weights=wi.SompolinskySmoothedFixedGainWeightInitializer(0.001, 20),
        hidden_biases=wi.GaussianWeightInitializer(mean=0, vari=0.3, normalize_dim=0),
        output_weights=wi.GaussianWeightInitializer(mean=0, vari=0.3, normalize_dim=0),
        output_biases=wi.ZerosWeightInitializer()
    )

    trainer = tnr.GenericTrainer(
        train_pwl=train_pwl,
        test_pwl=test_pwl,
        teacher=RNNTeacher(recurrent_times=10, input_times=1),
        batch_size=30,
        learning_rate=0.0001,
        optimizer=torch.optim.RMSprop([p for p in network.parameters() if p.requires_grad], lr=0.0001, alpha=0.9),
        criterion=torch.nn.CrossEntropyLoss()
    )

    (trainer
     .reg(tnr.EpochsTracker())
     .reg(tnr.EpochsStopper(150))
     .reg(tnr.InfOrNANDetecter())
     .reg(tnr.InfOrNANDetecter())
     .reg(tnr.DecayTracker())
     .reg(tnr.DecayStopper(5))
     .reg(tnr.LRMultiplicativeDecayer())
     .reg(tnr.DecayOnPlateau())
     .reg(tnr.AccuracyTracker(5, 1000, True))
    )

    print('--saving pcs before training--')
    traj = pca.find_trajectory(network, train_pwl, 10, 2)
    savepath = os.path.join(SAVEDIR, 'pca_before_train')
    pca.plot_trajectory(traj, savepath, exist_ok=True)
    traj = pca.find_trajectory(network, test_pwl, 10, 2)
    savepath = os.path.join(SAVEDIR, 'pca_before_test')
    pca.plot_trajectory(traj, savepath, exist_ok=True)
    del traj

    # print('--saving distance through time before training--')
    # savepath = os.path.join(SAVEDIR, 'dtt_before_train')
    # dtt.measure_dtt(network, train_pwl, 10, savepath, verbose=True, exist_ok=True)
    # savepath = os.path.join(SAVEDIR, 'dtt_before_test')
    # dtt.measure_dtt(network, test_pwl, 10, savepath, verbose=True, exist_ok=True)


    print('--training--')
    result = trainer.train(network)
    print('--finished training--')
    print(result)

    print('--saving pcs after training--')
    traj = pca.find_trajectory(network, train_pwl, 10, 2)
    savepath = os.path.join(SAVEDIR, 'pca_after_train')
    pca.plot_trajectory(traj, savepath, exist_ok=True)
    traj = pca.find_trajectory(network, test_pwl, 10, 2)
    savepath = os.path.join(SAVEDIR, 'pca_after_test')
    pca.plot_trajectory(traj, savepath, exist_ok=True)
    del traj

    # print('--saving distance through time after training--')
    # savepath = os.path.join(SAVEDIR, 'dtt_after_train')
    # dtt.measure_dtt(network, train_pwl, 10, savepath, verbose=True, exist_ok=True)
    # savepath = os.path.join(SAVEDIR, 'dtt_after_test')
    # dtt.measure_dtt(network, test_pwl, 10, savepath, verbose=True, exist_ok=True)

    print('--saving 3d pca plots after training--')
    layer_names = ['Input']
    for i in range(1, trainer.teacher.recurrent_times + 1):
        layer_names.append(f'Timestep {i}')
    dig = npmp.NPDigestor('mnist_train_one_rnn', 2)
    nha = mutils.get_hidacts_rnn(network, train_pwl, trainer.teacher.recurrent_times)
    nha.torch()
    traj = pca_ff.to_trajectory(nha.sample_labels, nha.hid_acts, 3)
    pca_3d.plot_ff(traj, os.path.join(SAVEDIR, 'pca3d_after_train'), False, digestor=dig,
                   layer_names=layer_names)

    nha = mutils.get_hidacts_rnn(network, test_pwl, trainer.teacher.recurrent_times)
    nha.torch()
    traj = pca_ff.to_trajectory(nha.sample_labels, nha.hid_acts, 3)
    pca_3d.plot_ff(traj, os.path.join(SAVEDIR, 'pca3d_after_test'), False, digestor=dig,
                   layer_names=layer_names)

    print('--saving model--')
    torch.save(network, os.path.join(SAVEDIR, 'model.pt'))

    dig.join()
Пример #6
0
def main():
    """Entry point"""
    pwl = GaussianSpheresPWLP.create(epoch_size=1800,
                                     input_dim=200,
                                     output_dim=2,
                                     cube_half_side_len=2,
                                     num_clusters=60,
                                     std_dev=0.04,
                                     mean=0,
                                     min_sep=0.1)

    network = NaturalRNN.create(
        'tanh',
        pwl.input_dim,
        200,
        pwl.output_dim,
        input_weights=wi.OrthogonalWeightInitializer(0.03, 0),
        input_biases=wi.ZerosWeightInitializer(),  #
        hidden_weights=wi.SompolinskySmoothedFixedGainWeightInitializer(
            0.001, 20),
        hidden_biases=wi.GaussianWeightInitializer(mean=0,
                                                   vari=0.3,
                                                   normalize_dim=0),
        output_weights=wi.GaussianWeightInitializer(mean=0,
                                                    vari=0.3,
                                                    normalize_dim=0),
        output_biases=wi.ZerosWeightInitializer())

    trainer = tnr.GenericTrainer(
        train_pwl=pwl,
        test_pwl=pwl,
        teacher=RNNTeacher(recurrent_times=10, input_times=1),
        batch_size=30,
        learning_rate=0.001,
        optimizer=torch.optim.RMSprop(
            [p for p in network.parameters() if p.requires_grad],
            lr=0.001,
            alpha=0.9),
        criterion=torch.nn.CrossEntropyLoss())

    (trainer.reg(tnr.EpochsTracker()).reg(tnr.EpochsStopper(150)).reg(
        tnr.InfOrNANDetecter()).reg(tnr.DecayTracker()).reg(
            tnr.DecayStopper(8)).reg(tnr.LRMultiplicativeDecayer()).reg(
                tnr.DecayOnPlateau()).reg(tnr.AccuracyTracker(5, 1000, True)))

    print('--saving pcs before training--')
    traj = pca.find_trajectory(network, pwl, 10, 2)

    print('--saving distance through time before training--')
    savepath = os.path.join(SAVEDIR, 'dtt_before')
    dtt.measure_dtt(network, pwl, 10, savepath, verbose=True, exist_ok=True)

    savepath = os.path.join(SAVEDIR, 'pca_before')
    pca.plot_trajectory(traj, savepath, exist_ok=True)
    del traj

    print('--training--')
    result = trainer.train(network)
    print('--finished training--')
    print(result)
    print('--saving pcs after training--')

    print('--saving distance through time after training--')
    savepath = os.path.join(SAVEDIR, 'dtt_after')
    dtt.measure_dtt(network, pwl, 10, savepath, verbose=True, exist_ok=True)

    print('--saving pcs after training--')
    traj = pca.find_trajectory(network, pwl, 10, 2)
    savepath = os.path.join(SAVEDIR, 'pca_after')
    pca.plot_trajectory(traj, savepath, exist_ok=True)

    print('--saving pr after training')
    savepath = os.path.join(SAVEDIR, 'pr_after')
Пример #7
0
def main():
    """Entry point"""

    cu.DEFAULT_LINEAR_BIAS_INIT = wi.ZerosWeightInitializer()
    cu.DEFAULT_LINEAR_WEIGHT_INIT = wi.GaussianWeightInitializer(
        mean=0, vari=0.3, normalize_dim=0)

    nets = cu.FluentShape(32 * 32 * 3).verbose()
    network = FeedforwardComplex(INPUT_DIM, OUTPUT_DIM, [
        nets.linear_(32 * 32 * 6),
        nets.nonlin('isrlu'),
        nets.linear_(500),
        nets.nonlin('tanh'),
        nets.linear_(250),
        nets.nonlin('tanh'),
        nets.linear_(250),
        nets.nonlin('tanh'),
        nets.linear_(100),
        nets.tanh(),
        nets.linear_(100),
        nets.tanh(),
        nets.linear_(100),
        nets.tanh(),
        nets.linear_(OUTPUT_DIM),
        nets.nonlin('isrlu'),
    ])

    train_pwl = CIFARData.load_train().to_pwl().restrict_to(set(
        range(10))).rescale()
    test_pwl = CIFARData.load_test().to_pwl().restrict_to(set(
        range(10))).rescale()

    layer_names = ('input', 'FC -> 32*32*6 (ISRLU)', 'FC -> 500 (tanh)',
                   'FC -> 250 (tang)', 'FC -> 250 (tanh)', 'FC -> 100 (tanh)',
                   'FC -> 100 (tanh)', 'FC -> 100 (tanh)',
                   f'FC -> {OUTPUT_DIM} (ISRLU)')
    plot_layers = tuple(i for i in range(2, len(layer_names) - 1))
    trainer = tnr.GenericTrainer(
        train_pwl=train_pwl,
        test_pwl=test_pwl,
        teacher=FFTeacher(),
        batch_size=45,
        learning_rate=0.001,
        optimizer=torch.optim.Adam(
            [p for p in network.parameters() if p.requires_grad], lr=0.001),
        criterion=torch.nn.CrossEntropyLoss())

    pca3d_throughtrain.FRAMES_PER_TRAIN = 1
    pca3d_throughtrain.SKIP_TRAINS = 16
    pca3d_throughtrain.NUM_FRAME_WORKERS = 1

    dig = npmp.NPDigestor('train_one_complex', 5)

    dtt_training_dir = os.path.join(SAVEDIR, 'dtt')
    pca_training_dir = os.path.join(SAVEDIR, 'pca')
    pca3d_training_dir = os.path.join(SAVEDIR, 'pca3d')
    pr_training_dir = os.path.join(SAVEDIR, 'pr')
    svm_training_dir = os.path.join(SAVEDIR, 'svm')
    satur_training_dir = os.path.join(SAVEDIR, 'saturation')
    trained_net_dir = os.path.join(SAVEDIR, 'trained_model')
    pca_throughtrain_dir = os.path.join(SAVEDIR, 'pca_throughtrain')
    logpath = os.path.join(SAVEDIR, 'log.txt')
    (trainer.reg(tnr.EpochsTracker()).reg(tnr.EpochsStopper(STOP_EPOCH)).reg(
        tnr.DecayTracker()).reg(tnr.DecayStopper(8)).reg(
            tnr.EpochProgress(print_every=120, hint_end_epoch=STOP_EPOCH)).reg(
                tnr.LRMultiplicativeDecayer()).reg(
                    tnr.DecayOnPlateau(patience=3)).reg(
                        tnr.AccuracyTracker(1, 1000, True)).reg(
                            tnr.OnEpochCaller.create_every(
                                dtt.during_training_ff(dtt_training_dir, True,
                                                       dig),
                                skip=5)).reg(
                                    tnr.OnEpochCaller.create_every(
                                        pca_3d.during_training(
                                            pca3d_training_dir,
                                            True,
                                            dig,
                                            plot_kwargs={
                                                'layer_names': layer_names
                                            }),
                                        start=10,
                                        skip=100)).
     reg(
         tnr.OnEpochCaller.create_every(
             pca_ff.during_training(pca_training_dir, True, dig), skip=5)).reg(
                 tnr.OnEpochCaller.create_every(
                     pr.during_training_ff(pr_training_dir,
                                           True,
                                           dig,
                                           labels=False),
                     skip=5)).reg(
                         tnr.OnEpochCaller.create_every(
                             svm.during_training_ff(svm_training_dir, True,
                                                    dig),
                             skip=5)).reg(
                                 tnr.OnEpochCaller.create_every(
                                     satur.during_training(
                                         satur_training_dir, True, dig),
                                     skip=5)).reg(
                                         tnr.OnEpochCaller.create_every(
                                             tnr.save_model(trained_net_dir),
                                             skip=5))
     #.reg(pca3d_throughtrain.PCAThroughTrain(pca_throughtrain_dir, layer_names, True, layer_indices=plot_layers))
     .reg(tnr.OnFinishCaller(lambda *args, **kwargs: dig.join())).reg(
         tnr.ZipDirOnFinish(dtt_training_dir)).reg(
             tnr.ZipDirOnFinish(pca_training_dir)).reg(
                 tnr.ZipDirOnFinish(pca3d_training_dir)).reg(
                     tnr.ZipDirOnFinish(pr_training_dir)).reg(
                         tnr.ZipDirOnFinish(svm_training_dir)).reg(
                             tnr.ZipDirOnFinish(satur_training_dir)).reg(
                                 tnr.ZipDirOnFinish(trained_net_dir)).reg(
                                     tnr.CopyLogOnFinish(logpath)))

    trainer.train(network)
    dig.archive_raw_inputs(os.path.join(SAVEDIR, 'digestor_raw.zip'))