예제 #1
0
    def gen_additional_blobs(
        predict_net: core.Net,
        probability_out,
        model_out: torch.Tensor,
        output_name: str,
        label_names: List[str],
    ) -> List[core.BlobReference]:
        """
        Utility method to generate additional blobs for human readable result for
        models that use explicit labels.
        """
        res = []
        tmp_out_score = predict_net.Log(probability_out)
        label_scores = predict_net.Split(tmp_out_score,
                                         label_names,
                                         axis=model_out.dim() - 1)

        # Make sure label_scores is iterable
        if not isinstance(label_scores, tuple):
            label_scores = (label_scores, )
        for name, label_score in zip(label_names, label_scores):
            res.append(
                predict_net.Copy(label_score,
                                 "{}:{}".format(output_name, name)))
        return res
예제 #2
0
def _prepare_gradient_while_ops(
        fwd_op, input_names, output_names, loop_grad_net, workspace_blob,
        init_grad_map, loop_grad_map):
    gradient_while_def = caffe2_pb2.OperatorDef()
    gradient_while_def.CopyFrom(fwd_op)
    if gradient_while_def.name:
        gradient_while_def.name += "_grad"

    loop_net_arg = caffe2_pb2.Argument()
    loop_net_arg.name = "loop_net"
    loop_net_arg.n.CopyFrom(loop_grad_net)

    cond_net_arg = caffe2_pb2.Argument()
    cond_net_arg.name = "cond_net"
    from caffe2.python.core import Net, BlobReference
    # Construct condition net - check that there're still forward workspaces
    # left using HasScope op
    cond_net = Net('gradient_loop_cond_net')
    cond_init_net = Net('gradient_loop_cond_net_init')
    cond_blob = cond_net.NextScopedBlob(cond_net.Name() + '/cond')
    cond_init_net.HasScope(workspace_blob, cond_blob)
    cond_net.HasScope(workspace_blob, cond_blob)
    for blob, init_grad_blob in init_grad_map.items():
        blob_name = str(blob)
        init_grad_blob_name = str(init_grad_blob)
        if blob_name in loop_grad_map and \
                loop_grad_map[blob_name] != init_grad_blob_name:
            cond_net.Copy(
                BlobReference(loop_grad_map[blob_name]), init_grad_blob)
            cond_init_net.Copy(
                init_grad_blob, BlobReference(loop_grad_map[blob_name]))
    cond_net_arg.n.CopyFrom(cond_net.Proto())

    del gradient_while_def.arg[:]
    gradient_while_def.arg.extend([loop_net_arg, cond_net_arg])

    del gradient_while_def.control_input[:]
    del gradient_while_def.input[:]
    gradient_while_def.input.extend(
        [str(cond_blob).encode('utf-8')] + list(input_names))
    del gradient_while_def.output[:]
    gradient_while_def.output.extend(output_names)
    gradient_while_def.is_gradient_op = True
    return [o for o in cond_init_net.Proto().op] + [gradient_while_def]