def predict_sino8v2(load_step=None, filenames=[], **kwargs): net = SRSino8(filenames=filenames, **kwargs) net.build() with Sinograms2(filenames=filenames) as dataset_train: with Sinograms2(filenames=filenames, mode='test') as dataset_test: pt = ProgressTimer(total_step) for i in range(total_step): ss = next(dataset_train) loss_v, _ = net.sess.run([net.loss2x, net.train_2x], feed_dict={ net.ip: ss[0], net.ll: ss[1][0], net.lr: ss[1][1] }) pt.event(i, msg='loss %f.' % loss_v) if i % 100 == 0: summ = net.sess.run(net.summary_op, feed_dict={ net.ip: ss[0], net.ll: ss[1][0], net.lr: ss[1][1] }) net.sw.add_summary(summ, net.sess.run(net.global_step)) if i % 1000 == 0: net.save() net.save()
def train_sino8v3_pet(load_step=None, total_step=None, filenames=[], **kwargs): print("TRAINGING v3 net on PET data.") net = SRSino8v3(filenames=filenames, **kwargs) net.build() if load_step is not None: net.load(load_step=load_step) pre_sum = time.time() pre_save = time.time() with SinogramsPETRebin(filenames=filenames) as dataset_train: with SinogramsPETRebin(filenames=filenames, mode='test') as dataset_test: pt = ProgressTimer(total_step) for i in range(total_step): ss = next(dataset_train) loss_v, _ = net.train(ss) pt.event(i, msg='loss %e.' % loss_v) now = time.time() if now - pre_sum > 120: ss = next(dataset_train) net.summary(ss, True) ss = next(dataset_test) net.summary(ss, False) pre_sum = now if now - pre_save > 600: net.save() pre_save = now net.save()
def train(filenames=None, settings=None, **kwargs): dataset = MNIST(filenames=filenames) print("=" * 30) print("DATASET SETTINGS:") print(dataset.pretty_settings()) print("=" * 30) net = AAE1D(filenames=filenames) net.define_net() print("=" * 30) print("NETWORK SETTINGS:") print(net.pretty_settings()) print("=" * 30) nb_batches = settings['nb_batches'] net.load() ptp = ProgressTimer(nb_batches) # for i in range(nb_batches // 3 * 2): # loss_ae = train_ae(net, dataset) # msg = 'T:AuE, loss=%05f' % (loss_ae) # ptp.event(net.step, msg) # net.lr_decay() # for i in range(nb_batches // 3): # loss_ae = train_ae(net, dataset) # msg = 'T:AuE, loss=%05f' % (loss_ae) # ptp.event(net.step, msg) for i in range(nb_batches // 3): loss_cri = train_cri(net, dataset) msg = '|T:Cri, loss=%05f|' % (loss_cri) loss_gen = train_gen(net, dataset) msg += '|T:Gen, loss=%05f|' % (loss_gen) loss_ae = train_ae(net, dataset) msg = '|T:AuE, loss=%05f|' % (loss_ae) ptp.event(net.step, msg) ptp = ProgressTimer(nb_batches) # for i in range(nb_batches // 2): # loss_ae = train_ae(net, dataset) # msg = 'step #%5d, AuE, loss=%05f' % (net.step, loss_ae) # ptp.event(net.step, msg) # net.lr_decay() # for i in range(nb_batches // 2): # loss_ae = train_ae(net, dataset) # msg = 'step #%5d, AuE, loss=%05f' % (net.step, loss_ae) # ptp.event(net.step, msg) # for i in range(nb_batches): # loss_cri = train_cri(net, dataset) # msg = 'Cri, loss=%05f ' % (loss_cri) # loss_gen = train_gen(net, dataset) # msg += 'Gen, loss=%05f ' % (loss_gen) # ptp.event(net.step, msg) net.save('net', is_print=True)
def train2(cfg='train.json', **kwargs): with open(cfg, 'r') as fin: cfgs = json.load(fin) pp_json('TRAIN 2 CONFIGS:', cfgs) dataset_class = getattr(datasets, cfgs['dataset']) net_class = getattr(nets2, cfgs['net']) filenames = cfgs.get('filenames') load_step = cfgs.get('load_step') net = net_class(filenames=filenames, **kwargs) net.build() if load_step is not None: net.load(load_step=load_step) pre_sum = time.time() pre_save = time.time() lrs = cfgs['lrs'] total_step = cfgs['total_step'] summary_freq = cfgs['summary_freq'] save_freq = cfgs['save_freq'] with dataset_class(filenames=filenames) as dataset_train: with dataset_class(filenames=filenames, mode='test') as dataset_test: pt = ProgressTimer(total_step * len(lrs)) cstep = 0 for lrv in lrs: net.learning_rate_value = lrv for i in range(total_step): ss = next(dataset_train) loss_v, _ = net.train(ss) pt.event(cstep, msg='loss %e.' % loss_v) cstep += 1 now = time.time() if now - pre_sum > summary_freq: ss = next(dataset_train) net.summary(ss, True) ss = next(dataset_test) net.summary(ss, False) pre_sum = now if now - pre_save > save_freq: net.save() pre_save = now net.save()
def train_sino8v2(load_step=-1, total_step=None, step8=1, step4=1, step2=1, sumf=5, savef=100, filenames=[], **kwargs): click.echo("START TRAINING!!!!") net = SRSino8v2(filenames='srsino8v2.json', **kwargs) net.build() if load_step > 0: net.load(load_step=load_step) ds8x_tr = Sinograms2(filenames='sino2_shep8x.json', mode='train') ds4x_tr = Sinograms2(filenames='sino2_shep4x.json', mode='train') ds2x_tr = Sinograms2(filenames='sino2_shep2x.json', mode='train') ds8x_te = Sinograms2(filenames='sino2_shep8x.json', mode='test') ds4x_te = Sinograms2(filenames='sino2_shep4x.json', mode='test') ds2x_te = Sinograms2(filenames='sino2_shep2x.json', mode='test') datasets = [ds8x_tr, ds4x_tr, ds2x_tr, ds8x_te, ds4x_te, ds2x_te] for ds in datasets: ds.init() pt = ProgressTimer(total_step * 3) cstp = 0 for i in range(total_step // (step8 + step4 + step2)): for _ in range(step8): ss = next(ds8x_tr) loss_v, _ = net.train('net_8x', ss) pt.event(cstp, msg='train net_8x, loss %e.' % loss_v) cstp += 1 for _ in range(step4): ss = next(ds4x_tr) loss_v, _ = net.train('net_4x', ss) pt.event(cstp, msg='train net_4x, loss %e.' % loss_v) cstp += 1 for _ in range(step2): ss = next(ds2x_tr) loss_v, _ = net.train('net_2x', ss) pt.event(cstp, msg='train net_2x, loss %e.' % loss_v) cstp += 1 if i % sumf == 0: ss = next(ds8x_tr) net.summary('net_8x', ss, True) ss = next(ds4x_tr) net.summary('net_4x', ss, True) ss = next(ds2x_tr) net.summary('net_2x', ss, True) ss = next(ds8x_te) net.summary('net_8x', ss, False) ss = next(ds4x_te) net.summary('net_4x', ss, False) ss = next(ds2x_te) net.summary('net_2x', ss, False) if i % savef == 0: net.save() net.save() for ds in datasets: ds.close()
def train_sr_d(dataset_name, net_name, epochs, steps_per_epoch, load_step=-1, filenames=[], **kwargs): """ train super resolution net """ dsc = getattr(xlearn.datasets, dataset_name) netc = getattr(xlearn.nets, net_name) print_pretty_args(train_sr_d, locals()) with dsc(filenames=filenames) as dataset: net_settings = {'filenames': filenames} if load_step is not None: net_settings.update({'init_step': load_step}) net = netc(**net_settings) net.define_net() cpx, cpy = net.crop_size if load_step is not None: if load_step > 0: click.secho( '=' * 5 + 'LOAD PRE TRAIN WEIGHTS OF {0:7d} STEPS.'.format(load_step) + '=' * 5, fg='yellow') net.load(step=load_step, is_force=True) net.global_step = load_step click.secho(net.pretty_settings()) pt = ProgressTimer(epochs * steps_per_epoch) for _ in range(epochs): for _ in range(steps_per_epoch): s = next(dataset) loss = net.train_on_batch(inputs=s[0], outputs=s[1]) msg = "model:{0:5s}, loss={1:10e}, gs={2:7d}.".format( net._scadule_model(), loss, net.global_step) pt.event(msg=msg) net.save(step=net.global_step) net.dump_loss()
def train(nb_batches, cfs): dataset = MNIST(filenames=cfs) print("=" * 30) print("DATASET SETTINGS:") print(dataset.pretty_settings()) print("=" * 30) net = LSGAN(filenames=cfs) net.define_net() print("=" * 30) print("NETWORK SETTINGS:") print(net.pretty_settings()) print("=" * 30) ptp = ProgressTimer(net.pre_train) for i in range(net.pre_train): s = next(dataset) z = net.gen_latent() loss_c = net.train_on_batch('Cri', [s[0], z], []) msg = 'loss_c= %f' % loss_c ptp.event(i, msg) pt = ProgressTimer(nb_batches) loss_c = np.nan loss_g = np.nan for i in range(nb_batches): s = next(dataset) z = net.gen_latent() if i % net.gen_freq > 0: loss_c = net.train_on_batch('Cri', [s[0], z], []) msg = 'c_step, loss_c= %f; loss_g= %f' % (loss_c, loss_g) pt.event(i, msg) else: loss_g = net.train_on_batch('WGan', [s[0], z], []) msg = 'g_step, loss_c= %f; loss_g= %f' % (loss_c, loss_g) pt.event(i, msg) if i % 1000 == 0: net.save('AutoEncoder') net.save('AutoEncoder', is_print=True)