Skip to content

Commit 2af7220

Browse files
authored
Add table-level constraint parsing for CREATE TABLE (#49)
1 parent cc4d039 commit 2af7220

File tree

3 files changed

+301
-13
lines changed

3 files changed

+301
-13
lines changed

parser/marshal.go

Lines changed: 299 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,14 +2476,38 @@ func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error)
24762476

24772477
stmt.Definition = &ast.TableDefinition{}
24782478

2479-
// Parse column definitions
2479+
// Parse column definitions and table constraints
24802480
for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF {
2481-
colDef, err := p.parseColumnDefinition()
2482-
if err != nil {
2483-
p.skipToEndOfStatement()
2484-
return stmt, nil
2481+
upperLit := strings.ToUpper(p.curTok.Literal)
2482+
2483+
// Check for table-level constraints
2484+
if upperLit == "CONSTRAINT" {
2485+
constraint, err := p.parseNamedTableConstraint()
2486+
if err != nil {
2487+
p.skipToEndOfStatement()
2488+
return stmt, nil
2489+
}
2490+
if constraint != nil {
2491+
stmt.Definition.TableConstraints = append(stmt.Definition.TableConstraints, constraint)
2492+
}
2493+
} else if upperLit == "PRIMARY" || upperLit == "UNIQUE" || upperLit == "FOREIGN" || upperLit == "CHECK" {
2494+
constraint, err := p.parseUnnamedTableConstraint()
2495+
if err != nil {
2496+
p.skipToEndOfStatement()
2497+
return stmt, nil
2498+
}
2499+
if constraint != nil {
2500+
stmt.Definition.TableConstraints = append(stmt.Definition.TableConstraints, constraint)
2501+
}
2502+
} else {
2503+
// Parse column definition
2504+
colDef, err := p.parseColumnDefinition()
2505+
if err != nil {
2506+
p.skipToEndOfStatement()
2507+
return stmt, nil
2508+
}
2509+
stmt.Definition.ColumnDefinitions = append(stmt.Definition.ColumnDefinitions, colDef)
24852510
}
2486-
stmt.Definition.ColumnDefinitions = append(stmt.Definition.ColumnDefinitions, colDef)
24872511

24882512
if p.curTok.Type == TokenComma {
24892513
p.nextToken()
@@ -2685,6 +2709,265 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) {
26852709
return col, nil
26862710
}
26872711

2712+
// parseNamedTableConstraint parses a CONSTRAINT name ... table constraint
2713+
func (p *Parser) parseNamedTableConstraint() (ast.TableConstraint, error) {
2714+
// Consume CONSTRAINT
2715+
p.nextToken()
2716+
2717+
// Parse constraint name
2718+
constraintName := p.parseIdentifier()
2719+
2720+
// Now parse the actual constraint type
2721+
upperLit := strings.ToUpper(p.curTok.Literal)
2722+
2723+
if upperLit == "PRIMARY" {
2724+
constraint, err := p.parsePrimaryKeyConstraint()
2725+
if err != nil {
2726+
return nil, err
2727+
}
2728+
constraint.ConstraintIdentifier = constraintName
2729+
return constraint, nil
2730+
} else if upperLit == "UNIQUE" {
2731+
constraint, err := p.parseUniqueConstraint()
2732+
if err != nil {
2733+
return nil, err
2734+
}
2735+
constraint.ConstraintIdentifier = constraintName
2736+
return constraint, nil
2737+
} else if upperLit == "FOREIGN" {
2738+
constraint, err := p.parseForeignKeyConstraint()
2739+
if err != nil {
2740+
return nil, err
2741+
}
2742+
constraint.ConstraintIdentifier = constraintName
2743+
return constraint, nil
2744+
} else if upperLit == "CHECK" {
2745+
constraint, err := p.parseCheckConstraint()
2746+
if err != nil {
2747+
return nil, err
2748+
}
2749+
constraint.ConstraintIdentifier = constraintName
2750+
return constraint, nil
2751+
}
2752+
2753+
return nil, nil
2754+
}
2755+
2756+
// parseUnnamedTableConstraint parses an unnamed table constraint (PRIMARY KEY, UNIQUE, FOREIGN KEY, CHECK)
2757+
func (p *Parser) parseUnnamedTableConstraint() (ast.TableConstraint, error) {
2758+
upperLit := strings.ToUpper(p.curTok.Literal)
2759+
2760+
if upperLit == "PRIMARY" {
2761+
return p.parsePrimaryKeyConstraint()
2762+
} else if upperLit == "UNIQUE" {
2763+
return p.parseUniqueConstraint()
2764+
} else if upperLit == "FOREIGN" {
2765+
return p.parseForeignKeyConstraint()
2766+
} else if upperLit == "CHECK" {
2767+
return p.parseCheckConstraint()
2768+
}
2769+
2770+
return nil, nil
2771+
}
2772+
2773+
// parsePrimaryKeyConstraint parses PRIMARY KEY CLUSTERED/NONCLUSTERED (columns)
2774+
func (p *Parser) parsePrimaryKeyConstraint() (*ast.UniqueConstraintDefinition, error) {
2775+
// Consume PRIMARY
2776+
p.nextToken()
2777+
if p.curTok.Type == TokenKey {
2778+
p.nextToken() // consume KEY
2779+
}
2780+
2781+
constraint := &ast.UniqueConstraintDefinition{
2782+
IsPrimaryKey: true,
2783+
}
2784+
2785+
// Parse optional CLUSTERED/NONCLUSTERED
2786+
if strings.ToUpper(p.curTok.Literal) == "CLUSTERED" {
2787+
constraint.Clustered = true
2788+
constraint.IndexType = &ast.IndexType{IndexTypeKind: "Clustered"}
2789+
p.nextToken()
2790+
} else if strings.ToUpper(p.curTok.Literal) == "NONCLUSTERED" {
2791+
constraint.Clustered = false
2792+
constraint.IndexType = &ast.IndexType{IndexTypeKind: "NonClustered"}
2793+
p.nextToken()
2794+
}
2795+
2796+
// Parse column list
2797+
if p.curTok.Type == TokenLParen {
2798+
p.nextToken() // consume (
2799+
for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF {
2800+
col := p.parseColumnWithSortOrder()
2801+
constraint.Columns = append(constraint.Columns, col)
2802+
2803+
if p.curTok.Type == TokenComma {
2804+
p.nextToken()
2805+
} else {
2806+
break
2807+
}
2808+
}
2809+
if p.curTok.Type == TokenRParen {
2810+
p.nextToken() // consume )
2811+
}
2812+
}
2813+
2814+
return constraint, nil
2815+
}
2816+
2817+
// parseUniqueConstraint parses UNIQUE CLUSTERED/NONCLUSTERED (columns)
2818+
func (p *Parser) parseUniqueConstraint() (*ast.UniqueConstraintDefinition, error) {
2819+
// Consume UNIQUE
2820+
p.nextToken()
2821+
2822+
constraint := &ast.UniqueConstraintDefinition{
2823+
IsPrimaryKey: false,
2824+
}
2825+
2826+
// Parse optional CLUSTERED/NONCLUSTERED
2827+
if strings.ToUpper(p.curTok.Literal) == "CLUSTERED" {
2828+
constraint.Clustered = true
2829+
constraint.IndexType = &ast.IndexType{IndexTypeKind: "Clustered"}
2830+
p.nextToken()
2831+
} else if strings.ToUpper(p.curTok.Literal) == "NONCLUSTERED" {
2832+
constraint.Clustered = false
2833+
constraint.IndexType = &ast.IndexType{IndexTypeKind: "NonClustered"}
2834+
p.nextToken()
2835+
}
2836+
2837+
// Parse column list
2838+
if p.curTok.Type == TokenLParen {
2839+
p.nextToken() // consume (
2840+
for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF {
2841+
col := p.parseColumnWithSortOrder()
2842+
constraint.Columns = append(constraint.Columns, col)
2843+
2844+
if p.curTok.Type == TokenComma {
2845+
p.nextToken()
2846+
} else {
2847+
break
2848+
}
2849+
}
2850+
if p.curTok.Type == TokenRParen {
2851+
p.nextToken() // consume )
2852+
}
2853+
}
2854+
2855+
return constraint, nil
2856+
}
2857+
2858+
// parseForeignKeyConstraint parses FOREIGN KEY (columns) REFERENCES table (columns)
2859+
func (p *Parser) parseForeignKeyConstraint() (*ast.ForeignKeyConstraintDefinition, error) {
2860+
// Consume FOREIGN
2861+
p.nextToken()
2862+
if p.curTok.Type == TokenKey {
2863+
p.nextToken() // consume KEY
2864+
}
2865+
2866+
constraint := &ast.ForeignKeyConstraintDefinition{}
2867+
2868+
// Parse column list
2869+
if p.curTok.Type == TokenLParen {
2870+
p.nextToken() // consume (
2871+
for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF {
2872+
ident := p.parseIdentifier()
2873+
constraint.Columns = append(constraint.Columns, ident)
2874+
2875+
if p.curTok.Type == TokenComma {
2876+
p.nextToken()
2877+
} else {
2878+
break
2879+
}
2880+
}
2881+
if p.curTok.Type == TokenRParen {
2882+
p.nextToken() // consume )
2883+
}
2884+
}
2885+
2886+
// Parse REFERENCES
2887+
if strings.ToUpper(p.curTok.Literal) == "REFERENCES" {
2888+
p.nextToken() // consume REFERENCES
2889+
2890+
// Parse reference table name
2891+
refTable, err := p.parseSchemaObjectName()
2892+
if err != nil {
2893+
return nil, err
2894+
}
2895+
constraint.ReferenceTableName = refTable
2896+
2897+
// Parse referenced column list
2898+
if p.curTok.Type == TokenLParen {
2899+
p.nextToken() // consume (
2900+
for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF {
2901+
ident := p.parseIdentifier()
2902+
constraint.ReferencedColumns = append(constraint.ReferencedColumns, ident)
2903+
2904+
if p.curTok.Type == TokenComma {
2905+
p.nextToken()
2906+
} else {
2907+
break
2908+
}
2909+
}
2910+
if p.curTok.Type == TokenRParen {
2911+
p.nextToken() // consume )
2912+
}
2913+
}
2914+
}
2915+
2916+
return constraint, nil
2917+
}
2918+
2919+
// parseCheckConstraint parses CHECK (expression)
2920+
func (p *Parser) parseCheckConstraint() (*ast.CheckConstraintDefinition, error) {
2921+
// Consume CHECK
2922+
p.nextToken()
2923+
2924+
constraint := &ast.CheckConstraintDefinition{}
2925+
2926+
// Parse condition
2927+
if p.curTok.Type == TokenLParen {
2928+
p.nextToken() // consume (
2929+
cond, err := p.parseBooleanExpression()
2930+
if err != nil {
2931+
return nil, err
2932+
}
2933+
constraint.CheckCondition = cond
2934+
if p.curTok.Type == TokenRParen {
2935+
p.nextToken() // consume )
2936+
}
2937+
}
2938+
2939+
return constraint, nil
2940+
}
2941+
2942+
// parseColumnWithSortOrder parses a column name with optional ASC/DESC sort order
2943+
func (p *Parser) parseColumnWithSortOrder() *ast.ColumnWithSortOrder {
2944+
col := &ast.ColumnWithSortOrder{
2945+
SortOrder: ast.SortOrderNotSpecified,
2946+
}
2947+
2948+
// Parse column name
2949+
ident := p.parseIdentifier()
2950+
col.Column = &ast.ColumnReferenceExpression{
2951+
ColumnType: "Regular",
2952+
MultiPartIdentifier: &ast.MultiPartIdentifier{
2953+
Count: 1,
2954+
Identifiers: []*ast.Identifier{ident},
2955+
},
2956+
}
2957+
2958+
// Parse optional ASC/DESC
2959+
upperLit := strings.ToUpper(p.curTok.Literal)
2960+
if upperLit == "ASC" {
2961+
col.SortOrder = ast.SortOrderAscending
2962+
p.nextToken()
2963+
} else if upperLit == "DESC" {
2964+
col.SortOrder = ast.SortOrderDescending
2965+
p.nextToken()
2966+
}
2967+
2968+
return col
2969+
}
2970+
26882971
func (p *Parser) parseGrantStatement() (*ast.GrantStatement, error) {
26892972
// Consume GRANT
26902973
p.nextToken()
@@ -3427,14 +3710,19 @@ func foreignKeyConstraintToJSON(c *ast.ForeignKeyConstraintDefinition) jsonNode
34273710
for i, col := range c.ReferencedColumns {
34283711
cols[i] = identifierToJSON(col)
34293712
}
3430-
node["ReferencedColumns"] = cols
3713+
node["ReferencedTableColumns"] = cols
34313714
}
3432-
if c.DeleteAction != "" {
3433-
node["DeleteAction"] = c.DeleteAction
3715+
// Always include DeleteAction and UpdateAction with default value
3716+
deleteAction := c.DeleteAction
3717+
if deleteAction == "" {
3718+
deleteAction = "NotSpecified"
34343719
}
3435-
if c.UpdateAction != "" {
3436-
node["UpdateAction"] = c.UpdateAction
3720+
node["DeleteAction"] = deleteAction
3721+
updateAction := c.UpdateAction
3722+
if updateAction == "" {
3723+
updateAction = "NotSpecified"
34373724
}
3725+
node["UpdateAction"] = updateAction
34383726
return node
34393727
}
34403728

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"todo": true}
1+
{}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"todo": true}
1+
{}

0 commit comments

Comments
 (0)