ctxutil.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. // Copyright 2016 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package sql
  5. import (
  6. "context"
  7. "database/sql/driver"
  8. "errors"
  9. )
  10. func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) {
  11. if ciCtx, is := ci.(driver.ConnPrepareContext); is {
  12. return ciCtx.PrepareContext(ctx, query)
  13. }
  14. si, err := ci.Prepare(query)
  15. if err == nil {
  16. select {
  17. default:
  18. case <-ctx.Done():
  19. si.Close()
  20. return nil, ctx.Err()
  21. }
  22. }
  23. return si, err
  24. }
  25. func ctxDriverExec(ctx context.Context, execerCtx driver.ExecerContext, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
  26. if execerCtx != nil {
  27. return execerCtx.ExecContext(ctx, query, nvdargs)
  28. }
  29. dargs, err := namedValueToValue(nvdargs)
  30. if err != nil {
  31. return nil, err
  32. }
  33. select {
  34. default:
  35. case <-ctx.Done():
  36. return nil, ctx.Err()
  37. }
  38. return execer.Exec(query, dargs)
  39. }
  40. func ctxDriverQuery(ctx context.Context, queryerCtx driver.QueryerContext, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
  41. if queryerCtx != nil {
  42. return queryerCtx.QueryContext(ctx, query, nvdargs)
  43. }
  44. dargs, err := namedValueToValue(nvdargs)
  45. if err != nil {
  46. return nil, err
  47. }
  48. select {
  49. default:
  50. case <-ctx.Done():
  51. return nil, ctx.Err()
  52. }
  53. return queryer.Query(query, dargs)
  54. }
  55. func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
  56. if siCtx, is := si.(driver.StmtExecContext); is {
  57. return siCtx.ExecContext(ctx, nvdargs)
  58. }
  59. dargs, err := namedValueToValue(nvdargs)
  60. if err != nil {
  61. return nil, err
  62. }
  63. select {
  64. default:
  65. case <-ctx.Done():
  66. return nil, ctx.Err()
  67. }
  68. return si.Exec(dargs)
  69. }
  70. func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
  71. if siCtx, is := si.(driver.StmtQueryContext); is {
  72. return siCtx.QueryContext(ctx, nvdargs)
  73. }
  74. dargs, err := namedValueToValue(nvdargs)
  75. if err != nil {
  76. return nil, err
  77. }
  78. select {
  79. default:
  80. case <-ctx.Done():
  81. return nil, ctx.Err()
  82. }
  83. return si.Query(dargs)
  84. }
  85. func ctxDriverBegin(ctx context.Context, opts *TxOptions, ci driver.Conn) (driver.Tx, error) {
  86. if ciCtx, is := ci.(driver.ConnBeginTx); is {
  87. dopts := driver.TxOptions{}
  88. if opts != nil {
  89. dopts.Isolation = driver.IsolationLevel(opts.Isolation)
  90. dopts.ReadOnly = opts.ReadOnly
  91. }
  92. return ciCtx.BeginTx(ctx, dopts)
  93. }
  94. if opts != nil {
  95. // Check the transaction level. If the transaction level is non-default
  96. // then return an error here as the BeginTx driver value is not supported.
  97. if opts.Isolation != LevelDefault {
  98. return nil, errors.New("sql: driver does not support non-default isolation level")
  99. }
  100. // If a read-only transaction is requested return an error as the
  101. // BeginTx driver value is not supported.
  102. if opts.ReadOnly {
  103. return nil, errors.New("sql: driver does not support read-only transactions")
  104. }
  105. }
  106. if ctx.Done() == nil {
  107. return ci.Begin()
  108. }
  109. txi, err := ci.Begin()
  110. if err == nil {
  111. select {
  112. default:
  113. case <-ctx.Done():
  114. txi.Rollback()
  115. return nil, ctx.Err()
  116. }
  117. }
  118. return txi, err
  119. }
  120. func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
  121. dargs := make([]driver.Value, len(named))
  122. for n, param := range named {
  123. if len(param.Name) > 0 {
  124. return nil, errors.New("sql: driver does not support the use of Named Parameters")
  125. }
  126. dargs[n] = param.Value
  127. }
  128. return dargs, nil
  129. }