diff --git a/internal/photos/api/account.go b/internal/photos/api/account.go index 332e266..72c4c9a 100644 --- a/internal/photos/api/account.go +++ b/internal/photos/api/account.go @@ -70,9 +70,13 @@ func (s *Service) Login(c *gin.Context) { } func (s *Service) Logout(c *gin.Context) { - var sess *models.Session = c.MustGet("session").(*models.Session) - if err := s.DB.Delete(sess).Error; err != nil { - s.Error(c, http.StatusInternalServerError, err) + res := s.DB.Where("token = ?", c.GetString("token")).Delete(&models.Session{}) + if res.Error != nil { + s.Error(c, http.StatusInternalServerError, res.Error) + return + } + if res.RowsAffected == 0 { + s.Error(c, http.StatusNotFound, ErrSessionNotFound) return } c.JSON(http.StatusOK, gin.H{ diff --git a/internal/photos/api/main.go b/internal/photos/api/main.go index e9f3359..5ee5ae8 100644 --- a/internal/photos/api/main.go +++ b/internal/photos/api/main.go @@ -53,7 +53,7 @@ func (s *Service) SetupRoutes() { ac := s.Gin.Group("/account") ac.POST("/signup", s.Signup) ac.POST("/login", s.Login) - ac.GET("/logout", s.RequireSession, s.Logout) + ac.GET("/logout", s.RequireAuthToken, s.Logout) s.Gin.NoRoute(func(c *gin.Context) { s.Error(c, http.StatusNotFound, ErrReqNotFound) diff --git a/internal/photos/api/session.go b/internal/photos/api/session.go index f3d0e5b..3fc2e6e 100644 --- a/internal/photos/api/session.go +++ b/internal/photos/api/session.go @@ -10,7 +10,7 @@ import ( "gorm.io/gorm" ) -func (s *Service) RequireSession(c *gin.Context) { +func (s *Service) RequireAuthToken(c *gin.Context) { token := c.GetHeader("Authorization") if !strings.HasPrefix(token, "Private ") { s.Error(c, http.StatusForbidden, ErrTokenMissing) @@ -18,9 +18,16 @@ func (s *Service) RequireSession(c *gin.Context) { } token = token[8:] c.Set("token", token) +} + +func (s *Service) RequireSession(c *gin.Context) { + s.RequireAuthToken(c) + if c.IsAborted() { + return + } sess := &models.Session{} - if err := s.DB.Preload("Account").Where("token = ?", token).First(sess).Error; err != nil { + if err := s.DB.Preload("Account").Where("token = ?", c.GetString("token")).First(sess).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { s.Error(c, http.StatusForbidden, ErrSessionNotFound) } else {