コード例 #1
0
ファイル: evaluate.py プロジェクト: xiehuiyuan/ShapeWorld
    else:
        with open(args.hyperparams_file, 'r') as filehandle:
            parameters = json.load(fp=filehandle)

    # restore
    iteration_start = 1
    if args.report_file:
        with open(args.report_file, 'r') as filehandle:
            for line in filehandle:
                value = line.split(',')[0]
        if value != 'iteration':
            iteration_start = int(value) + 1

    with Model(name=args.model,
               learning_rate=parameters.pop('learning_rate', 1e-3),
               weight_decay=parameters.pop('weight_decay', None),
               clip_gradients=parameters.pop('clip_gradients', None),
               model_directory=args.model_dir) as model:
        parameters.pop('dropout_rate', None)

        module = import_module('models.{}.{}'.format(args.type, args.model))
        module.model(model=model,
                     inputs=dict(),
                     dataset_parameters=dataset_parameters,
                     **parameters
                     )  # no input tensors, hence None for placeholder creation
        model.finalize(restore=(args.model_dir is not None))

        if args.verbosity >= 1:
            sys.stdout.write('         parameters: {:,}\n'.format(
                model.num_parameters))
コード例 #2
0
            report_file_dir = os.path.dirname(args.report_file)
            if report_file_dir and not os.path.isdir(report_file_dir):
                os.makedirs(report_file_dir)
            with open(args.report_file, 'w') as filehandle:
                filehandle.write('iteration,saved')
                for name in query:
                    filehandle.write(',train ' + name)
                if not args.tf_records:
                    for name in query:
                        filehandle.write(',validation ' + name)
                filehandle.write('\n')
    iteration_end = iteration_start + args.iterations - 1

    with Model(name=args.model,
               learning_rate=parameters.pop('learning_rate'),
               weight_decay=parameters.pop('weight_decay', 0.0),
               model_directory=args.model_dir,
               summary_directory=args.summary_dir) as model:
        dropout = parameters.pop('dropout_rate', 0.0)

        module = import_module('models.{}.{}'.format(args.type, args.model))
        if args.tf_records:
            module.model(model=model,
                         inputs=tf_util.batch_records(
                             dataset=dataset,
                             batch_size=args.batch_size,
                             noise_range=args.pixel_noise),
                         **parameters)
        else:
            module.model(
                model=model, inputs=dict(), **parameters