Skip to content

SQL Database

DisruptionPy uses logbook sql databases for convenience when retrieving data from MDSPlus. Users may also use DisruptionPy to directly retrieve data from the logbook database's disruption_warning tables.

The disruption_warning table¤

The disruption_warning sql tables for CMod and DIII-D contain important disruption parameters for a large number of shots.

CMod Dataset¤

The dataset contains unique plasma discharges from MIT's Alcator C-Mod tokamak, from the 2012 to 2016 experimental campaigns, plus additional discharges from 2005.

Available columns on CMod
'dbkey', 'shot', 'time', 'time_until_disrupt', 'ip_error', 'dip_dt',
'beta_p', 'beta_n', 'li', 'n_equal_1_normalized', 'z_error', 'v_z',
'z_times_v_z', 'kappa', 'pressure_peaking', 'H98', 'q0', 'qstar', 'q95',
'v_0', 'v_mid', 'v_edge', 'dn_dt', 'p_rad_slow', 'p_oh_slow', 'p_icrf',
'p_lh', 'radiated_fraction', 'power_supply_railed', 'v_loop_efit',
'r_dd', 'lower_gap', 'upper_gap', 'dbetap_dt', 'dli_dt', 'ip', 'zcur',
'n_e', 'dipprog_dt', 'v_loop', 'p_rad', 'p_oh', 'ssep', 'dWmhd_dt',
'dprad_dt', 'v_0_uncalibrated', 'Te_width', 'Greenwald_fraction',
'intentional_disruption', 'Te_width_ECE', 'Wmhd', 'n_over_ncrit',
'n_equal_1_mode', 'Mirnov', 'Mirnov_norm_btor', 'Mirnov_norm_bpol',
'Te_peaking', 'ne_peaking', 'Te_peaking_ECE', 'SXR_peaking',
'kappa_area', 'I_efc', 'SXR', 'H_alpha', 'Prad_peaking_CVA',
'commit_hash'

For more details on computed values please see parameter reference.

Retrieving data from the SQL database¤

Here is an example that uses DisruptionPy to get shot data from the disruption_warning table for eight shots from the disruption warning shotlist:

#!/usr/bin/env python3

"""Example usage of `get_shots_data` testing the connection to the SQL database."""

from disruption_py.workflow import get_database

cmod_database = get_database(tokamak="cmod")
shotlist = cmod_database.get_disruption_warning_shotlist()["shot"][0:8].tolist()
result = cmod_database.get_shots_data(shotlist, sql_table="disruption_warning")

Database Class Reference¤

Module for managing SQL database connections.

DummyDatabase ¤

Bases: ShotDatabase

A database class that does not require connecting to an SQL server and returns no data.

Source code in disruption_py/inout/sql.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
class DummyDatabase(ShotDatabase):
    """
    A database class that does not require connecting to an SQL server and returns
    no data.
    """

    # pylint: disable-next=super-init-not-called
    def __init__(self, **kwargs):
        pass

    @classmethod
    # pylint: disable-next=missing-function-docstring
    def initializer(cls, **_kwargs):
        return cls()

    @property
    def conn(self):
        return DummyObject()

    # pylint: disable-next=arguments-differ
    def query(self, **_kwargs):
        return pd.DataFrame()

    # pylint: disable-next=arguments-differ
    def get_shots_data(self, **_kwargs):
        return pd.DataFrame()

    # pylint: disable-next=arguments-differ
    def get_disruption_time(self, **_kwargs):
        return None

    def get_disruption_shotlist(self, **_kwargs):
        return []

    def get_disruption_warning_shotlist(self, **_kwargs):
        return []

DummyObject ¤

A dummy connection object.

Source code in disruption_py/inout/sql.py
471
472
473
474
475
476
477
478
479
480
481
482
class DummyObject:
    """
    A dummy connection object.
    """

    def __getattr__(self, name):
        # Return self for any attribute or method call
        return self

    def __call__(self, *args, **kwargs):
        # Return self for any method call
        return self

ShotDatabase ¤

Handles grabbing data from MySQL server.

