Source code for iq_readout.three_state_classifiers.gmda

from __future__ import annotations
import warnings
from typing import Dict

import numpy as np
from scipy.optimize import curve_fit

from ..classifiers import ThreeStateClassifier
from ..utils import check_2d_input, histogram_2d, reshape_histogram_2d, FIT_KARGS
from ..pdfs import simple_2d_gaussian_triple_mixture


[docs] class GaussMixClassifier(ThreeStateClassifier): """ Read `gmda.md` and `ThreeStateClassifier` documentation """ _pdf_func_0 = simple_2d_gaussian_triple_mixture _pdf_func_1 = simple_2d_gaussian_triple_mixture _pdf_func_2 = simple_2d_gaussian_triple_mixture # parameter name ordering must match the ordering in the pdf functions _names = [ "mu_0_x", "mu_0_y", "mu_1_x", "mu_1_y", "mu_2_x", "mu_2_y", "sigma", "angle1", "angle2", ] _param_names = { 0: _names, 1: _names, 2: _names, } @property def statistics(self) -> Dict[str, np.ndarray]: """ Returns dictionary with general statistical data: * ``mu_0``: ``np.array([float, float])`` * ``mu_1``: ``np.array([float, float])`` * ``mu_2``: ``np.array([float, float])`` * ``cov_0``: ``np.array([[float, float], [float, float]])`` * ``cov_1``: ``np.array([[float, float], [float, float]])`` * ``cov_2``: ``np.array([[float, float], [float, float]])`` NB: this property is used for plotting and for storing useful information in the YAML file. """ statistics = {} statistics["mu_0"] = np.array( [self.params[0]["mu_0_x"], self.params[0]["mu_0_y"]] ) statistics["mu_1"] = np.array( [self.params[1]["mu_1_x"], self.params[1]["mu_1_y"]] ) statistics["mu_2"] = np.array( [self.params[1]["mu_2_x"], self.params[1]["mu_2_y"]] ) statistics["cov_0"] = self.params[0]["sigma"] ** 2 * np.eye(2) statistics["cov_1"] = self.params[1]["sigma"] ** 2 * np.eye(2) statistics["cov_2"] = self.params[2]["sigma"] ** 2 * np.eye(2) return statistics @classmethod def fit( cls: GaussMixClassifier, shots_0: np.ndarray, shots_1: np.ndarray, shots_2: np.ndarray, n_bins: list = [100, 100], ) -> GaussMixClassifier: """ Fits the given data to extract the best parameters for classification. Parameters ---------- shots_0: np.ndarray(N, 2) IQ data when preparing state 0 shots_1: np.ndarray(M, 2) IQ data when preparing state 1 shots_2: np.ndarray(P, 2) IQ data when preparing state 2 n_bins: (nx_bins, ny_bins) Number of bins for the first and second coordinate used in the 2d histograms Returns ------- `GaussMixClassifier` containing the fitted parameters """ check_2d_input(shots_0, axis=1) check_2d_input(shots_1, axis=1) check_2d_input(shots_2, axis=1) # populate `params` during fitting params = {state: {} for state in range(3)} all_shots = np.concatenate([shots_0, shots_1, shots_2]) counts, zz = reshape_histogram_2d(*histogram_2d(all_shots, n_bins=n_bins)) # in the first fit the shots_i are concatenated # to extract the means and covariance matrices, # thus the Gaussian weights are approx. 1/3. guess = [ *np.average(shots_0, axis=0), *np.average(shots_1, axis=0), *np.average(shots_2, axis=0), np.average(np.std(shots_0, axis=0)), 0.7854, 0.9553, ] bounds = ( ( *np.min(shots_0, axis=0), *np.min(shots_1, axis=0), *np.min(shots_2, axis=0), 1e-10, 0, 0, ), ( *np.max(shots_0, axis=0), *np.max(shots_1, axis=0), *np.max(shots_2, axis=0), np.max(all_shots), np.pi / 2, np.pi / 2, ), ) popt_comb, pcov = curve_fit( cls._pdf_func_0, # it is the same for all states zz, counts, p0=guess, bounds=bounds, **FIT_KARGS, ) perr = np.sqrt(np.diag(pcov)) if (perr / popt_comb > 0.1).any(): warnings.warn("Fitted means and covariances may not be accurate") mu_0, mu_1, mu_2 = popt_comb[:2], popt_comb[2:4], popt_comb[4:6] sigma = popt_comb[6] for s in range(3): params[s]["mu_0_x"], params[s]["mu_0_y"] = mu_0 params[s]["mu_1_x"], params[s]["mu_1_y"] = mu_1 params[s]["mu_2_x"], params[s]["mu_2_y"] = mu_2 params[s]["sigma"] = sigma # get amplitudes of Gaussians for each state # Note: fitting in log scale improves the results, however there is the # problem of having counts=0 (np.log(0) = inf) due to undersampling bounds = ((0, 0), (np.pi / 2, np.pi / 2)) # PDF state 0 log_pdf = lambda z, angle1, angle2: np.log10( cls._pdf_func_0(z, *popt_comb[:-2], angle1, angle2) ) guess = [0.1, np.pi / 2 - 0.25] # avoid getting stuck in max bound counts, zz = reshape_histogram_2d(*histogram_2d(shots_0, n_bins=n_bins)) zz, counts = zz[counts != 0], counts[counts != 0] popt, pcov = curve_fit( log_pdf, zz, np.log10(counts), p0=guess, bounds=bounds, **FIT_KARGS ) perr = np.sqrt(np.diag(pcov)) if (perr / popt > 0.1).any(): warnings.warn("Fitted means and covariances may not be accurate") params[0]["angle1"], params[0]["angle2"] = popt # PDF state 1 log_pdf = lambda z, angle1, angle2: np.log10( cls._pdf_func_1(z, *popt_comb[:-2], angle1, angle2) ) guess = [1.4706, np.pi / 2 - 0.25] # avoid getting stuck in max bound counts, zz = reshape_histogram_2d(*histogram_2d(shots_1, n_bins=n_bins)) zz, counts = zz[counts != 0], counts[counts != 0] popt, pcov = curve_fit( log_pdf, zz, np.log10(counts), p0=guess, bounds=bounds, **FIT_KARGS ) perr = np.sqrt(np.diag(pcov)) if (perr / popt > 0.1).any(): warnings.warn("Fitted means and covariances may not be accurate") params[1]["angle1"], params[1]["angle2"] = popt # PDF state 2 log_pdf = lambda z, angle1, angle2: np.log10( cls._pdf_func_2(z, *popt_comb[:-2], angle1, angle2) ) guess = [np.pi / 4, 0.2255] counts, zz = reshape_histogram_2d(*histogram_2d(shots_2, n_bins=n_bins)) zz, counts = zz[counts != 0], counts[counts != 0] popt, pcov = curve_fit( log_pdf, zz, np.log10(counts), p0=guess, bounds=bounds, **FIT_KARGS ) perr = np.sqrt(np.diag(pcov)) if (perr / popt > 0.1).any(): warnings.warn("Fitted means and covariances may not be accurate") params[2]["angle1"], params[2]["angle2"] = popt return cls(params)