コード例 #1
0
def get_iters(yaml_dict, seed):
    assert_is_instance(yaml_dict, dict)
    assert_integer(seed)

    training_set, validation_set = _get_datasets(yaml_dict)

    hyperparams_dict = yaml_dict['hyperparams']

    training_iter = training_set.iterator(
        'random_vec3',
        hyperparams_dict['training_batch_size'],
        rng=RandomState(seed))

    validation_iter = validation_set.iterator(
        'random_vec3',
        hyperparams_dict['validation_batch_size'],
        rng=RandomState(seed))

    return training_iter, validation_iter
コード例 #2
0
def main():
    '''
    Entry point of this script.
    '''

    args = parse_args()
    input_path = os.path.abspath(args.input)
    assert_in(os.path.splitext(input_path)[1], ('.yaml', '.pkl'))

    input_dir, input_filename = os.path.split(input_path)
    input_basename, input_extension = os.path.splitext(input_filename)

    # dependency:
    # hyperparams->dataset->iter->input_nodes->model
    if os.path.splitext(input_path)[1] == '.yaml':
        print("Starting fresh training run with config {}.".format(input_path))

        with open(input_path, 'r') as yaml_file:
            yaml_dict = yaml.load(yaml_file)

        # Add library version info to yaml_dict
        if args.h5 is None:
            assert_not_in('versions', yaml_dict)
            yaml_dict['versions'] = {'simplelearn': simplelearn.__version__,
                                     'poselearn': poselearn.__version__,
                                     'theano': theano.__version__}
        else:
            def check_version(name, yaml_dict, module):
                '''
                Warn if resuming from a software version that's
                different from what was trained on.
                '''
                yaml_version = yaml_dict['versions'][name]
                current_version = module.__version__

                if yaml_version != current_version:
                    warnings.warn(
                        "resuming from a different version of {} than "
                        "what was trained on:\n"
                        "trained on {}\n"
                        "resuming on {}".format(
                            name,
                            yaml_dict['versions'][name],
                            module.__version__))

            check_version('simplelearn', yaml_dict, simplelearn)
            check_version('poselearn', yaml_dict, poselearn)
            check_version('theano', yaml_dict, theano)

        hyperparams_dict = yaml_dict['hyperparams']

        # Dataset & iterators
        training_set, validation_set = _get_datasets(yaml_dict)
        iterator_rng = numpy.random.RandomState(2352)

        training_iter = training_set.iterator(
            'random_vec3',
            batch_size=hyperparams_dict['training_batch_size'],
            rng=iterator_rng)
        validation_iter = validation_set.iterator(
            'random_vec3',
            batch_size=hyperparams_dict['validation_batch_size'],
            rng=iterator_rng)

        #
        # The DAG
        #

        input_nodes = training_iter.make_input_nodes()

        use_dropout = yaml_dict['hyperparams']['use_dropout']

        def make_model():
            numpy_rng = numpy.random.RandomState(2352)
            theano_rng = (RandomStreams(231134) if use_dropout
                          else None)
            return IdAndCameraDirModelConv(input_nodes[0],
                                           yaml_dict,
                                           numpy_rng,
                                           theano_rng)

        model = make_model()
        loss_node = IdAndCameraDirConvLoss(model.id_layers[-1],
                                           model.cam_dir_layers[-1],
                                           input_nodes,
                                           blank_id=6)

        # Iterator and updater state
        # sic; no IdAndCameraDirConvTrainingState.
        training_state = IdAndCameraDirTrainingState(model,
                                                     loss_node,
                                                     yaml_dict,
                                                     training_iter,
                                                     validation_iter)

        if args.h5 is not None:
            print("Resuming old training run from {}.".format(args.h5))
            with h5py.File(args.h5, mode='r') as h5_file:
                model.load_from_h5(h5_file['model'])
                training_state.load_from_h5(h5_file['training_state'])

            output_dir = os.path.split(args.h5)[0]
        else:
            assert_false(os.path.isdir(args.output_dir))
            output_dir = args.output_dir
            os.mkdir(output_dir)

    else:
        print("Resuming old training run from {}.".format(args.input))

        with open(args.input, 'r') as pkl_file:
            pkl_dict = cPickle.load(pkl_file)

        yaml_dict = pkl_dict['yaml_dict']
        model = pkl_dict['model']
        loss_node = pkl_dict['loss_node']
        training_state = pkl_dict['training_state']

        assert_is(args.output_dir, None)
        output_dir = input_dir

        # Expect .pkl file to end in _best.pkl or _state_#####.pkl, and
        # remove those parts from the input_basename.
        if input_basename.endswith('_best'):
            input_basename = input_basename[:-len('_best')]
        else:
            input_basename_parts = input_basename.split('_')
            assert_equal(input_basename_parts[-2], 'state')
            assert_equal(int(input_basename_parts[-1]),
                         training_state.num_epochs_seen)
            input_basename = '_'.join(input_basename_parts[:-2])

    if input_extension == '.yaml':
        # write modified yaml file to output_dir/
        new_yaml_path = os.path.join(output_dir, input_filename)
        with open(new_yaml_path, 'w') as new_yaml_file:
            yaml.dump(yaml_dict, new_yaml_file)

    output_basepath = os.path.join(output_dir, input_basename)

    print("Compiling training function...")
    sgd = make_sgd(yaml_dict,
                   model,
                   loss_node,
                   training_state,
                   training_state.training_iter,
                   training_state.validation_iter,
                   output_basepath)

    print("Training...")
    sgd.train()