In [None]:
https://shorturl.at/nXuEx

Jasp, a dynamic Pythonic low-level IR
-------------------------------------

Within this notebook we demonstrate the latest feature of the Jax Integration.

We introduce a Jasp, a new IR that represents hybrid programs embedded into the Jaxpr IR.

Creating a Jasp program is simple:

In [27]:
from qrisp import *
from qrisp.jasp import *
from jax import make_jaxpr


def main(i):
    qf = QuantumFloat(i)
    h(qf[0])
    cx(qf[0], qf[1])

    meas_float = measure(qf)

    return meas_float
    

jaspr = make_jaspr(main)(5)

print(jaspr)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:QuantumCircuit[39m b[35m:i64[][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:QuantumCircuit[39m d[35m:QubitArray[39m = jasp.create_qubits a b
    e[35m:Qubit[39m = jasp.get_qubit d 0
    f[35m:QuantumCircuit[39m = jasp.h c e
    g[35m:Qubit[39m = jasp.get_qubit d 1
    h[35m:QuantumCircuit[39m = jasp.cx f e g
    i[35m:QuantumCircuit[39m j[35m:i64[][39m = jasp.measure h d
    k[35m:QuantumCircuit[39m = jasp.reset i d
    l[35m:QuantumCircuit[39m = jasp.delete_qubits k d
  [34m[22m[1min [39m[22m[22m(l, j) }


Jasp programs can be executed with the Jasp interpreter by calling them like a function:

In [36]:
print(jaspr(5))

0                                                                                    [2K


A quicker way to do this is to use the ``jaspify`` decorator. This decorator automatically transforms the function into a Jaspr and calls the simulator

In [37]:
@jaspify
def main(i):
    qf = QuantumFloat(i)
    h(qf[0])
    cx(qf[0], qf[1])

    meas_float = measure(qf)

    return meas_float

print(main(5))

3                                                                                    [2K


Jasp programs can be compiled to QIR, which is one of the most popular low-level representations for quantum computers. For that you need Catalyst installed (only on Mac & Linux).

In [None]:
try:
    import catalyst
except:
    !pip install pennylane-catalyst

In [38]:
qir_string = jaspr.to_qir()
print(qir_string[:2500])

; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

@"{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}" = internal constant [66 x i8] c"{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}\00"
@LightningSimulator = internal constant [19 x i8] c"LightningSimulator\00"
@"/home/positr0nium/miniforge3/envs/qrisp/lib/python3.10/site-packages/pennylane_lightning/liblightning_qubit_catalyst.so" = internal constant [120 x i8] c"/home/positr0nium/miniforge3/envs/qrisp/lib/python3.10/site-packages/pennylane_lightning/liblightning_qubit_catalyst.so\00"
@__constant_1024xi64 = private constant [1024 x i64] zeroinitializer
@__constant_30xi64 = private constant [30 x i64] [i64 30, i64 29, i64 28, i64 27, i64 26, i64 25, i64 24, i64 23, i64 22, i64 21, i64 20, i64 19, i64 18, i64 17, i64 16, i64 15, i64 14, i6

The Qache decorator
-------------------

One of the most powerful features of this IR is that it is fully dynamic, allowing many functions to be cached and reused. For this we have the ``qache`` decorator. Qached functions are only excutes ones (per calling signature) and otherwise retrieved from cache.

In [39]:
import time

@qache
def inner_function(qv, i):
    cx(qv[0], qv[1])
    h(qv[i])
    # Complicated compilation, that takes a lot of time
    time.sleep(1)

def main(i):
    qv = QuantumFloat(i)

    inner_function(qv, 0)
    inner_function(qv, 1)
    inner_function(qv, 2)

    return measure(qv)


t0 = time.time()
jaspr = make_jaspr(main)(5)
print(time.time()- t0)

1.0131170749664307


If a cached function is called with a different type (classical or quantum) the function will not be retrieved from cache but instead retraced. If called with the same signature, the appropriate implementation will be retrieved from the cache.

In [40]:
@qache
def inner_function(qv):
    x(qv)
    time.sleep(1)

def main():
    qf = QuantumFloat(5)
    qbl = QuantumBool()

    inner_function(qf)
    inner_function(qf)
    inner_function(qbl)
    inner_function(qbl)

    return measure(qf)

t0 = time.time()
jaspr = make_jaspr(main)()
print(time.time()- t0)

2.0223705768585205


We see 2 seconds now because the ``inner_function`` has been traced twice: Once for the ``QuantumFloat`` and once for the ``QuantumBool``.

Another important concept are dynamic values. Dynamic values are values that are only known at runtime (i.e. when the program is actually executed). This could be because the value is coming from a quantum measurement. Every QuantumVariable and it's ``.size`` attribute are dynamic. Furthermore classical values can also be dynamic. For classical values, we can use the Python native ``isinstance`` check for the ``jax.core.Tracer`` class, whether a variable is dynamic. Note that even though ``QuantumVariables`` behave dynamic, they are not tracers themselves.

In [41]:
from jax.core import Tracer

def main(i):
    print("i is dynamic?: ", isinstance(i, Tracer))
    
    qf = QuantumFloat(5)
    j = qf.size
    print("j is dynamic?: ", isinstance(i, Tracer))
    
    h(qf)
    k = measure(qf)
    print("k is dynamic?: ", isinstance(k, Tracer))

    # Regular Python integers are not dynamic
    l = 5
    print("l is dynamic?: ", isinstance(l, Tracer))

    # Arbitrary Python objects can be used within Jasp
    # but they are not dynamic
    import networkx as nx
    G = nx.DiGraph()
    G.add_edge(1,2)
    print("G is dynamic?: ", isinstance(l, Tracer))
    
    return k

jaspr = make_jaspr(main)(5)


i is dynamic?:  True
j is dynamic?:  True
k is dynamic?:  True
l is dynamic?:  False
G is dynamic?:  False


What is the advantage of dynamic values? Dynamical code is scale invariant! For this we can use the ``jrange`` iterator, which allows you to execute a dynamic amount of loop iterations. Some restrictions apply however (check the docs to see which).

In [42]:
@jaspify
def main(k):

    a = QuantumFloat(k)
    b = QuantumFloat(k)

    # Brings a into uniform superposition via Hadamard
    h(a)

    c = measure(a)

    # Excutes c iterations (i.e. depending the measurement outcome)
    for i in jrange(c):

        # Performs a quantum incrementation on b based on the measurement outcome
        b += c//5

    return measure(b)

print(main(5))

28                                                                                   [2K


It is possible to execute a multi-controlled X gate with a dynamic amount of controls.

In [43]:
@jaspify
def main(i, j, k):

    a = QuantumFloat(5)
    a[:] = i
    
    qbl = QuantumBool()

    # a[:j] is a dynamic amount of controls
    mcx(a[:j], qbl[0], ctrl_state = k)

    return measure(qbl)

This function encodes the integer ``i`` into a ``QuantumFloat`` and subsequently performs an MCX gate with control state ``k``. Therefore, we expect the function to return ``True`` if ``i == k`` and ``j > 5``.

In [44]:
print(main(1, 6, 1))
print(main(3, 6, 1))
print(main(2, 1, 1))

True                                                                                 [2K
False                                                                                [2K
False                                                                                [2K


Classical control flow
----------------------

Jasp code can be conditioned on classically known values. For that we simply use the ``control`` feature from base-Qrisp but with dynamical, classical bools. Some restrictions apply (check the docs for more details).

In [49]:
@jaspify
def main():

    qf = QuantumFloat(3)
    h(qf)

    # This is a classical, dynamical int
    meas_res = measure(qf)

    # This is a classical, dynamical bool
    ctrl_bl = meas_res >= 4
    
    with control(ctrl_bl):
        qf -= 4

    return measure(qf)

for i in range(5):
    print(main())

1                                                                                    [2K
1                                                                                    [2K
3                                                                                    [2K
0                                                                                    [2K
2                                                                                    [2K


The RUS decorator
-----------------

RUS stands for Repeat-Until-Success and is an essential part for many quantum algorithms such as HHL or LCU. As the name says the RUS component repeats a certain subroutine until a measurement yields ``True``. The RUS decorator should be applied to a ``trial_function``, which returns a classical bool as the first return value and some arbitrary other values. The trial function will be repeated until the classical bool is ``True``.

To demonstrate the RUS behavior, we initialize a GHZ state 

$\ket{\psi} = \frac{1}{\sqrt{2}} (\ket{00000} + \ket{11111})$

and measure the first qubit into a boolean value. This will be the value to cancel the repetition. This will collapse the GHZ state into either $\ket{00000}$ (which will cause a new repetition) or $\ket{11111} = \ket{31}$, which cancels the loop. After the repetition is canceled we are therefore guaranteed to have the latter state.


In [51]:
from qrisp.jasp import RUS, make_jaspr
from qrisp import QuantumFloat, h, cx, measure

def init_GHZ(qf):
    h(qf[0])
    for i in jrange(1, qf.size):
        cx(qf[0], qf[i])

@RUS
def rus_trial_function():
    qf = QuantumFloat(5)

    init_GHZ(qf)
    
    cancelation_bool = measure(qf[0])
    
    return cancelation_bool, qf

@jaspify
def main():

    qf = rus_trial_function()

    return measure(qf)

print(main())

31.0                                                                                 [2K


Terminal sampling
-----------------

The ``jaspify`` decorator executes one "shot". For many quantum algorithms we however need the distribution of shots. In principle we could execute a bunch of "jaspified" function calls, which is however not as scalable. For this situation we have the ``terminal_sampling`` decorator. To use this decorator we need a function that returns a ``QuantumVariable`` (instead of a classical measurement result). The decorator will then perform a (hybrid) simulation of the given script and subsequently sample from the distribution at the end.

In [56]:

@RUS
def rus_trial_function():
    qf = QuantumFloat(5)

    init_GHZ(qf)
    
    cancelation_bool = measure(qf[0])
    
    return cancelation_bool, qf

@terminal_sampling(shots = 50)
def main():

    qf = rus_trial_function()
    h(qf[0])
    qf_2 = QuantumFloat(5)
    qf_2[:] = 5

    return qf, qf_2

print(main())

{(31.0, 5.0): 31, (30.0, 5.0): 19}                                                   [2K


The ``terminal_sampling`` decorator requires some care however. Remember that it only samples from the distribution at the end of the algorithm. This distribution can depend on random chances that happened during the execution. We demonstrate faulty use in the following example.

In [57]:
from qrisp import QuantumBool, measure, control

@terminal_sampling
def main():

    qbl = QuantumBool()
    qf = QuantumFloat(4)

    # Bring qbl into superposition
    h(qbl)

    # Perform a measure
    cl_bl = measure(qbl)

    # Perform a conditional operation based on the measurement outcome
    with control(cl_bl):
        qf[:] = 1
        h(qf[2])

    return qf

for i in range(5):
    print(main())
# Yields either {0.0: 1.0} or {1.0: 0.5, 5.0: 0.5} (with a 50/50 probability)

{1.0: 0.5, 5.0: 0.5}                                                                 [2K
{1.0: 0.5, 5.0: 0.5}                                                                 [2K
{1.0: 0.5, 5.0: 0.5}                                                                 [2K
{0.0: 1.0}                                                                           [2K
{0.0: 1.0}                                                                           [2K


Boolean simulation
------------------

The tight Jax integration of Jasp enables some powerful features such as a highly performant simulator of purely boolean circuits. This simulator works by transforming Jaspr objects that contain only X, CX, MCX etc. into boolean Jax logic. Subsequently this is inserted into the Jax pipeline, which yields a highly scalable simulator for purely classical Jasp functions.

To call this simulator, we simply use the ``boolean_simulation`` decorator like we did with the ``jaspify`` decorator.

In [67]:
from qrisp import *
from qrisp.jasp import *

def quantum_mult(a, b):
    return a*b

@boolean_simulation(bit_array_padding = 2**10)
def main(i, j, iterations):

    a = QuantumFloat(10)
    b = QuantumFloat(10)

    a[:] = i
    b[:] = j

    c = QuantumFloat(30)

    for i in jrange(iterations): 

        # Compute the quantum product
        temp = quantum_mult(a,b)

        # add into c
        c += temp

        # Uncompute the quantum product
        with invert():
            # The << operator "injects" the quantum variable into
            # the function. This means that the quantum_mult
            # function, which was originally out-of-place, is
            # now an in-place function operating on temp.

            # It can therefore be used for uncomputation
            # Automatic uncomputation is not yet available within Jasp.
            (temp << quantum_mult)(a, b)

        # Delete temp
        temp.delete()

    return measure(c)


The first call needs some time for compilation

In [68]:
import time
t0 = time.time()
main(1, 2, 5)
print(time.time()-t0)

7.430185317993164


Any subsequent call is super fast

In [60]:
t0 = time.time()
print(main(3, 4, 120)) # Expected to be 3*4*120 = 1440
print(f"Took {time.time()-t0} to simulate 120 iterations")

1440.0
Took 0.006090641021728516 to simulate 120 iterations


Compile and simulate A MILLION QFLOPs!

In [61]:
print(main(532, 233, 1000000))

475690240.0


Letting a classical, neural network decide when to stop
-------------------------------------------------------

The following example showcases how a simple neural network can decide (in real-time) whether to go on or break the RUS iteration. For that we create a simple binary classifier and train it on dummy data (disclaimer: ML code by ChatGPT). This is code is not really useful in anyway and the classifier is classifying random data, but it shows how such an algorithm can be constructed and evaluated.

In [62]:
import jax
import jax.numpy as jnp
from jax import grad, jit
import optax

# Define the model
def model(params, x):
    W, b = params
    return jax.nn.sigmoid(jnp.dot(x, W) + b)

# Define the loss function (binary cross-entropy)
def loss_fn(params, x, y):
    preds = model(params, x)
    return -jnp.mean(y * jnp.log(preds) + (1 - y) * jnp.log(1 - preds))

# Initialize parameters
key = jax.random.PRNGKey(0)
W = jax.random.normal(key, (2, 1))
b = jax.random.normal(key, (1,))
params = (W, b)

# Create optimizer
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(params)

# Define training step
@jit
def train_step(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Generate some dummy data
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (1000, 2))
y = jnp.sum(X > 0, axis=1) % 2

# Training loop
for epoch in range(100):
    params, opt_state, loss = train_step(params, opt_state, X, y)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss}")

# Make predictions
predictions = model(params, X)
accuracy = jnp.mean((predictions > 0.5) == y)
print(f"Final accuracy: {accuracy}")


Epoch 0, Loss: 1.1255793726499572
Epoch 10, Loss: 1.075287382286886
Epoch 20, Loss: 1.0277242824464026
Epoch 30, Loss: 0.9834605224633858
Epoch 40, Loss: 0.9429307856413472
Epoch 50, Loss: 0.9063988210301136
Epoch 60, Loss: 0.8739494586871209
Epoch 70, Loss: 0.845496268788587
Epoch 80, Loss: 0.8208042884256824
Epoch 90, Loss: 0.7995303839215936
Final accuracy: 0.49931600689888


We can now use the ``model`` function to evaluate the classifier. Since this function is Jax-based it integrates seamlessly into Jasp.

In [63]:
from qrisp.jasp import *
from qrisp import *
   
@RUS
def rus_trial_function(params):

    # Sample data from two QuantumFloats.
    # This is a placeholder for an arbitrary quantum algorithm.
    qf_0 = QuantumFloat(5)
    h(qf_0)

    qf_1 = QuantumFloat(5)
    h(qf_1)

    meas_res_0 = measure(qf_0)
    meas_res_1 = measure(qf_1)

    # Turn the data into a Jax array
    X = jnp.array([meas_res_0,meas_res_1])/2**qf_0.size

    # Evaluate the model
    model_res = model(params, X)

    # Determine the cancelation
    cancelation_bool = (model_res > 0.5)[0]
    
    return cancelation_bool, qf_0

@jaspify
def main(params):

    qf = rus_trial_function(params)
    h(qf[0])

    return measure(qf)

print(main(params))

6.0                                                                                  [2K
