Skip to content
113 changes: 103 additions & 10 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Connection(metaclass=ConnectionMeta):
'_intro_query', '_reset_query', '_proxy',
'_stmt_exclusive_section', '_config', '_params', '_addr',
'_log_listeners', '_termination_listeners', '_cancellations',
'_source_traceback', '__weakref__')
'_source_traceback', '_query_loggers', '__weakref__')

def __init__(self, protocol, transport, loop,
addr,
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(self, protocol, transport, loop,
self._log_listeners = set()
self._cancellations = set()
self._termination_listeners = set()
self._query_loggers = set()

settings = self._protocol.get_settings()
ver_string = settings.server_version
Expand Down Expand Up @@ -221,6 +222,30 @@ def remove_termination_listener(self, callback):
"""
self._termination_listeners.discard(_Callback.from_callable(callback))

def add_query_logger(self, callback):
"""Add a logger that will be called when queries are executed.

:param callable callback:
A callable or a coroutine function receiving two arguments:
**connection**: a Connection the callback is registered with.
**query**: a LoggedQuery containing the query, args, timeout, and
elapsed.

.. versionadded:: 0.29.0
"""
self._query_loggers.add(_Callback.from_callable(callback))

def remove_query_logger(self, callback):
"""Remove a query logger callback.

:param callable callback:
The callable or coroutine function that was passed to
:meth:`Connection.add_query_logger`.

.. versionadded:: 0.29.0
"""
self._query_loggers.discard(_Callback.from_callable(callback))

def get_server_pid(self):
"""Return the PID of the Postgres server the connection is bound to."""
return self._protocol.get_server_pid()
Expand Down Expand Up @@ -314,7 +339,10 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
self._check_open()

if not args:
return await self._protocol.query(query, timeout)
with utils.timer() as t:
result = await self._protocol.query(query, timeout)
self._log_query(query, args, timeout, t.elapsed)
return result

_, status, _ = await self._execute(
query,
Expand Down Expand Up @@ -1667,6 +1695,45 @@ async def _execute(
)
return result

def logger(self, callback):
"""Context manager that adds `callback` to the list of query loggers,
and removes it upon exit.

:param callable callback:
A callable or a coroutine function receiving two arguments:
**connection**: a Connection the callback is registered with.
**query**: a LoggedQuery containing the query, args, timeout, and
elapsed.

Example:

.. code-block:: pycon

>>> class QuerySaver:
def __init__(self):
self.queries = []
def __call__(self, conn, record):
self.queries.append(record.query)
>>> with con.logger(QuerySaver()) as log:
>>> await con.execute("SELECT 1")
>>> print(log.queries)
['SELECT 1']

.. versionadded:: 0.29.0
"""
return _LoggingContext(self, callback)

def _log_query(self, query, args, timeout, elapsed):
if not self._query_loggers:
return
con_ref = self._unwrap()
record = LoggedQuery(query, args, timeout, elapsed)
for cb in self._query_loggers:
if cb.is_async:
self._loop.create_task(cb.cb(con_ref, record))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing a con_ref is probably unnecessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I likely wouldn't use it, so happy to remove it, but I put it there so you could potentially log queries by host.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A concern here is potentially retaining references to free-d connections. Other callbacks take it, of course, but that's an API decision I've come to regret. Perhaps we can pass connection's addr and params instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me!

else:
self._loop.call_soon(cb.cb, con_ref, record)

async def __execute(
self,
query,
Expand All @@ -1681,20 +1748,25 @@ async def __execute(
executor = lambda stmt, timeout: self._protocol.bind_execute(
stmt, args, '', limit, return_status, timeout)
timeout = self._protocol._get_timeout(timeout)
return await self._do_execute(
query,
executor,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
with utils.timer() as t:
result, stmt = await self._do_execute(
query,
executor,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
self._log_query(query, args, timeout, t.elapsed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could go further and pull self._log_query into the context manager and that way you'll be able to log query errors too, e.g.

Suggested change
with utils.timer() as t:
result, stmt = await self._do_execute(
query,
executor,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
self._log_query(query, args, timeout, t.elapsed)
with self._time_and_log(query, args, timeout):
result, stmt = await self._do_execute(
query,
executor,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea, I feel silly for not considering error logging as part of this.

return result, stmt

async def _executemany(self, query, args, timeout):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
stmt, args, '', timeout)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
result, _ = await self._do_execute(query, executor, timeout)
with utils.timer() as t:
result, _ = await self._do_execute(query, executor, timeout)
self._log_query(query, args, timeout, t.elapsed)
return result

async def _do_execute(
Expand Down Expand Up @@ -2327,6 +2399,27 @@ class _ConnectionProxy:
__slots__ = ()


LoggedQuery = collections.namedtuple(
'LoggedQuery',
['query', 'args', 'timeout', 'elapsed'])
LoggedQuery.__doc__ = 'Log record of an executed query.'


class _LoggingContext:
__slots__ = ('_conn', '_cb')

def __init__(self, conn, callback):
self._conn = conn
self._cb = callback

def __enter__(self):
self._conn.add_query_logger(self._cb)
return self._cb

def __exit__(self, *exc_info):
self._conn.remove_query_logger(self._cb)


ServerCapabilities = collections.namedtuple(
'ServerCapabilities',
['advisory_locks', 'notifications', 'plpgsql', 'sql_reset',
Expand Down
26 changes: 26 additions & 0 deletions asyncpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


import re
import time


def _quote_ident(ident):
Expand Down Expand Up @@ -43,3 +44,28 @@ async def _mogrify(conn, query, args):
# Finally, replace $n references with text values.
return re.sub(
r'\$(\d+)\b', lambda m: textified[int(m.group(1)) - 1], query)


class timer:
__slots__ = ('start', 'elapsed')

def __init__(self):
self.start = time.monotonic()
self.elapsed = None

@property
def current(self):
return time.monotonic() - self.start

def restart(self):
self.start = time.monotonic()

def stop(self):
self.elapsed = self.current

def __enter__(self):
self.restart()
return self

def __exit__(self, *exc):
self.stop()
30 changes: 30 additions & 0 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import asyncio

from asyncpg import _testbase as tb


class TestQueryLogging(tb.ConnectedTestCase):

async def test_logging_context(self):
queries = asyncio.Queue()

def query_saver(conn, record):
queries.put_nowait(record)

class QuerySaver:
def __init__(self):
self.queries = []

def __call__(self, conn, record):
self.queries.append(record.query)

with self.con.logger(query_saver):
self.assertEqual(len(self.con._query_loggers), 1)
with self.con.logger(QuerySaver()) as log:
self.assertEqual(len(self.con._query_loggers), 2)
await self.con.execute("SELECT 1")

record = await queries.get()
self.assertEqual(record.query, "SELECT 1")
self.assertEqual(log.queries, ["SELECT 1"])
self.assertEqual(len(self.con._query_loggers), 0)