예제 #1
0
def main(args):
    if not args.embed and not args.checkpoint_path:
        raise ValueError('checkpoint_path needed')

    # build model
    global model
    model = load_model(args)
    B = model.batch_size
    log.infov('Batch Size : %d', B)

    # load datasets (validation only)
    log.info("Data loading start !")
    data_sets = crc_input_data_seq.read_crc_data_sets(
        98,
        98,
        49,
        49,
        np.float32,
        use_cache=False,
        parallel_jobs=10,
        dataset=args.dataset,
        split_modes=['valid'],  # XXX valid
        fixation_original_scale=True,  # WTF
        max_folders=500  # XXX FOR FAST DEBUGGING AND TESTING
    )
    log.info("Data loading done")

    if args.embed:

        def go(checkpoint_path):
            reload_checkpoint(model, checkpoint_path)
            out_dir = get_out_dir(args.dataset, checkpoint_path)
            print 'go at %s' % checkpoint_path
            run_evaluation(model, args, data_sets, out_dir)
            import gc
            gc.collect()

        log.infov('Usage: >>> go(checkpoint_path)')
        from IPython import embed
        embed()
    else:
        reload_checkpoint(model, args.checkpoint_path)

        out_dir = get_out_dir(args.dataset, args.checkpoint_path)
        run_evaluation(model, args, data_sets, out_dir)
model = Gaze_Prediction_Module(batch_size, n_lstm_steps, dim_feature,
                               dim_hidden_u, dim_hidden_b, dim_sal,
                               dim_sal_proj, dim_cnn, dim_cnn_proj, out_proj)

#tf_loss, tf_saliency, tf_cnn, tf_cnnmask, tf_gazemap, tf_probs = model.build_model()
tf_loss, tf_cnn, tf_gazemap, tf_probs = model.build_model()
cnn_tf, gaze_output, pupil_output = model.build_generator()

print 'Prediction RNN Model builded '

print 'Dataset Loading'
crc_data_sets = crc_input_data_seq.read_crc_data_sets(FLAGS.image_height,
                                                      FLAGS.image_width,
                                                      FLAGS.gazemap_height,
                                                      FLAGS.gazemap_width,
                                                      tf.float32,
                                                      reload_=True,
                                                      dataset='crc')

#crc_data_sets = crc_input_data_seq.read_crc_data_sets(FLAGS.image_height,
#                                                      FLAGS.image_width,
#                                                      FLAGS.gazemap_height,
#                                                      FLAGS.gazemap_width,
#                                                      tf.float32,
#                                                      reload_ = False,
#                                                      dataset='hollywood2')
print 'Dataset setting finished'
sess = tf.InteractiveSession()
saver = tf.train.Saver(max_to_keep=100)
train_op = tf.train.AdamOptimizer(learning_rate).minimize(tf_loss)
예제 #3
0
def self_test(args):
    global model, data_sets

    session = tf.Session(config=tf.ConfigProto(
        gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.5),
        device_count={'GPU': True},  # self-testing: NO GPU, USE CPU
    ))

    log.warn('Loading %s input data ...', args.dataset)
    if args.dataset == 'salicon':
        #salicon_input_data.

        # data_sets = salicon_input_data.read_salicon_data_sets(
        #     98, 98, 49, 49, np.float32,
        #     use_example=False, # only tens
        #     use_val_split=True,
        # ) # self test small only
        salicon_data = salicon_input_data.SaliconData(
            98,
            98,
            49,
            49,
            np.float32,
            use_example=False,  # only tens
            use_val_split=True,
        )
        data_sets = salicon_data.build()

    elif args.dataset == 'crc':
        data_sets = crc_input_data.read_crc_data_sets(98,
                                                      98,
                                                      49,
                                                      49,
                                                      np.float32,
                                                      use_cache=True)
    else:
        raise ValueError('Unknown dataset : %s' % args.dataset)

    print 'Train', data_sets.train
    print 'Validation', data_sets.valid

    log.warn('Building Model ...')
    # default configuration as of now
    config = BaseModelConfig()
    config.train_dir = args.train_dir
    if args.train_tag:
        config.train_tag = args.train_tag

    config.batch_size = 128
    config.use_flip_batch = True
    #config.initial_learning_rate = 0.03
    config.initial_learning_rate = 0.00003
    config.optimization_method = 'adam'
    config.steps_per_evaluation = 7000  # for debugging

    if args.learning_rate is not None:
        config.initial_learning_rate = float(args.learning_rate)
    if args.learning_rate_decay is not None:
        config.learning_rate_decay = float(args.learning_rate_decay)
    if args.batch_size is not None:
        config.batch_size = int(args.batch_size)

    if args.max_steps:
        config.max_steps = int(args.max_steps)

    if args.dataset == 'crc':
        config.batch_size = 2  # because of T~=35
        config.steps_per_evaluation = 200

    config.dump(sys.stdout)

    log.warn('Start Fitting Model ...')
    model = SaliencyModel(session, data_sets, config)
    print model

    model.fit()

    log.warn('Fitting Done. Evaluating!')
    model.generate_and_evaluate(data_sets.test)
