示例#1
0
def predict_sino8v2(load_step=None, filenames=[], **kwargs):
    net = SRSino8(filenames=filenames, **kwargs)
    net.build()
    with Sinograms2(filenames=filenames) as dataset_train:
        with Sinograms2(filenames=filenames, mode='test') as dataset_test:
            pt = ProgressTimer(total_step)
            for i in range(total_step):
                ss = next(dataset_train)
                loss_v, _ = net.sess.run([net.loss2x, net.train_2x],
                                         feed_dict={
                                             net.ip: ss[0],
                                             net.ll: ss[1][0],
                                             net.lr: ss[1][1]
                                         })
                pt.event(i, msg='loss %f.' % loss_v)
                if i % 100 == 0:
                    summ = net.sess.run(net.summary_op,
                                        feed_dict={
                                            net.ip: ss[0],
                                            net.ll: ss[1][0],
                                            net.lr: ss[1][1]
                                        })
                    net.sw.add_summary(summ, net.sess.run(net.global_step))
                if i % 1000 == 0:
                    net.save()
        net.save()
示例#2
0
def train_sino8v3_pet(load_step=None, total_step=None, filenames=[], **kwargs):
    print("TRAINGING v3 net on PET data.")
    net = SRSino8v3(filenames=filenames, **kwargs)
    net.build()
    if load_step is not None:
        net.load(load_step=load_step)

    pre_sum = time.time()
    pre_save = time.time()
    with SinogramsPETRebin(filenames=filenames) as dataset_train:
        with SinogramsPETRebin(filenames=filenames,
                               mode='test') as dataset_test:
            pt = ProgressTimer(total_step)
            for i in range(total_step):
                ss = next(dataset_train)
                loss_v, _ = net.train(ss)
                pt.event(i, msg='loss %e.' % loss_v)
                now = time.time()
                if now - pre_sum > 120:
                    ss = next(dataset_train)
                    net.summary(ss, True)
                    ss = next(dataset_test)
                    net.summary(ss, False)
                    pre_sum = now
                if now - pre_save > 600:
                    net.save()
                    pre_save = now
        net.save()
