Skip to content

SQL Database

DisruptionPy uses logbook sql databases for convenience when retrieving data from MDSPlus. Users may also use DisruptionPy to directly retrieve data from the logbook database's disruption_warning tables.

The disruption_warning table¤

The disruption_warning sql tables for CMod and DIII-D contain important disruption parameters for a large number of shots.

CMod Dataset¤

The dataset contains unique plasma discharges from MIT's Alcator C-Mod tokamak, from the 2012 to 2016 experimental campaigns, plus additional discharges from 2005.

Available columns on CMod
'dbkey', 'shot', 'time', 'time_until_disrupt', 'ip_error', 'dip_dt',
'beta_p', 'beta_n', 'li', 'n_equal_1_normalized', 'z_error', 'v_z',
'z_times_v_z', 'kappa', 'pressure_peaking', 'H98', 'q0', 'qstar', 'q95',
'v_0', 'v_mid', 'v_edge', 'dn_dt', 'p_rad_slow', 'p_oh_slow', 'p_icrf',
'p_lh', 'radiated_fraction', 'power_supply_railed', 'v_loop_efit',
'r_dd', 'lower_gap', 'upper_gap', 'dbetap_dt', 'dli_dt', 'ip', 'zcur',
'n_e', 'dipprog_dt', 'v_loop', 'p_rad', 'p_oh', 'ssep', 'dWmhd_dt',
'dprad_dt', 'v_0_uncalibrated', 'Te_width', 'Greenwald_fraction',
'intentional_disruption', 'Te_width_ECE', 'Wmhd', 'n_over_ncrit',
'n_equal_1_mode', 'Mirnov', 'Mirnov_norm_btor', 'Mirnov_norm_bpol',
'Te_peaking', 'ne_peaking', 'Te_peaking_ECE', 'SXR_peaking',
'kappa_area', 'I_efc', 'SXR', 'H_alpha', 'Prad_peaking_CVA'

For more details on computed values please see parameter reference.

Retrieving data from the SQL database¤

Here is an example retrieving data from disruption_warning or disruptions table

#!/usr/bin/env python3

"""
example module for SQL.
"""


from disruption_py.machine.tokamak import Tokamak, resolve_tokamak_from_environment
from disruption_py.workflow import get_database


def main():
    """
    execute a few meaningful queries to test DB connection.
    """

    queries = [
        "select count(distinct shot) from disruption_warning",
        "select count(distinct shot) from disruption_warning"
        + " where shot not in (select shot from disruptions)",
        "select count(distinct shot) from disruption_warning"
        + " where shot in (select shot from disruptions)",
        "select count(distinct shot) from disruptions",
    ]
    tokamak = resolve_tokamak_from_environment()
    db = get_database(tokamak=tokamak)

    if tokamak is Tokamak.D3D:
        vals = [13245, 8055, 5190, 24219]
    elif tokamak is Tokamak.CMOD:
        vals = [10435, 6640, 3795, 13785]
    elif tokamak is Tokamak.EAST:
        vals = [18568, 9875, 8693, 30482]
    else:
        raise ValueError(f"Unspecified or unsupported tokamak: {tokamak}.")

    print(f"Initialized DB: {db.user}@{db.host}/{db.db_name}")
    print("Version:", db.get_version())

    while queries:

        query = queries.pop(0)
        print(">", query.strip(" "))

        out = db.query(query)
        print("=", out.shape)

        print(out.iloc[0] if out.shape[0] == 1 else out, "\n")
        if vals:
            assert out.iloc[0, 0] == vals.pop(0)


if __name__ == "__main__":
    main()

Database Class Reference¤

Module for managing SQL database connections.

DummyDatabase ¤

Bases: ShotDatabase

A database class that does not require connecting to an SQL server and returns no data.