Source code in disruption_py/inout/sql.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
class ShotDatabase:
    """
    Handles grabbing data from MySQL server.
    """

    logger = logging.getLogger("disruption_py")

    def __init__(
        self,
        driver,
        host,
        port,
        db_name,
        user,
        passwd,
        protected_columns=None,
        write_database_table_name=None,
        **_kwargs,
    ):

        if protected_columns is None:
            protected_columns = []

        self.logger.info("Database initialization: %s@%s/%s", user, host, db_name)
        drivers = pyodbc.drivers()
        if driver in drivers:
            self.driver = driver
        else:
            self.driver = drivers[0]
            self.logger.warning(
                "Database driver fallback: '%s' -> '%s'", driver, self.driver
            )
        self.host = host
        self.port = port
        self.db_name = db_name
        self.user = user
        self.passwd = passwd
        self.protected_columns = protected_columns
        self.write_database_table_name = write_database_table_name

        self.connection_string = self._get_connection_string(self.db_name)
        self._thread_connections = {}
        quoted_connection_string = quote_plus(self.connection_string)
        self.engine = create_engine(
            f"mssql+pyodbc:///?odbc_connect={quoted_connection_string}"
        )

    @classmethod
    def from_config(cls, tokamak: Tokamak):
        """
        Initialize database from config file.
        """
        return cls._from_dict(config(tokamak).database)

    @classmethod
    def _from_dict(cls, database_dict: dict):
        """
        Initialize database from config file.
        """

        # read profile
        profile_path = database_dict["profile_path"]
        profile = os.path.expanduser(profile_path)
        with open(profile, "r", encoding="utf-8") as fio:
            db_user, db_pass = fio.read().split()[-2:]

        return SharedInstance(ShotDatabase).get_instance(
            driver=database_dict["driver"],
            host=database_dict["host"],
            port=database_dict["port"],
            db_name=database_dict["db_name"],
            user=db_user,
            passwd=db_pass,
            protected_columns=without_duplicates(database_dict["protected_columns"]),
            write_database_table_name=database_dict.get("write_database_table_name"),
        )

    def _get_connection_string(self, db_name):
        params = {
            "DRIVER": self.driver,
            "SERVER": self.host,
            "PORT": self.port,
            "DATABASE": db_name,
            "UID": self.user,
            "PWD": self.passwd,
            "TrustServerCertificate": "yes",
            "Connection Timeout": 60,
        }
        if "ODBC" in self.driver:
            params["SERVER"] += f",{params.pop('PORT')}"
        conn_str = ";".join([f"{k}={v}" for k, v in params.items()])
        return conn_str

    @property
    def conn(self):
        """
        Property returning a connection to sql database.

        If a connection exists for the given thread returns that connection,
        otherwise creates a new connection

        Returns
        -------
        _type_
            Database connection
        """
        current_thread = threading.current_thread()
        if current_thread not in self._thread_connections:
            self.logger.info("Connecting to database for thread %s", current_thread)
            self._thread_connections[current_thread] = pyodbc.connect(
                self.connection_string
            )
        return self._thread_connections[current_thread]

    def query(self, query: str, use_pandas=True):
        """
        query sql database

        Parameters
        ----------
        query : str
            The query string
        use_pandas : bool, optional
            Whether pd.read_sql_query should be used to run the query. Default value
            is true.

        Returns
        -------
        Any
            Result of query
        """
        if "alter" in query.lower():
            if query.lower() in self.protected_columns:
                return 0
        elif use_pandas:
            return pd.read_sql_query(query, self.engine)
        curs = self.conn.cursor()
        output = None
        try:
            curs.execute(query)
            if "select" in query.lower():
                output = curs.fetchall()
        except pyodbc.DatabaseError as e:
            print(e)
            self.logger.debug(e)
            self.logger.error("Query failed, returning None")
        curs.close()
        return output

    def add_shot_data(
        self,
        shot_id: int,
        shot_data: pd.DataFrame,
        update=False,
        override_columns: List[str] = None,
    ):
        """
        Upload shot to SQL database.

        Either inserts or updates shot data depending on whether a shot already exists
        in database. If shot exists, then the timebase of the shot data must match
        the timebase of the shot in the database.

        Parameters
        ----------
        shot_id : int
            Shot id of the shot being modified
        shot_data : pd.DataFrame
            Dataframe containing shot data for update
        update : bool
            Whether to update shot data if the shot already exists in database.
            Update will happen regardless of whether the column being updated is
            all nil. Default value is False.
        override_columns : List[str]
            List of protected columns that can still be updated. Update must be
            true for input values in the columns to be changed. Default value is [].
        """
        if self.write_database_table_name is None:
            raise ValueError(
                "specify write_database_table_name in the configuration before "
                + "adding shot data"
            )
        curr_df = pd.read_sql_query(
            f"select * from {self.write_database_table_name} where shot={shot_id} "
            + "order by time",
            self.engine,
        )

        if len(curr_df) == 0:
            return self._insert_shot_data(
                curr_df=curr_df,
                shot_data=shot_data,
                table_name=self.write_database_table_name,
            )
        if (
            len(curr_df) == len(shot_data)
            and (
                (curr_df["time"] - shot_data["time"]).abs() < config().TIME_CONST
            ).all()
        ):
            return self._update_shot_data(
                shot_id=shot_id,
                curr_df=curr_df,
                shot_data=shot_data,
                update=update,
                table_name=self.write_database_table_name,
                override_columns=override_columns,
            )

        self.logger.error("Invalid timebase for data output")
        return False

    def _insert_shot_data(
        self,
        curr_df: pd.DataFrame,
        shot_data: pd.DataFrame,
        table_name: str,
    ):
        """
        Insert shot data into SQL table.

        Assumes that the shot id does not already exist in the database.
        """

        identity_column_names = self._get_identity_column_names(table_name)

        matching_columns_shot_data = pd.DataFrame()
        for column_name in curr_df.columns:
            if column_name in identity_column_names:
                continue

            if column_name in shot_data.columns:
                matching_columns_shot_data[column_name] = shot_data[column_name]

        matching_columns_shot_data = matching_columns_shot_data.replace({np.nan: None})

        column_names = matching_columns_shot_data.columns.tolist()
        sql_column_names = ", ".join(column_names)
        parameter_markers = "(" + ", ".join(["?"] * len(column_names)) + ")"
        with self.conn.cursor() as curs:
            data_tuples = list(
                matching_columns_shot_data.itertuples(index=False, name=None)
            )
            curs.executemany(
                f"insert into {table_name} ({sql_column_names}) values "
                + f"{parameter_markers}",
                data_tuples,
            )
        return True

    def _update_shot_data(
        self,
        shot_id: int,
        curr_df: pd.DataFrame,
        shot_data: pd.DataFrame,
        update: bool,
        table_name: str,
        override_columns: List[str] = None,
    ):
        """
        Update shot data into SQL table.

        Assumes that the shot id already exist in the database and the timebase of
        shot_data is the same as curr_df.

        Parameters
        ----------
        curr_df : pd.DataFrame
            Data currently in sql database.
        shot_data : pd.DataFrame
            Dataframe containing shot data for update.
        update : bool
            Whether to update shot data if the shot already exists in database.
            Update will happen regardless of whether the column being updated is
            all nil. Default value is False.
        override_columns : List[str]
            List of columns that can should still be updated. Update must be true
            for input values in the columns to be changed. Default value is [].
        table_name : str
            Name of the table for data insert or update. Default value is
            "disruption_warning".
        """
        override_columns = override_columns or []

        update_columns_shot_data = pd.DataFrame()
        for column_name in curr_df.columns:
            if column_name in config().database.protected_columns or (
                column_name in self.protected_columns
                and column_name not in override_columns
            ):
                continue

            if (
                column_name in shot_data.columns
                and not shot_data[column_name].isna().all()
                and (update or curr_df[column_name].isna().all())
            ):
                update_columns_shot_data[column_name] = shot_data[column_name]
        # pyodbc will fill SQL with NULL for None, but not for np.nan
        update_columns_shot_data = update_columns_shot_data.replace({np.nan: None})
        with self.conn.cursor() as curs:
            for index, row in enumerate(
                update_columns_shot_data.itertuples(index=False, name=None)
            ):
                update_column_names = list(update_columns_shot_data.columns)
                sql_set_string = ", ".join(
                    [f"{col} = ?" for col in update_column_names]
                )
                sql_command = (
                    f"UPDATE {table_name} SET {sql_set_string} "
                    + "WHERE time = ? AND shot = ?;"
                )
                curs.execute(sql_command, row + (curr_df["time"][index], str(shot_id)))
        return True

    def _get_identity_column_names(self, table_name: str):
        """Get which column names are identity columns in table."""
        with self.conn.cursor() as curs:
            query = f"""\
            SELECT c.name AS ColumnName 
            FROM sys.columns c
            INNER JOIN sys.tables t ON c.object_id = t.object_id
            LEFT JOIN sys.identity_columns ic ON ic.object_id = c.object_id AND ic.column_id = c.column_id
            WHERE t.name = '{table_name}' AND ic.object_id IS NOT NULL
            """
            curs.execute(query)
            return [row[0] for row in curs.fetchall()]

    def remove_shot_data(self, shot_id):
        """Remove shot from SQL table."""
        if self.write_database_table_name is None:
            raise ValueError(
                "specify write_database_table_name in the configuration before "
                + "adding shot data"
            )
        if self.write_database_table_name == "disruption_warning":
            raise ValueError(
                "Please do not delete from the disruption_warning database"
            )
        data_df = pd.read_sql_query(
            f"""select * from {self.write_database_table_name} where shot = """
            + f"""{shot_id} order by time""",
            self.engine,
        )
        if len(data_df) == 0:
            self.logger.info("Shot %s does not exist in database", shot_id)
            return False
        with self.conn.cursor() as curs:
            curs.execute(
                f"delete from {self.write_database_table_name} where shot = {shot_id}"
            )
        return True

    def add_column(self, col_name, var_type="TEXT"):
        """Add column to SQL table without filling in data for column."""
        if self.write_database_table_name is None:
            raise ValueError(
                "specify write_database_table_name in the configuration before "
                + "adding shot data"
            )
        self.query(
            f"alter table {self.write_database_table_name} add {col_name} {var_type};",
            use_pandas=False,
        )
        return True

    def remove_column(self, col_name):
        """Remove column from SQL table"""
        if self.write_database_table_name is None:
            raise ValueError(
                "specify write_database_table_name in the configuration before "
                + "adding shot data"
            )
        if col_name in self.protected_columns:
            self.logger.error("Failed to drop protected column %s", col_name)
            return False
        self.query(
            f"alter table {self.write_database_table_name} drop column {col_name};",
            use_pandas=False,
        )
        return True

    def get_shots_data(
        self,
        shotlist: List[int],
        cols: List[str] = None,
        sql_table="disruption_warning",
    ):
        """
        get_shots_data retrieves columns from sql data for given shotlist

        Parameters
        ----------
        shotlist : List[int]
            List of shot ids to get data for.
        cols : List[str]
            List of columns to retrieve. Default value is ["*"], meaning all columns.
        sql_table : str, optional
            The sql_table to retrieve data from. Default value is "disruption_warning".

        Returns
        -------
        pd.Dataframe
            Dataframe containing queried data
        """
        if cols is None:
            cols = ["*"]
        cols = ", ".join(str(col) for col in cols)
        shotlist = ",".join(str(shot) for shot in shotlist)
        query = f"select {cols} from {sql_table}"
        if shotlist is None:
            query += " order by time"
        else:
            query += f" where shot in ({shotlist}) order by shot, time"
        shot_df = pd.read_sql_query(query, self.engine)
        shot_df.columns = shot_df.columns.str.lower()
        return shot_df

    def get_disruption_time(self, shot_id):
        """
        Get disruption time for shot_id or None if there was no disruption.
        """
        with self.conn.cursor() as curs:
            curs.execute(f"select t_disrupt from disruptions where shot = {shot_id}")
            t_disrupt = curs.fetchall()
        if len(t_disrupt) == 0:
            return None
        t_disrupt = t_disrupt[0][0]
        return t_disrupt

    def get_disruption_shotlist(self):
        """
        Get Pandas DataFrame of all disruptive shots and times from the disruption
        table. Can be set as a cross-reference to determine whether a given shot
        is disruptive or not (all shots in this table are disruptive) and contain
        a t_disrupt.
        """
        return self.query("select distinct shot from disruptions order by shot")

    def get_disruption_warning_shotlist(self):
        """
        Get Pandas DataFrame of all shots in the disruption_warning table. NOTE:
        The disruption_warning table contains ONLY a subset of shots in this table
        """
        return self.query("select distinct shot from disruption_warning order by shot")

