From 87caf478dfa4fbe0b4899d2c0751b81fcd747d6c Mon Sep 17 00:00:00 2001 From: James Griffin Date: Thu, 5 Mar 2026 16:20:23 -0400 Subject: [PATCH] Conform Go code to project conventions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Propagate context.Context through all exported store/service methods that perform I/O; use QueryContext/ExecContext/QueryRowContext throughout - Add package-level sentinel errors (ErrNotFound, ErrAlreadyCheckedIn, ErrNotCheckedIn) and replace nil,nil returns with explicit errors - Update handlers to use errors.Is() instead of nil checks, with correct HTTP status codes per error type - Fix SQLite datetime('now') → MySQL NOW() in volunteer, schedule, timeoff, and checkin stores - Refactor db.Migrate to execute schema statements individually (MySQL driver does not support multi-statement Exec) - Fix import grouping in handler files (stdlib, external, internal) Co-Authored-By: Claude Sonnet 4.6 --- cmd/server/main.go | 3 +- internal/auth/auth.go | 5 +-- internal/checkin/checkin.go | 51 +++++++++++++++------------ internal/checkin/handler.go | 19 +++++++--- internal/db/db.go | 11 ++++-- internal/db/schema.go | 28 +++++++-------- internal/notification/handler.go | 15 ++++---- internal/notification/notification.go | 31 ++++++++-------- internal/schedule/handler.go | 19 +++++----- internal/schedule/schedule.go | 37 ++++++++++--------- internal/timeoff/handler.go | 17 ++++----- internal/timeoff/timeoff.go | 27 +++++++------- internal/volunteer/handler.go | 29 +++++++-------- internal/volunteer/volunteer.go | 33 +++++++++-------- 14 files changed, 180 insertions(+), 145 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index 0b91e5c..f39e148 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "log" "net/http" @@ -29,7 +30,7 @@ func main() { } defer database.Close() - if err := db.Migrate(database); err != nil { + if err := db.Migrate(context.Background(), database); err != nil { log.Fatalf("migrate database: %v", err) } diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 600e860..8b980fb 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,6 +1,7 @@ package auth import ( + "context" "database/sql" "errors" "fmt" @@ -27,10 +28,10 @@ func NewService(db *sql.DB, secret string) *Service { return &Service{db: db, jwtSecret: []byte(secret)} } -func (s *Service) Login(email, password string) (string, error) { +func (s *Service) Login(ctx context.Context, email, password string) (string, error) { var id int64 var hash, role string - err := s.db.QueryRow( + err := s.db.QueryRowContext(ctx, `SELECT id, password, role FROM volunteers WHERE email = ? AND active = 1`, email, ).Scan(&id, &hash, &role) diff --git a/internal/checkin/checkin.go b/internal/checkin/checkin.go index 411671b..c30c657 100644 --- a/internal/checkin/checkin.go +++ b/internal/checkin/checkin.go @@ -1,19 +1,26 @@ package checkin import ( + "context" "database/sql" "errors" "fmt" "time" ) +var ( + ErrNotFound = fmt.Errorf("check-in not found") + ErrAlreadyCheckedIn = fmt.Errorf("already checked in") + ErrNotCheckedIn = fmt.Errorf("not checked in") +) + type CheckIn struct { - ID int64 `json:"id"` - VolunteerID int64 `json:"volunteer_id"` - ScheduleID *int64 `json:"schedule_id,omitempty"` - CheckedInAt time.Time `json:"checked_in_at"` - CheckedOutAt *time.Time `json:"checked_out_at,omitempty"` - Notes string `json:"notes,omitempty"` + ID int64 `json:"id"` + VolunteerID int64 `json:"volunteer_id"` + ScheduleID *int64 `json:"schedule_id,omitempty"` + CheckedInAt time.Time `json:"checked_in_at"` + CheckedOutAt *time.Time `json:"checked_out_at,omitempty"` + Notes string `json:"notes,omitempty"` } type CheckInInput struct { @@ -33,17 +40,17 @@ func NewStore(db *sql.DB) *Store { return &Store{db: db} } -func (s *Store) CheckIn(volunteerID int64, in CheckInInput) (*CheckIn, error) { +func (s *Store) CheckIn(ctx context.Context, volunteerID int64, in CheckInInput) (*CheckIn, error) { // Ensure no active check-in exists var count int - s.db.QueryRow( + s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM checkins WHERE volunteer_id = ? AND checked_out_at IS NULL`, volunteerID, ).Scan(&count) if count > 0 { - return nil, fmt.Errorf("already checked in") + return nil, ErrAlreadyCheckedIn } - res, err := s.db.Exec( + res, err := s.db.ExecContext(ctx, `INSERT INTO checkins (volunteer_id, schedule_id, notes) VALUES (?, ?, ?)`, volunteerID, in.ScheduleID, in.Notes, ) @@ -51,43 +58,43 @@ func (s *Store) CheckIn(volunteerID int64, in CheckInInput) (*CheckIn, error) { return nil, fmt.Errorf("insert checkin: %w", err) } id, _ := res.LastInsertId() - return s.GetByID(id) + return s.GetByID(ctx, id) } -func (s *Store) CheckOut(volunteerID int64, in CheckOutInput) (*CheckIn, error) { +func (s *Store) CheckOut(ctx context.Context, volunteerID int64, in CheckOutInput) (*CheckIn, error) { var id int64 - err := s.db.QueryRow( + err := s.db.QueryRowContext(ctx, `SELECT id FROM checkins WHERE volunteer_id = ? AND checked_out_at IS NULL ORDER BY checked_in_at DESC LIMIT 1`, volunteerID, ).Scan(&id) if errors.Is(err, sql.ErrNoRows) { - return nil, fmt.Errorf("not checked in") + return nil, ErrNotCheckedIn } if err != nil { return nil, fmt.Errorf("find active checkin: %w", err) } - _, err = s.db.Exec( - `UPDATE checkins SET checked_out_at=datetime('now'), notes=COALESCE(NULLIF(?, ''), notes) WHERE id=?`, + _, err = s.db.ExecContext(ctx, + `UPDATE checkins SET checked_out_at=NOW(), notes=COALESCE(NULLIF(?, ''), notes) WHERE id=?`, in.Notes, id, ) if err != nil { return nil, fmt.Errorf("checkout: %w", err) } - return s.GetByID(id) + return s.GetByID(ctx, id) } -func (s *Store) GetByID(id int64) (*CheckIn, error) { +func (s *Store) GetByID(ctx context.Context, id int64) (*CheckIn, error) { ci := &CheckIn{} var checkedInAt string var checkedOutAt sql.NullString var scheduleID sql.NullInt64 var notes sql.NullString - err := s.db.QueryRow( + err := s.db.QueryRowContext(ctx, `SELECT id, volunteer_id, schedule_id, checked_in_at, checked_out_at, notes FROM checkins WHERE id = ?`, id, ).Scan(&ci.ID, &ci.VolunteerID, &scheduleID, &checkedInAt, &checkedOutAt, ¬es) if errors.Is(err, sql.ErrNoRows) { - return nil, nil + return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("get checkin: %w", err) @@ -106,7 +113,7 @@ func (s *Store) GetByID(id int64) (*CheckIn, error) { return ci, nil } -func (s *Store) History(volunteerID int64) ([]CheckIn, error) { +func (s *Store) History(ctx context.Context, volunteerID int64) ([]CheckIn, error) { query := `SELECT id, volunteer_id, schedule_id, checked_in_at, checked_out_at, notes FROM checkins` args := []any{} if volunteerID > 0 { @@ -115,7 +122,7 @@ func (s *Store) History(volunteerID int64) ([]CheckIn, error) { } query += ` ORDER BY checked_in_at DESC` - rows, err := s.db.Query(query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("list checkins: %w", err) } diff --git a/internal/checkin/handler.go b/internal/checkin/handler.go index 968ca3b..c4fa470 100644 --- a/internal/checkin/handler.go +++ b/internal/checkin/handler.go @@ -2,6 +2,7 @@ package checkin import ( "encoding/json" + "errors" "net/http" "git.unsupervised.ca/walkies/internal/respond" @@ -24,11 +25,15 @@ func (h *Handler) CheckIn(w http.ResponseWriter, r *http.Request) { respond.Error(w, http.StatusBadRequest, "invalid request body") return } - ci, err := h.store.CheckIn(claims.VolunteerID, in) - if err != nil { + ci, err := h.store.CheckIn(r.Context(), claims.VolunteerID, in) + if errors.Is(err, ErrAlreadyCheckedIn) { respond.Error(w, http.StatusConflict, err.Error()) return } + if err != nil { + respond.Error(w, http.StatusInternalServerError, "could not check in") + return + } respond.JSON(w, http.StatusCreated, ci) } @@ -37,11 +42,15 @@ func (h *Handler) CheckOut(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) var in CheckOutInput json.NewDecoder(r.Body).Decode(&in) - ci, err := h.store.CheckOut(claims.VolunteerID, in) - if err != nil { + ci, err := h.store.CheckOut(r.Context(), claims.VolunteerID, in) + if errors.Is(err, ErrNotCheckedIn) { respond.Error(w, http.StatusConflict, err.Error()) return } + if err != nil { + respond.Error(w, http.StatusInternalServerError, "could not check out") + return + } respond.JSON(w, http.StatusOK, ci) } @@ -52,7 +61,7 @@ func (h *Handler) History(w http.ResponseWriter, r *http.Request) { if claims.Role != "admin" { volunteerID = claims.VolunteerID } - history, err := h.store.History(volunteerID) + history, err := h.store.History(r.Context(), volunteerID) if err != nil { respond.Error(w, http.StatusInternalServerError, "could not get history") return diff --git a/internal/db/db.go b/internal/db/db.go index cf73d99..06c9a98 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "fmt" @@ -18,7 +19,11 @@ func Open(dsn string) (*sql.DB, error) { return db, nil } -func Migrate(db *sql.DB) error { - _, err := db.Exec(schema) - return err +func Migrate(ctx context.Context, db *sql.DB) error { + for _, stmt := range statements { + if _, err := db.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("migrate: %w", err) + } + } + return nil } diff --git a/internal/db/schema.go b/internal/db/schema.go index ac7e6bf..7cf814f 100644 --- a/internal/db/schema.go +++ b/internal/db/schema.go @@ -1,7 +1,7 @@ package db -const schema = ` -CREATE TABLE IF NOT EXISTS volunteers ( +var statements = []string{ + `CREATE TABLE IF NOT EXISTS volunteers ( id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL UNIQUE, @@ -10,9 +10,8 @@ CREATE TABLE IF NOT EXISTS volunteers ( active TINYINT NOT NULL DEFAULT 1, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - -CREATE TABLE IF NOT EXISTS schedules ( +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci`, + `CREATE TABLE IF NOT EXISTS schedules ( id INT AUTO_INCREMENT PRIMARY KEY, volunteer_id INT NOT NULL, title VARCHAR(255) NOT NULL, @@ -23,9 +22,8 @@ CREATE TABLE IF NOT EXISTS schedules ( updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, FOREIGN KEY (volunteer_id) REFERENCES volunteers(id) ON DELETE CASCADE, INDEX idx_volunteer_id (volunteer_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - -CREATE TABLE IF NOT EXISTS time_off_requests ( +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci`, + `CREATE TABLE IF NOT EXISTS time_off_requests ( id INT AUTO_INCREMENT PRIMARY KEY, volunteer_id INT NOT NULL, starts_at DATETIME NOT NULL, @@ -40,9 +38,8 @@ CREATE TABLE IF NOT EXISTS time_off_requests ( FOREIGN KEY (reviewed_by) REFERENCES volunteers(id) ON DELETE SET NULL, INDEX idx_volunteer_id (volunteer_id), INDEX idx_status (status) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - -CREATE TABLE IF NOT EXISTS checkins ( +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci`, + `CREATE TABLE IF NOT EXISTS checkins ( id INT AUTO_INCREMENT PRIMARY KEY, volunteer_id INT NOT NULL, schedule_id INT, @@ -53,9 +50,8 @@ CREATE TABLE IF NOT EXISTS checkins ( FOREIGN KEY (schedule_id) REFERENCES schedules(id) ON DELETE SET NULL, INDEX idx_volunteer_id (volunteer_id), INDEX idx_schedule_id (schedule_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; - -CREATE TABLE IF NOT EXISTS notifications ( +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci`, + `CREATE TABLE IF NOT EXISTS notifications ( id INT AUTO_INCREMENT PRIMARY KEY, volunteer_id INT NOT NULL, message TEXT NOT NULL, @@ -64,5 +60,5 @@ CREATE TABLE IF NOT EXISTS notifications ( FOREIGN KEY (volunteer_id) REFERENCES volunteers(id) ON DELETE CASCADE, INDEX idx_volunteer_id (volunteer_id), INDEX idx_read (read) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; -` +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci`, +} diff --git a/internal/notification/handler.go b/internal/notification/handler.go index f16942d..2223601 100644 --- a/internal/notification/handler.go +++ b/internal/notification/handler.go @@ -1,12 +1,13 @@ package notification import ( + "errors" "net/http" "strconv" - "github.com/go-chi/chi/v5" "git.unsupervised.ca/walkies/internal/respond" "git.unsupervised.ca/walkies/internal/server/middleware" + "github.com/go-chi/chi/v5" ) type Handler struct { @@ -20,7 +21,7 @@ func NewHandler(store *Store) *Handler { // GET /api/v1/notifications func (h *Handler) List(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) - notifications, err := h.store.ListForVolunteer(claims.VolunteerID) + notifications, err := h.store.ListForVolunteer(r.Context(), claims.VolunteerID) if err != nil { respond.Error(w, http.StatusInternalServerError, "could not list notifications") return @@ -39,14 +40,14 @@ func (h *Handler) MarkRead(w http.ResponseWriter, r *http.Request) { respond.Error(w, http.StatusBadRequest, "invalid id") return } - n, err := h.store.MarkRead(id, claims.VolunteerID) + n, err := h.store.MarkRead(r.Context(), id, claims.VolunteerID) + if errors.Is(err, ErrNotFound) { + respond.Error(w, http.StatusNotFound, "notification not found") + return + } if err != nil { respond.Error(w, http.StatusInternalServerError, "could not mark notification as read") return } - if n == nil { - respond.Error(w, http.StatusNotFound, "notification not found") - return - } respond.JSON(w, http.StatusOK, n) } diff --git a/internal/notification/notification.go b/internal/notification/notification.go index 62950e8..d805583 100644 --- a/internal/notification/notification.go +++ b/internal/notification/notification.go @@ -1,12 +1,15 @@ package notification import ( + "context" "database/sql" "errors" "fmt" "time" ) +var ErrNotFound = fmt.Errorf("notification not found") + type Notification struct { ID int64 `json:"id"` VolunteerID int64 `json:"volunteer_id"` @@ -23,8 +26,8 @@ func NewStore(db *sql.DB) *Store { return &Store{db: db} } -func (s *Store) Create(volunteerID int64, message string) (*Notification, error) { - res, err := s.db.Exec( +func (s *Store) Create(ctx context.Context, volunteerID int64, message string) (*Notification, error) { + res, err := s.db.ExecContext(ctx, `INSERT INTO notifications (volunteer_id, message) VALUES (?, ?)`, volunteerID, message, ) @@ -32,17 +35,17 @@ func (s *Store) Create(volunteerID int64, message string) (*Notification, error) return nil, fmt.Errorf("insert notification: %w", err) } id, _ := res.LastInsertId() - return s.GetByID(id) + return s.GetByID(ctx, id) } -func (s *Store) GetByID(id int64) (*Notification, error) { +func (s *Store) GetByID(ctx context.Context, id int64) (*Notification, error) { n := &Notification{} var createdAt string - err := s.db.QueryRow( + err := s.db.QueryRowContext(ctx, `SELECT id, volunteer_id, message, read, created_at FROM notifications WHERE id = ?`, id, ).Scan(&n.ID, &n.VolunteerID, &n.Message, &n.Read, &createdAt) if errors.Is(err, sql.ErrNoRows) { - return nil, nil + return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("get notification: %w", err) @@ -51,8 +54,8 @@ func (s *Store) GetByID(id int64) (*Notification, error) { return n, nil } -func (s *Store) ListForVolunteer(volunteerID int64) ([]Notification, error) { - rows, err := s.db.Query( +func (s *Store) ListForVolunteer(ctx context.Context, volunteerID int64) ([]Notification, error) { + rows, err := s.db.QueryContext(ctx, `SELECT id, volunteer_id, message, read, created_at FROM notifications WHERE volunteer_id = ? ORDER BY created_at DESC`, volunteerID, ) @@ -74,17 +77,17 @@ func (s *Store) ListForVolunteer(volunteerID int64) ([]Notification, error) { return notifications, rows.Err() } -func (s *Store) MarkRead(id, volunteerID int64) (*Notification, error) { - result, err := s.db.Exec( +func (s *Store) MarkRead(ctx context.Context, id, volunteerID int64) (*Notification, error) { + result, err := s.db.ExecContext(ctx, `UPDATE notifications SET read = 1 WHERE id = ? AND volunteer_id = ?`, id, volunteerID, ) if err != nil { return nil, fmt.Errorf("mark read: %w", err) } - rows, _ := result.RowsAffected() - if rows == 0 { - return nil, nil + affected, _ := result.RowsAffected() + if affected == 0 { + return nil, ErrNotFound } - return s.GetByID(id) + return s.GetByID(ctx, id) } diff --git a/internal/schedule/handler.go b/internal/schedule/handler.go index 5029727..0a0272d 100644 --- a/internal/schedule/handler.go +++ b/internal/schedule/handler.go @@ -2,12 +2,13 @@ package schedule import ( "encoding/json" + "errors" "net/http" "strconv" - "github.com/go-chi/chi/v5" "git.unsupervised.ca/walkies/internal/respond" "git.unsupervised.ca/walkies/internal/server/middleware" + "github.com/go-chi/chi/v5" ) type Handler struct { @@ -25,7 +26,7 @@ func (h *Handler) List(w http.ResponseWriter, r *http.Request) { if claims.Role != "admin" { volunteerID = claims.VolunteerID } - schedules, err := h.store.List(volunteerID) + schedules, err := h.store.List(r.Context(), volunteerID) if err != nil { respond.Error(w, http.StatusInternalServerError, "could not list schedules") return @@ -51,7 +52,7 @@ func (h *Handler) Create(w http.ResponseWriter, r *http.Request) { respond.Error(w, http.StatusBadRequest, "title, starts_at, and ends_at are required") return } - sc, err := h.store.Create(in) + sc, err := h.store.Create(r.Context(), in) if err != nil { respond.Error(w, http.StatusInternalServerError, "could not create schedule") return @@ -71,13 +72,13 @@ func (h *Handler) Update(w http.ResponseWriter, r *http.Request) { respond.Error(w, http.StatusBadRequest, "invalid request body") return } - sc, err := h.store.Update(id, in) - if err != nil { - respond.Error(w, http.StatusInternalServerError, "could not update schedule") + sc, err := h.store.Update(r.Context(), id, in) + if errors.Is(err, ErrNotFound) { + respond.Error(w, http.StatusNotFound, "schedule not found") return } - if sc == nil { - respond.Error(w, http.StatusNotFound, "schedule not found") + if err != nil { + respond.Error(w, http.StatusInternalServerError, "could not update schedule") return } respond.JSON(w, http.StatusOK, sc) @@ -90,7 +91,7 @@ func (h *Handler) Delete(w http.ResponseWriter, r *http.Request) { respond.Error(w, http.StatusBadRequest, "invalid id") return } - if err := h.store.Delete(id); err != nil { + if err := h.store.Delete(r.Context(), id); err != nil { respond.Error(w, http.StatusInternalServerError, "could not delete schedule") return } diff --git a/internal/schedule/schedule.go b/internal/schedule/schedule.go index c500721..4cba542 100644 --- a/internal/schedule/schedule.go +++ b/internal/schedule/schedule.go @@ -1,12 +1,15 @@ package schedule import ( + "context" "database/sql" "errors" "fmt" "time" ) +var ErrNotFound = fmt.Errorf("schedule not found") + type Schedule struct { ID int64 `json:"id"` VolunteerID int64 `json:"volunteer_id"` @@ -43,8 +46,8 @@ func NewStore(db *sql.DB) *Store { return &Store{db: db} } -func (s *Store) Create(in CreateInput) (*Schedule, error) { - res, err := s.db.Exec( +func (s *Store) Create(ctx context.Context, in CreateInput) (*Schedule, error) { + res, err := s.db.ExecContext(ctx, `INSERT INTO schedules (volunteer_id, title, starts_at, ends_at, notes) VALUES (?, ?, ?, ?, ?)`, in.VolunteerID, in.Title, in.StartsAt, in.EndsAt, in.Notes, ) @@ -52,18 +55,18 @@ func (s *Store) Create(in CreateInput) (*Schedule, error) { return nil, fmt.Errorf("insert schedule: %w", err) } id, _ := res.LastInsertId() - return s.GetByID(id) + return s.GetByID(ctx, id) } -func (s *Store) GetByID(id int64) (*Schedule, error) { +func (s *Store) GetByID(ctx context.Context, id int64) (*Schedule, error) { sc := &Schedule{} var startsAt, endsAt, createdAt, updatedAt string var notes sql.NullString - err := s.db.QueryRow( + err := s.db.QueryRowContext(ctx, `SELECT id, volunteer_id, title, starts_at, ends_at, notes, created_at, updated_at FROM schedules WHERE id = ?`, id, ).Scan(&sc.ID, &sc.VolunteerID, &sc.Title, &startsAt, &endsAt, ¬es, &createdAt, &updatedAt) if errors.Is(err, sql.ErrNoRows) { - return nil, nil + return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("get schedule: %w", err) @@ -78,7 +81,7 @@ func (s *Store) GetByID(id int64) (*Schedule, error) { return sc, nil } -func (s *Store) List(volunteerID int64) ([]Schedule, error) { +func (s *Store) List(ctx context.Context, volunteerID int64) ([]Schedule, error) { query := `SELECT id, volunteer_id, title, starts_at, ends_at, notes, created_at, updated_at FROM schedules` args := []any{} if volunteerID > 0 { @@ -87,7 +90,7 @@ func (s *Store) List(volunteerID int64) ([]Schedule, error) { } query += ` ORDER BY starts_at` - rows, err := s.db.Query(query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("list schedules: %w", err) } @@ -113,10 +116,10 @@ func (s *Store) List(volunteerID int64) ([]Schedule, error) { return schedules, rows.Err() } -func (s *Store) Update(id int64, in UpdateInput) (*Schedule, error) { - sc, err := s.GetByID(id) - if err != nil || sc == nil { - return sc, err +func (s *Store) Update(ctx context.Context, id int64, in UpdateInput) (*Schedule, error) { + sc, err := s.GetByID(ctx, id) + if err != nil { + return nil, err } title := sc.Title startsAt := sc.StartsAt.Format("2006-01-02 15:04:05") @@ -135,17 +138,17 @@ func (s *Store) Update(id int64, in UpdateInput) (*Schedule, error) { if in.Notes != nil { notes = *in.Notes } - _, err = s.db.Exec( - `UPDATE schedules SET title=?, starts_at=?, ends_at=?, notes=?, updated_at=datetime('now') WHERE id=?`, + _, err = s.db.ExecContext(ctx, + `UPDATE schedules SET title=?, starts_at=?, ends_at=?, notes=?, updated_at=NOW() WHERE id=?`, title, startsAt, endsAt, notes, id, ) if err != nil { return nil, fmt.Errorf("update schedule: %w", err) } - return s.GetByID(id) + return s.GetByID(ctx, id) } -func (s *Store) Delete(id int64) error { - _, err := s.db.Exec(`DELETE FROM schedules WHERE id = ?`, id) +func (s *Store) Delete(ctx context.Context, id int64) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM schedules WHERE id = ?`, id) return err } diff --git a/internal/timeoff/handler.go b/internal/timeoff/handler.go index ed32f71..0d4365a 100644 --- a/internal/timeoff/handler.go +++ b/internal/timeoff/handler.go @@ -2,12 +2,13 @@ package timeoff import ( "encoding/json" + "errors" "net/http" "strconv" - "github.com/go-chi/chi/v5" "git.unsupervised.ca/walkies/internal/respond" "git.unsupervised.ca/walkies/internal/server/middleware" + "github.com/go-chi/chi/v5" ) type Handler struct { @@ -25,7 +26,7 @@ func (h *Handler) List(w http.ResponseWriter, r *http.Request) { if claims.Role != "admin" { volunteerID = claims.VolunteerID } - requests, err := h.store.List(volunteerID) + requests, err := h.store.List(r.Context(), volunteerID) if err != nil { respond.Error(w, http.StatusInternalServerError, "could not list time off requests") return @@ -48,7 +49,7 @@ func (h *Handler) Create(w http.ResponseWriter, r *http.Request) { respond.Error(w, http.StatusBadRequest, "starts_at and ends_at are required") return } - req, err := h.store.Create(claims.VolunteerID, in) + req, err := h.store.Create(r.Context(), claims.VolunteerID, in) if err != nil { respond.Error(w, http.StatusInternalServerError, "could not create time off request") return @@ -73,14 +74,14 @@ func (h *Handler) Review(w http.ResponseWriter, r *http.Request) { respond.Error(w, http.StatusBadRequest, "status must be 'approved' or 'rejected'") return } - req, err := h.store.Review(id, claims.VolunteerID, in.Status) + req, err := h.store.Review(r.Context(), id, claims.VolunteerID, in.Status) + if errors.Is(err, ErrNotFound) { + respond.Error(w, http.StatusNotFound, "time off request not found") + return + } if err != nil { respond.Error(w, http.StatusInternalServerError, "could not review time off request") return } - if req == nil { - respond.Error(w, http.StatusNotFound, "time off request not found") - return - } respond.JSON(w, http.StatusOK, req) } diff --git a/internal/timeoff/timeoff.go b/internal/timeoff/timeoff.go index f0eac10..2ddc4f2 100644 --- a/internal/timeoff/timeoff.go +++ b/internal/timeoff/timeoff.go @@ -1,12 +1,15 @@ package timeoff import ( + "context" "database/sql" "errors" "fmt" "time" ) +var ErrNotFound = fmt.Errorf("time off request not found") + type Request struct { ID int64 `json:"id"` VolunteerID int64 `json:"volunteer_id"` @@ -38,8 +41,8 @@ func NewStore(db *sql.DB) *Store { return &Store{db: db} } -func (s *Store) Create(volunteerID int64, in CreateInput) (*Request, error) { - res, err := s.db.Exec( +func (s *Store) Create(ctx context.Context, volunteerID int64, in CreateInput) (*Request, error) { + res, err := s.db.ExecContext(ctx, `INSERT INTO time_off_requests (volunteer_id, starts_at, ends_at, reason) VALUES (?, ?, ?, ?)`, volunteerID, in.StartsAt, in.EndsAt, in.Reason, ) @@ -47,22 +50,22 @@ func (s *Store) Create(volunteerID int64, in CreateInput) (*Request, error) { return nil, fmt.Errorf("insert time off request: %w", err) } id, _ := res.LastInsertId() - return s.GetByID(id) + return s.GetByID(ctx, id) } -func (s *Store) GetByID(id int64) (*Request, error) { +func (s *Store) GetByID(ctx context.Context, id int64) (*Request, error) { req := &Request{} var startsAt, endsAt, createdAt, updatedAt string var reason sql.NullString var reviewedBy sql.NullInt64 var reviewedAt sql.NullString - err := s.db.QueryRow( + err := s.db.QueryRowContext(ctx, `SELECT id, volunteer_id, starts_at, ends_at, reason, status, reviewed_by, reviewed_at, created_at, updated_at FROM time_off_requests WHERE id = ?`, id, ).Scan(&req.ID, &req.VolunteerID, &startsAt, &endsAt, &reason, &req.Status, &reviewedBy, &reviewedAt, &createdAt, &updatedAt) if errors.Is(err, sql.ErrNoRows) { - return nil, nil + return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("get time off request: %w", err) @@ -84,7 +87,7 @@ func (s *Store) GetByID(id int64) (*Request, error) { return req, nil } -func (s *Store) List(volunteerID int64) ([]Request, error) { +func (s *Store) List(ctx context.Context, volunteerID int64) ([]Request, error) { query := `SELECT id, volunteer_id, starts_at, ends_at, reason, status, reviewed_by, reviewed_at, created_at, updated_at FROM time_off_requests` args := []any{} if volunteerID > 0 { @@ -93,7 +96,7 @@ func (s *Store) List(volunteerID int64) ([]Request, error) { } query += ` ORDER BY starts_at DESC` - rows, err := s.db.Query(query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("list time off requests: %w", err) } @@ -128,13 +131,13 @@ func (s *Store) List(volunteerID int64) ([]Request, error) { return requests, rows.Err() } -func (s *Store) Review(id, reviewerID int64, status string) (*Request, error) { - _, err := s.db.Exec( - `UPDATE time_off_requests SET status=?, reviewed_by=?, reviewed_at=datetime('now'), updated_at=datetime('now') WHERE id=?`, +func (s *Store) Review(ctx context.Context, id, reviewerID int64, status string) (*Request, error) { + _, err := s.db.ExecContext(ctx, + `UPDATE time_off_requests SET status=?, reviewed_by=?, reviewed_at=NOW(), updated_at=NOW() WHERE id=?`, status, reviewerID, id, ) if err != nil { return nil, fmt.Errorf("review time off request: %w", err) } - return s.GetByID(id) + return s.GetByID(ctx, id) } diff --git a/internal/volunteer/handler.go b/internal/volunteer/handler.go index 2db0125..9533f91 100644 --- a/internal/volunteer/handler.go +++ b/internal/volunteer/handler.go @@ -2,12 +2,13 @@ package volunteer import ( "encoding/json" + "errors" "net/http" "strconv" - "github.com/go-chi/chi/v5" "git.unsupervised.ca/walkies/internal/auth" "git.unsupervised.ca/walkies/internal/respond" + "github.com/go-chi/chi/v5" ) type Handler struct { @@ -38,7 +39,7 @@ func (h *Handler) Register(w http.ResponseWriter, r *http.Request) { respond.Error(w, http.StatusInternalServerError, "could not hash password") return } - v, err := h.store.Create(in.Name, in.Email, hash, in.Role) + v, err := h.store.Create(r.Context(), in.Name, in.Email, hash, in.Role) if err != nil { respond.Error(w, http.StatusConflict, "email already in use") return @@ -56,7 +57,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { respond.Error(w, http.StatusBadRequest, "invalid request body") return } - token, err := h.authSvc.Login(body.Email, body.Password) + token, err := h.authSvc.Login(r.Context(), body.Email, body.Password) if err != nil { respond.Error(w, http.StatusUnauthorized, "invalid credentials") return @@ -66,7 +67,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { // GET /api/v1/volunteers func (h *Handler) List(w http.ResponseWriter, r *http.Request) { - volunteers, err := h.store.List(true) + volunteers, err := h.store.List(r.Context(), true) if err != nil { respond.Error(w, http.StatusInternalServerError, "could not list volunteers") return @@ -84,13 +85,13 @@ func (h *Handler) Get(w http.ResponseWriter, r *http.Request) { respond.Error(w, http.StatusBadRequest, "invalid id") return } - v, err := h.store.GetByID(id) - if err != nil { - respond.Error(w, http.StatusInternalServerError, "could not get volunteer") + v, err := h.store.GetByID(r.Context(), id) + if errors.Is(err, ErrNotFound) { + respond.Error(w, http.StatusNotFound, "volunteer not found") return } - if v == nil { - respond.Error(w, http.StatusNotFound, "volunteer not found") + if err != nil { + respond.Error(w, http.StatusInternalServerError, "could not get volunteer") return } respond.JSON(w, http.StatusOK, v) @@ -108,14 +109,14 @@ func (h *Handler) Update(w http.ResponseWriter, r *http.Request) { respond.Error(w, http.StatusBadRequest, "invalid request body") return } - v, err := h.store.Update(id, in) + v, err := h.store.Update(r.Context(), id, in) + if errors.Is(err, ErrNotFound) { + respond.Error(w, http.StatusNotFound, "volunteer not found") + return + } if err != nil { respond.Error(w, http.StatusInternalServerError, "could not update volunteer") return } - if v == nil { - respond.Error(w, http.StatusNotFound, "volunteer not found") - return - } respond.JSON(w, http.StatusOK, v) } diff --git a/internal/volunteer/volunteer.go b/internal/volunteer/volunteer.go index 023ab40..8440929 100644 --- a/internal/volunteer/volunteer.go +++ b/internal/volunteer/volunteer.go @@ -1,12 +1,15 @@ package volunteer import ( + "context" "database/sql" "errors" "fmt" "time" ) +var ErrNotFound = fmt.Errorf("volunteer not found") + type Volunteer struct { ID int64 `json:"id"` Name string `json:"name"` @@ -39,8 +42,8 @@ func NewStore(db *sql.DB) *Store { return &Store{db: db} } -func (s *Store) Create(name, email, hashedPassword, role string) (*Volunteer, error) { - res, err := s.db.Exec( +func (s *Store) Create(ctx context.Context, name, email, hashedPassword, role string) (*Volunteer, error) { + res, err := s.db.ExecContext(ctx, `INSERT INTO volunteers (name, email, password, role) VALUES (?, ?, ?, ?)`, name, email, hashedPassword, role, ) @@ -48,17 +51,17 @@ func (s *Store) Create(name, email, hashedPassword, role string) (*Volunteer, er return nil, fmt.Errorf("insert volunteer: %w", err) } id, _ := res.LastInsertId() - return s.GetByID(id) + return s.GetByID(ctx, id) } -func (s *Store) GetByID(id int64) (*Volunteer, error) { +func (s *Store) GetByID(ctx context.Context, id int64) (*Volunteer, error) { v := &Volunteer{} var createdAt, updatedAt string - err := s.db.QueryRow( + err := s.db.QueryRowContext(ctx, `SELECT id, name, email, role, active, created_at, updated_at FROM volunteers WHERE id = ?`, id, ).Scan(&v.ID, &v.Name, &v.Email, &v.Role, &v.Active, &createdAt, &updatedAt) if errors.Is(err, sql.ErrNoRows) { - return nil, nil + return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("get volunteer: %w", err) @@ -68,14 +71,14 @@ func (s *Store) GetByID(id int64) (*Volunteer, error) { return v, nil } -func (s *Store) List(activeOnly bool) ([]Volunteer, error) { +func (s *Store) List(ctx context.Context, activeOnly bool) ([]Volunteer, error) { query := `SELECT id, name, email, role, active, created_at, updated_at FROM volunteers` if activeOnly { query += ` WHERE active = 1` } query += ` ORDER BY name` - rows, err := s.db.Query(query) + rows, err := s.db.QueryContext(ctx, query) if err != nil { return nil, fmt.Errorf("list volunteers: %w", err) } @@ -95,10 +98,10 @@ func (s *Store) List(activeOnly bool) ([]Volunteer, error) { return volunteers, rows.Err() } -func (s *Store) Update(id int64, in UpdateInput) (*Volunteer, error) { - v, err := s.GetByID(id) - if err != nil || v == nil { - return v, err +func (s *Store) Update(ctx context.Context, id int64, in UpdateInput) (*Volunteer, error) { + v, err := s.GetByID(ctx, id) + if err != nil { + return nil, err } if in.Name != nil { v.Name = *in.Name @@ -116,12 +119,12 @@ func (s *Store) Update(id int64, in UpdateInput) (*Volunteer, error) { if v.Active { activeInt = 1 } - _, err = s.db.Exec( - `UPDATE volunteers SET name=?, email=?, role=?, active=?, updated_at=datetime('now') WHERE id=?`, + _, err = s.db.ExecContext(ctx, + `UPDATE volunteers SET name=?, email=?, role=?, active=?, updated_at=NOW() WHERE id=?`, v.Name, v.Email, v.Role, activeInt, id, ) if err != nil { return nil, fmt.Errorf("update volunteer: %w", err) } - return s.GetByID(id) + return s.GetByID(ctx, id) }