Source code in disruption_py/inout/sql.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
class DummyDatabase(ShotDatabase):
    """
    A database class that does not require connecting to an SQL server and returns
    no data.
    """

    # pylint: disable-next=super-init-not-called
    def __init__(self, **kwargs):
        pass

    @classmethod
    # pylint: disable-next=missing-function-docstring
    def initializer(cls, **_kwargs):
        return cls()

    @property
    def conn(self):
        return DummyObject()

    # pylint: disable-next=arguments-differ
    def query(self, **_kwargs):
        return pd.DataFrame()

    # pylint: disable-next=arguments-differ
    def get_shots_data(self, **_kwargs):
        return pd.DataFrame()

    # pylint: disable-next=arguments-differ
    def get_disruption_time(self, **_kwargs):
        return None

DummyObject ¤

A dummy connection object.

Source code in disruption_py/inout/sql.py
236
237
238
239
240
241
242
243
244
245
246
247
class DummyObject:
    """
    A dummy connection object.
    """

    def __getattr__(self, name):
        # Return self for any attribute or method call
        return self

    def __call__(self, *args, **kwargs):
        # Return self for any method call
        return self

ShotDatabase ¤

Handles grabbing data from MySQL server.

Source code in disruption_py/inout/sql.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
class ShotDatabase:
    """
    Handles grabbing data from MySQL server.
    """

    def __init__(
        self,
        driver,
        host,
        port,
        db_name,
        user,
        passwd,
        **_kwargs,
    ):

        logger.debug(
            "Database initialization: {user}@{host}/{db_name}",
            user=user,
            host=host,
            db_name=db_name,
        )
        drivers = pyodbc.drivers()
        if driver in drivers:
            # exact driver
            self.driver = driver
        elif any(d.startswith(driver) for d in drivers):
            # fuzzy driver
            self.driver = next(d for d in drivers if d.startswith(driver))
            logger.info(
                "Database driver fallback: '{driver}' -> '{class_driver}'",
                driver=driver,
                class_driver=self.driver,
            )
        else:
            # first driver
            self.driver = drivers[0]
            logger.warning(
                "Database driver fallback: '{driver}' -> '{class_driver}'",
                driver=driver,
                class_driver=self.driver,
            )
        self.host = host
        self.port = port
        self.db_name = db_name
        self.user = user
        self.passwd = passwd

        self.dialect = "mysql" if "mysql" in self.driver.lower() else "mssql"
        self.connection_string = self._get_connection_string(self.db_name)
        self._thread_connections = {}
        quoted_connection_string = quote_plus(self.connection_string)
        self.engine = create_engine(
            f"{self.dialect}+pyodbc:///?odbc_connect={quoted_connection_string}"
        )

    @classmethod
    def from_config(cls, tokamak: Tokamak):
        """
        Initialize database from config.
        """

        db_conf = config(tokamak).inout.sql

        # read sybase login
        if any(f"db_{key}" not in db_conf for key in ["user", "pass"]):
            db_name = db_conf["db_name"]
            for name in [db_name.lower(), db_name.upper()]:
                profile = os.path.expanduser(f"~/{name}.sybase_login")
                if not os.path.exists(profile):
                    continue
                with open(profile, "r", encoding="utf-8") as fio:
                    db_conf["db_user"], db_conf["db_pass"] = fio.read().split()[-2:]
                break
            else:
                raise ValueError("could not read DB username and password.")

        return SharedInstance(ShotDatabase).get_instance(
            driver=db_conf["driver"],
            host=db_conf["host"],
            port=db_conf["port"],
            db_name=db_conf["db_name"],
            user=db_conf["db_user"],
            passwd=db_conf["db_pass"],
        )

    def _get_connection_string(self, db_name):
        params = {
            "DRIVER": self.driver,
            "SERVER": self.host,
            "PORT": self.port,
            "DATABASE": db_name,
            "UID": self.user,
            "PWD": self.passwd,
            "TrustServerCertificate": "yes",
            "Connection Timeout": 60,
        }
        if self.driver.lower().startswith("odbc"):
            params["SERVER"] += f",{params.pop('PORT')}"
        return ";".join([f"{k}={v}" for k, v in params.items()])

    @property
    def conn(self):
        """
        Property returning a connection to sql database.

        If a connection exists for the given thread returns that connection,
        otherwise creates a new connection

        Returns
        -------
        _type_
            Database connection
        """
        current_thread = threading.current_thread()
        if current_thread not in self._thread_connections:
            logger.debug(
                "PID #{pid} | Connecting to SQL database: {server}",
                pid=threading.get_native_id(),
                server=self.host,
            )
            self._thread_connections[current_thread] = pyodbc.connect(
                self.connection_string
            )
        return self._thread_connections[current_thread]

    def query(self, query: str, use_pandas=True):
        """
        query sql database

        Parameters
        ----------
        query : str
            The query string
        use_pandas : bool, optional
            Whether pd.read_sql_query should be used to run the query. Default value
            is true.

        Returns
        -------
        Any
            Result of query
        """
        if use_pandas:
            return pd.read_sql_query(query, self.engine)
        curs = self.conn.cursor()
        output = None
        try:
            curs.execute(query)
            output = curs.fetchall()
        except pyodbc.DatabaseError as e:
            logger.error("Query failed: {e}", e=repr(e))
            logger.opt(exception=True).debug(e)
        curs.close()
        return output

    def get_version(self):
        """
        Query the version of the SQL database.
        """
        mysql = "mysql" in self.driver.lower()
        query = "select " + ("version()" if mysql else "@@version")
        version = self.query(query, use_pandas=False)
        return version[0][0]

    def get_shots_data(
        self,
        shotlist: List[int],
        cols: List[str] = None,
        sql_table="disruption_warning",
    ):
        """
        get_shots_data retrieves columns from sql data for given shotlist

        Parameters
        ----------
        shotlist : List[int]
            List of shot ids to get data for.
        cols : List[str]
            List of columns to retrieve. Default value is ["*"], meaning all columns.
        sql_table : str, optional
            The sql_table to retrieve data from. Default value is "disruption_warning".

        Returns
        -------
        pd.Dataframe
            Dataframe containing queried data
        """
        if cols is None:
            cols = ["*"]
        cols = ", ".join(str(col) for col in cols)
        shotlist = ",".join(str(shot) for shot in shotlist)
        query = f"select {cols} from {sql_table}"
        if shotlist is None:
            query += " order by time"
        else:
            query += f" where shot in ({shotlist}) order by shot, time"
        shot_df = pd.read_sql_query(query, self.engine)
        shot_df.columns = shot_df.columns.str.lower()
        return shot_df

    def get_disruption_time(self, shot_id):
        """
        Get disruption time for shot_id or None if there was no disruption.
        """
        with self.conn.cursor() as curs:
            curs.execute(f"select t_disrupt from disruptions where shot = {shot_id}")
            t_disrupt = curs.fetchall()
        if len(t_disrupt) == 0:
            return None
        t_disrupt = t_disrupt[0][0]
        return t_disrupt