conn property ¤

conn

Property returning a connection to sql database.

If a connection exists for the given thread returns that connection, otherwise creates a new connection

RETURNS DESCRIPTION
_type_

Database connection

add_column ¤

add_column(col_name, var_type='TEXT')

Add column to SQL table without filling in data for column.

Source code in disruption_py/inout/sql.py
377
378
379
380
381
382
383
384
385
386
387
388
def add_column(self, col_name, var_type="TEXT"):
    """Add column to SQL table without filling in data for column."""
    if self.write_database_table_name is None:
        raise ValueError(
            "specify write_database_table_name in the configuration before "
            + "adding shot data"
        )
    self.query(
        f"alter table {self.write_database_table_name} add {col_name} {var_type};",
        use_pandas=False,
    )
    return True

add_shot_data ¤

add_shot_data(
    shot_id: int,
    shot_data: pd.DataFrame,
    update=False,
    override_columns: List[str] = None,
)

Upload shot to SQL database.

Either inserts or updates shot data depending on whether a shot already exists in database. If shot exists, then the timebase of the shot data must match the timebase of the shot in the database.

PARAMETER DESCRIPTION
shot_id

Shot id of the shot being modified

TYPE: int

shot_data

Dataframe containing shot data for update

TYPE: DataFrame

update

