예제 #1
0
    # Set optimizer.
    optimizer_conf = dict(param_interface['training']['optimizer'])
    optimizer_name = optimizer_conf['name']
    del optimizer_conf['name']
    # Select for optimization only those parameters that require update!
    optimizer = getattr(torch.optim,
                        optimizer_name)(filter(lambda p: p.requires_grad,
                                               model.parameters()),
                                        **optimizer_conf)

    # Ok, finished loading the configuration.
    # Save the resulting configuration into a yaml settings file, under log_dir
    with open(log_dir + "training_configuration.yaml",
              'w') as yaml_backup_file:
        yaml.dump(param_interface.to_dict(),
                  yaml_backup_file,
                  default_flow_style=False)

    # Log the training configuration.
    conf_str = '\n' + '=' * 80 + '\n'
    conf_str += 'Final registry configuration for training {} on {}:\n'.format(
        model_name, task_name)
    conf_str += '=' * 80 + '\n'
    conf_str += yaml.safe_dump(param_interface.to_dict(),
                               default_flow_style=False)
    conf_str += '=' * 80 + '\n'
    logger.info(conf_str)

    # Ask for confirmation - optional.
    if FLAGS.confirm:
예제 #2
0
    problem = ProblemFactory.build_problem(
        param_interface['testing']['problem'])

    # Create statistics collector.
    stat_col = StatisticsCollector()
    # Add model/problem dependent statistics.
    problem.add_statistics(stat_col)
    model.add_statistics(stat_col)

    # Create test output csv file.
    test_file = stat_col.initialize_csv_file(log_dir, 'testing.csv')

    # Ok, finished loading the configuration.
    # Save the resulting configuration into a yaml settings file, under log_dir
    with open(log_dir + "testing_configuration.yaml", 'w') as yaml_backup_file:
        yaml.dump(param_interface.to_dict(),
                  yaml_backup_file, default_flow_style=False)

    # Run test
    with torch.no_grad():
        for episode, (data_tuple, aux_tuple) in enumerate(
                problem.return_generator()):

            if episode == param_interface["testing"]["problem"][
                    "max_test_episodes"]:
                break

            logits, loss = forward_step(
                model, problem, episode, stat_col, data_tuple, aux_tuple)

            # Log to logger.