conn property ¤

conn

Property returning a connection to sql database.

If a connection exists for the given thread returns that connection, otherwise creates a new connection

RETURNS DESCRIPTION
_type_

Database connection

from_config classmethod ¤

from_config(tokamak: Tokamak)

Initialize database from config.

Source code in disruption_py/inout/sql.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
@classmethod
def from_config(cls, tokamak: Tokamak):
    """
    Initialize database from config.
    """

    db_conf = config(tokamak).inout.sql

    # read sybase login
    if any(f"db_{key}" not in db_conf for key in ["user", "pass"]):
        db_name = db_conf["db_name"]
        for name in [db_name.lower(), db_name.upper()]:
            profile = os.path.expanduser(f"~/{name}.sybase_login")
            if not os.path.exists(profile):
                continue
            with open(profile, "r", encoding="utf-8") as fio:
                db_conf["db_user"], db_conf["db_pass"] = fio.read().split()[-2:]
            break
        else:
            raise ValueError("could not read DB username and password.")

    return SharedInstance(ShotDatabase).get_instance(
        driver=db_conf["driver"],
        host=db_conf["host"],
        port=db_conf["port"],
        db_name=db_conf["db_name"],
        user=db_conf["db_user"],
        passwd=db_conf["db_pass"],
    )

get_disruption_time ¤