Whether to update shot data if the shot already exists in database. Update will happen regardless of whether the column being updated is all nil. Default value is False.

TYPE: bool DEFAULT: False

override_columns

List of protected columns that can still be updated. Update must be true for input values in the columns to be changed. Default value is [].

TYPE: List[str] DEFAULT: None

Source code in disruption_py/inout/sql.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def add_shot_data(
    self,
    shot_id: int,
    shot_data: pd.DataFrame,
    update=False,
    override_columns: List[str] = None,
):
    """
    Upload shot to SQL database.

    Either inserts or updates shot data depending on whether a shot already exists
    in database. If shot exists, then the timebase of the shot data must match
    the timebase of the shot in the database.

    Parameters
    ----------
    shot_id : int
        Shot id of the shot being modified
    shot_data : pd.DataFrame
        Dataframe containing shot data for update
    update : bool
        Whether to update shot data if the shot already exists in database.
        Update will happen regardless of whether the column being updated is
        all nil. Default value is False.
    override_columns : List[str]
        List of protected columns that can still be updated. Update must be
        true for input values in the columns to be changed. Default value is [].
    """
    if self.write_database_table_name is None:
        raise ValueError(
            "specify write_database_table_name in the configuration before "
            + "adding shot data"
        )
    curr_df = pd.read_sql_query(
        f"select * from {self.write_database_table_name} where shot={shot_id} "
        + "order by time",
        self.engine,
    )

    if len(curr_df) == 0:
        return self._insert_shot_data(
            curr_df=curr_df,
            shot_data=shot_data,
            table_name=self.write_database_table_name,
        )
    if (
        len(curr_df) == len(shot_data)
        and (
            (curr_df["time"] - shot_data["time"]).abs() < config().TIME_CONST
        ).all()
    ):
        return self._update_shot_data(
            shot_id=shot_id,
            curr_df=curr_df,
            shot_data=shot_data,
            update=update,
            table_name=self.write_database_table_name,
            override_columns=override_columns,
        )

    self.logger.error("Invalid timebase for data output")
    return False

