# Copyright 2018 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pushes 180 degree rotations around axes in the XY plane later in the circuit.
"""
from typing import Optional, cast, TYPE_CHECKING, Iterable, Tuple
from cirq import circuits, ops, value, protocols
from cirq.optimizers import decompositions
if TYPE_CHECKING:
# pylint: disable=unused-import
from typing import Dict, List
class _OptimizerState:
def __init__(self):
# The phases of the W gates currently being pushed along each qubit.
self.held_w_phases = {} # type: Dict[ops.QubitId, Optional[float]]
# Accumulated commands to batch-apply to the circuit later.
self.deletions = [] # type: List[Tuple[int, ops.Operation]]
self.inline_intos = [] # type: List[Tuple[int, ops.Operation]]
self.insertions = [] # type: List[Tuple[int, ops.Operation]]
[docs]class EjectPhasedPaulis(circuits.OptimizationPass):
"""Pushes X, Y, and PhasedX gates towards the end of the circuit.
As the gates get pushed, they may absorb Z gates, cancel against other
X, Y, or PhasedX gates with exponent=1, get merged into measurements (as
output bit flips), and cause phase kickback operations across CZs (which can
then be removed by the EjectZ optimization).
"""
[docs] def __init__(self, tolerance: float = 1e-8) -> None:
"""
Args:
tolerance: Maximum absolute error tolerance. The optimization is
permitted to simply drop negligible combinations of Z gates,
with a threshold determined by this tolerance.
"""
self.tolerance = tolerance
[docs] def optimize_circuit(self, circuit: circuits.Circuit):
state = _OptimizerState()
for moment_index, moment in enumerate(circuit):
for op in moment.operations:
affected = [q for q in op.qubits
if state.held_w_phases.get(q) is not None]
# Collect, phase, and merge Ws.
w = _try_get_known_phased_pauli(op)
if w is not None:
if decompositions.is_negligible_turn(
w[0] - 1,
self.tolerance):
_potential_cross_whole_w(moment_index,
op,
self.tolerance,
state)
else:
_potential_cross_partial_w(moment_index, op, state)
continue
if not affected:
continue
# Absorb Z rotations.
t = _try_get_known_z_half_turns(op)
if t is not None:
_absorb_z_into_w(moment_index, op, state)
continue
# Dump coherent flips into measurement bit flips.
if ops.MeasurementGate.is_measurement(op):
_dump_into_measurement(moment_index, op, state)
# Cross CZs using kickback.
if _try_get_known_cz_half_turns(op) is not None:
if len(affected) == 1:
_single_cross_over_cz(moment_index,
op,
affected[0],
state)
else:
_double_cross_over_cz(op, state)
continue
# Don't know how to handle this situation. Dump the gates.
_dump_held(op.qubits, moment_index, state)
# Put anything that's still held at the end of the circuit.
_dump_held(state.held_w_phases.keys(), len(circuit), state)
circuit.batch_remove(state.deletions)
circuit.batch_insert_into(state.inline_intos)
circuit.batch_insert(state.insertions)
def _absorb_z_into_w(moment_index: int,
op: ops.Operation,
state: _OptimizerState) -> None:
"""Absorbs a Z^t gate into a W(a) flip.
[Where W(a) is shorthand for PhasedX(phase_exponent=a).]
Uses the following identity:
───W(a)───Z^t───
≡ ───W(a)───────────Z^t/2──────────Z^t/2─── (split Z)
≡ ───W(a)───W(a)───Z^-t/2───W(a)───Z^t/2─── (flip Z)
≡ ───W(a)───W(a)──────────W(a+t/2)───────── (phase W)
≡ ────────────────────────W(a+t/2)───────── (cancel Ws)
≡ ───W(a+t/2)───
"""
t = cast(float, _try_get_known_z_half_turns(op))
q = op.qubits[0]
state.held_w_phases[q] = cast(float, state.held_w_phases[q]) + t / 2
state.deletions.append((moment_index, op))
def _dump_held(qubits: Iterable[ops.QubitId],
moment_index: int,
state: _OptimizerState):
# Note: sorting is to avoid non-determinism in the insertion order.
for q in sorted(qubits):
p = state.held_w_phases.get(q)
if p is not None:
dump_op = ops.PhasedXPowGate(phase_exponent=p).on(q)
state.insertions.append((moment_index, dump_op))
state.held_w_phases[q] = None
def _dump_into_measurement(moment_index: int,
op: ops.Operation,
state: _OptimizerState) -> None:
measurement = cast(ops.MeasurementGate, cast(ops.GateOperation, op).gate)
new_measurement = measurement.with_bits_flipped(
*[i
for i, q in enumerate(op.qubits)
if state.held_w_phases.get(q) is not None]
).on(*op.qubits)
for q in op.qubits:
state.held_w_phases[q] = None
state.deletions.append((moment_index, op))
state.inline_intos.append((moment_index, new_measurement))
def _potential_cross_whole_w(moment_index: int,
op: ops.Operation,
tolerance: float,
state: _OptimizerState) -> None:
"""Grabs or cancels a held W gate against an existing W gate.
[Where W(a) is shorthand for PhasedX(phase_exponent=a).]
Uses the following identity:
───W(a)───W(b)───
≡ ───Z^-a───X───Z^a───Z^-b───X───Z^b───
≡ ───Z^-a───Z^-a───Z^b───X───X───Z^b───
≡ ───Z^-a───Z^-a───Z^b───Z^b───
≡ ───Z^2(b-a)───
"""
state.deletions.append((moment_index, op))
_, phase_exponent = cast(Tuple[float, float],
_try_get_known_phased_pauli(op))
q = op.qubits[0]
a = state.held_w_phases.get(q)
b = phase_exponent
if a is None:
# Collect the gate.
state.held_w_phases[q] = b
else:
# Cancel the gate.
state.held_w_phases[q] = None
t = 2*(b - a)
if not decompositions.is_negligible_turn(t / 2, tolerance):
leftover_phase = ops.Z(q)**t
state.inline_intos.append((moment_index, leftover_phase))
def _potential_cross_partial_w(moment_index: int,
op: ops.Operation,
state: _OptimizerState) -> None:
"""Cross the held W over a partial W gate.
[Where W(a) is shorthand for PhasedX(phase_exponent=a).]
Uses the following identity:
───W(a)───W(b)^t───
≡ ───Z^-a───X───Z^a───W(b)^t────── (expand W(a))
≡ ───Z^-a───X───W(b-a)^t───Z^a──── (move Z^a across, phasing axis)
≡ ───Z^-a───W(a-b)^t───X───Z^a──── (move X across, negating axis angle)
≡ ───W(2a-b)^t───Z^-a───X───Z^a─── (move Z^-a across, phasing axis)
≡ ───W(2a-b)^t───W(a)───
"""
a = state.held_w_phases.get(op.qubits[0])
if a is None:
return
exponent, phase_exponent = cast(Tuple[float, float],
_try_get_known_phased_pauli(op))
new_op = ops.PhasedXPowGate(
exponent=exponent,
phase_exponent=2 * a - phase_exponent).on(op.qubits[0])
state.deletions.append((moment_index, op))
state.inline_intos.append((moment_index, new_op))
def _single_cross_over_cz(moment_index: int,
op: ops.Operation,
qubit_with_w: ops.QubitId,
state: _OptimizerState) -> None:
"""Crosses exactly one W flip over a partial CZ.
[Where W(a) is shorthand for PhasedX(phase_exponent=a).]
Uses the following identity:
──────────@─────
│
───W(a)───@^t───
≡ ───@──────O──────@────────────────────
| | │ (split into on/off cases)
───W(a)───W(a)───@^t──────────────────
≡ ───@─────────────@─────────────O──────
| │ | (off doesn't interact with on)
───W(a)──────────@^t───────────W(a)───
≡ ───────────Z^t───@──────@──────O──────
│ | | (crossing causes kickback)
─────────────────@^-t───W(a)───W(a)─── (X Z^t X Z^-t = exp(pi t) I)
≡ ───────────Z^t───@────────────────────
│ (merge on/off cases)
─────────────────@^-t───W(a)──────────
≡ ───Z^t───@──────────────
│
─────────@^-t───W(a)────
"""
t = cast(float, _try_get_known_cz_half_turns(op))
other_qubit = op.qubits[0] if qubit_with_w == op.qubits[1] else op.qubits[1]
negated_cz = ops.CZ(*op.qubits)**-t
kickback = ops.Z(other_qubit)**t
state.deletions.append((moment_index, op))
state.inline_intos.append((moment_index, negated_cz))
state.insertions.append((moment_index, kickback))
def _double_cross_over_cz(op: ops.Operation,
state: _OptimizerState) -> None:
"""Crosses two W flips over a partial CZ.
[Where W(a) is shorthand for PhasedX(phase_exponent=a).]
Uses the following identity:
───W(a)───@─────
│
───W(b)───@^t───
≡ ──────────@────────────W(a)───
│ (single-cross top W over CZ)
───W(b)───@^-t─────────Z^t────
≡ ──────────@─────Z^-t───W(a)───
│ (single-cross bottom W over CZ)
──────────@^t───W(b)───Z^t────
≡ ──────────@─────W(a)───Z^t────
│ (flip over Z^-t)
──────────@^t───W(b)───Z^t────
≡ ──────────@─────W(a+t/2)──────
│ (absorb Zs into Ws)
──────────@^t───W(b+t/2)──────
≡ ───@─────W(a+t/2)───
│
───@^t───W(b+t/2)───
"""
t = cast(float, _try_get_known_cz_half_turns(op))
for q in op.qubits:
state.held_w_phases[q] = cast(float, state.held_w_phases[q]) + t / 2
def _try_get_known_cz_half_turns(op: ops.Operation) -> Optional[float]:
if (not isinstance(op, ops.GateOperation) or
not isinstance(op.gate, ops.CZPowGate)):
return None
h = op.gate.exponent
if isinstance(h, value.Symbol):
return None
return h
def _try_get_known_phased_pauli(op: ops.Operation
) -> Optional[Tuple[float, float]]:
if protocols.is_parameterized(op) or not isinstance(op, ops.GateOperation):
return None
gate = op.gate
if isinstance(gate, ops.PhasedXPowGate):
e = gate.exponent
p = gate.phase_exponent
elif isinstance(gate, ops.YPowGate):
e = gate.exponent
p = 0.5
elif isinstance(gate, ops.XPowGate):
e = gate.exponent
p = 0.0
else:
return None
return cast(Tuple[float, float],
(value.canonicalize_half_turns(e),
value.canonicalize_half_turns(p)))
def _try_get_known_z_half_turns(op: ops.Operation) -> Optional[float]:
if (not isinstance(op, ops.GateOperation) or
not isinstance(op.gate, ops.ZPowGate)):
return None
h = op.gate.exponent
if isinstance(h, value.Symbol):
return None
return h