def main():
    '''
    Entry point of this script.
    '''

    args = parse_args()
    input_path = os.path.abspath(args.input)
    assert_in(os.path.splitext(input_path)[1], ('.yaml', '.pkl'))

    input_dir, input_filename = os.path.split(input_path)
    input_basename, input_extension = os.path.splitext(input_filename)

    # dependency:
    # hyperparams->dataset->iter->input_nodes->model
    if os.path.splitext(input_path)[1] == '.yaml':
        print("Starting fresh training run with config {}.".format(input_path))

        with open(input_path, 'r') as yaml_file:
            yaml_dict = yaml.load(yaml_file)

        # Add library version info to yaml_dict
        if args.h5 is None:
            assert_not_in('versions', yaml_dict)
            yaml_dict['versions'] = {'simplelearn': simplelearn.__version__,
                                     'poselearn': poselearn.__version__,
                                     'theano': theano.__version__}
        else:
            def check_version(name, yaml_dict, module):
                '''
                Warn if resuming from a software version that's
                different from what was trained on.
                '''
                yaml_version = yaml_dict['versions'][name]
                current_version = module.__version__

                if yaml_version != current_version:
                    warnings.warn(
                        "resuming from a different version of {} than "
                        "what was trained on:\n"
                        "trained on {}\n"
                        "resuming on {}".format(
                            name,
                            yaml_dict['versions'][name],
                            module.__version__))

            check_version('simplelearn', yaml_dict, simplelearn)
            check_version('poselearn', yaml_dict, poselearn)
            check_version('theano', yaml_dict, theano)

        hyperparams_dict = yaml_dict['hyperparams']

        # Dataset & iterators
        training_set, validation_set = _get_datasets(yaml_dict)
        iterator_rng = numpy.random.RandomState(2352)

        training_iter = training_set.iterator(
            'random_vec3',
            batch_size=hyperparams_dict['training_batch_size'],
            rng=iterator_rng)
        validation_iter = validation_set.iterator(
            'random_vec3',
            batch_size=hyperparams_dict['validation_batch_size'],
            rng=iterator_rng)

        #
        # The DAG
        #

        input_nodes = training_iter.make_input_nodes()

        use_dropout = yaml_dict['hyperparams']['use_dropout']

        def make_model():
            numpy_rng = numpy.random.RandomState(2352)
            theano_rng = (RandomStreams(231134) if use_dropout
                          else None)
            return IdAndCameraDirModelConv(input_nodes[0],
                                           yaml_dict,
                                           numpy_rng,
                                           theano_rng)

        model = make_model()
        loss_node = IdAndCameraDirConvLoss(model.id_layers[-1],
                                           model.cam_dir_layers[-1],
                                           input_nodes,
                                           blank_id=6)

        # Iterator and updater state
        # sic; no IdAndCameraDirConvTrainingState.
        training_state = IdAndCameraDirTrainingState(model,
                                                     loss_node,
                                                     yaml_dict,
                                                     training_iter,
                                                     validation_iter)

        if args.h5 is not None:
            print("Resuming old training run from {}.".format(args.h5))
            with h5py.File(args.h5, mode='r') as h5_file:
                model.load_from_h5(h5_file['model'])
                training_state.load_from_h5(h5_file['training_state'])

            output_dir = os.path.split(args.h5)[0]
        else:
            assert_false(os.path.isdir(args.output_dir))
            output_dir = args.output_dir
            os.mkdir(output_dir)

    else:
        print("Resuming old training run from {}.".format(args.input))

        with open(args.input, 'r') as pkl_file:
            pkl_dict = cPickle.load(pkl_file)

        yaml_dict = pkl_dict['yaml_dict']
        model = pkl_dict['model']
        loss_node = pkl_dict['loss_node']
        training_state = pkl_dict['training_state']

        assert_is(args.output_dir, None)
        output_dir = input_dir

        # Expect .pkl file to end in _best.pkl or _state_#####.pkl, and
        # remove those parts from the input_basename.
        if input_basename.endswith('_best'):
            input_basename = input_basename[:-len('_best')]
        else:
            input_basename_parts = input_basename.split('_')
            assert_equal(input_basename_parts[-2], 'state')
            assert_equal(int(input_basename_parts[-1]),
                         training_state.num_epochs_seen)
            input_basename = '_'.join(input_basename_parts[:-2])

    if input_extension == '.yaml':
        # write modified yaml file to output_dir/
        new_yaml_path = os.path.join(output_dir, input_filename)
        with open(new_yaml_path, 'w') as new_yaml_file:
            yaml.dump(yaml_dict, new_yaml_file)

    output_basepath = os.path.join(output_dir, input_basename)

    print("Compiling training function...")
    sgd = make_sgd(yaml_dict,
                   model,
                   loss_node,
                   training_state,
                   training_state.training_iter,
                   training_state.validation_iter,
                   output_basepath)

    print("Training...")
    sgd.train()
def main():
    args = parse_args()

    def read_yaml(filepath):
        with open(filepath, 'r') as yaml_file:
            return yaml.load(yaml_file)

    yaml_dict = read_yaml(args.yaml)
    input_h5 = h5py.File(args.input_h5, mode='r')
    # output_yaml = read_yaml(args.output_yaml)
    output_h5 = h5py.File(args.output_h5, mode='w')

    numpy_seed = 1341
    theano_seed = 2353

    training_iter, validation_iter = get_iters(yaml_dict, numpy_seed)

    input_nodes = training_iter.make_input_nodes()
    image_node = input_nodes[0]

    # Build input model, load params from input_h5
    input_model = IdAndCameraDirModel(image_node,
                                      yaml_dict,
                                      RandomState(numpy_seed),
                                      RandomStreams(theano_seed))
    input_model.load_from_h5(input_h5['model'])


    # Build TrainingState around input_model, load its momentum etc from
    # input_h5
    input_loss_node = IdAndCameraDirLoss(input_model.id_layers[-1],
                                         input_model.cam_dir_layers[-1],
                                         input_nodes,
                                         blank_id=6)

    input_state = IdAndCameraDirTrainingState(input_model,
                                              input_loss_node,
                                              yaml_dict,
                                              training_iter,
                                              validation_iter)

    # Need to disable class name check, since we've renamed the class from
    # TrainingState to IdAndCameraDirTrainingState at some point betw. sept 3
    # and sept 12
    input_state.load_from_h5(input_h5['training_state'],
                             disable_class_name_check=True)

    # Build output (conv) model from the same yaml dict as input model,
    # copy params from input model.
    output_model = IdAndCameraDirModelConv(image_node,
                                           yaml_dict,
                                           RandomState(numpy_seed),
                                           RandomStreams(theano_seed))
    copy_model_params(input_model, output_model)
    output_model.save_to_h5(output_h5.create_group('model'))

    # Build TrainingState around output model, check that model params have
    # been copied correctly by comparing against input TrainingState's params,
    # and also copy over the momenta etc from input TrainingState.
    output_loss_node = IdAndCameraDirConvLoss(output_model.id_layers[-1],
                                              output_model.cam_dir_layers[-1],
                                              input_nodes,
                                              blank_id=6)

    output_state = IdAndCameraDirTrainingState(output_model,
                                               output_loss_node,
                                               yaml_dict,
                                               training_iter,
                                               validation_iter)
    copy_state_except_params(input_state, output_state)

    output_state.save_to_h5(output_h5.create_group('training_state'))