from_config classmethod ¤

from_config(tokamak: Tokamak)

Initialize database from config file.

Source code in disruption_py/inout/sql.py
71
72
73
74
75
76
@classmethod
def from_config(cls, tokamak: Tokamak):
    """
    Initialize database from config file.
    """
    return cls._from_dict(config(tokamak).database)

get_disruption_shotlist ¤

get_disruption_shotlist()

Get Pandas DataFrame of all disruptive shots and times from the disruption table. Can be set as a cross-reference to determine whether a given shot is disruptive or not (all shots in this table are disruptive) and contain a t_disrupt.

Source code in disruption_py/inout/sql.py
454
455
456
457
458
459
460
461
def get_disruption_shotlist(self):
    """
    Get Pandas DataFrame of all disruptive shots and times from the disruption
    table. Can be set as a cross-reference to determine whether a given shot
    is disruptive or not (all shots in this table are disruptive) and contain
    a t_disrupt.
    """
    return self.query("select distinct shot from disruptions order by shot")

get_disruption_time ¤

get_disruption_time(shot_id)

Get disruption time for shot_id or None if there was no disruption.

Source code in disruption_py/inout/sql.py
442
443
444
445
446
447
448
449
450
451
452
def get_disruption_time(self, shot_id):
    """
    Get disruption time for shot_id or None if there was no disruption.
    """
    with self.conn.cursor() as curs:
        curs.execute(f"select t_disrupt from disruptions where shot = {shot_id}")
        t_disrupt = curs.fetchall()
    if len(t_disrupt) == 0:
        return None
    t_disrupt = t_disrupt[0][0]
    return t_disrupt

get_disruption_warning_shotlist ¤

get_disruption_warning_shotlist()

Get Pandas DataFrame of all shots in the disruption_warning table. NOTE: The disruption_warning table contains ONLY a subset of shots in this table

