コード例 #1
0
def evaluator_thread(cnn_file, hp, inputs, outputs):
    # start tensorflow session
    import numpy as np
    import tensorflow as tf
    # sess = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=True))
    session_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=hp.config['log_device_placement'])
    sess = tf.InteractiveSession(config=session_config)

    # import network
    from cnn_autoencoder_v9 import create_cnn
    cnn = create_cnn(hp)

    # init session and load network state
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, cnn_file)

    # little chat
    print('lf cnn evaluator waiting for inputs')

    terminated = 0
    while not terminated:

        batch = inputs.get()
        if batch == ():
            terminated = 1
        else:
            out = dict()
            # default params for network input
            net_in = cnn.prepare_net_input(batch)
            net_in[cnn.noise_sigma] = 0.0

            # FOR SPECIFIC DECODER PATH (todo: make less of a hack)
            decoder_path = batch['decoder_path']
            if decoder_path == 'depth':
                decoder = cnn.decoders_2D[decoder_path]
                (sv, loss) = sess.run([decoder['upconv'], decoder['loss']],
                                      feed_dict=net_in)
            else:
                decoder = cnn.decoders_3D[decoder_path]
                (sv_v,
                 loss_v) = sess.run([decoder['upconv_v'], decoder['loss']],
                                    feed_dict=net_in)
                (sv_h,
                 loss_h) = sess.run([decoder['upconv_h'], decoder['loss']],
                                    feed_dict=net_in)
                sv_h = np.transpose(sv_h, [0, 1, 3, 2, 4])
                sv = sv_v[:, :, :, :, :]
                out['cv_v'] = sv_v
                out['cv_h'] = sv_h

            out['cv'] = sv
            outputs.put((out, batch))

        inputs.task_done()
コード例 #2
0
def trainer_thread(model_path, hp, inputs):

    # start tensorflow session
    session_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=hp.config['log_device_placement'])
    sess = tf.InteractiveSession(config=session_config)

    # import network
    from cnn_autoencoder_v9 import create_cnn
    cnn = create_cnn(hp)

    # add optimizers (will be saved with the network)
    cnn.add_training_ops()
    # start session
    print('  initialising TF session')
    sess.run(tf.global_variables_initializer())
    print('  ... done')

    # save object
    print('  checking for model ' + model_path)
    if os.path.exists(model_path + 'model.ckpt.index'):
        print('  restoring model ' + model_path)
        tft.optimistic_restore(sess, model_path + 'model.ckpt')
        print('  ... done.')
    else:
        print('  ... not found.')

    writerTensorboard = tf.summary.FileWriter(
        hp.tf_log_path + hp.network_model, sess.graph)
    # writerTensorboard = tf.summary.FileWriter('./visual_' + hp.network_model, sess.graph)
    # new saver object with complete network
    saver = tf.train.Saver()

    # statistics
    count = 0.0
    print('lf cnn trainer waiting for inputs')

    terminated = 0
    while not terminated:

        batch = inputs.get()
        if batch == ():
            terminated = 1
        else:

            niter = batch['niter']
            ep = batch['epoch']

            # default params for network input
            net_in = cnn.prepare_net_input(batch)

            # evaluate current network performance on mini-batch
            if batch['logging']:

                print()
                sys.stdout.write('  dataset(%d:%s) ep(%d) batch(%d) : ' %
                                 (batch['nfeed'], batch['feed_id'], ep, niter))

                #loss_average = (count * loss_average + loss) / (count + 1.0)
                count = count + 1.0
                fields = [
                    '%s' % (datetime.datetime.now()), batch['feed_id'],
                    batch['nfeed'], niter, ep
                ]

                for id in cnn.decoders_2D:
                    # if id in batch:
                    (loss) = sess.run(cnn.decoders_2D[id]['loss'],
                                      feed_dict=net_in)
                    sys.stdout.write('  %s %g   ' % (id, loss))
                    fields.append(id)
                    fields.append(loss)
                # else:
                #   fields.append( '' )
                #   fields.append( '' )

                import csv
                with open(model_path + batch['logfile'], 'a+') as f:
                    writer = csv.writer(f)
                    writer.writerow(fields)

                summary = sess.run(cnn.merged, feed_dict=net_in)
                writerTensorboard.add_summary(summary, niter)
                print('')
                #code.interact( local=locals() )

            if batch['niter'] % hp.training[
                    'save_interval'] == 0 and niter != 0 and batch[
                        'nfeed'] == 0 and batch['training']:
                # epochs now take too long, save every few 100 steps
                # Save the variables to disk.
                save_path = saver.save(sess, model_path + 'model.ckpt')
                print('NEXT EPOCH')
                print("  model saved in file: %s" % save_path)
                # statistics
                #print("  past epoch average loss %g"%(loss_average))
                count = 0.0

            # run training step
            if batch['training']:
                net_in[cnn.phase] = True
                #code.interact( local=locals() )
                sys.stdout.write('.')  #T%i ' % int(count) )
                for id in cnn.minimizers:
                    # check if all channels required for minimizer are present in batch
                    ok = True
                    for r in cnn.minimizers[id]['requires']:
                        if not (r in batch):
                            ok = False

                    if ok:
                        sys.stdout.write(cnn.minimizers[id]['id'] + ' ')
                        sess.run(cnn.minimizers[id]['train_step'],
                                 feed_dict=net_in)

                sys.stdout.flush()

        inputs.task_done()
