jwt.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. package middleware
  2. import (
  3. "fmt"
  4. "net/http"
  5. "reflect"
  6. "strings"
  7. "github.com/dgrijalva/jwt-go"
  8. "github.com/labstack/echo"
  9. )
  10. type (
  11. // JWTConfig defines the config for JWT middleware.
  12. JWTConfig struct {
  13. // Skipper defines a function to skip middleware.
  14. Skipper Skipper
  15. // BeforeFunc defines a function which is executed just before the middleware.
  16. BeforeFunc BeforeFunc
  17. // SuccessHandler defines a function which is executed for a valid token.
  18. SuccessHandler JWTSuccessHandler
  19. // ErrorHandler defines a function which is executed for an invalid token.
  20. // It may be used to define a custom JWT error.
  21. ErrorHandler JWTErrorHandler
  22. // Signing key to validate token.
  23. // Required.
  24. SigningKey interface{}
  25. // Signing method, used to check token signing method.
  26. // Optional. Default value HS256.
  27. SigningMethod string
  28. // Context key to store user information from the token into context.
  29. // Optional. Default value "user".
  30. ContextKey string
  31. // Claims are extendable claims data defining token content.
  32. // Optional. Default value jwt.MapClaims
  33. Claims jwt.Claims
  34. // TokenLookup is a string in the form of "<source>:<name>" that is used
  35. // to extract token from the request.
  36. // Optional. Default value "header:Authorization".
  37. // Possible values:
  38. // - "header:<name>"
  39. // - "query:<name>"
  40. // - "cookie:<name>"
  41. TokenLookup string
  42. // AuthScheme to be used in the Authorization header.
  43. // Optional. Default value "Bearer".
  44. AuthScheme string
  45. keyFunc jwt.Keyfunc
  46. }
  47. // JWTSuccessHandler defines a function which is executed for a valid token.
  48. JWTSuccessHandler func(echo.Context)
  49. // JWTErrorHandler defines a function which is executed for an invalid token.
  50. JWTErrorHandler func(error) error
  51. jwtExtractor func(echo.Context) (string, error)
  52. )
  53. // Algorithms
  54. const (
  55. AlgorithmHS256 = "HS256"
  56. )
  57. // Errors
  58. var (
  59. ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt")
  60. )
  61. var (
  62. // DefaultJWTConfig is the default JWT auth middleware config.
  63. DefaultJWTConfig = JWTConfig{
  64. Skipper: DefaultSkipper,
  65. SigningMethod: AlgorithmHS256,
  66. ContextKey: "user",
  67. TokenLookup: "header:" + echo.HeaderAuthorization,
  68. AuthScheme: "Bearer",
  69. Claims: jwt.MapClaims{},
  70. }
  71. )
  72. // JWT returns a JSON Web Token (JWT) auth middleware.
  73. //
  74. // For valid token, it sets the user in context and calls next handler.
  75. // For invalid token, it returns "401 - Unauthorized" error.
  76. // For missing token, it returns "400 - Bad Request" error.
  77. //
  78. // See: https://jwt.io/introduction
  79. // See `JWTConfig.TokenLookup`
  80. func JWT(key interface{}) echo.MiddlewareFunc {
  81. c := DefaultJWTConfig
  82. c.SigningKey = key
  83. return JWTWithConfig(c)
  84. }
  85. // JWTWithConfig returns a JWT auth middleware with config.
  86. // See: `JWT()`.
  87. func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
  88. // Defaults
  89. if config.Skipper == nil {
  90. config.Skipper = DefaultJWTConfig.Skipper
  91. }
  92. if config.SigningKey == nil {
  93. panic("echo: jwt middleware requires signing key")
  94. }
  95. if config.SigningMethod == "" {
  96. config.SigningMethod = DefaultJWTConfig.SigningMethod
  97. }
  98. if config.ContextKey == "" {
  99. config.ContextKey = DefaultJWTConfig.ContextKey
  100. }
  101. if config.Claims == nil {
  102. config.Claims = DefaultJWTConfig.Claims
  103. }
  104. if config.TokenLookup == "" {
  105. config.TokenLookup = DefaultJWTConfig.TokenLookup
  106. }
  107. if config.AuthScheme == "" {
  108. config.AuthScheme = DefaultJWTConfig.AuthScheme
  109. }
  110. config.keyFunc = func(t *jwt.Token) (interface{}, error) {
  111. // Check the signing method
  112. if t.Method.Alg() != config.SigningMethod {
  113. return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
  114. }
  115. return config.SigningKey, nil
  116. }
  117. // Initialize
  118. parts := strings.Split(config.TokenLookup, ":")
  119. extractor := jwtFromHeader(parts[1], config.AuthScheme)
  120. switch parts[0] {
  121. case "query":
  122. extractor = jwtFromQuery(parts[1])
  123. case "cookie":
  124. extractor = jwtFromCookie(parts[1])
  125. }
  126. return func(next echo.HandlerFunc) echo.HandlerFunc {
  127. return func(c echo.Context) error {
  128. if config.Skipper(c) {
  129. return next(c)
  130. }
  131. if config.BeforeFunc != nil {
  132. config.BeforeFunc(c)
  133. }
  134. auth, err := extractor(c)
  135. if err != nil {
  136. if config.ErrorHandler != nil {
  137. return config.ErrorHandler(err)
  138. }
  139. return err
  140. }
  141. token := new(jwt.Token)
  142. // Issue #647, #656
  143. if _, ok := config.Claims.(jwt.MapClaims); ok {
  144. token, err = jwt.Parse(auth, config.keyFunc)
  145. } else {
  146. t := reflect.ValueOf(config.Claims).Type().Elem()
  147. claims := reflect.New(t).Interface().(jwt.Claims)
  148. token, err = jwt.ParseWithClaims(auth, claims, config.keyFunc)
  149. }
  150. if err == nil && token.Valid {
  151. // Store user information from token into context.
  152. c.Set(config.ContextKey, token)
  153. if config.SuccessHandler != nil {
  154. config.SuccessHandler(c)
  155. }
  156. return next(c)
  157. }
  158. if config.ErrorHandler != nil {
  159. return config.ErrorHandler(err)
  160. }
  161. return &echo.HTTPError{
  162. Code: http.StatusUnauthorized,
  163. Message: "invalid or expired jwt",
  164. Internal: err,
  165. }
  166. }
  167. }
  168. }
  169. // jwtFromHeader returns a `jwtExtractor` that extracts token from the request header.
  170. func jwtFromHeader(header string, authScheme string) jwtExtractor {
  171. return func(c echo.Context) (string, error) {
  172. auth := c.Request().Header.Get(header)
  173. l := len(authScheme)
  174. if len(auth) > l+1 && auth[:l] == authScheme {
  175. return auth[l+1:], nil
  176. }
  177. return "", ErrJWTMissing
  178. }
  179. }
  180. // jwtFromQuery returns a `jwtExtractor` that extracts token from the query string.
  181. func jwtFromQuery(param string) jwtExtractor {
  182. return func(c echo.Context) (string, error) {
  183. token := c.QueryParam(param)
  184. if token == "" {
  185. return "", ErrJWTMissing
  186. }
  187. return token, nil
  188. }
  189. }
  190. // jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie.
  191. func jwtFromCookie(name string) jwtExtractor {
  192. return func(c echo.Context) (string, error) {
  193. cookie, err := c.Cookie(name)
  194. if err != nil {
  195. return "", ErrJWTMissing
  196. }
  197. return cookie.Value, nil
  198. }
  199. }