예제 #4
0
def self_test(args):
    global model, config

    assert 0.0 < args.gpu_fraction <= 1.0
    session = tf.Session(config=tf.ConfigProto(
        gpu_options=tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_fraction),
        device_count={'GPU': True},  # self-testing: NO GPU, USE CPU
    ))

    log.warn('Building Model ...')
    # default configuration as of now
    config = GRUModelConfig()

    # CRC likes 28 :)
    config.batch_size = 28

    config.train_dir = args.train_dir
    if args.train_tag:
        config.train_tag = args.train_tag
    config.initial_learning_rate = 0.0001
    config.max_grad_norm = 1.0

    if args.learning_rate is not None:
        config.initial_learning_rate = float(args.learning_rate)
    if args.learning_rate_decay is not None:
        config.learning_rate_decay = float(args.learning_rate_decay)
    if args.batch_size is not None:
        config.batch_size = int(args.batch_size)

    config.steps_per_evaluation = 500
    config.steps_per_validation = 50  #?!?!
    config.steps_per_checkpoint = 500

    if args.max_steps:
        config.max_steps = int(args.max_steps)
    if args.dataset == 'crc':
        config.steps_per_evaluation = 100

    config.dump(sys.stdout)

    log.warn('Dataset (%s) Loading ...', args.dataset)
    assert args.dataset in ('crc', 'hollywood2')
    data_sets = crc_input_data_seq.read_crc_data_sets(CONSTANTS.image_height,
                                                      CONSTANTS.image_width,
                                                      CONSTANTS.gazemap_height,
                                                      CONSTANTS.gazemap_width,
                                                      np.float32,
                                                      use_cache=True,
                                                      dataset=args.dataset)
    log.warn('Dataset Loading Finished ! (%d instances)', len(data_sets))

    log.warn('Start Fitting Model ...')
    model = GazePredictionGRCN(session, data_sets, config)
    print model

    if args.shallownet_pretrain is not None:
        log.warn('Loading ShallowNet weights from checkpoint %s ...',
                 args.shallownet_pretrain)
        model.initialize_pretrained_shallownet(args.shallownet_pretrain)

    model.fit()

    log.warn('Fitting Done. Evaluating!')
    model.evaluate(data_sets.test)