コード例 #3
0
def trainer_thread(model_path, hp, inputs):

    # start tensorflow session
    session_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=hp.config['log_device_placement'])
    sess = tf.InteractiveSession(config=session_config)

    # import network
    from cnn_autoencoder_v9 import create_cnn
    cnn = create_cnn(hp)

    # add optimizers (will be saved with the network)
    cnn.add_training_ops()
    # start session
    print('  initialising TF session')
    sess.run(tf.global_variables_initializer())

    # save object
    print('  checking for model ' + model_path)
    if os.path.exists(model_path + 'model.ckpt.index'):
        print('  restoring model ' + model_path)
        tft.optimistic_restore(sess, model_path + 'model.ckpt')
        print('  ... done.')
    else:
        print('  ... not found.')

    writerTensorboard = tf.summary.FileWriter(
        hp.tf_log_path + hp.network_model, sess.graph)
    # writerTensorboard = tf.summary.FileWriter('/home/mz/HD data/Tensorboard Logs/' + hp.network_model, sess.graph)
    # writerTensorboard = tf.summary.FileWriter('./visual_' + hp.network_model, sess.graph)
    # new saver object with complete network
    saver = tf.train.Saver()

    # statistics
    count = 0.0
    print('lf cnn trainer waiting for inputs')

    terminated = 0
    counter = 0
    iterGan = hp.iterGan
    if len(hp.discriminator) > 0:
        use_gan = True
        train_gan = True
    else:
        use_gan = False
        train_gan = False

    while not terminated:

        batch = inputs.get()
        if batch == ():
            terminated = 1
        else:

            niter = batch['niter']
            ep = batch['epoch']

            # default params for network input
            net_in = cnn.prepare_net_input(batch)

            # evaluate current network performance on mini-batch
            # (tst_h,tst_v) = sess.run([cnn.refinement_3D['lf']['stack_h'],cnn.refinement_3D['lf']['stack_v']], feed_dict=net_in)
            if batch['logging']:
                summary_image = sess.run(cnn.merged_images, feed_dict=net_in)
                writerTensorboard.add_summary(summary_image, niter)

                print()
                sys.stdout.write('  dataset(%d:%s) ep(%d) batch(%d) : ' %
                                 (batch['nfeed'], batch['feed_id'], ep, niter))

                #loss_average = (count * loss_average + loss) / (count + 1.0)
                count = count + 1.0
                fields = [
                    '%s' % (datetime.datetime.now()), batch['feed_id'],
                    batch['nfeed'], niter, ep
                ]

                # for id in cnn.decoders_3D:

                (loss) = sess.run(cnn.LFLSTM['losses'], feed_dict=net_in)
                print('adv_percep_loss: %s' % str(
                    sess.run(cnn.LFLSTM['loss_percep_adv'], feed_dict=net_in)))
                # print('gan_logit_sr %s' %str(sigmoid(sess.run( cnn.discriminator['logits'], feed_dict=net_in ))))
                # print('gan_logit_gt %s' %str(sigmoid(sess.run( cnn.discriminator['logits_gt'], feed_dict=net_in ))))
                sys.stdout.write('  %s %g   ' % ('LFLSTM', loss))
                fields.append('LFLSTM')
                fields.append(loss)
                for id in cnn.minimizers:
                    ok = True
                    if ok:
                        summary = sess.run(cnn.merged_lstm, feed_dict=net_in)
                        if id == 'GAN':
                            summary = sess.run(cnn.merged_gan,
                                               feed_dict=net_in)
                        writerTensorboard.add_summary(summary, niter)

                # (tst) = sess.run( cnn.tst, feed_dict=net_in )
                import csv
                with open(model_path + batch['logfile'], 'a+') as f:
                    writer = csv.writer(f)
                    writer.writerow(fields)

                print('')
                #code.interact( local=locals() )

            if batch['niter'] % hp.training[
                    'save_interval'] == 0 and niter != 0 and batch[
                        'nfeed'] == 0 and batch['training']:
                # epochs now take too long, save every few 100 steps
                # Save the variables to disk.
                # save_path = saver.save(sess, model_path + 'model_'+ str(batch[ 'niter' ])+'.ckpt' )
                save_path = saver.save(sess, model_path + 'model' + '.ckpt')
                print('NEXT EPOCH')
                print("  model saved in file: %s" % save_path)
                # statistics
                #print("  past epoch average loss %g"%(loss_average))
                count = 0.0

            # run training step
            if batch['training']:
                net_in[cnn.phase] = True
                #code.interact( local=locals() )
                sys.stdout.write('.')  #T%i ' % int(count) )
                if train_gan:
                    ok = True
                    for r in cnn.minimizers['GAN']['requires']:
                        if not ('lf_patches' + r in batch):
                            ok = False
                    if ok:
                        sys.stdout.write(cnn.minimizers['GAN']['id'] + ' ')
                        sess.run(cnn.minimizers['GAN']['train_step'],
                                 feed_dict=net_in)
                    counter += 1
                    if counter == iterGan:
                        train_gan = False
                else:
                    for id in cnn.minimizers:
                        if id != 'GAN':
                            ok = True
                            for r in cnn.minimizers[id]['requires']:
                                if not ('lf_patches' + r in batch):
                                    ok = False
                            if ok:
                                sys.stdout.write(cnn.minimizers[id]['id'] +
                                                 ' ')
                                sess.run(cnn.minimizers[id]['train_step'],
                                         feed_dict=net_in)
                    if use_gan:
                        train_gan = True
                        counter = 0
                sys.stdout.flush()

        inputs.task_done()
