Exemple #1
0
def train_wrapper(model):
    """Wrapping function to train the model."""
    if FLAGS.pretrained_model:
        model.load(FLAGS.pretrained_model)
    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name,
        FLAGS.train_data_paths,
        FLAGS.valid_data_paths,
        FLAGS.batch_size * FLAGS.n_gpu,
        FLAGS.img_width,
        seq_length=FLAGS.total_length,
        is_training=True)

    eta = FLAGS.sampling_start_value

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        ims = train_input_handle.get_batch()
        if FLAGS.dataset_name == 'penn':
            ims = ims['frame']
        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)

        eta, real_input_flag = schedule_sampling(eta, itr)

        trainer.train(model, ims, real_input_flag, FLAGS, itr)

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)

        if itr % FLAGS.test_interval == 0:
            trainer.test(model, test_input_handle, FLAGS, itr)

        train_input_handle.next()
Exemple #2
0
def test_wrapper(model):
    model.load(FLAGS.pretrained_model)
    test_input_handle = datasets_factory.data_provider(FLAGS.dataset_name,
                                                       FLAGS.train_data_paths,
                                                       FLAGS.valid_data_paths,
                                                       FLAGS.batch_size *
                                                       FLAGS.n_gpu,
                                                       FLAGS.img_width,
                                                       is_training=False)
    trainer.test(model, test_input_handle, FLAGS, 'test_result')
Exemple #3
0
def train_wrapper(model):
	"""Wrapping function to train the model."""
	if FLAGS.pretrained_model:
		model.load(FLAGS.pretrained_model)
  # load data
	train_input_handle, test_input_handle = datasets_factory.data_provider(
		FLAGS.dataset_name,
		FLAGS.train_data_paths,
		FLAGS.valid_data_paths,
		FLAGS.batch_size * FLAGS.n_gpu,
		FLAGS.img_width,
		FLAGS.input_seq_length,
		FLAGS.output_seq_length,
		FLAGS.dimension_3D,
		is_training=True)
	print('Data loaded.')
	eta = FLAGS.sampling_start_value

	tra_cost = 0.0
	batch_id = 0
	stopping = [10000000000000000]
	for itr in range(2351, FLAGS.max_iterations + 1):
		if itr == 2:
			print('training process started.')
		#if itr % 50 == 0:
		#	print('training timestep: ' + str(itr))
		if train_input_handle.no_batch_left() or itr % 50 == 0:
			model.save(itr)
			print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),'itr: ' + str(itr))
			print('training loss: ' + str(tra_cost / batch_id))
			val_cost = trainer.test(model, test_input_handle,FLAGS, itr)
			if val_cost < min(stopping):
				stopping = [val_cost]
			elif len(stopping) < 10:
				stopping.append(val_cost)
			if len(stopping) == 10:
				break
			train_input_handle.begin(do_shuffle=True)
			tra_cost = 0
			batch_id = 0

		ims = train_input_handle.get_batch()
		batch_id += 1

		eta, real_input_flag = schedule_sampling(eta, itr)

		tra_cost += trainer.train(model, ims, real_input_flag, FLAGS, itr)

		#if itr % FLAGS.snapshot_interval == 0:
			#model.save(itr)

		#if itr % FLAGS.test_interval == 0:
			#trainer.test(model, test_input_handle, FLAGS, itr)

		train_input_handle.next_batch()
Exemple #4
0
def test_wrapper(model):
	model.load(FLAGS.pretrained_model)
	test_input_handle = datasets_factory.data_provider(
		FLAGS.dataset_name,
		FLAGS.train_data_paths, 
		FLAGS.test_data_paths, # Should use test data rather than training or validation data.
		FLAGS.batch_size * FLAGS.n_gpu,
		FLAGS.img_width,
		FLAGS.input_seq_length,
		FLAGS.output_seq_length,
		FLAGS.dimension_3D,
		is_training=False)
	trainer.test(model, test_input_handle, FLAGS, 'test_result')