示例#3
0
def train(filenames=None, settings=None, **kwargs):
    dataset = MNIST(filenames=filenames)
    print("=" * 30)
    print("DATASET SETTINGS:")
    print(dataset.pretty_settings())
    print("=" * 30)
    net = AAE1D(filenames=filenames)
    net.define_net()
    print("=" * 30)
    print("NETWORK SETTINGS:")
    print(net.pretty_settings())
    print("=" * 30)
    nb_batches = settings['nb_batches']
    net.load()
    ptp = ProgressTimer(nb_batches)
    # for i in range(nb_batches // 3 * 2):
    #     loss_ae = train_ae(net, dataset)
    #     msg = 'T:AuE, loss=%05f' % (loss_ae)
    #     ptp.event(net.step, msg)
    # net.lr_decay()
    # for i in range(nb_batches // 3):
    #     loss_ae = train_ae(net, dataset)
    #     msg = 'T:AuE, loss=%05f' % (loss_ae)
    #     ptp.event(net.step, msg)
    for i in range(nb_batches // 3):
        loss_cri = train_cri(net, dataset)
        msg = '|T:Cri, loss=%05f|' % (loss_cri)
        loss_gen = train_gen(net, dataset)
        msg += '|T:Gen, loss=%05f|' % (loss_gen)
        loss_ae = train_ae(net, dataset)
        msg = '|T:AuE, loss=%05f|' % (loss_ae)
        ptp.event(net.step, msg)
        ptp = ProgressTimer(nb_batches)
    # for i in range(nb_batches // 2):
    #     loss_ae = train_ae(net, dataset)
    #     msg = 'step #%5d, AuE, loss=%05f' % (net.step, loss_ae)
    #     ptp.event(net.step, msg)
    # net.lr_decay()
    # for i in range(nb_batches // 2):
    #     loss_ae = train_ae(net, dataset)
    #     msg = 'step #%5d, AuE, loss=%05f' % (net.step, loss_ae)
    #     ptp.event(net.step, msg)
    # for i in range(nb_batches):
    #     loss_cri = train_cri(net, dataset)
    #     msg = 'Cri, loss=%05f    ' % (loss_cri)
    #     loss_gen = train_gen(net, dataset)
    #     msg += 'Gen, loss=%05f    ' % (loss_gen)
    #     ptp.event(net.step, msg)
    net.save('net', is_print=True)
示例#4
0
def train2(cfg='train.json', **kwargs):
    with open(cfg, 'r') as fin:
        cfgs = json.load(fin)
        pp_json('TRAIN 2 CONFIGS:', cfgs)
    dataset_class = getattr(datasets, cfgs['dataset'])
    net_class = getattr(nets2, cfgs['net'])
    filenames = cfgs.get('filenames')
    load_step = cfgs.get('load_step')
    net = net_class(filenames=filenames, **kwargs)
    net.build()
    if load_step is not None:
        net.load(load_step=load_step)

    pre_sum = time.time()
    pre_save = time.time()
    lrs = cfgs['lrs']
    total_step = cfgs['total_step']
    summary_freq = cfgs['summary_freq']
    save_freq = cfgs['save_freq']
    with dataset_class(filenames=filenames) as dataset_train:
        with dataset_class(filenames=filenames, mode='test') as dataset_test:
            pt = ProgressTimer(total_step * len(lrs))
            cstep = 0
            for lrv in lrs:
                net.learning_rate_value = lrv
                for i in range(total_step):
                    ss = next(dataset_train)
                    loss_v, _ = net.train(ss)
                    pt.event(cstep, msg='loss %e.' % loss_v)
                    cstep += 1
                    now = time.time()
                    if now - pre_sum > summary_freq:
                        ss = next(dataset_train)
                        net.summary(ss, True)
                        ss = next(dataset_test)
                        net.summary(ss, False)
                        pre_sum = now
                    if now - pre_save > save_freq:
                        net.save()
                        pre_save = now
    net.save()
示例#5
0
def train_sino8v2(load_step=-1,
                  total_step=None,
                  step8=1,
                  step4=1,
                  step2=1,
                  sumf=5,
                  savef=100,
                  filenames=[],
                  **kwargs):
    click.echo("START TRAINING!!!!")
    net = SRSino8v2(filenames='srsino8v2.json', **kwargs)
    net.build()
    if load_step > 0:
        net.load(load_step=load_step)
    ds8x_tr = Sinograms2(filenames='sino2_shep8x.json', mode='train')
    ds4x_tr = Sinograms2(filenames='sino2_shep4x.json', mode='train')
    ds2x_tr = Sinograms2(filenames='sino2_shep2x.json', mode='train')
    ds8x_te = Sinograms2(filenames='sino2_shep8x.json', mode='test')
    ds4x_te = Sinograms2(filenames='sino2_shep4x.json', mode='test')
    ds2x_te = Sinograms2(filenames='sino2_shep2x.json', mode='test')
    datasets = [ds8x_tr, ds4x_tr, ds2x_tr, ds8x_te, ds4x_te, ds2x_te]
    for ds in datasets:
        ds.init()
    pt = ProgressTimer(total_step * 3)
    cstp = 0
    for i in range(total_step // (step8 + step4 + step2)):
        for _ in range(step8):
            ss = next(ds8x_tr)
            loss_v, _ = net.train('net_8x', ss)
            pt.event(cstp, msg='train net_8x, loss %e.' % loss_v)
            cstp += 1
        for _ in range(step4):
            ss = next(ds4x_tr)
            loss_v, _ = net.train('net_4x', ss)
            pt.event(cstp, msg='train net_4x, loss %e.' % loss_v)
            cstp += 1
        for _ in range(step2):
            ss = next(ds2x_tr)
            loss_v, _ = net.train('net_2x', ss)
            pt.event(cstp, msg='train net_2x, loss %e.' % loss_v)
            cstp += 1
        if i % sumf == 0:
            ss = next(ds8x_tr)
            net.summary('net_8x', ss, True)
            ss = next(ds4x_tr)
            net.summary('net_4x', ss, True)
            ss = next(ds2x_tr)
            net.summary('net_2x', ss, True)
            ss = next(ds8x_te)
            net.summary('net_8x', ss, False)
            ss = next(ds4x_te)
            net.summary('net_4x', ss, False)
            ss = next(ds2x_te)
            net.summary('net_2x', ss, False)
        if i % savef == 0:
            net.save()
    net.save()
    for ds in datasets:
        ds.close()
示例#6
0
def train_sr_d(dataset_name,
               net_name,
               epochs,
               steps_per_epoch,
               load_step=-1,
               filenames=[],
               **kwargs):
    """ train super resolution net
    """
    dsc = getattr(xlearn.datasets, dataset_name)
    netc = getattr(xlearn.nets, net_name)
    print_pretty_args(train_sr_d, locals())
    with dsc(filenames=filenames) as dataset:
        net_settings = {'filenames': filenames}
        if load_step is not None:
            net_settings.update({'init_step': load_step})
        net = netc(**net_settings)
        net.define_net()
        cpx, cpy = net.crop_size
        if load_step is not None:
            if load_step > 0:
                click.secho(
                    '=' * 5 +
                    'LOAD PRE TRAIN WEIGHTS OF {0:7d} STEPS.'.format(load_step)
                    + '=' * 5,
                    fg='yellow')
                net.load(step=load_step, is_force=True)
                net.global_step = load_step
        click.secho(net.pretty_settings())
        pt = ProgressTimer(epochs * steps_per_epoch)
        for _ in range(epochs):
            for _ in range(steps_per_epoch):
                s = next(dataset)
                loss = net.train_on_batch(inputs=s[0], outputs=s[1])
                msg = "model:{0:5s}, loss={1:10e}, gs={2:7d}.".format(
                    net._scadule_model(), loss, net.global_step)
                pt.event(msg=msg)
        net.save(step=net.global_step)
        net.dump_loss()
示例#7
0
def train(nb_batches, cfs):
    dataset = MNIST(filenames=cfs)
    print("=" * 30)
    print("DATASET SETTINGS:")
    print(dataset.pretty_settings())
    print("=" * 30)
    net = LSGAN(filenames=cfs)
    net.define_net()
    print("=" * 30)
    print("NETWORK SETTINGS:")
    print(net.pretty_settings())
    print("=" * 30)
    ptp = ProgressTimer(net.pre_train)
    for i in range(net.pre_train):
        s = next(dataset)
        z = net.gen_latent()
        loss_c = net.train_on_batch('Cri', [s[0], z], [])
        msg = 'loss_c= %f' % loss_c
        ptp.event(i, msg)
    pt = ProgressTimer(nb_batches)
    loss_c = np.nan
    loss_g = np.nan
    for i in range(nb_batches):
        s = next(dataset)
        z = net.gen_latent()
        if i % net.gen_freq > 0:
            loss_c = net.train_on_batch('Cri', [s[0], z], [])
            msg = 'c_step, loss_c= %f; loss_g= %f' % (loss_c, loss_g)
            pt.event(i, msg)
        else:
            loss_g = net.train_on_batch('WGan', [s[0], z], [])
            msg = 'g_step, loss_c= %f; loss_g= %f' % (loss_c, loss_g)
            pt.event(i, msg)
        if i % 1000 == 0:
            net.save('AutoEncoder')
    net.save('AutoEncoder', is_print=True)