Source code in disruption_py/inout/sql.py
463
464
465
466
467
468
def get_disruption_warning_shotlist(self):
    """
    Get Pandas DataFrame of all shots in the disruption_warning table. NOTE:
    The disruption_warning table contains ONLY a subset of shots in this table
    """
    return self.query("select distinct shot from disruption_warning order by shot")

get_shots_data ¤

get_shots_data(
    shotlist: List[int],
    cols: List[str] = None,
    sql_table="disruption_warning",
)

get_shots_data retrieves columns from sql data for given shotlist

PARAMETER DESCRIPTION
shotlist

List of shot ids to get data for.

TYPE: List[int]

cols

List of columns to retrieve. Default value is ["*"], meaning all columns.

TYPE: List[str] DEFAULT: None

sql_table

The sql_table to retrieve data from. Default value is "disruption_warning".

TYPE: str DEFAULT: 'disruption_warning'

RETURNS DESCRIPTION
Dataframe

Dataframe containing queried data

Source code in disruption_py/inout/sql.py
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
def get_shots_data(
    self,
    shotlist: List[int],
    cols: List[str] = None,
    sql_table="disruption_warning",
):
    """
    get_shots_data retrieves columns from sql data for given shotlist

    Parameters
    ----------
    shotlist : List[int]
        List of shot ids to get data for.
    cols : List[str]
        List of columns to retrieve. Default value is ["*"], meaning all columns.
    sql_table : str, optional
        The sql_table to retrieve data from. Default value is "disruption_warning".

    Returns
    -------
    pd.Dataframe
        Dataframe containing queried data
    """
    if cols is None:
        cols = ["*"]
    cols = ", ".join(str(col) for col in cols)
    shotlist = ",".join(str(shot) for shot in shotlist)
    query = f"select {cols} from {sql_table}"
    if shotlist is None:
        query += " order by time"
    else:
        query += f" where shot in ({shotlist}) order by shot, time"
    shot_df = pd.read_sql_query(query, self.engine)
    shot_df.columns = shot_df.columns.str.lower()
    return shot_df

query ¤

query(query: str, use_pandas=True)

query sql database

PARAMETER DESCRIPTION
query

The query string

TYPE: str

use_pandas

Whether pd.read_sql_query should be used to run the query. Default value is true.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
Any

Result of query

Source code in disruption_py/inout/sql.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def query(self, query: str, use_pandas=True):
    """
    query sql database

    Parameters
    ----------
    query : str
        The query string
    use_pandas : bool, optional
        Whether pd.read_sql_query should be used to run the query. Default value
        is true.

    Returns
    -------
    Any
        Result of query
    """
    if "alter" in query.lower():
        if query.lower() in self.protected_columns:
            return 0
    elif use_pandas:
        return pd.read_sql_query(query, self.engine)
    curs = self.conn.cursor()
    output = None
    try:
        curs.execute(query)
        if "select" in query.lower():
            output = curs.fetchall()
    except pyodbc.DatabaseError as e:
        print(e)
        self.logger.debug(e)
        self.logger.error("Query failed, returning None")
    curs.close()
    return output

remove_column ¤

remove_column(col_name)

Remove column from SQL table

Source code in disruption_py/inout/sql.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
def remove_column(self, col_name):
    """Remove column from SQL table"""
    if self.write_database_table_name is None:
        raise ValueError(
            "specify write_database_table_name in the configuration before "
            + "adding shot data"
        )
    if col_name in self.protected_columns:
        self.logger.error("Failed to drop protected column %s", col_name)
        return False
    self.query(
        f"alter table {self.write_database_table_name} drop column {col_name};",
        use_pandas=False,
    )
    return True

remove_shot_data ¤

remove_shot_data(shot_id)

Remove shot from SQL table.

Source code in disruption_py/inout/sql.py
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
def remove_shot_data(self, shot_id):
    """Remove shot from SQL table."""
    if self.write_database_table_name is None:
        raise ValueError(
            "specify write_database_table_name in the configuration before "
            + "adding shot data"
        )
    if self.write_database_table_name == "disruption_warning":
        raise ValueError(
            "Please do not delete from the disruption_warning database"
        )
    data_df = pd.read_sql_query(
        f"""select * from {self.write_database_table_name} where shot = """
        + f"""{shot_id} order by time""",
        self.engine,
    )
    if len(data_df) == 0:
        self.logger.info("Shot %s does not exist in database", shot_id)
        return False
    with self.conn.cursor() as curs:
        curs.execute(
            f"delete from {self.write_database_table_name} where shot = {shot_id}"
        )
    return True