def main(): ''' Entry point of this script. ''' args = parse_args() input_path = os.path.abspath(args.input) assert_in(os.path.splitext(input_path)[1], ('.yaml', '.pkl')) input_dir, input_filename = os.path.split(input_path) input_basename, input_extension = os.path.splitext(input_filename) # dependency: # hyperparams->dataset->iter->input_nodes->model if os.path.splitext(input_path)[1] == '.yaml': print("Starting fresh training run with config {}.".format(input_path)) with open(input_path, 'r') as yaml_file: yaml_dict = yaml.load(yaml_file) # Add library version info to yaml_dict if args.h5 is None: assert_not_in('versions', yaml_dict) yaml_dict['versions'] = {'simplelearn': simplelearn.__version__, 'poselearn': poselearn.__version__, 'theano': theano.__version__} else: def check_version(name, yaml_dict, module): ''' Warn if resuming from a software version that's different from what was trained on. ''' yaml_version = yaml_dict['versions'][name] current_version = module.__version__ if yaml_version != current_version: warnings.warn( "resuming from a different version of {} than " "what was trained on:\n" "trained on {}\n" "resuming on {}".format( name, yaml_dict['versions'][name], module.__version__)) check_version('simplelearn', yaml_dict, simplelearn) check_version('poselearn', yaml_dict, poselearn) check_version('theano', yaml_dict, theano) hyperparams_dict = yaml_dict['hyperparams'] # Dataset & iterators training_set, validation_set = _get_datasets(yaml_dict) iterator_rng = numpy.random.RandomState(2352) training_iter = training_set.iterator( 'random_vec3', batch_size=hyperparams_dict['training_batch_size'], rng=iterator_rng) validation_iter = validation_set.iterator( 'random_vec3', batch_size=hyperparams_dict['validation_batch_size'], rng=iterator_rng) # # The DAG # input_nodes = training_iter.make_input_nodes() use_dropout = yaml_dict['hyperparams']['use_dropout'] def make_model(): numpy_rng = numpy.random.RandomState(2352) theano_rng = (RandomStreams(231134) if use_dropout else None) return IdAndCameraDirModelConv(input_nodes[0], yaml_dict, numpy_rng, theano_rng) model = make_model() loss_node = IdAndCameraDirConvLoss(model.id_layers[-1], model.cam_dir_layers[-1], input_nodes, blank_id=6) # Iterator and updater state # sic; no IdAndCameraDirConvTrainingState. training_state = IdAndCameraDirTrainingState(model, loss_node, yaml_dict, training_iter, validation_iter) if args.h5 is not None: print("Resuming old training run from {}.".format(args.h5)) with h5py.File(args.h5, mode='r') as h5_file: model.load_from_h5(h5_file['model']) training_state.load_from_h5(h5_file['training_state']) output_dir = os.path.split(args.h5)[0] else: assert_false(os.path.isdir(args.output_dir)) output_dir = args.output_dir os.mkdir(output_dir) else: print("Resuming old training run from {}.".format(args.input)) with open(args.input, 'r') as pkl_file: pkl_dict = cPickle.load(pkl_file) yaml_dict = pkl_dict['yaml_dict'] model = pkl_dict['model'] loss_node = pkl_dict['loss_node'] training_state = pkl_dict['training_state'] assert_is(args.output_dir, None) output_dir = input_dir # Expect .pkl file to end in _best.pkl or _state_#####.pkl, and # remove those parts from the input_basename. if input_basename.endswith('_best'): input_basename = input_basename[:-len('_best')] else: input_basename_parts = input_basename.split('_') assert_equal(input_basename_parts[-2], 'state') assert_equal(int(input_basename_parts[-1]), training_state.num_epochs_seen) input_basename = '_'.join(input_basename_parts[:-2]) if input_extension == '.yaml': # write modified yaml file to output_dir/ new_yaml_path = os.path.join(output_dir, input_filename) with open(new_yaml_path, 'w') as new_yaml_file: yaml.dump(yaml_dict, new_yaml_file) output_basepath = os.path.join(output_dir, input_basename) print("Compiling training function...") sgd = make_sgd(yaml_dict, model, loss_node, training_state, training_state.training_iter, training_state.validation_iter, output_basepath) print("Training...") sgd.train()
def main(): args = parse_args() def read_yaml(filepath): with open(filepath, 'r') as yaml_file: return yaml.load(yaml_file) yaml_dict = read_yaml(args.yaml) input_h5 = h5py.File(args.input_h5, mode='r') # output_yaml = read_yaml(args.output_yaml) output_h5 = h5py.File(args.output_h5, mode='w') numpy_seed = 1341 theano_seed = 2353 training_iter, validation_iter = get_iters(yaml_dict, numpy_seed) input_nodes = training_iter.make_input_nodes() image_node = input_nodes[0] # Build input model, load params from input_h5 input_model = IdAndCameraDirModel(image_node, yaml_dict, RandomState(numpy_seed), RandomStreams(theano_seed)) input_model.load_from_h5(input_h5['model']) # Build TrainingState around input_model, load its momentum etc from # input_h5 input_loss_node = IdAndCameraDirLoss(input_model.id_layers[-1], input_model.cam_dir_layers[-1], input_nodes, blank_id=6) input_state = IdAndCameraDirTrainingState(input_model, input_loss_node, yaml_dict, training_iter, validation_iter) # Need to disable class name check, since we've renamed the class from # TrainingState to IdAndCameraDirTrainingState at some point betw. sept 3 # and sept 12 input_state.load_from_h5(input_h5['training_state'], disable_class_name_check=True) # Build output (conv) model from the same yaml dict as input model, # copy params from input model. output_model = IdAndCameraDirModelConv(image_node, yaml_dict, RandomState(numpy_seed), RandomStreams(theano_seed)) copy_model_params(input_model, output_model) output_model.save_to_h5(output_h5.create_group('model')) # Build TrainingState around output model, check that model params have # been copied correctly by comparing against input TrainingState's params, # and also copy over the momenta etc from input TrainingState. output_loss_node = IdAndCameraDirConvLoss(output_model.id_layers[-1], output_model.cam_dir_layers[-1], input_nodes, blank_id=6) output_state = IdAndCameraDirTrainingState(output_model, output_loss_node, yaml_dict, training_iter, validation_iter) copy_state_except_params(input_state, output_state) output_state.save_to_h5(output_h5.create_group('training_state'))