Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion google/cloud/sqlalchemy_spanner/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def isolation_level(self):

@property
def sequences(self):
return exclusions.closed()
return exclusions.open()

@property
def temporary_tables(self):
Expand Down
102 changes: 101 additions & 1 deletion google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from sqlalchemy.sql.default_comparator import operator_lookup
from sqlalchemy.sql.operators import json_getitem_op
from sqlalchemy.sql import expression

from google.cloud.spanner_v1.data_types import JsonObject
from google.cloud import spanner_dbapi
Expand Down Expand Up @@ -173,6 +174,16 @@ def pre_exec(self):
if priority is not None:
self._dbapi_connection.connection.request_priority = priority

def fire_sequence(self, seq, type_):
"""Builds a statement for fetching next value of the sequence."""
return self._execute_scalar(
(
"SELECT GET_NEXT_SEQUENCE_VALUE(SEQUENCE %s)"
% self.identifier_preparer.format_sequence(seq)
),
type_,
)


class SpannerIdentifierPreparer(IdentifierPreparer):
"""Identifiers compiler.
Expand Down Expand Up @@ -343,6 +354,20 @@ def limit_clause(self, select, **kw):
text += " OFFSET " + self.process(select._offset_clause, **kw)
return text

def returning_clause(self, stmt, returning_cols, **kw):
columns = [
self._label_select_column(None, c, True, False, {})
for c in expression._select_iterables(returning_cols)
]

return "THEN RETURN " + ", ".join(columns)

def visit_sequence(self, seq, **kw):
"""Builds a statement for fetching next value of the sequence."""
return " GET_NEXT_SEQUENCE_VALUE(SEQUENCE %s)" % self.preparer.format_sequence(
seq
)


class SpannerDDLCompiler(DDLCompiler):
"""Spanner DDL statements compiler."""
Expand Down Expand Up @@ -457,6 +482,24 @@ def post_create_table(self, table):

return post_cmds

def get_identity_options(self, identity_options):
text = ["sequence_kind = 'bit_reversed_positive'"]
if identity_options.start is not None:
text.append("start_with_counter = %d" % identity_options.start)
return ", ".join(text)

def visit_create_sequence(self, create, prefix=None, **kw):
"""Builds a ``CREATE SEQUENCE`` statement for the sequence."""
text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
options = self.get_identity_options(create.element)
if options:
text += " OPTIONS (" + options + ")"
return text

def visit_drop_sequence(self, drop, **kw):
"""Builds a ``DROP SEQUENCE`` statement for the sequence."""
return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)


class SpannerTypeCompiler(GenericTypeCompiler):
"""Spanner types compiler.
Expand Down Expand Up @@ -531,7 +574,8 @@ class SpannerDialect(DefaultDialect):
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
supports_default_values = False
supports_sequences = False
supports_sequences = True
sequences_optional = False
supports_native_enum = True
supports_native_boolean = True
supports_native_decimal = True
Expand Down Expand Up @@ -694,6 +738,36 @@ def get_view_names(self, connection, schema=None, **kw):

return all_views

@engine_to_connection
def get_sequence_names(self, connection, schema=None, **kw):
"""
Return a list of all sequence names available in the database.

The method is used by SQLAlchemy introspection systems.

Args:
connection (sqlalchemy.engine.base.Connection):
SQLAlchemy connection or engine object.
schema (str): Optional. Schema name

Returns:
list: List of sequence names.
"""
sql = """
SELECT name
FROM information_schema.sequences
WHERE SCHEMA='{}'
""".format(
schema or ""
)
all_sequences = []
with connection.connection.database.snapshot() as snap:
rows = list(snap.execute_sql(sql))
for seq in rows:
all_sequences.append(seq[0])

return all_sequences

@engine_to_connection
def get_view_definition(self, connection, view_name, schema=None, **kw):
"""
Expand Down Expand Up @@ -1294,6 +1368,32 @@ def has_table(self, connection, table_name, schema=None, **kw):

return False

@engine_to_connection
def has_sequence(self, connection, sequence_name, schema=None, **kw):
"""Check the existence of a particular sequence in the database.

Given a :class:`_engine.Connection` object and a string
`sequence_name`, return True if the given sequence exists in
the database, False otherwise.
"""

with connection.connection.database.snapshot() as snap:
rows = snap.execute_sql(
"""
SELECT true
FROM INFORMATION_SCHEMA.SEQUENCES
WHERE NAME="{sequence_name}"
LIMIT 1
""".format(
sequence_name=sequence_name
)
)

for _ in rows:
return True

return False

def set_isolation_level(self, conn_proxy, level):
"""Set the connection isolation level.

