Source code for cirq.optimizers.eject_z

# Copyright 2018 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An optimization pass that pushes Z gates later and later in the circuit."""

from typing import Optional, cast, TYPE_CHECKING, Iterable

from collections import defaultdict

from cirq import circuits, ops, value, protocols
from cirq.optimizers import decompositions

if TYPE_CHECKING:
    # pylint: disable=unused-import
    from typing import Dict, List, Tuple


[docs]class EjectZ(circuits.OptimizationPass): """Pushes Z gates towards the end of the circuit. As the Z gates get pushed they may absorb other Z gates, get absorbed into measurements, cross CZ gates, cross W gates (by phasing them), etc. """
[docs] def __init__(self, tolerance: float = 0.0) -> None: """ Args: tolerance: Maximum absolute error tolerance. The optimization is permitted to simply drop negligible combinations of Z gates, with a threshold determined by this tolerance. """ self.tolerance = tolerance
[docs] def optimize_circuit(self, circuit: circuits.Circuit): # Tracks qubit phases (in half turns; multiply by pi to get radians). qubit_phase = defaultdict(lambda: 0) # type: Dict[ops.QubitId, float] def dump_tracked_phase(qubits: Iterable[ops.QubitId], index: int) -> None: """Zeroes qubit_phase entries by emitting Z gates.""" for q in qubits: p = qubit_phase[q] if not decompositions.is_negligible_turn(p, self.tolerance): dump_op = ops.Z(q)**(p * 2) insertions.append((index, dump_op)) qubit_phase[q] = 0 deletions = [] # type: List[Tuple[int, ops.Operation]] inline_intos = [] # type: List[Tuple[int, ops.Operation]] insertions = [] # type: List[Tuple[int, ops.Operation]] for moment_index, moment in enumerate(circuit): for op in moment.operations: # Move Z gates into tracked qubit phases. h = _try_get_known_z_half_turns(op) if h is not None: q = op.qubits[0] qubit_phase[q] += h / 2 deletions.append((moment_index, op)) continue # Z gate before measurement is a no-op. Drop tracked phase. if ops.MeasurementGate.is_measurement(op): for q in op.qubits: qubit_phase[q] = 0 # If there's no tracked phase, we can move on. phases = [qubit_phase[q] for q in op.qubits] if all(decompositions.is_negligible_turn(p, self.tolerance) for p in phases): continue # Try to move the tracked phasing over the operation. phased_op = op for i, p in enumerate(phases): if not decompositions.is_negligible_turn(p, self.tolerance): phased_op = protocols.phase_by(phased_op, -p, i, default=None) if phased_op is not None: deletions.append((moment_index, op)) inline_intos.append((moment_index, cast(ops.Operation, phased_op))) else: dump_tracked_phase(op.qubits, moment_index) dump_tracked_phase(qubit_phase.keys(), len(circuit)) circuit.batch_remove(deletions) circuit.batch_insert_into(inline_intos) circuit.batch_insert(insertions)
def _try_get_known_z_half_turns(op: ops.Operation) -> Optional[float]: if not isinstance(op, ops.GateOperation): return None if not isinstance(op.gate, ops.ZPowGate): return None h = op.gate.exponent if isinstance(h, value.Symbol): return None return h