slash.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. package middleware
  2. import (
  3. "github.com/labstack/echo"
  4. )
  5. type (
  6. // TrailingSlashConfig defines the config for TrailingSlash middleware.
  7. TrailingSlashConfig struct {
  8. // Skipper defines a function to skip middleware.
  9. Skipper Skipper
  10. // Status code to be used when redirecting the request.
  11. // Optional, but when provided the request is redirected using this code.
  12. RedirectCode int `yaml:"redirect_code"`
  13. }
  14. )
  15. var (
  16. // DefaultTrailingSlashConfig is the default TrailingSlash middleware config.
  17. DefaultTrailingSlashConfig = TrailingSlashConfig{
  18. Skipper: DefaultSkipper,
  19. }
  20. )
  21. // AddTrailingSlash returns a root level (before router) middleware which adds a
  22. // trailing slash to the request `URL#Path`.
  23. //
  24. // Usage `Echo#Pre(AddTrailingSlash())`
  25. func AddTrailingSlash() echo.MiddlewareFunc {
  26. return AddTrailingSlashWithConfig(DefaultTrailingSlashConfig)
  27. }
  28. // AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config.
  29. // See `AddTrailingSlash()`.
  30. func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
  31. // Defaults
  32. if config.Skipper == nil {
  33. config.Skipper = DefaultTrailingSlashConfig.Skipper
  34. }
  35. return func(next echo.HandlerFunc) echo.HandlerFunc {
  36. return func(c echo.Context) error {
  37. if config.Skipper(c) {
  38. return next(c)
  39. }
  40. req := c.Request()
  41. url := req.URL
  42. path := url.Path
  43. qs := c.QueryString()
  44. if path != "/" && path[len(path)-1] != '/' {
  45. path += "/"
  46. uri := path
  47. if qs != "" {
  48. uri += "?" + qs
  49. }
  50. // Redirect
  51. if config.RedirectCode != 0 {
  52. return c.Redirect(config.RedirectCode, uri)
  53. }
  54. // Forward
  55. req.RequestURI = uri
  56. url.Path = path
  57. }
  58. return next(c)
  59. }
  60. }
  61. }
  62. // RemoveTrailingSlash returns a root level (before router) middleware which removes
  63. // a trailing slash from the request URI.
  64. //
  65. // Usage `Echo#Pre(RemoveTrailingSlash())`
  66. func RemoveTrailingSlash() echo.MiddlewareFunc {
  67. return RemoveTrailingSlashWithConfig(TrailingSlashConfig{})
  68. }
  69. // RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config.
  70. // See `RemoveTrailingSlash()`.
  71. func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
  72. // Defaults
  73. if config.Skipper == nil {
  74. config.Skipper = DefaultTrailingSlashConfig.Skipper
  75. }
  76. return func(next echo.HandlerFunc) echo.HandlerFunc {
  77. return func(c echo.Context) error {
  78. if config.Skipper(c) {
  79. return next(c)
  80. }
  81. req := c.Request()
  82. url := req.URL
  83. path := url.Path
  84. qs := c.QueryString()
  85. l := len(path) - 1
  86. if l >= 0 && path != "/" && path[l] == '/' {
  87. path = path[:l]
  88. uri := path
  89. if qs != "" {
  90. uri += "?" + qs
  91. }
  92. // Redirect
  93. if config.RedirectCode != 0 {
  94. return c.Redirect(config.RedirectCode, uri)
  95. }
  96. // Forward
  97. req.RequestURI = uri
  98. url.Path = path
  99. }
  100. return next(c)
  101. }
  102. }
  103. }