Source code for cirq.optimizers.merge_interactions

# Copyright 2018 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An optimization pass that combines adjacent single-qubit rotations."""

from typing import Callable, List, Optional, Sequence, Tuple, cast

import numpy as np

from cirq import circuits, ops, protocols
from cirq.optimizers import two_qubit_decompositions


[docs]class MergeInteractions(circuits.PointOptimizer): """Combines series of adjacent one and two-qubit gates operating on a pair of qubits."""
[docs] def __init__(self, tolerance: float = 1e-8, allow_partial_czs: bool = True, post_clean_up: Callable[ [Sequence[ops.Operation]], ops.OP_TREE ] = lambda op_list: op_list) -> None: super().__init__(post_clean_up=post_clean_up) self.tolerance = tolerance self.allow_partial_czs = allow_partial_czs
[docs] def optimization_at(self, circuit: circuits.Circuit, index: int, op: ops.Operation ) -> Optional[circuits.PointOptimizationSummary]: if len(op.qubits) != 2: return None old_operations, indices, matrix = ( self._scan_two_qubit_ops_into_matrix(circuit, index, op.qubits)) old_interaction_count = len([old_op for old_op in old_operations if len(old_op.qubits) == 2]) switch_to_new = False switch_to_new |= any(len(old_op.qubits) == 2 and not (isinstance(old_op, ops.GateOperation) and isinstance(old_op.gate, ops.CZPowGate)) for old_op in old_operations) if not self.allow_partial_czs: switch_to_new |= any(isinstance(old_op, ops.GateOperation) and isinstance(old_op.gate, ops.CZPowGate) and old_op.gate.exponent != 1 for old_op in old_operations) # This point cannot be optimized using this method if not switch_to_new and old_interaction_count <= 1: return None # Find a max-3-cz construction. new_operations = ( two_qubit_decompositions.two_qubit_matrix_to_operations( op.qubits[0], op.qubits[1], matrix, self.allow_partial_czs, self.tolerance)) new_interaction_count = len([new_op for new_op in new_operations if len(new_op.qubits) == 2]) switch_to_new |= new_interaction_count < old_interaction_count if not switch_to_new: return None return circuits.PointOptimizationSummary( clear_span=max(indices) + 1 - index, clear_qubits=op.qubits, new_operations=new_operations)
def _op_to_matrix(self, op: Optional[ops.Operation], qubits: Tuple[ops.QubitId, ...] ) -> Optional[np.ndarray]: """Determines the effect of an operation on the given qubits. If the operation is a 1-qubit operation on one of the given qubits, or a 2-qubit operation on both of the given qubits, and also the operation has a known matrix, then a matrix is returned. Otherwise None is returned. Args: op: The operation to understand. qubits: The qubits we care about. Order determines matrix tensor order. Returns: None, or else a matrix equivalent to the effect of the operation. """ q1, q2 = qubits matrix = protocols.unitary(op, None) if matrix is None: return None assert op is not None if op.qubits == qubits: return matrix if op.qubits == (q2, q1): return MergeInteractions._flip_kron_order(matrix) if op.qubits == (q1,): return np.kron(matrix, np.eye(2)) if op.qubits == (q2,): return np.kron(np.eye(2), matrix) return None def _scan_two_qubit_ops_into_matrix( self, circuit: circuits.Circuit, index: Optional[int], qubits: Tuple[ops.QubitId, ...] ) -> Tuple[List[ops.Operation], List[int], np.ndarray]: """Accumulates operations affecting the given pair of qubits. The scan terminates when it hits the end of the circuit, finds an operation without a known matrix, or finds an operation that interacts the given qubits with other qubits. Args: circuit: The circuit to scan for operations. index: The index to start scanning forward from. qubits: The pair of qubits we care about. Returns: A tuple containing: 0. The operations. 1. The moment indices those operations were on. 2. A matrix equivalent to the effect of the scanned operations. """ product = np.eye(4, dtype=np.complex128) all_operations = [] touched_indices = [] while index is not None: operations = list({circuit.operation_at(q, index) for q in qubits}) op_data = [ self._op_to_matrix(op, qubits) for op in operations if op is not None ] # Stop at any non-constant or non-local interaction. if any(e is None for e in op_data): break present_ops = [op for op in operations if op] present_op_data = cast(List[np.ndarray], op_data) for op_mat in present_op_data: product = np.dot(op_mat, product) all_operations.extend(present_ops) touched_indices.append(index) index = circuit.next_moment_operating_on(qubits, index + 1) return all_operations, touched_indices, product @staticmethod def _flip_kron_order(mat4x4: np.ndarray) -> np.ndarray: """Given M = sum(kron(a_i, b_i)), returns M' = sum(kron(b_i, a_i)).""" result = np.array([[0] * 4] * 4, dtype=np.complex128) order = [0, 2, 1, 3] for i in range(4): for j in range(4): result[order[i], order[j]] = mat4x4[i, j] return result