예제 #5
0
def train(args):
    global model, config

    assert 0.0 < args.gpu_fraction <= 1.0
    session = tf.Session(config=tf.ConfigProto(
        gpu_options=tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_fraction,
            allow_growth=True),
        device_count={'GPU': True},  # self-testing: NO GPU, USE CPU
    ))

    log.warn('Building Model ...')

    log.infov('MODEL   : %s', args.model)
    log.infov('DATASET : %s', args.dataset)
    if args.model == 'gaze_grcn':
        from gaze_grcn import GazePredictionGRCN as TheModel
        from gaze_grcn import CONSTANTS, GRUModelConfig
    elif args.model == 'gaze_lstm':
        from gaze_lstm import GazePredictionLSTM as TheModel
        from gaze_lstm import CONSTANTS, GRUModelConfig
    elif args.model == 'gaze_grcn77':
        from gaze_grcn77 import GazePredictionGRCN as TheModel
        from gaze_grcn77 import CONSTANTS, GRUModelConfig
    elif args.model == 'gaze_rnn77':
        from gaze_rnn77 import GazePredictionGRU as TheModel
        from gaze_rcn77 import CONSTANTS, GRUModelConfig
    elif args.model == 'gaze_rnn':
        from gaze_rnn import GazePredictionGRU as TheModel
        from gaze_rnn import CONSTANTS, GRUModelConfig
    elif args.model == 'gaze_c3d_conv':
        from gaze_c3d_conv import GazePredictionConv as TheModel
        from gaze_c3d_conv import CONSTANTS, GRUModelConfig
    elif args.model == 'gaze_shallownet_rnn':
        from gaze_shallownet_rnn import GazePredictionGRU as TheModel
        from gaze_shallownet_rnn import CONSTANTS, GRUModelConfig
    elif args.model == 'gaze_framewise_shallownet':
        from gaze_framewise_shallownet import FramewiseShallowNet as TheModel
        from gaze_framewise_shallownet import CONSTANTS, GRUModelConfig
    elif args.model == 'gaze_deeprnn':
        from gaze_rnn_deep import DEEPRNN as TheModel
        from gaze_rnn_deep import CONSTANTS, GRUModelConfig
    else:
        raise NotImplementedError(args.model)

    # default configuration as of now
    config = GRUModelConfig()

    # CRC likes 28 :)
    config.batch_size = 28

    config.train_dir = args.train_dir
    if args.train_tag:
        config.train_tag = args.train_tag
    config.initial_learning_rate = 0.0001
    config.max_grad_norm = 10.0
    config.use_flip_batch = True

    if args.max_grad_norm is not None:
        config.max_grad_norm = args.max_grad_norm
    if args.learning_rate is not None:
        config.initial_learning_rate = float(args.learning_rate)
    if args.learning_rate_decay is not None:
        config.learning_rate_decay = float(args.learning_rate_decay)
    if args.batch_size is not None:
        config.batch_size = int(args.batch_size)
    if args.loss_type is not None:
        config.loss_type = args.loss_type

    config.steps_per_evaluation = 100
    config.steps_per_validation = 20
    config.steps_per_checkpoint = 100

    if args.max_steps:
        config.max_steps = int(args.max_steps)

    config.dump(sys.stdout)

    log.warn('Dataset (%s) Loading ...', args.dataset)
    assert args.dataset in ('crc', 'hollywood2', 'crcxh2')
    data_sets = crc_input_data_seq.read_crc_data_sets(
        CONSTANTS.image_height,
        CONSTANTS.image_width,
        CONSTANTS.gazemap_height,
        CONSTANTS.gazemap_width,
        np.float32,
        use_cache=True,
        batch_norm=args.batch_norm,
        dataset=args.dataset)

    log.warn('Dataset Loading Finished ! (%d instances)', len(data_sets))

    log.warn('Start Fitting Model ...')
    model = TheModel(session, data_sets, config)

    print model

    if args.shallownet_pretrain is not None:
        log.warn('Loading ShallowNet weights from checkpoint %s ...',
                 args.shallownet_pretrain)
        model.initialize_pretrained_shallownet(args.shallownet_pretrain)

    model.fit()

    log.warn('Fitting Done. Evaluating!')
    model.generate_and_evaluate(
        data_sets.test,
        max_instances=None)  #WHERE IS THIS FUNCTION I CANNOT FIND IT
reload_ = False
n_epochs = 1000

model = Gaze_Prediction_Module(batch_size, n_lstm_steps, dim_feature,
                               dim_hidden_u, dim_hidden_b, dim_sal,
                               dim_sal_proj, dim_cnn, dim_cnn_proj, out_proj)

#tf_loss, tf_saliency, tf_cnn, tf_cnnmask, tf_gazemap, tf_probs = model.build_model()
tf_loss, tf_cnn, tf_gazemap, tf_probs = model.build_model()

