key_auth.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. package middleware
  2. import (
  3. "errors"
  4. "net/http"
  5. "strings"
  6. "github.com/labstack/echo"
  7. )
  8. type (
  9. // KeyAuthConfig defines the config for KeyAuth middleware.
  10. KeyAuthConfig struct {
  11. // Skipper defines a function to skip middleware.
  12. Skipper Skipper
  13. // KeyLookup is a string in the form of "<source>:<name>" that is used
  14. // to extract key from the request.
  15. // Optional. Default value "header:Authorization".
  16. // Possible values:
  17. // - "header:<name>"
  18. // - "query:<name>"
  19. // - "form:<name>"
  20. KeyLookup string `yaml:"key_lookup"`
  21. // AuthScheme to be used in the Authorization header.
  22. // Optional. Default value "Bearer".
  23. AuthScheme string
  24. // Validator is a function to validate key.
  25. // Required.
  26. Validator KeyAuthValidator
  27. }
  28. // KeyAuthValidator defines a function to validate KeyAuth credentials.
  29. KeyAuthValidator func(string, echo.Context) (bool, error)
  30. keyExtractor func(echo.Context) (string, error)
  31. )
  32. var (
  33. // DefaultKeyAuthConfig is the default KeyAuth middleware config.
  34. DefaultKeyAuthConfig = KeyAuthConfig{
  35. Skipper: DefaultSkipper,
  36. KeyLookup: "header:" + echo.HeaderAuthorization,
  37. AuthScheme: "Bearer",
  38. }
  39. )
  40. // KeyAuth returns an KeyAuth middleware.
  41. //
  42. // For valid key it calls the next handler.
  43. // For invalid key, it sends "401 - Unauthorized" response.
  44. // For missing key, it sends "400 - Bad Request" response.
  45. func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc {
  46. c := DefaultKeyAuthConfig
  47. c.Validator = fn
  48. return KeyAuthWithConfig(c)
  49. }
  50. // KeyAuthWithConfig returns an KeyAuth middleware with config.
  51. // See `KeyAuth()`.
  52. func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
  53. // Defaults
  54. if config.Skipper == nil {
  55. config.Skipper = DefaultKeyAuthConfig.Skipper
  56. }
  57. // Defaults
  58. if config.AuthScheme == "" {
  59. config.AuthScheme = DefaultKeyAuthConfig.AuthScheme
  60. }
  61. if config.KeyLookup == "" {
  62. config.KeyLookup = DefaultKeyAuthConfig.KeyLookup
  63. }
  64. if config.Validator == nil {
  65. panic("echo: key-auth middleware requires a validator function")
  66. }
  67. // Initialize
  68. parts := strings.Split(config.KeyLookup, ":")
  69. extractor := keyFromHeader(parts[1], config.AuthScheme)
  70. switch parts[0] {
  71. case "query":
  72. extractor = keyFromQuery(parts[1])
  73. case "form":
  74. extractor = keyFromForm(parts[1])
  75. }
  76. return func(next echo.HandlerFunc) echo.HandlerFunc {
  77. return func(c echo.Context) error {
  78. if config.Skipper(c) {
  79. return next(c)
  80. }
  81. // Extract and verify key
  82. key, err := extractor(c)
  83. if err != nil {
  84. return echo.NewHTTPError(http.StatusBadRequest, err.Error())
  85. }
  86. valid, err := config.Validator(key, c)
  87. if err != nil {
  88. return err
  89. } else if valid {
  90. return next(c)
  91. }
  92. return echo.ErrUnauthorized
  93. }
  94. }
  95. }
  96. // keyFromHeader returns a `keyExtractor` that extracts key from the request header.
  97. func keyFromHeader(header string, authScheme string) keyExtractor {
  98. return func(c echo.Context) (string, error) {
  99. auth := c.Request().Header.Get(header)
  100. if auth == "" {
  101. return "", errors.New("missing key in request header")
  102. }
  103. if header == echo.HeaderAuthorization {
  104. l := len(authScheme)
  105. if len(auth) > l+1 && auth[:l] == authScheme {
  106. return auth[l+1:], nil
  107. }
  108. return "", errors.New("invalid key in the request header")
  109. }
  110. return auth, nil
  111. }
  112. }
  113. // keyFromQuery returns a `keyExtractor` that extracts key from the query string.
  114. func keyFromQuery(param string) keyExtractor {
  115. return func(c echo.Context) (string, error) {
  116. key := c.QueryParam(param)
  117. if key == "" {
  118. return "", errors.New("missing key in the query string")
  119. }
  120. return key, nil
  121. }
  122. }
  123. // keyFromForm returns a `keyExtractor` that extracts key from the form.
  124. func keyFromForm(param string) keyExtractor {
  125. return func(c echo.Context) (string, error) {
  126. key := c.FormValue(param)
  127. if key == "" {
  128. return "", errors.New("missing key in the form")
  129. }
  130. return key, nil
  131. }
  132. }