123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- package middleware
- import (
- "crypto/subtle"
- "errors"
- "net/http"
- "strings"
- "time"
- "github.com/labstack/echo"
- "github.com/labstack/gommon/random"
- )
- type (
- // CSRFConfig defines the config for CSRF middleware.
- CSRFConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
- // TokenLength is the length of the generated token.
- TokenLength uint8 `yaml:"token_length"`
- // Optional. Default value 32.
- // TokenLookup is a string in the form of "<source>:<key>" that is used
- // to extract token from the request.
- // Optional. Default value "header:X-CSRF-Token".
- // Possible values:
- // - "header:<name>"
- // - "form:<name>"
- // - "query:<name>"
- TokenLookup string `yaml:"token_lookup"`
- // Context key to store generated CSRF token into context.
- // Optional. Default value "csrf".
- ContextKey string `yaml:"context_key"`
- // Name of the CSRF cookie. This cookie will store CSRF token.
- // Optional. Default value "csrf".
- CookieName string `yaml:"cookie_name"`
- // Domain of the CSRF cookie.
- // Optional. Default value none.
- CookieDomain string `yaml:"cookie_domain"`
- // Path of the CSRF cookie.
- // Optional. Default value none.
- CookiePath string `yaml:"cookie_path"`
- // Max age (in seconds) of the CSRF cookie.
- // Optional. Default value 86400 (24hr).
- CookieMaxAge int `yaml:"cookie_max_age"`
- // Indicates if CSRF cookie is secure.
- // Optional. Default value false.
- CookieSecure bool `yaml:"cookie_secure"`
- // Indicates if CSRF cookie is HTTP only.
- // Optional. Default value false.
- CookieHTTPOnly bool `yaml:"cookie_http_only"`
- }
- // csrfTokenExtractor defines a function that takes `echo.Context` and returns
- // either a token or an error.
- csrfTokenExtractor func(echo.Context) (string, error)
- )
- var (
- // DefaultCSRFConfig is the default CSRF middleware config.
- DefaultCSRFConfig = CSRFConfig{
- Skipper: DefaultSkipper,
- TokenLength: 32,
- TokenLookup: "header:" + echo.HeaderXCSRFToken,
- ContextKey: "csrf",
- CookieName: "_csrf",
- CookieMaxAge: 86400,
- }
- )
- // CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
- // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
- func CSRF() echo.MiddlewareFunc {
- c := DefaultCSRFConfig
- return CSRFWithConfig(c)
- }
- // CSRFWithConfig returns a CSRF middleware with config.
- // See `CSRF()`.
- func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
- // Defaults
- if config.Skipper == nil {
- config.Skipper = DefaultCSRFConfig.Skipper
- }
- if config.TokenLength == 0 {
- config.TokenLength = DefaultCSRFConfig.TokenLength
- }
- if config.TokenLookup == "" {
- config.TokenLookup = DefaultCSRFConfig.TokenLookup
- }
- if config.ContextKey == "" {
- config.ContextKey = DefaultCSRFConfig.ContextKey
- }
- if config.CookieName == "" {
- config.CookieName = DefaultCSRFConfig.CookieName
- }
- if config.CookieMaxAge == 0 {
- config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
- }
- // Initialize
- parts := strings.Split(config.TokenLookup, ":")
- extractor := csrfTokenFromHeader(parts[1])
- switch parts[0] {
- case "form":
- extractor = csrfTokenFromForm(parts[1])
- case "query":
- extractor = csrfTokenFromQuery(parts[1])
- }
- return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) error {
- if config.Skipper(c) {
- return next(c)
- }
- req := c.Request()
- k, err := c.Cookie(config.CookieName)
- token := ""
- // Generate token
- if err != nil {
- token = random.String(config.TokenLength)
- } else {
- // Reuse token
- token = k.Value
- }
- switch req.Method {
- case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
- default:
- // Validate token only for requests which are not defined as 'safe' by RFC7231
- clientToken, err := extractor(c)
- if err != nil {
- return echo.NewHTTPError(http.StatusBadRequest, err.Error())
- }
- if !validateCSRFToken(token, clientToken) {
- return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
- }
- }
- // Set CSRF cookie
- cookie := new(http.Cookie)
- cookie.Name = config.CookieName
- cookie.Value = token
- if config.CookiePath != "" {
- cookie.Path = config.CookiePath
- }
- if config.CookieDomain != "" {
- cookie.Domain = config.CookieDomain
- }
- cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
- cookie.Secure = config.CookieSecure
- cookie.HttpOnly = config.CookieHTTPOnly
- c.SetCookie(cookie)
- // Store token in the context
- c.Set(config.ContextKey, token)
- // Protect clients from caching the response
- c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
- return next(c)
- }
- }
- }
- // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
- // provided request header.
- func csrfTokenFromHeader(header string) csrfTokenExtractor {
- return func(c echo.Context) (string, error) {
- return c.Request().Header.Get(header), nil
- }
- }
- // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
- // provided form parameter.
- func csrfTokenFromForm(param string) csrfTokenExtractor {
- return func(c echo.Context) (string, error) {
- token := c.FormValue(param)
- if token == "" {
- return "", errors.New("missing csrf token in the form parameter")
- }
- return token, nil
- }
- }
- // csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
- // provided query parameter.
- func csrfTokenFromQuery(param string) csrfTokenExtractor {
- return func(c echo.Context) (string, error) {
- token := c.QueryParam(param)
- if token == "" {
- return "", errors.New("missing csrf token in the query string")
- }
- return token, nil
- }
- }
- func validateCSRFToken(token, clientToken string) bool {
- return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
- }
|