def train_wrapper(model): """Wrapping function to train the model.""" if FLAGS.pretrained_model: model.load(FLAGS.pretrained_model) # load data train_input_handle, test_input_handle = datasets_factory.data_provider( FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths, FLAGS.batch_size * FLAGS.n_gpu, FLAGS.img_width, seq_length=FLAGS.total_length, is_training=True) eta = FLAGS.sampling_start_value for itr in range(1, FLAGS.max_iterations + 1): if train_input_handle.no_batch_left(): train_input_handle.begin(do_shuffle=True) ims = train_input_handle.get_batch() if FLAGS.dataset_name == 'penn': ims = ims['frame'] ims = preprocess.reshape_patch(ims, FLAGS.patch_size) eta, real_input_flag = schedule_sampling(eta, itr) trainer.train(model, ims, real_input_flag, FLAGS, itr) if itr % FLAGS.snapshot_interval == 0: model.save(itr) if itr % FLAGS.test_interval == 0: trainer.test(model, test_input_handle, FLAGS, itr) train_input_handle.next()
def test_wrapper(model): model.load(FLAGS.pretrained_model) test_input_handle = datasets_factory.data_provider(FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths, FLAGS.batch_size * FLAGS.n_gpu, FLAGS.img_width, is_training=False) trainer.test(model, test_input_handle, FLAGS, 'test_result')
def train_wrapper(model): """Wrapping function to train the model.""" if FLAGS.pretrained_model: model.load(FLAGS.pretrained_model) # load data train_input_handle, test_input_handle = datasets_factory.data_provider( FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths, FLAGS.batch_size * FLAGS.n_gpu, FLAGS.img_width, FLAGS.input_seq_length, FLAGS.output_seq_length, FLAGS.dimension_3D, is_training=True) print('Data loaded.') eta = FLAGS.sampling_start_value tra_cost = 0.0 batch_id = 0 stopping = [10000000000000000] for itr in range(2351, FLAGS.max_iterations + 1): if itr == 2: print('training process started.') #if itr % 50 == 0: # print('training timestep: ' + str(itr)) if train_input_handle.no_batch_left() or itr % 50 == 0: model.save(itr) print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),'itr: ' + str(itr)) print('training loss: ' + str(tra_cost / batch_id)) val_cost = trainer.test(model, test_input_handle,FLAGS, itr) if val_cost < min(stopping): stopping = [val_cost] elif len(stopping) < 10: stopping.append(val_cost) if len(stopping) == 10: break train_input_handle.begin(do_shuffle=True) tra_cost = 0 batch_id = 0 ims = train_input_handle.get_batch() batch_id += 1 eta, real_input_flag = schedule_sampling(eta, itr) tra_cost += trainer.train(model, ims, real_input_flag, FLAGS, itr) #if itr % FLAGS.snapshot_interval == 0: #model.save(itr) #if itr % FLAGS.test_interval == 0: #trainer.test(model, test_input_handle, FLAGS, itr) train_input_handle.next_batch()
def test_wrapper(model): model.load(FLAGS.pretrained_model) test_input_handle = datasets_factory.data_provider( FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.test_data_paths, # Should use test data rather than training or validation data. FLAGS.batch_size * FLAGS.n_gpu, FLAGS.img_width, FLAGS.input_seq_length, FLAGS.output_seq_length, FLAGS.dimension_3D, is_training=False) trainer.test(model, test_input_handle, FLAGS, 'test_result')