get_disruption_time(shot_id)

Get disruption time for shot_id or None if there was no disruption.

Source code in disruption_py/inout/sql.py
223
224
225
226
227
228
229
230
231
232
233
def get_disruption_time(self, shot_id):
    """
    Get disruption time for shot_id or None if there was no disruption.
    """
    with self.conn.cursor() as curs:
        curs.execute(f"select t_disrupt from disruptions where shot = {shot_id}")
        t_disrupt = curs.fetchall()
    if len(t_disrupt) == 0:
        return None
    t_disrupt = t_disrupt[0][0]
    return t_disrupt

get_shots_data ¤

get_shots_data(
    shotlist: List[int],
    cols: List[str] = None,
    sql_table="disruption_warning",
)

get_shots_data retrieves columns from sql data for given shotlist

PARAMETER DESCRIPTION
shotlist

List of shot ids to get data for.

TYPE: List[int]

cols

List of columns to retrieve. Default value is ["*"], meaning all columns.

TYPE: List[str] DEFAULT: None

sql_table

The sql_table to retrieve data from. Default value is "disruption_warning".

TYPE: str DEFAULT: 'disruption_warning'

RETURNS DESCRIPTION
Dataframe

Dataframe containing queried data

Source code in disruption_py/inout/sql.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def get_shots_data(
    self,
    shotlist: List[int],
    cols: List[str] = None,
    sql_table="disruption_warning",
):
    """
    get_shots_data retrieves columns from sql data for given shotlist

    Parameters
    ----------
    shotlist : List[int]
        List of shot ids to get data for.
    cols : List[str]
        List of columns to retrieve. Default value is ["*"], meaning all columns.
    sql_table : str, optional
        The sql_table to retrieve data from. Default value is "disruption_warning".

    Returns
    -------
    pd.Dataframe
        Dataframe containing queried data
    """
    if cols is None:
        cols = ["*"]
    cols = ", ".join(str(col) for col in cols)
    shotlist = ",".join(str(shot) for shot in shotlist)
    query = f"select {cols} from {sql_table}"
    if shotlist is None:
        query += " order by time"
    else:
        query += f" where shot in ({shotlist}) order by shot, time"
    shot_df = pd.read_sql_query(query, self.engine)
    shot_df.columns = shot_df.columns.str.lower()
    return shot_df

get_version ¤

get_version()

Query the version of the SQL database.

Source code in disruption_py/inout/sql.py
178
179
180
181
182
183
184
185
def get_version(self):
    """
    Query the version of the SQL database.
    """
    mysql = "mysql" in self.driver.lower()
    query = "select " + ("version()" if mysql else "@@version")
    version = self.query(query, use_pandas=False)
    return version[0][0]

query ¤

query(query: str, use_pandas=True)

query sql database

PARAMETER DESCRIPTION
query

The query string

TYPE: str

use_pandas

Whether pd.read_sql_query should be used to run the query. Default value is true.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
Any

Result of query

Source code in disruption_py/inout/sql.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def query(self, query: str, use_pandas=True):
    """
    query sql database

    Parameters
    ----------
    query : str
        The query string
    use_pandas : bool, optional
        Whether pd.read_sql_query should be used to run the query. Default value
        is true.

    Returns
    -------
    Any
        Result of query
    """
    if use_pandas:
        return pd.read_sql_query(query, self.engine)
    curs = self.conn.cursor()
    output = None
    try:
        curs.execute(query)
        output = curs.fetchall()
    except pyodbc.DatabaseError as e:
        logger.error("Query failed: {e}", e=repr(e))
        logger.opt(exception=True).debug(e)
    curs.close()
    return output