method_override.go 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. package middleware
  2. import "github.com/labstack/echo"
  3. type (
  4. // MethodOverrideConfig defines the config for MethodOverride middleware.
  5. MethodOverrideConfig struct {
  6. // Skipper defines a function to skip middleware.
  7. Skipper Skipper
  8. // Getter is a function that gets overridden method from the request.
  9. // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride).
  10. Getter MethodOverrideGetter
  11. }
  12. // MethodOverrideGetter is a function that gets overridden method from the request
  13. MethodOverrideGetter func(echo.Context) string
  14. )
  15. var (
  16. // DefaultMethodOverrideConfig is the default MethodOverride middleware config.
  17. DefaultMethodOverrideConfig = MethodOverrideConfig{
  18. Skipper: DefaultSkipper,
  19. Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
  20. }
  21. )
  22. // MethodOverride returns a MethodOverride middleware.
  23. // MethodOverride middleware checks for the overridden method from the request and
  24. // uses it instead of the original method.
  25. //
  26. // For security reasons, only `POST` method can be overridden.
  27. func MethodOverride() echo.MiddlewareFunc {
  28. return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
  29. }
  30. // MethodOverrideWithConfig returns a MethodOverride middleware with config.
  31. // See: `MethodOverride()`.
  32. func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
  33. // Defaults
  34. if config.Skipper == nil {
  35. config.Skipper = DefaultMethodOverrideConfig.Skipper
  36. }
  37. if config.Getter == nil {
  38. config.Getter = DefaultMethodOverrideConfig.Getter
  39. }
  40. return func(next echo.HandlerFunc) echo.HandlerFunc {
  41. return func(c echo.Context) error {
  42. if config.Skipper(c) {
  43. return next(c)
  44. }
  45. req := c.Request()
  46. if req.Method == echo.POST {
  47. m := config.Getter(c)
  48. if m != "" {
  49. req.Method = m
  50. }
  51. }
  52. return next(c)
  53. }
  54. }
  55. }
  56. // MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from
  57. // the request header.
  58. func MethodFromHeader(header string) MethodOverrideGetter {
  59. return func(c echo.Context) string {
  60. return c.Request().Header.Get(header)
  61. }
  62. }
  63. // MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the
  64. // form parameter.
  65. func MethodFromForm(param string) MethodOverrideGetter {
  66. return func(c echo.Context) string {
  67. return c.FormValue(param)
  68. }
  69. }
  70. // MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from
  71. // the query parameter.
  72. func MethodFromQuery(param string) MethodOverrideGetter {
  73. return func(c echo.Context) string {
  74. return c.QueryParam(param)
  75. }
  76. }