コード例 #4
0
ファイル: thread_train_v9.py プロジェクト: wps1215/LFSR
def trainer_thread(model_path, hp, inputs):

    # start tensorflow session
    session_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=hp.config['log_device_placement'])
    sess = tf.InteractiveSession(config=session_config)

    # import network
    from cnn_autoencoder_v9 import create_cnn
    cnn = create_cnn(hp)

    # add optimizers (will be saved with the network)
    cnn.add_training_ops()
    # start session
    print('  initialising TF session')
    sess.run(tf.global_variables_initializer())
    print('... restoring resnet50 ')
    # print('... restoring vgg19 ')
    # print('... restoring moobilenet ')
    # print('... restoring inception ')
    ckpt_resnet = 'resnet_v1_50.ckpt'
    # ckpt_vgg = 'vgg_19.ckpt'
    # ckpt_mobilenet = 'mobilenet/mobilenet_v1_0.5_192.ckpt'
    # ckpt_inception = 'inception/inception_v3.ckpt'

    # reader = tf.train.NewCheckpointReader(ckpt_mobilenet)
    # tst = reader.get_variable_to_shape_map()
    # reader.get_variable_to_shape_map()['MobilenetV2/expanded_conv_10/expand/weights']
    # import numpy as np
    # print(np.transpose(sorted(tst)))
    # resnet_saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope='resnet_v1_50'))
    # resnet_saver.restore(sess, ckpt_resnet)

    # vgg_saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope='vgg_19'))
    # vgg_saver.restore(sess, ckpt_vgg)

    # mobilenet_saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope='MobilenetV1'))
    # mobilenet_saver.restore(sess, ckpt_mobilenet)

    # inception_saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope='InceptionV3'))
    # inception_saver.restore(sess, ckpt_inception)

    # var_list = [op.name+':0' for op in sess.graph.get_operations() if
    #             'variable' in op.type.lower() and 'resnet' in op.name.lower()]
    # var_list = [sess.graph.get_tensor_by_name(t) for t in var_list]
    # saver = tf.train.Saver(var_list=var_list)
    # saver.restore(sess, ckpt_path)
    # test
    # (tst) = sess.run('resnet_v1_50/block1/unit_1/bottleneck_v1/conv3/weights:0')
    print('  ... done')

    # save object
    print('  checking for model ' + model_path)
    if os.path.exists(model_path + 'model.ckpt.index'):
        print('  restoring model ' + model_path)
        tft.optimistic_restore(sess, model_path + 'model.ckpt')
        print('  ... done.')
    else:
        print('  ... not found.')

    writerTensorboard = tf.summary.FileWriter(
        hp.tf_log_path + hp.network_model, sess.graph)
    # writerTensorboard = tf.summary.FileWriter('/home/mz/HD data/Tensorboard Logs/' + hp.network_model, sess.graph)
    # writerTensorboard = tf.summary.FileWriter('./visual_' + hp.network_model, sess.graph)
    # new saver object with complete network
    saver = tf.train.Saver()

    # statistics
    count = 0.0
    print('lf cnn trainer waiting for inputs')

    terminated = 0
    counter = 0
    iterGan = hp.iterGan
    if len(hp.discriminator) > 0:
        use_gan = True
        train_gan = True
    else:
        use_gan = False
        train_gan = False

    while not terminated:

        batch = inputs.get()
        if batch == ():
            terminated = 1
        else:

            niter = batch['niter']
            ep = batch['epoch']

            # default params for network input
            net_in = cnn.prepare_net_input(batch)

            # evaluate current network performance on mini-batch
            # (tst_h,tst_v) = sess.run([cnn.refinement_3D['lf']['stack_h'],cnn.refinement_3D['lf']['stack_v']], feed_dict=net_in)
            if batch['logging']:
                summary_image = sess.run(cnn.merged_images, feed_dict=net_in)
                writerTensorboard.add_summary(summary_image, niter)

                print()
                sys.stdout.write('  dataset(%d:%s) ep(%d) batch(%d) : ' %
                                 (batch['nfeed'], batch['feed_id'], ep, niter))

                #loss_average = (count * loss_average + loss) / (count + 1.0)
                count = count + 1.0
                fields = [
                    '%s' % (datetime.datetime.now()), batch['feed_id'],
                    batch['nfeed'], niter, ep
                ]

                for id in cnn.decoders_3D:
                    (loss) = sess.run(cnn.decoders_3D[id]['loss_s2'],
                                      feed_dict=net_in)
                    sys.stdout.write('  %s %g   ' % (id, loss))
                    fields.append(id)
                    fields.append(loss)
                for id in cnn.minimizers:
                    ok = True
                    for r in cnn.minimizers[id]['requires']:
                        if not ('stacks_v_' + r in batch):
                            ok = False
                    if ok:
                        if id.endswith('s2'):
                            summary = sess.run(cnn.merged_s2, feed_dict=net_in)
                        if id.endswith('s4'):
                            summary = sess.run(cnn.merged_s4, feed_dict=net_in)
                        if id == 'GAN':
                            summary = sess.run(cnn.merged_gan,
                                               feed_dict=net_in)
                        writerTensorboard.add_summary(summary, niter)

                # (tst) = sess.run( cnn.tst, feed_dict=net_in )
                import csv
                with open(model_path + batch['logfile'], 'a+') as f:
                    writer = csv.writer(f)
                    writer.writerow(fields)

                print('')
                #code.interact( local=locals() )

            if batch['niter'] % hp.training[
                    'save_interval'] == 0 and niter != 0 and batch[
                        'nfeed'] == 0 and batch['training']:
                # epochs now take too long, save every few 100 steps
                # Save the variables to disk.
                save_path = saver.save(
                    sess,
                    model_path + 'model_' + str(batch['niter']) + '.ckpt')
                print('NEXT EPOCH')
                print("  model saved in file: %s" % save_path)
                # statistics
                #print("  past epoch average loss %g"%(loss_average))
                count = 0.0

            # run training step
            if batch['training']:
                net_in[cnn.phase] = True
                #code.interact( local=locals() )
                sys.stdout.write('.')  #T%i ' % int(count) )
                if train_gan:
                    ok = True
                    for r in cnn.minimizers['GAN']['requires']:
                        if not ('stacks_v_' + r in batch):
                            ok = False
                    if ok:
                        sys.stdout.write(cnn.minimizers['GAN']['id'] + ' ')
                        sess.run(cnn.minimizers['GAN']['train_step'],
                                 feed_dict=net_in)
                    counter += 1
                    if counter == iterGan:
                        train_gan = False
                else:
                    for id in cnn.minimizers:
                        if id != 'GAN':
                            ok = True
                            for r in cnn.minimizers[id]['requires']:
                                if not ('stacks_v_' + r in batch):
                                    ok = False
                            if ok:
                                sys.stdout.write(cnn.minimizers[id]['id'] +
                                                 ' ')
                                sess.run(cnn.minimizers[id]['train_step'],
                                         feed_dict=net_in)
                    if use_gan:
                        train_gan = True
                        counter = 0
                sys.stdout.flush()

        inputs.task_done()