Expand Down
127 changes: 126 additions & 1 deletion test/test_suite_13.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import is_instance_of
from sqlalchemy.testing import provide_metadata, emits_warning
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_true
Expand Down Expand Up @@ -73,7 +74,10 @@
from sqlalchemy.testing.suite.test_reflection import * # noqa: F401, F403
from sqlalchemy.testing.suite.test_results import * # noqa: F401, F403
from sqlalchemy.testing.suite.test_select import * # noqa: F401, F403
from sqlalchemy.testing.suite.test_sequence import * # noqa: F401, F403
from sqlalchemy.testing.suite.test_sequence import (
SequenceTest as _SequenceTest,
HasSequenceTest as _HasSequenceTest,
) # noqa: F401, F403
from sqlalchemy.testing.suite.test_update_delete import * # noqa: F401, F403

from sqlalchemy.testing.suite.test_cte import CTETest as _CTETest
Expand Down Expand Up @@ -2059,3 +2063,124 @@ def test_create_engine_wo_database(self):
engine = create_engine(get_db_url().split("/database")[0])
with engine.connect() as connection:
assert connection.connection.database is None


@pytest.mark.skipif(
bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator"
)
class SequenceTest(_SequenceTest):
@classmethod
def define_tables(cls, metadata):
Table(
"seq_pk",
metadata,
Column(
"id",
Integer,
sqlalchemy.Sequence("tab_id_seq"),
primary_key=True,
),
Column("data", String(50)),
)

Table(
"seq_opt_pk",
metadata,
Column(
"id",
Integer,
sqlalchemy.Sequence("tab_id_seq_opt", data_type=Integer, optional=True),
primary_key=True,
),
Column("data", String(50)),
)

Table(
"seq_no_returning",
metadata,
Column(
"id",
Integer,
sqlalchemy.Sequence("noret_id_seq"),
primary_key=True,
),
Column("data", String(50)),
implicit_returning=False,
)

def test_insert_lastrowid(self, connection):
r = connection.execute(self.tables.seq_pk.insert(), dict(data="some data"))
assert len(r.inserted_primary_key) == 1
is_instance_of(r.inserted_primary_key[0], int)

def test_nextval_direct(self, connection):
r = connection.execute(self.tables.seq_pk.c.id.default)
is_instance_of(r, int)

def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
id, name = row
is_instance_of(id, int)
eq_(name, "some data")

@testing.combinations((True,), (False,), argnames="implicit_returning")
@testing.requires.schemas
@pytest.mark.skip("Not supported by Cloud Spanner")
def test_insert_roundtrip_translate(self, connection, implicit_returning):
pass

@testing.requires.schemas
@pytest.mark.skip("Not supported by Cloud Spanner")
def test_nextval_direct_schema_translate(self, connection):
pass


@pytest.mark.skipif(
bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator"
)
class HasSequenceTest(_HasSequenceTest):
@classmethod
def define_tables(cls, metadata):
sqlalchemy.Sequence("user_id_seq", metadata=metadata)
sqlalchemy.Sequence(
"other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True
)
Table(
"user_id_table",
metadata,
Column("id", Integer, primary_key=True),
)

@pytest.mark.skip("Not supported by Cloud Spanner")
def test_has_sequence_cache(self, connection, metadata):
pass

@testing.requires.schemas
@pytest.mark.skip("Not supported by Cloud Spanner")
def test_has_sequence_schema(self, connection):
pass

@testing.requires.schemas
@pytest.mark.skip("Not supported by Cloud Spanner")
def test_has_sequence_schemas_neg(self, connection):
pass

@testing.requires.schemas
@pytest.mark.skip("Not supported by Cloud Spanner")
def test_has_sequence_default_not_in_remote(self, connection):
pass

@testing.requires.schemas
@pytest.mark.skip("Not supported by Cloud Spanner")
def test_has_sequence_remote_not_in_default(self, connection):
pass

@testing.requires.schemas
@pytest.mark.skip("Not supported by Cloud Spanner")
def test_get_sequence_names_no_sequence_schema(self, connection):
pass

@testing.requires.schemas
@pytest.mark.skip("Not supported by Cloud Spanner")
def test_get_sequence_names_sequences_schema(self, connection):
pass
Loading