Exemple #1
0
def train(filenames=None, settings=None, **kwargs):
    dataset = MNIST(filenames=filenames)
    print("=" * 30)
    print("DATASET SETTINGS:")
    print(dataset.pretty_settings())
    print("=" * 30)
    net = AAE1D(filenames=filenames)
    net.define_net()
    print("=" * 30)
    print("NETWORK SETTINGS:")
    print(net.pretty_settings())
    print("=" * 30)
    nb_batches = settings['nb_batches']
    net.load()
    ptp = ProgressTimer(nb_batches)
    # for i in range(nb_batches // 3 * 2):
    #     loss_ae = train_ae(net, dataset)
    #     msg = 'T:AuE, loss=%05f' % (loss_ae)
    #     ptp.event(net.step, msg)
    # net.lr_decay()
    # for i in range(nb_batches // 3):
    #     loss_ae = train_ae(net, dataset)
    #     msg = 'T:AuE, loss=%05f' % (loss_ae)
    #     ptp.event(net.step, msg)
    for i in range(nb_batches // 3):
        loss_cri = train_cri(net, dataset)
        msg = '|T:Cri, loss=%05f|' % (loss_cri)
        loss_gen = train_gen(net, dataset)
        msg += '|T:Gen, loss=%05f|' % (loss_gen)
        loss_ae = train_ae(net, dataset)
        msg = '|T:AuE, loss=%05f|' % (loss_ae)
        ptp.event(net.step, msg)
        ptp = ProgressTimer(nb_batches)
    # for i in range(nb_batches // 2):
    #     loss_ae = train_ae(net, dataset)
    #     msg = 'step #%5d, AuE, loss=%05f' % (net.step, loss_ae)
    #     ptp.event(net.step, msg)
    # net.lr_decay()
    # for i in range(nb_batches // 2):
    #     loss_ae = train_ae(net, dataset)
    #     msg = 'step #%5d, AuE, loss=%05f' % (net.step, loss_ae)
    #     ptp.event(net.step, msg)
    # for i in range(nb_batches):
    #     loss_cri = train_cri(net, dataset)
    #     msg = 'Cri, loss=%05f    ' % (loss_cri)
    #     loss_gen = train_gen(net, dataset)
    #     msg += 'Gen, loss=%05f    ' % (loss_gen)
    #     ptp.event(net.step, msg)
    net.save('net', is_print=True)
Exemple #2
0
def train(nb_batches, cfs):
    dataset = MNIST(filenames=cfs)
    print("=" * 30)
    print("DATASET SETTINGS:")
    print(dataset.pretty_settings())
    print("=" * 30)
    net = LSGAN(filenames=cfs)
    net.define_net()
    print("=" * 30)
    print("NETWORK SETTINGS:")
    print(net.pretty_settings())
    print("=" * 30)
    ptp = ProgressTimer(net.pre_train)
    for i in range(net.pre_train):
        s = next(dataset)
        z = net.gen_latent()
        loss_c = net.train_on_batch('Cri', [s[0], z], [])
        msg = 'loss_c= %f' % loss_c
        ptp.event(i, msg)
    pt = ProgressTimer(nb_batches)
    loss_c = np.nan
    loss_g = np.nan
    for i in range(nb_batches):
        s = next(dataset)
        z = net.gen_latent()
        if i % net.gen_freq > 0:
            loss_c = net.train_on_batch('Cri', [s[0], z], [])
            msg = 'c_step, loss_c= %f; loss_g= %f' % (loss_c, loss_g)
            pt.event(i, msg)
        else:
            loss_g = net.train_on_batch('WGan', [s[0], z], [])
            msg = 'g_step, loss_c= %f; loss_g= %f' % (loss_c, loss_g)
            pt.event(i, msg)
        if i % 1000 == 0:
            net.save('AutoEncoder')
    net.save('AutoEncoder', is_print=True)