Ejemplo n.º 1
0
def transformer_factory(request):
    def set_and_get_factory(transformer_name):
        factory = ngt.make_transformer_factory(transformer_name)
        ngt.set_transformer_factory(factory)
        return factory

    name = request.config.getoption("--transformer")

    yield set_and_get_factory(name)

    # Reset transformer factory to default
    ngt.set_transformer_factory(ngt.make_transformer_factory("ngcpu"))
Ejemplo n.º 2
0
        self.ngraph_backend = "INTERPRETER"
        super(PybindINTERPRETERTransformer, self).__init__(**kwargs)


class PybindGPUTransformer(PybindTransformer):
    """
    Transformer for ngraph c++ with gpu backend.

    """
    transformer_name = "nggpu"

    def __init__(self, **kwargs):
        self.ngraph_backend = "GPU"
        super(PybindGPUTransformer, self).__init__(**kwargs)


class PybindARGONTransformer(PybindTransformer):
    """
    Transformer for ngraph c++ with argon backend.

    """
    transformer_name = "ngargon"

    def __init__(self, **kwargs):
        self.ngraph_backend = "ARGON"
        super(PybindARGONTransformer, self).__init__(**kwargs)


set_transformer_factory(
    make_transformer_factory(PybindCPUTransformer.transformer_name))
Ejemplo n.º 3
0
        restore_eval_function = transformer.add_computation(eval_computation)
        weight_saver.setup_restore(transformer=transformer,
                                   computation=eval_computation,
                                   filename=args.inference)
        # Restore weight
        weight_saver.restore()
        # Calculate losses
        eval_losses = loop_eval(valid_set, input_ph, metric_names,
                                restore_eval_function, en_top5)
        # Print statistics
        print("From restored weights: Test Avg loss:{tcost}".format(
            tcost=eval_losses))
        exit()

# Training the network by calling transformer
with closing(ngt.make_transformer_factory(args.backend)()) as transformer:
    # Trainer
    train_function = transformer.add_computation(train_computation)
    # Inference
    eval_function = transformer.add_computation(eval_computation)
    # Set Saver for saving weights
    weight_saver.setup_save(transformer=transformer,
                            computation=train_computation)
    # Resume weights for training from a checkpoint
    if args.resume is not None:
        weight_saver.setup_restore(transformer=transformer,
                                   computation=train_computation,
                                   filename=args.resume)
        weight_saver.restore()
    # Progress bar
    tpbar = tqdm(unit="batches", ncols=100, total=args.num_iterations)
Ejemplo n.º 4
0
 def make_and_set_transformer_factory(self, args):
     factory = ngt.make_transformer_factory(args.backend)
     ngt.set_transformer_factory(factory)
Ejemplo n.º 5
0
 def set_and_get_factory(transformer_name):
     factory = ngt.make_transformer_factory(transformer_name)
     ngt.set_transformer_factory(factory)
     return factory