print 'Prediction RNN Model builded '

print 'Dataset Loading'
data_sets = crc_input_data_seq.read_crc_data_sets(FLAGS.image_height,
                                                  FLAGS.image_width,
                                                  FLAGS.gazemap_height,
                                                  FLAGS.gazemap_width,
                                                  tf.float32, True)
print 'Dataset setting finished'

sess = tf.InteractiveSession()
saver = tf.train.Saver(max_to_keep=100)
train_op = tf.train.AdamOptimizer(learning_rate).minimize(tf_loss)

#init = tf.initialize_all_variables()
tf.initialize_all_variables().run()

offset = 0
if reload_:
    print 'Realod ' + from_dir + model_name
    saver.restore(sess, from_dir + model_name)
    #labels_dict_test = OrderedDict()
    ## PROBS MAKE TEST LONGER THAN 1 ....
    #test_folders = labels_dict_val.keys[-1]
    #labels_dict_test[labels_dict_val.keys[-1]] = labels_dict_val[labels_dict_val.keys[-1]]

    #import pdb; pdb.set_trace()
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--attention', action='store_true', default=False)
    args = parser.parse_args()
    dataset = crc_input_data_seq.read_crc_data_sets(
        98,
        98,
        49,
        49,
        np.float32,
        use_cache=False,
        parallel_jobs=10,
        dataset='hollywood2',
        with_attention=args.attention,
        fixation_original_scale=False,  # WTF
        # XXX FOR FAST DEBUGGING AND TESTING
    )

    model = load_model()

    ### loading train data
    gazemap_pred, gazemap_gt, frame, labels, c3d, clipnames = load_data(
        model, dataset.train, labels_dict_train)
    if args.attention:
        attention = '_attention'
    else:
def visualize_outputs_wrapper(checkpoint_path,
                              session=None,
                              split_mode='valid',
                              dataset='hollywood2',
                              model_class=GazePredictionGRU,
                              max_instances=100):

    if session is None:
        session = tf.InteractiveSession()

    # WTF model persistence

    if checkpoint_path is not None:
        # load config and data loader
        #config_path = os.path.join(os.path.dirname(checkpoint_path), 'config.pkl')
        #with open(config_path, 'rb') as fp:
        #    config = pkl.load(fp)
        #    log.info('Loaded config from %s', config_path)
        #    config.dump(sys.stdout)
        config_path = os.path.join(os.path.dirname(checkpoint_path),
                                   '../config.json')
        config = BaseModelConfig.load(config_path)
        log.info('Loaded config from %s', config_path)
    else:
        # default config!?
        config = GRUModelConfig()

    # do not noise original train dirs.
    config.train_dir = None
    config.dump(sys.stdout)

    log.warn('Dataset (%s) Loading ...', dataset)
    assert dataset in ('crc', 'hollywood2')

    from models.gaze_rnn import CONSTANTS  # XXX Dirty here
    data_sets = crc_input_data_seq.read_crc_data_sets(CONSTANTS.image_height,
                                                      CONSTANTS.image_width,
                                                      CONSTANTS.gazemap_height,
                                                      CONSTANTS.gazemap_width,
                                                      np.float32,
                                                      use_cache=True,
                                                      dataset=dataset,
                                                      split_modes=[split_mode])

    # resurrect model
    # XXX remove hard-codes

    # TODO assuming there can be only one graph in the process?
    # TODO should any of our model should contain a graph context ????
    #tf.reset_default_graph()
    model = model_class(session, data_sets, config)

    # load checkpoint
    if checkpoint_path is not None:
        assert os.path.isfile(checkpoint_path)
        model.load_model_from_checkpoint_file(checkpoint_path)

    # Run!
    if split_mode == 'valid':
        model_outputs = visualize_outputs_of(model, data_sets.valid,
                                             max_instances)
    elif split_mode == 'train':
        model_outputs = visualize_outputs_of(model, data_sets.train,
                                             max_instances)
    else:
        raise ValueError(split_mode)

    return model, model_outputs