csrf.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. package middleware
  2. import (
  3. "crypto/subtle"
  4. "errors"
  5. "net/http"
  6. "strings"
  7. "time"
  8. "github.com/labstack/echo"
  9. "github.com/labstack/gommon/random"
  10. )
  11. type (
  12. // CSRFConfig defines the config for CSRF middleware.
  13. CSRFConfig struct {
  14. // Skipper defines a function to skip middleware.
  15. Skipper Skipper
  16. // TokenLength is the length of the generated token.
  17. TokenLength uint8 `yaml:"token_length"`
  18. // Optional. Default value 32.
  19. // TokenLookup is a string in the form of "<source>:<key>" that is used
  20. // to extract token from the request.
  21. // Optional. Default value "header:X-CSRF-Token".
  22. // Possible values:
  23. // - "header:<name>"
  24. // - "form:<name>"
  25. // - "query:<name>"
  26. TokenLookup string `yaml:"token_lookup"`
  27. // Context key to store generated CSRF token into context.
  28. // Optional. Default value "csrf".
  29. ContextKey string `yaml:"context_key"`
  30. // Name of the CSRF cookie. This cookie will store CSRF token.
  31. // Optional. Default value "csrf".
  32. CookieName string `yaml:"cookie_name"`
  33. // Domain of the CSRF cookie.
  34. // Optional. Default value none.
  35. CookieDomain string `yaml:"cookie_domain"`
  36. // Path of the CSRF cookie.
  37. // Optional. Default value none.
  38. CookiePath string `yaml:"cookie_path"`
  39. // Max age (in seconds) of the CSRF cookie.
  40. // Optional. Default value 86400 (24hr).
  41. CookieMaxAge int `yaml:"cookie_max_age"`
  42. // Indicates if CSRF cookie is secure.
  43. // Optional. Default value false.
  44. CookieSecure bool `yaml:"cookie_secure"`
  45. // Indicates if CSRF cookie is HTTP only.
  46. // Optional. Default value false.
  47. CookieHTTPOnly bool `yaml:"cookie_http_only"`
  48. }
  49. // csrfTokenExtractor defines a function that takes `echo.Context` and returns
  50. // either a token or an error.
  51. csrfTokenExtractor func(echo.Context) (string, error)
  52. )
  53. var (
  54. // DefaultCSRFConfig is the default CSRF middleware config.
  55. DefaultCSRFConfig = CSRFConfig{
  56. Skipper: DefaultSkipper,
  57. TokenLength: 32,
  58. TokenLookup: "header:" + echo.HeaderXCSRFToken,
  59. ContextKey: "csrf",
  60. CookieName: "_csrf",
  61. CookieMaxAge: 86400,
  62. }
  63. )
  64. // CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
  65. // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
  66. func CSRF() echo.MiddlewareFunc {
  67. c := DefaultCSRFConfig
  68. return CSRFWithConfig(c)
  69. }
  70. // CSRFWithConfig returns a CSRF middleware with config.
  71. // See `CSRF()`.
  72. func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
  73. // Defaults
  74. if config.Skipper == nil {
  75. config.Skipper = DefaultCSRFConfig.Skipper
  76. }
  77. if config.TokenLength == 0 {
  78. config.TokenLength = DefaultCSRFConfig.TokenLength
  79. }
  80. if config.TokenLookup == "" {
  81. config.TokenLookup = DefaultCSRFConfig.TokenLookup
  82. }
  83. if config.ContextKey == "" {
  84. config.ContextKey = DefaultCSRFConfig.ContextKey
  85. }
  86. if config.CookieName == "" {
  87. config.CookieName = DefaultCSRFConfig.CookieName
  88. }
  89. if config.CookieMaxAge == 0 {
  90. config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
  91. }
  92. // Initialize
  93. parts := strings.Split(config.TokenLookup, ":")
  94. extractor := csrfTokenFromHeader(parts[1])
  95. switch parts[0] {
  96. case "form":
  97. extractor = csrfTokenFromForm(parts[1])
  98. case "query":
  99. extractor = csrfTokenFromQuery(parts[1])
  100. }
  101. return func(next echo.HandlerFunc) echo.HandlerFunc {
  102. return func(c echo.Context) error {
  103. if config.Skipper(c) {
  104. return next(c)
  105. }
  106. req := c.Request()
  107. k, err := c.Cookie(config.CookieName)
  108. token := ""
  109. // Generate token
  110. if err != nil {
  111. token = random.String(config.TokenLength)
  112. } else {
  113. // Reuse token
  114. token = k.Value
  115. }
  116. switch req.Method {
  117. case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
  118. default:
  119. // Validate token only for requests which are not defined as 'safe' by RFC7231
  120. clientToken, err := extractor(c)
  121. if err != nil {
  122. return echo.NewHTTPError(http.StatusBadRequest, err.Error())
  123. }
  124. if !validateCSRFToken(token, clientToken) {
  125. return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
  126. }
  127. }
  128. // Set CSRF cookie
  129. cookie := new(http.Cookie)
  130. cookie.Name = config.CookieName
  131. cookie.Value = token
  132. if config.CookiePath != "" {
  133. cookie.Path = config.CookiePath
  134. }
  135. if config.CookieDomain != "" {
  136. cookie.Domain = config.CookieDomain
  137. }
  138. cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
  139. cookie.Secure = config.CookieSecure
  140. cookie.HttpOnly = config.CookieHTTPOnly
  141. c.SetCookie(cookie)
  142. // Store token in the context
  143. c.Set(config.ContextKey, token)
  144. // Protect clients from caching the response
  145. c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
  146. return next(c)
  147. }
  148. }
  149. }
  150. // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
  151. // provided request header.
  152. func csrfTokenFromHeader(header string) csrfTokenExtractor {
  153. return func(c echo.Context) (string, error) {
  154. return c.Request().Header.Get(header), nil
  155. }
  156. }
  157. // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
  158. // provided form parameter.
  159. func csrfTokenFromForm(param string) csrfTokenExtractor {
  160. return func(c echo.Context) (string, error) {
  161. token := c.FormValue(param)
  162. if token == "" {
  163. return "", errors.New("missing csrf token in the form parameter")
  164. }
  165. return token, nil
  166. }
  167. }
  168. // csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
  169. // provided query parameter.
  170. func csrfTokenFromQuery(param string) csrfTokenExtractor {
  171. return func(c echo.Context) (string, error) {
  172. token := c.QueryParam(param)
  173. if token == "" {
  174. return "", errors.New("missing csrf token in the query string")
  175. }
  176. return token, nil
  177. }
  178. }
  179. func validateCSRFToken(token, clientToken string) bool {
  180. return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
  181. }