Source code for hamil_clever_sim.components.statevector_display

from itertools import product
from typing import Tuple

from rich.text import Text
from textual import on
from textual.app import ComposeResult
from textual.reactive import Reactive, reactive
from textual.widgets import (
    DataTable,
    Static,
)
from textual.widgets.data_table import ColumnKey

from hamil_clever_sim.components.sparkline_controllable import Sparkline
from hamil_clever_sim.hamil_runner import SimulationRunnerResult

COLS = ("|𝜓⟩", "amplitude (ℝ)", "amplitude (ℂ)", "probability (%)")


[docs] class StatevectorDisplay(Static): data: Reactive[SimulationRunnerResult.Data | None] = reactive(None, layout=True) rows: Reactive[list[tuple[str, str, str]]] = reactive([]) column_keys: list[ColumnKey] = [] spark_data: Reactive[list[float]] = reactive([]) all_states = reactive([]) marked_index = reactive(-1) def on_mount(self): table = self.query_one(DataTable) self.column_keys = table.add_columns(*COLS) def compute_rows(self): if self.data is None: return [] rows = [ ( f"|{label}⟩", Text("{: .5f}".format(entry[0].real), justify="right"), "{:+.5f}𝕚".format(entry[0].imag), Text("{:.2f}%".format(entry[1] * 100), justify="right"), ) for label, entry in self.data.items() ] return rows def watch_data(self, update: dict[str, Tuple[complex, float]]): if update is None: return def watch_rows(self, update: list[tuple[str, str, str]]) -> None: table = self.query_one(DataTable) for row in update: table.add_row(*row) spark = self.query_one(Sparkline) spark.data = self.spark_data def compute_all_states(self): if self.data is None: return [] el = next(iter(self.data.keys())) states = ["".join(state) for state in product("01", repeat=len(el))] return states def compute_spark_data(self): states = self.all_states if len(states) == 0: return assert self.data is not None return [abs((self.data.get(state) or (0,))[0]) for state in states] @on(DataTable.RowHighlighted) def handle_row_select(self, highlight: DataTable.RowHighlighted): cursor = highlight.cursor_row states = self.all_states as_state = highlight.data_table.get_row_at(cursor)[0][1:-1] # to remove the ket marked_cursor = states.index(as_state) self.marked_index = marked_cursor spark = self.query_one(Sparkline) spark.marked_index = marked_cursor spark.refresh() print(spark.marked_index) def compose(self) -> ComposeResult: yield Sparkline(self.spark_data) yield DataTable(cursor_type="row", classes="statevector-display-table")