mirror of
https://codeberg.org/forgejo/forgejo.git
synced 2025-05-31 11:52:10 +00:00
Refactor auth package (#17962)
This commit is contained in:
parent
e61b390d54
commit
de8e3948a5
87 changed files with 2880 additions and 2770 deletions
22
models/auth/main_test.go
Normal file
22
models/auth/main_test.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
unittest.MainTest(m, filepath.Join("..", ".."),
|
||||
"login_source.yml",
|
||||
"oauth2_application.yml",
|
||||
"oauth2_authorization_code.yml",
|
||||
"oauth2_grant.yml",
|
||||
"u2f_registration.yml",
|
||||
)
|
||||
}
|
564
models/auth/oauth2.go
Normal file
564
models/auth/oauth2.go
Normal file
|
@ -0,0 +1,564 @@
|
|||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/secret"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
uuid "github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"xorm.io/xorm"
|
||||
)
|
||||
|
||||
// OAuth2Application represents an OAuth2 client (RFC 6749)
|
||||
type OAuth2Application struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
UID int64 `xorm:"INDEX"`
|
||||
Name string
|
||||
ClientID string `xorm:"unique"`
|
||||
ClientSecret string
|
||||
RedirectURIs []string `xorm:"redirect_uris JSON TEXT"`
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(OAuth2Application))
|
||||
db.RegisterModel(new(OAuth2AuthorizationCode))
|
||||
db.RegisterModel(new(OAuth2Grant))
|
||||
}
|
||||
|
||||
// TableName sets the table name to `oauth2_application`
|
||||
func (app *OAuth2Application) TableName() string {
|
||||
return "oauth2_application"
|
||||
}
|
||||
|
||||
// PrimaryRedirectURI returns the first redirect uri or an empty string if empty
|
||||
func (app *OAuth2Application) PrimaryRedirectURI() string {
|
||||
if len(app.RedirectURIs) == 0 {
|
||||
return ""
|
||||
}
|
||||
return app.RedirectURIs[0]
|
||||
}
|
||||
|
||||
// ContainsRedirectURI checks if redirectURI is allowed for app
|
||||
func (app *OAuth2Application) ContainsRedirectURI(redirectURI string) bool {
|
||||
return util.IsStringInSlice(redirectURI, app.RedirectURIs, true)
|
||||
}
|
||||
|
||||
// GenerateClientSecret will generate the client secret and returns the plaintext and saves the hash at the database
|
||||
func (app *OAuth2Application) GenerateClientSecret() (string, error) {
|
||||
clientSecret, err := secret.New()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
app.ClientSecret = string(hashedSecret)
|
||||
if _, err := db.GetEngine(db.DefaultContext).ID(app.ID).Cols("client_secret").Update(app); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return clientSecret, nil
|
||||
}
|
||||
|
||||
// ValidateClientSecret validates the given secret by the hash saved in database
|
||||
func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
|
||||
return bcrypt.CompareHashAndPassword([]byte(app.ClientSecret), secret) == nil
|
||||
}
|
||||
|
||||
// GetGrantByUserID returns a OAuth2Grant by its user and application ID
|
||||
func (app *OAuth2Application) GetGrantByUserID(userID int64) (*OAuth2Grant, error) {
|
||||
return app.getGrantByUserID(db.GetEngine(db.DefaultContext), userID)
|
||||
}
|
||||
|
||||
func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant *OAuth2Grant, err error) {
|
||||
grant = new(OAuth2Grant)
|
||||
if has, err := e.Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
}
|
||||
return grant, nil
|
||||
}
|
||||
|
||||
// CreateGrant generates a grant for an user
|
||||
func (app *OAuth2Application) CreateGrant(userID int64, scope string) (*OAuth2Grant, error) {
|
||||
return app.createGrant(db.GetEngine(db.DefaultContext), userID, scope)
|
||||
}
|
||||
|
||||
func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope string) (*OAuth2Grant, error) {
|
||||
grant := &OAuth2Grant{
|
||||
ApplicationID: app.ID,
|
||||
UserID: userID,
|
||||
Scope: scope,
|
||||
}
|
||||
_, err := e.Insert(grant)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return grant, nil
|
||||
}
|
||||
|
||||
// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
|
||||
func GetOAuth2ApplicationByClientID(clientID string) (app *OAuth2Application, err error) {
|
||||
return getOAuth2ApplicationByClientID(db.GetEngine(db.DefaultContext), clientID)
|
||||
}
|
||||
|
||||
func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Application, err error) {
|
||||
app = new(OAuth2Application)
|
||||
has, err := e.Where("client_id = ?", clientID).Get(app)
|
||||
if !has {
|
||||
return nil, ErrOAuthClientIDInvalid{ClientID: clientID}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
|
||||
func GetOAuth2ApplicationByID(id int64) (app *OAuth2Application, err error) {
|
||||
return getOAuth2ApplicationByID(db.GetEngine(db.DefaultContext), id)
|
||||
}
|
||||
|
||||
func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, err error) {
|
||||
app = new(OAuth2Application)
|
||||
has, err := e.ID(id).Get(app)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !has {
|
||||
return nil, ErrOAuthApplicationNotFound{ID: id}
|
||||
}
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// GetOAuth2ApplicationsByUserID returns all oauth2 applications owned by the user
|
||||
func GetOAuth2ApplicationsByUserID(userID int64) (apps []*OAuth2Application, err error) {
|
||||
return getOAuth2ApplicationsByUserID(db.GetEngine(db.DefaultContext), userID)
|
||||
}
|
||||
|
||||
func getOAuth2ApplicationsByUserID(e db.Engine, userID int64) (apps []*OAuth2Application, err error) {
|
||||
apps = make([]*OAuth2Application, 0)
|
||||
err = e.Where("uid = ?", userID).Find(&apps)
|
||||
return
|
||||
}
|
||||
|
||||
// CreateOAuth2ApplicationOptions holds options to create an oauth2 application
|
||||
type CreateOAuth2ApplicationOptions struct {
|
||||
Name string
|
||||
UserID int64
|
||||
RedirectURIs []string
|
||||
}
|
||||
|
||||
// CreateOAuth2Application inserts a new oauth2 application
|
||||
func CreateOAuth2Application(opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
|
||||
return createOAuth2Application(db.GetEngine(db.DefaultContext), opts)
|
||||
}
|
||||
|
||||
func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
|
||||
clientID := uuid.New().String()
|
||||
app := &OAuth2Application{
|
||||
UID: opts.UserID,
|
||||
Name: opts.Name,
|
||||
ClientID: clientID,
|
||||
RedirectURIs: opts.RedirectURIs,
|
||||
}
|
||||
if _, err := e.Insert(app); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// UpdateOAuth2ApplicationOptions holds options to update an oauth2 application
|
||||
type UpdateOAuth2ApplicationOptions struct {
|
||||
ID int64
|
||||
Name string
|
||||
UserID int64
|
||||
RedirectURIs []string
|
||||
}
|
||||
|
||||
// UpdateOAuth2Application updates an oauth2 application
|
||||
func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) {
|
||||
ctx, committer, err := db.TxContext()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer committer.Close()
|
||||
sess := db.GetEngine(ctx)
|
||||
|
||||
app, err := getOAuth2ApplicationByID(sess, opts.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if app.UID != opts.UserID {
|
||||
return nil, fmt.Errorf("UID mismatch")
|
||||
}
|
||||
|
||||
app.Name = opts.Name
|
||||
app.RedirectURIs = opts.RedirectURIs
|
||||
|
||||
if err = updateOAuth2Application(sess, app); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
app.ClientSecret = ""
|
||||
|
||||
return app, committer.Commit()
|
||||
}
|
||||
|
||||
func updateOAuth2Application(e db.Engine, app *OAuth2Application) error {
|
||||
if _, err := e.ID(app.ID).Update(app); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteOAuth2Application(sess db.Engine, id, userid int64) error {
|
||||
if deleted, err := sess.Delete(&OAuth2Application{ID: id, UID: userid}); err != nil {
|
||||
return err
|
||||
} else if deleted == 0 {
|
||||
return ErrOAuthApplicationNotFound{ID: id}
|
||||
}
|
||||
codes := make([]*OAuth2AuthorizationCode, 0)
|
||||
// delete correlating auth codes
|
||||
if err := sess.Join("INNER", "oauth2_grant",
|
||||
"oauth2_authorization_code.grant_id = oauth2_grant.id AND oauth2_grant.application_id = ?", id).Find(&codes); err != nil {
|
||||
return err
|
||||
}
|
||||
codeIDs := make([]int64, 0)
|
||||
for _, grant := range codes {
|
||||
codeIDs = append(codeIDs, grant.ID)
|
||||
}
|
||||
|
||||
if _, err := sess.In("id", codeIDs).Delete(new(OAuth2AuthorizationCode)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := sess.Where("application_id = ?", id).Delete(new(OAuth2Grant)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app.
|
||||
func DeleteOAuth2Application(id, userid int64) error {
|
||||
ctx, committer, err := db.TxContext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer committer.Close()
|
||||
if err := deleteOAuth2Application(db.GetEngine(ctx), id, userid); err != nil {
|
||||
return err
|
||||
}
|
||||
return committer.Commit()
|
||||
}
|
||||
|
||||
// ListOAuth2Applications returns a list of oauth2 applications belongs to given user.
|
||||
func ListOAuth2Applications(uid int64, listOptions db.ListOptions) ([]*OAuth2Application, int64, error) {
|
||||
sess := db.GetEngine(db.DefaultContext).
|
||||
Where("uid=?", uid).
|
||||
Desc("id")
|
||||
|
||||
if listOptions.Page != 0 {
|
||||
sess = db.SetSessionPagination(sess, &listOptions)
|
||||
|
||||
apps := make([]*OAuth2Application, 0, listOptions.PageSize)
|
||||
total, err := sess.FindAndCount(&apps)
|
||||
return apps, total, err
|
||||
}
|
||||
|
||||
apps := make([]*OAuth2Application, 0, 5)
|
||||
total, err := sess.FindAndCount(&apps)
|
||||
return apps, total, err
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////
|
||||
|
||||
// OAuth2AuthorizationCode is a code to obtain an access token in combination with the client secret once. It has a limited lifetime.
|
||||
type OAuth2AuthorizationCode struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Grant *OAuth2Grant `xorm:"-"`
|
||||
GrantID int64
|
||||
Code string `xorm:"INDEX unique"`
|
||||
CodeChallenge string
|
||||
CodeChallengeMethod string
|
||||
RedirectURI string
|
||||
ValidUntil timeutil.TimeStamp `xorm:"index"`
|
||||
}
|
||||
|
||||
// TableName sets the table name to `oauth2_authorization_code`
|
||||
func (code *OAuth2AuthorizationCode) TableName() string {
|
||||
return "oauth2_authorization_code"
|
||||
}
|
||||
|
||||
// GenerateRedirectURI generates a redirect URI for a successful authorization request. State will be used if not empty.
|
||||
func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (redirect *url.URL, err error) {
|
||||
if redirect, err = url.Parse(code.RedirectURI); err != nil {
|
||||
return
|
||||
}
|
||||
q := redirect.Query()
|
||||
if state != "" {
|
||||
q.Set("state", state)
|
||||
}
|
||||
q.Set("code", code.Code)
|
||||
redirect.RawQuery = q.Encode()
|
||||
return
|
||||
}
|
||||
|
||||
// Invalidate deletes the auth code from the database to invalidate this code
|
||||
func (code *OAuth2AuthorizationCode) Invalidate() error {
|
||||
return code.invalidate(db.GetEngine(db.DefaultContext))
|
||||
}
|
||||
|
||||
func (code *OAuth2AuthorizationCode) invalidate(e db.Engine) error {
|
||||
_, err := e.Delete(code)
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation.
|
||||
func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool {
|
||||
return code.validateCodeChallenge(verifier)
|
||||
}
|
||||
|
||||
func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool {
|
||||
switch code.CodeChallengeMethod {
|
||||
case "S256":
|
||||
// base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
hashedVerifier := base64.RawURLEncoding.EncodeToString(h[:])
|
||||
return hashedVerifier == code.CodeChallenge
|
||||
case "plain":
|
||||
return verifier == code.CodeChallenge
|
||||
case "":
|
||||
return true
|
||||
default:
|
||||
// unsupported method -> return false
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetOAuth2AuthorizationByCode returns an authorization by its code
|
||||
func GetOAuth2AuthorizationByCode(code string) (*OAuth2AuthorizationCode, error) {
|
||||
return getOAuth2AuthorizationByCode(db.GetEngine(db.DefaultContext), code)
|
||||
}
|
||||
|
||||
func getOAuth2AuthorizationByCode(e db.Engine, code string) (auth *OAuth2AuthorizationCode, err error) {
|
||||
auth = new(OAuth2AuthorizationCode)
|
||||
if has, err := e.Where("code = ?", code).Get(auth); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
}
|
||||
auth.Grant = new(OAuth2Grant)
|
||||
if has, err := e.ID(auth.GrantID).Get(auth.Grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////
|
||||
|
||||
// OAuth2Grant represents the permission of an user for a specific application to access resources
|
||||
type OAuth2Grant struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
UserID int64 `xorm:"INDEX unique(user_application)"`
|
||||
Application *OAuth2Application `xorm:"-"`
|
||||
ApplicationID int64 `xorm:"INDEX unique(user_application)"`
|
||||
Counter int64 `xorm:"NOT NULL DEFAULT 1"`
|
||||
Scope string `xorm:"TEXT"`
|
||||
Nonce string `xorm:"TEXT"`
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
|
||||
}
|
||||
|
||||
// TableName sets the table name to `oauth2_grant`
|
||||
func (grant *OAuth2Grant) TableName() string {
|
||||
return "oauth2_grant"
|
||||
}
|
||||
|
||||
// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
|
||||
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(redirectURI, codeChallenge, codeChallengeMethod string) (*OAuth2AuthorizationCode, error) {
|
||||
return grant.generateNewAuthorizationCode(db.GetEngine(db.DefaultContext), redirectURI, codeChallenge, codeChallengeMethod)
|
||||
}
|
||||
|
||||
func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
|
||||
var codeSecret string
|
||||
if codeSecret, err = secret.New(); err != nil {
|
||||
return &OAuth2AuthorizationCode{}, err
|
||||
}
|
||||
code = &OAuth2AuthorizationCode{
|
||||
Grant: grant,
|
||||
GrantID: grant.ID,
|
||||
RedirectURI: redirectURI,
|
||||
Code: codeSecret,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
}
|
||||
if _, err := e.Insert(code); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// IncreaseCounter increases the counter and updates the grant
|
||||
func (grant *OAuth2Grant) IncreaseCounter() error {
|
||||
return grant.increaseCount(db.GetEngine(db.DefaultContext))
|
||||
}
|
||||
|
||||
func (grant *OAuth2Grant) increaseCount(e db.Engine) error {
|
||||
_, err := e.ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updatedGrant, err := getOAuth2GrantByID(e, grant.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
grant.Counter = updatedGrant.Counter
|
||||
return nil
|
||||
}
|
||||
|
||||
// ScopeContains returns true if the grant scope contains the specified scope
|
||||
func (grant *OAuth2Grant) ScopeContains(scope string) bool {
|
||||
for _, currentScope := range strings.Split(grant.Scope, " ") {
|
||||
if scope == currentScope {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SetNonce updates the current nonce value of a grant
|
||||
func (grant *OAuth2Grant) SetNonce(nonce string) error {
|
||||
return grant.setNonce(db.GetEngine(db.DefaultContext), nonce)
|
||||
}
|
||||
|
||||
func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
|
||||
grant.Nonce = nonce
|
||||
_, err := e.ID(grant.ID).Cols("nonce").Update(grant)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOAuth2GrantByID returns the grant with the given ID
|
||||
func GetOAuth2GrantByID(id int64) (*OAuth2Grant, error) {
|
||||
return getOAuth2GrantByID(db.GetEngine(db.DefaultContext), id)
|
||||
}
|
||||
|
||||
func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
|
||||
grant = new(OAuth2Grant)
|
||||
if has, err := e.ID(id).Get(grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetOAuth2GrantsByUserID lists all grants of a certain user
|
||||
func GetOAuth2GrantsByUserID(uid int64) ([]*OAuth2Grant, error) {
|
||||
return getOAuth2GrantsByUserID(db.GetEngine(db.DefaultContext), uid)
|
||||
}
|
||||
|
||||
func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
|
||||
type joinedOAuth2Grant struct {
|
||||
Grant *OAuth2Grant `xorm:"extends"`
|
||||
Application *OAuth2Application `xorm:"extends"`
|
||||
}
|
||||
var results *xorm.Rows
|
||||
var err error
|
||||
if results, err = e.
|
||||
Table("oauth2_grant").
|
||||
Where("user_id = ?", uid).
|
||||
Join("INNER", "oauth2_application", "application_id = oauth2_application.id").
|
||||
Rows(new(joinedOAuth2Grant)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer results.Close()
|
||||
grants := make([]*OAuth2Grant, 0)
|
||||
for results.Next() {
|
||||
joinedGrant := new(joinedOAuth2Grant)
|
||||
if err := results.Scan(joinedGrant); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
joinedGrant.Grant.Application = joinedGrant.Application
|
||||
grants = append(grants, joinedGrant.Grant)
|
||||
}
|
||||
return grants, nil
|
||||
}
|
||||
|
||||
// RevokeOAuth2Grant deletes the grant with grantID and userID
|
||||
func RevokeOAuth2Grant(grantID, userID int64) error {
|
||||
return revokeOAuth2Grant(db.GetEngine(db.DefaultContext), grantID, userID)
|
||||
}
|
||||
|
||||
func revokeOAuth2Grant(e db.Engine, grantID, userID int64) error {
|
||||
_, err := e.Delete(&OAuth2Grant{ID: grantID, UserID: userID})
|
||||
return err
|
||||
}
|
||||
|
||||
// ErrOAuthClientIDInvalid will be thrown if client id cannot be found
|
||||
type ErrOAuthClientIDInvalid struct {
|
||||
ClientID string
|
||||
}
|
||||
|
||||
// IsErrOauthClientIDInvalid checks if an error is a ErrReviewNotExist.
|
||||
func IsErrOauthClientIDInvalid(err error) bool {
|
||||
_, ok := err.(ErrOAuthClientIDInvalid)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Error returns the error message
|
||||
func (err ErrOAuthClientIDInvalid) Error() string {
|
||||
return fmt.Sprintf("Client ID invalid [Client ID: %s]", err.ClientID)
|
||||
}
|
||||
|
||||
// ErrOAuthApplicationNotFound will be thrown if id cannot be found
|
||||
type ErrOAuthApplicationNotFound struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
// IsErrOAuthApplicationNotFound checks if an error is a ErrReviewNotExist.
|
||||
func IsErrOAuthApplicationNotFound(err error) bool {
|
||||
_, ok := err.(ErrOAuthApplicationNotFound)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Error returns the error message
|
||||
func (err ErrOAuthApplicationNotFound) Error() string {
|
||||
return fmt.Sprintf("OAuth application not found [ID: %d]", err.ID)
|
||||
}
|
||||
|
||||
// GetActiveOAuth2ProviderSources returns all actived LoginOAuth2 sources
|
||||
func GetActiveOAuth2ProviderSources() ([]*Source, error) {
|
||||
sources := make([]*Source, 0, 1)
|
||||
if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, OAuth2).Find(&sources); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sources, nil
|
||||
}
|
||||
|
||||
// GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name
|
||||
func GetActiveOAuth2SourceByName(name string) (*Source, error) {
|
||||
authSource := new(Source)
|
||||
has, err := db.GetEngine(db.DefaultContext).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(authSource)
|
||||
if !has || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return authSource, nil
|
||||
}
|
233
models/auth/oauth2_test.go
Normal file
233
models/auth/oauth2_test.go
Normal file
|
@ -0,0 +1,233 @@
|
|||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
//////////////////// Application
|
||||
|
||||
func TestOAuth2Application_GenerateClientSecret(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
|
||||
secret, err := app.GenerateClientSecret()
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, len(secret) > 0)
|
||||
unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1, ClientSecret: app.ClientSecret})
|
||||
}
|
||||
|
||||
func BenchmarkOAuth2Application_GenerateClientSecret(b *testing.B) {
|
||||
assert.NoError(b, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(b, &OAuth2Application{ID: 1}).(*OAuth2Application)
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = app.GenerateClientSecret()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Application_ContainsRedirectURI(t *testing.T) {
|
||||
app := &OAuth2Application{
|
||||
RedirectURIs: []string{"a", "b", "c"},
|
||||
}
|
||||
assert.True(t, app.ContainsRedirectURI("a"))
|
||||
assert.True(t, app.ContainsRedirectURI("b"))
|
||||
assert.True(t, app.ContainsRedirectURI("c"))
|
||||
assert.False(t, app.ContainsRedirectURI("d"))
|
||||
}
|
||||
|
||||
func TestOAuth2Application_ValidateClientSecret(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
|
||||
secret, err := app.GenerateClientSecret()
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, app.ValidateClientSecret([]byte(secret)))
|
||||
assert.False(t, app.ValidateClientSecret([]byte("fewijfowejgfiowjeoifew")))
|
||||
}
|
||||
|
||||
func TestGetOAuth2ApplicationByClientID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app, err := GetOAuth2ApplicationByClientID("da7da3ba-9a13-4167-856f-3899de0b0138")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID)
|
||||
|
||||
app, err = GetOAuth2ApplicationByClientID("invalid client id")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, app)
|
||||
}
|
||||
|
||||
func TestCreateOAuth2Application(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app, err := CreateOAuth2Application(CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "newapp", app.Name)
|
||||
assert.Len(t, app.ClientID, 36)
|
||||
unittest.AssertExistsAndLoadBean(t, &OAuth2Application{Name: "newapp"})
|
||||
}
|
||||
|
||||
func TestOAuth2Application_TableName(t *testing.T) {
|
||||
assert.Equal(t, "oauth2_application", new(OAuth2Application).TableName())
|
||||
}
|
||||
|
||||
func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
|
||||
grant, err := app.GetGrantByUserID(1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), grant.UserID)
|
||||
|
||||
grant, err = app.GetGrantByUserID(34923458)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, grant)
|
||||
}
|
||||
|
||||
func TestOAuth2Application_CreateGrant(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
|
||||
grant, err := app.CreateGrant(2, "")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, grant)
|
||||
assert.Equal(t, int64(2), grant.UserID)
|
||||
assert.Equal(t, int64(1), grant.ApplicationID)
|
||||
assert.Equal(t, "", grant.Scope)
|
||||
}
|
||||
|
||||
//////////////////// Grant
|
||||
|
||||
func TestGetOAuth2GrantByID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant, err := GetOAuth2GrantByID(1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), grant.ID)
|
||||
|
||||
grant, err = GetOAuth2GrantByID(34923458)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, grant)
|
||||
}
|
||||
|
||||
func TestOAuth2Grant_IncreaseCounter(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 1}).(*OAuth2Grant)
|
||||
assert.NoError(t, grant.IncreaseCounter())
|
||||
assert.Equal(t, int64(2), grant.Counter)
|
||||
unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 2})
|
||||
}
|
||||
|
||||
func TestOAuth2Grant_ScopeContains(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Scope: "openid profile"}).(*OAuth2Grant)
|
||||
assert.True(t, grant.ScopeContains("openid"))
|
||||
assert.True(t, grant.ScopeContains("profile"))
|
||||
assert.False(t, grant.ScopeContains("profil"))
|
||||
assert.False(t, grant.ScopeContains("profile2"))
|
||||
}
|
||||
|
||||
func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1}).(*OAuth2Grant)
|
||||
code, err := grant.GenerateNewAuthorizationCode("https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, code)
|
||||
assert.True(t, len(code.Code) > 32) // secret length > 32
|
||||
}
|
||||
|
||||
func TestOAuth2Grant_TableName(t *testing.T) {
|
||||
assert.Equal(t, "oauth2_grant", new(OAuth2Grant).TableName())
|
||||
}
|
||||
|
||||
func TestGetOAuth2GrantsByUserID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
result, err := GetOAuth2GrantsByUserID(1)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, int64(1), result[0].ID)
|
||||
assert.Equal(t, result[0].ApplicationID, result[0].Application.ID)
|
||||
|
||||
result, err = GetOAuth2GrantsByUserID(34134)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestRevokeOAuth2Grant(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
assert.NoError(t, RevokeOAuth2Grant(1, 1))
|
||||
unittest.AssertNotExistsBean(t, &OAuth2Grant{ID: 1, UserID: 1})
|
||||
}
|
||||
|
||||
//////////////////// Authorization Code
|
||||
|
||||
func TestGetOAuth2AuthorizationByCode(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
code, err := GetOAuth2AuthorizationByCode("authcode")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, code)
|
||||
assert.Equal(t, "authcode", code.Code)
|
||||
assert.Equal(t, int64(1), code.ID)
|
||||
|
||||
code, err = GetOAuth2AuthorizationByCode("does not exist")
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, code)
|
||||
}
|
||||
|
||||
func TestOAuth2AuthorizationCode_ValidateCodeChallenge(t *testing.T) {
|
||||
// test plain
|
||||
code := &OAuth2AuthorizationCode{
|
||||
CodeChallengeMethod: "plain",
|
||||
CodeChallenge: "test123",
|
||||
}
|
||||
assert.True(t, code.ValidateCodeChallenge("test123"))
|
||||
assert.False(t, code.ValidateCodeChallenge("ierwgjoergjio"))
|
||||
|
||||
// test S256
|
||||
code = &OAuth2AuthorizationCode{
|
||||
CodeChallengeMethod: "S256",
|
||||
CodeChallenge: "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg",
|
||||
}
|
||||
assert.True(t, code.ValidateCodeChallenge("N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt"))
|
||||
assert.False(t, code.ValidateCodeChallenge("wiogjerogorewngoenrgoiuenorg"))
|
||||
|
||||
// test unknown
|
||||
code = &OAuth2AuthorizationCode{
|
||||
CodeChallengeMethod: "monkey",
|
||||
CodeChallenge: "foiwgjioriogeiogjerger",
|
||||
}
|
||||
assert.False(t, code.ValidateCodeChallenge("foiwgjioriogeiogjerger"))
|
||||
|
||||
// test no code challenge
|
||||
code = &OAuth2AuthorizationCode{
|
||||
CodeChallengeMethod: "",
|
||||
CodeChallenge: "foierjiogerogerg",
|
||||
}
|
||||
assert.True(t, code.ValidateCodeChallenge(""))
|
||||
}
|
||||
|
||||
func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) {
|
||||
code := &OAuth2AuthorizationCode{
|
||||
RedirectURI: "https://example.com/callback",
|
||||
Code: "thecode",
|
||||
}
|
||||
|
||||
redirect, err := code.GenerateRedirectURI("thestate")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com/callback?code=thecode&state=thestate", redirect.String())
|
||||
|
||||
redirect, err = code.GenerateRedirectURI("")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com/callback?code=thecode", redirect.String())
|
||||
}
|
||||
|
||||
func TestOAuth2AuthorizationCode_Invalidate(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
code := unittest.AssertExistsAndLoadBean(t, &OAuth2AuthorizationCode{Code: "authcode"}).(*OAuth2AuthorizationCode)
|
||||
assert.NoError(t, code.Invalidate())
|
||||
unittest.AssertNotExistsBean(t, &OAuth2AuthorizationCode{Code: "authcode"})
|
||||
}
|
||||
|
||||
func TestOAuth2AuthorizationCode_TableName(t *testing.T) {
|
||||
assert.Equal(t, "oauth2_authorization_code", new(OAuth2AuthorizationCode).TableName())
|
||||
}
|
126
models/auth/session.go
Normal file
126
models/auth/session.go
Normal file
|
@ -0,0 +1,126 @@
|
|||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
)
|
||||
|
||||
// Session represents a session compatible for go-chi session
|
||||
type Session struct {
|
||||
Key string `xorm:"pk CHAR(16)"` // has to be Key to match with go-chi/session
|
||||
Data []byte `xorm:"BLOB"` // on MySQL this has a maximum size of 64Kb - this may need to be increased
|
||||
Expiry timeutil.TimeStamp // has to be Expiry to match with go-chi/session
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(Session))
|
||||
}
|
||||
|
||||
// UpdateSession updates the session with provided id
|
||||
func UpdateSession(key string, data []byte) error {
|
||||
_, err := db.GetEngine(db.DefaultContext).ID(key).Update(&Session{
|
||||
Data: data,
|
||||
Expiry: timeutil.TimeStampNow(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// ReadSession reads the data for the provided session
|
||||
func ReadSession(key string) (*Session, error) {
|
||||
session := Session{
|
||||
Key: key,
|
||||
}
|
||||
|
||||
ctx, committer, err := db.TxContext()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer committer.Close()
|
||||
|
||||
if has, err := db.GetByBean(ctx, &session); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
session.Expiry = timeutil.TimeStampNow()
|
||||
if err := db.Insert(ctx, &session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &session, committer.Commit()
|
||||
}
|
||||
|
||||
// ExistSession checks if a session exists
|
||||
func ExistSession(key string) (bool, error) {
|
||||
session := Session{
|
||||
Key: key,
|
||||
}
|
||||
return db.GetEngine(db.DefaultContext).Get(&session)
|
||||
}
|
||||
|
||||
// DestroySession destroys a session
|
||||
func DestroySession(key string) error {
|
||||
_, err := db.GetEngine(db.DefaultContext).Delete(&Session{
|
||||
Key: key,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// RegenerateSession regenerates a session from the old id
|
||||
func RegenerateSession(oldKey, newKey string) (*Session, error) {
|
||||
ctx, committer, err := db.TxContext()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer committer.Close()
|
||||
|
||||
if has, err := db.GetByBean(ctx, &Session{
|
||||
Key: newKey,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
} else if has {
|
||||
return nil, fmt.Errorf("session Key: %s already exists", newKey)
|
||||
}
|
||||
|
||||
if has, err := db.GetByBean(ctx, &Session{
|
||||
Key: oldKey,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
if err := db.Insert(ctx, &Session{
|
||||
Key: oldKey,
|
||||
Expiry: timeutil.TimeStampNow(),
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := db.Exec(ctx, "UPDATE "+db.TableName(&Session{})+" SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := Session{
|
||||
Key: newKey,
|
||||
}
|
||||
if _, err := db.GetByBean(ctx, &s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &s, committer.Commit()
|
||||
}
|
||||
|
||||
// CountSessions returns the number of sessions
|
||||
func CountSessions() (int64, error) {
|
||||
return db.GetEngine(db.DefaultContext).Count(&Session{})
|
||||
}
|
||||
|
||||
// CleanupSessions cleans up expired sessions
|
||||
func CleanupSessions(maxLifetime int64) error {
|
||||
_, err := db.GetEngine(db.DefaultContext).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
|
||||
return err
|
||||
}
|
397
models/auth/source.go
Normal file
397
models/auth/source.go
Normal file
|
@ -0,0 +1,397 @@
|
|||
// Copyright 2014 The Gogs Authors. All rights reserved.
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
|
||||
"xorm.io/xorm"
|
||||
"xorm.io/xorm/convert"
|
||||
)
|
||||
|
||||
// Type represents an login type.
|
||||
type Type int
|
||||
|
||||
// Note: new type must append to the end of list to maintain compatibility.
|
||||
const (
|
||||
NoType Type = iota
|
||||
Plain // 1
|
||||
LDAP // 2
|
||||
SMTP // 3
|
||||
PAM // 4
|
||||
DLDAP // 5
|
||||
OAuth2 // 6
|
||||
SSPI // 7
|
||||
)
|
||||
|
||||
// String returns the string name of the LoginType
|
||||
func (typ Type) String() string {
|
||||
return Names[typ]
|
||||
}
|
||||
|
||||
// Int returns the int value of the LoginType
|
||||
func (typ Type) Int() int {
|
||||
return int(typ)
|
||||
}
|
||||
|
||||
// Names contains the name of LoginType values.
|
||||
var Names = map[Type]string{
|
||||
LDAP: "LDAP (via BindDN)",
|
||||
DLDAP: "LDAP (simple auth)", // Via direct bind
|
||||
SMTP: "SMTP",
|
||||
PAM: "PAM",
|
||||
OAuth2: "OAuth2",
|
||||
SSPI: "SPNEGO with SSPI",
|
||||
}
|
||||
|
||||
// Config represents login config as far as the db is concerned
|
||||
type Config interface {
|
||||
convert.Conversion
|
||||
}
|
||||
|
||||
// SkipVerifiable configurations provide a IsSkipVerify to check if SkipVerify is set
|
||||
type SkipVerifiable interface {
|
||||
IsSkipVerify() bool
|
||||
}
|
||||
|
||||
// HasTLSer configurations provide a HasTLS to check if TLS can be enabled
|
||||
type HasTLSer interface {
|
||||
HasTLS() bool
|
||||
}
|
||||
|
||||
// UseTLSer configurations provide a HasTLS to check if TLS is enabled
|
||||
type UseTLSer interface {
|
||||
UseTLS() bool
|
||||
}
|
||||
|
||||
// SSHKeyProvider configurations provide ProvidesSSHKeys to check if they provide SSHKeys
|
||||
type SSHKeyProvider interface {
|
||||
ProvidesSSHKeys() bool
|
||||
}
|
||||
|
||||
// RegisterableSource configurations provide RegisterSource which needs to be run on creation
|
||||
type RegisterableSource interface {
|
||||
RegisterSource() error
|
||||
UnregisterSource() error
|
||||
}
|
||||
|
||||
var registeredConfigs = map[Type]func() Config{}
|
||||
|
||||
// RegisterTypeConfig register a config for a provided type
|
||||
func RegisterTypeConfig(typ Type, exemplar Config) {
|
||||
if reflect.TypeOf(exemplar).Kind() == reflect.Ptr {
|
||||
// Pointer:
|
||||
registeredConfigs[typ] = func() Config {
|
||||
return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(Config)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Not a Pointer
|
||||
registeredConfigs[typ] = func() Config {
|
||||
return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(Config)
|
||||
}
|
||||
}
|
||||
|
||||
// SourceSettable configurations can have their authSource set on them
|
||||
type SourceSettable interface {
|
||||
SetAuthSource(*Source)
|
||||
}
|
||||
|
||||
// Source represents an external way for authorizing users.
|
||||
type Source struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Type Type
|
||||
Name string `xorm:"UNIQUE"`
|
||||
IsActive bool `xorm:"INDEX NOT NULL DEFAULT false"`
|
||||
IsSyncEnabled bool `xorm:"INDEX NOT NULL DEFAULT false"`
|
||||
Cfg convert.Conversion `xorm:"TEXT"`
|
||||
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
|
||||
}
|
||||
|
||||
// TableName xorm will read the table name from this method
|
||||
func (Source) TableName() string {
|
||||
return "login_source"
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(Source))
|
||||
}
|
||||
|
||||
// BeforeSet is invoked from XORM before setting the value of a field of this object.
|
||||
func (source *Source) BeforeSet(colName string, val xorm.Cell) {
|
||||
if colName == "type" {
|
||||
typ := Type(db.Cell2Int64(val))
|
||||
constructor, ok := registeredConfigs[typ]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
source.Cfg = constructor()
|
||||
if settable, ok := source.Cfg.(SourceSettable); ok {
|
||||
settable.SetAuthSource(source)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TypeName return name of this login source type.
|
||||
func (source *Source) TypeName() string {
|
||||
return Names[source.Type]
|
||||
}
|
||||
|
||||
// IsLDAP returns true of this source is of the LDAP type.
|
||||
func (source *Source) IsLDAP() bool {
|
||||
return source.Type == LDAP
|
||||
}
|
||||
|
||||
// IsDLDAP returns true of this source is of the DLDAP type.
|
||||
func (source *Source) IsDLDAP() bool {
|
||||
return source.Type == DLDAP
|
||||
}
|
||||
|
||||
// IsSMTP returns true of this source is of the SMTP type.
|
||||
func (source *Source) IsSMTP() bool {
|
||||
return source.Type == SMTP
|
||||
}
|
||||
|
||||
// IsPAM returns true of this source is of the PAM type.
|
||||
func (source *Source) IsPAM() bool {
|
||||
return source.Type == PAM
|
||||
}
|
||||
|
||||
// IsOAuth2 returns true of this source is of the OAuth2 type.
|
||||
func (source *Source) IsOAuth2() bool {
|
||||
return source.Type == OAuth2
|
||||
}
|
||||
|
||||
// IsSSPI returns true of this source is of the SSPI type.
|
||||
func (source *Source) IsSSPI() bool {
|
||||
return source.Type == SSPI
|
||||
}
|
||||
|
||||
// HasTLS returns true of this source supports TLS.
|
||||
func (source *Source) HasTLS() bool {
|
||||
hasTLSer, ok := source.Cfg.(HasTLSer)
|
||||
return ok && hasTLSer.HasTLS()
|
||||
}
|
||||
|
||||
// UseTLS returns true of this source is configured to use TLS.
|
||||
func (source *Source) UseTLS() bool {
|
||||
useTLSer, ok := source.Cfg.(UseTLSer)
|
||||
return ok && useTLSer.UseTLS()
|
||||
}
|
||||
|
||||
// SkipVerify returns true if this source is configured to skip SSL
|
||||
// verification.
|
||||
func (source *Source) SkipVerify() bool {
|
||||
skipVerifiable, ok := source.Cfg.(SkipVerifiable)
|
||||
return ok && skipVerifiable.IsSkipVerify()
|
||||
}
|
||||
|
||||
// CreateSource inserts a AuthSource in the DB if not already
|
||||
// existing with the given name.
|
||||
func CreateSource(source *Source) error {
|
||||
has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(Source))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if has {
|
||||
return ErrSourceAlreadyExist{source.Name}
|
||||
}
|
||||
// Synchronization is only available with LDAP for now
|
||||
if !source.IsLDAP() {
|
||||
source.IsSyncEnabled = false
|
||||
}
|
||||
|
||||
_, err = db.GetEngine(db.DefaultContext).Insert(source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !source.IsActive {
|
||||
return nil
|
||||
}
|
||||
|
||||
if settable, ok := source.Cfg.(SourceSettable); ok {
|
||||
settable.SetAuthSource(source)
|
||||
}
|
||||
|
||||
registerableSource, ok := source.Cfg.(RegisterableSource)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = registerableSource.RegisterSource()
|
||||
if err != nil {
|
||||
// remove the AuthSource in case of errors while registering configuration
|
||||
if _, err := db.GetEngine(db.DefaultContext).Delete(source); err != nil {
|
||||
log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Sources returns a slice of all login sources found in DB.
|
||||
func Sources() ([]*Source, error) {
|
||||
auths := make([]*Source, 0, 6)
|
||||
return auths, db.GetEngine(db.DefaultContext).Find(&auths)
|
||||
}
|
||||
|
||||
// SourcesByType returns all sources of the specified type
|
||||
func SourcesByType(loginType Type) ([]*Source, error) {
|
||||
sources := make([]*Source, 0, 1)
|
||||
if err := db.GetEngine(db.DefaultContext).Where("type = ?", loginType).Find(&sources); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sources, nil
|
||||
}
|
||||
|
||||
// AllActiveSources returns all active sources
|
||||
func AllActiveSources() ([]*Source, error) {
|
||||
sources := make([]*Source, 0, 5)
|
||||
if err := db.GetEngine(db.DefaultContext).Where("is_active = ?", true).Find(&sources); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sources, nil
|
||||
}
|
||||
|
||||
// ActiveSources returns all active sources of the specified type
|
||||
func ActiveSources(tp Type) ([]*Source, error) {
|
||||
sources := make([]*Source, 0, 1)
|
||||
if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sources, nil
|
||||
}
|
||||
|
||||
// IsSSPIEnabled returns true if there is at least one activated login
|
||||
// source of type LoginSSPI
|
||||
func IsSSPIEnabled() bool {
|
||||
if !db.HasEngine {
|
||||
return false
|
||||
}
|
||||
sources, err := ActiveSources(SSPI)
|
||||
if err != nil {
|
||||
log.Error("ActiveSources: %v", err)
|
||||
return false
|
||||
}
|
||||
return len(sources) > 0
|
||||
}
|
||||
|
||||
// GetSourceByID returns login source by given ID.
|
||||
func GetSourceByID(id int64) (*Source, error) {
|
||||
source := new(Source)
|
||||
if id == 0 {
|
||||
source.Cfg = registeredConfigs[NoType]()
|
||||
// Set this source to active
|
||||
// FIXME: allow disabling of db based password authentication in future
|
||||
source.IsActive = true
|
||||
return source, nil
|
||||
}
|
||||
|
||||
has, err := db.GetEngine(db.DefaultContext).ID(id).Get(source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, ErrSourceNotExist{id}
|
||||
}
|
||||
return source, nil
|
||||
}
|
||||
|
||||
// UpdateSource updates a Source record in DB.
|
||||
func UpdateSource(source *Source) error {
|
||||
var originalSource *Source
|
||||
if source.IsOAuth2() {
|
||||
// keep track of the original values so we can restore in case of errors while registering OAuth2 providers
|
||||
var err error
|
||||
if originalSource, err = GetSourceByID(source.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
_, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !source.IsActive {
|
||||
return nil
|
||||
}
|
||||
|
||||
if settable, ok := source.Cfg.(SourceSettable); ok {
|
||||
settable.SetAuthSource(source)
|
||||
}
|
||||
|
||||
registerableSource, ok := source.Cfg.(RegisterableSource)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = registerableSource.RegisterSource()
|
||||
if err != nil {
|
||||
// restore original values since we cannot update the provider it self
|
||||
if _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(originalSource); err != nil {
|
||||
log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// CountSources returns number of login sources.
|
||||
func CountSources() int64 {
|
||||
count, _ := db.GetEngine(db.DefaultContext).Count(new(Source))
|
||||
return count
|
||||
}
|
||||
|
||||
// ErrSourceNotExist represents a "SourceNotExist" kind of error.
|
||||
type ErrSourceNotExist struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
// IsErrSourceNotExist checks if an error is a ErrSourceNotExist.
|
||||
func IsErrSourceNotExist(err error) bool {
|
||||
_, ok := err.(ErrSourceNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrSourceNotExist) Error() string {
|
||||
return fmt.Sprintf("login source does not exist [id: %d]", err.ID)
|
||||
}
|
||||
|
||||
// ErrSourceAlreadyExist represents a "SourceAlreadyExist" kind of error.
|
||||
type ErrSourceAlreadyExist struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// IsErrSourceAlreadyExist checks if an error is a ErrSourceAlreadyExist.
|
||||
func IsErrSourceAlreadyExist(err error) bool {
|
||||
_, ok := err.(ErrSourceAlreadyExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrSourceAlreadyExist) Error() string {
|
||||
return fmt.Sprintf("login source already exists [name: %s]", err.Name)
|
||||
}
|
||||
|
||||
// ErrSourceInUse represents a "SourceInUse" kind of error.
|
||||
type ErrSourceInUse struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
// IsErrSourceInUse checks if an error is a ErrSourceInUse.
|
||||
func IsErrSourceInUse(err error) bool {
|
||||
_, ok := err.(ErrSourceInUse)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrSourceInUse) Error() string {
|
||||
return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID)
|
||||
}
|
60
models/auth/source_test.go
Normal file
60
models/auth/source_test.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
"code.gitea.io/gitea/modules/json"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
type TestSource struct {
|
||||
Provider string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
OpenIDConnectAutoDiscoveryURL string
|
||||
IconURL string
|
||||
}
|
||||
|
||||
// FromDB fills up a LDAPConfig from serialized format.
|
||||
func (source *TestSource) FromDB(bs []byte) error {
|
||||
return json.Unmarshal(bs, &source)
|
||||
}
|
||||
|
||||
// ToDB exports a LDAPConfig to a serialized format.
|
||||
func (source *TestSource) ToDB() ([]byte, error) {
|
||||
return json.Marshal(source)
|
||||
}
|
||||
|
||||
func TestDumpAuthSource(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
authSourceSchema, err := db.TableInfo(new(Source))
|
||||
assert.NoError(t, err)
|
||||
|
||||
RegisterTypeConfig(OAuth2, new(TestSource))
|
||||
|
||||
CreateSource(&Source{
|
||||
Type: OAuth2,
|
||||
Name: "TestSource",
|
||||
IsActive: false,
|
||||
Cfg: &TestSource{
|
||||
Provider: "ConvertibleSourceName",
|
||||
ClientID: "42",
|
||||
},
|
||||
})
|
||||
|
||||
sb := new(strings.Builder)
|
||||
|
||||
db.DumpTables([]*schemas.Table{authSourceSchema}, sb)
|
||||
|
||||
assert.Contains(t, sb.String(), `"Provider":"ConvertibleSourceName"`)
|
||||
}
|
156
models/auth/twofactor.go
Normal file
156
models/auth/twofactor.go
Normal file
|
@ -0,0 +1,156 @@
|
|||
// Copyright 2017 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/secret"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"github.com/pquerna/otp/totp"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
//
|
||||
// Two-factor authentication
|
||||
//
|
||||
|
||||
// ErrTwoFactorNotEnrolled indicates that a user is not enrolled in two-factor authentication.
|
||||
type ErrTwoFactorNotEnrolled struct {
|
||||
UID int64
|
||||
}
|
||||
|
||||
// IsErrTwoFactorNotEnrolled checks if an error is a ErrTwoFactorNotEnrolled.
|
||||
func IsErrTwoFactorNotEnrolled(err error) bool {
|
||||
_, ok := err.(ErrTwoFactorNotEnrolled)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrTwoFactorNotEnrolled) Error() string {
|
||||
return fmt.Sprintf("user not enrolled in 2FA [uid: %d]", err.UID)
|
||||
}
|
||||
|
||||
// TwoFactor represents a two-factor authentication token.
|
||||
type TwoFactor struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
UID int64 `xorm:"UNIQUE"`
|
||||
Secret string
|
||||
ScratchSalt string
|
||||
ScratchHash string
|
||||
LastUsedPasscode string `xorm:"VARCHAR(10)"`
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(TwoFactor))
|
||||
}
|
||||
|
||||
// GenerateScratchToken recreates the scratch token the user is using.
|
||||
func (t *TwoFactor) GenerateScratchToken() (string, error) {
|
||||
token, err := util.RandomString(8)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
t.ScratchSalt, _ = util.RandomString(10)
|
||||
t.ScratchHash = HashToken(token, t.ScratchSalt)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// HashToken return the hashable salt
|
||||
func HashToken(token, salt string) string {
|
||||
tempHash := pbkdf2.Key([]byte(token), []byte(salt), 10000, 50, sha256.New)
|
||||
return fmt.Sprintf("%x", tempHash)
|
||||
}
|
||||
|
||||
// VerifyScratchToken verifies if the specified scratch token is valid.
|
||||
func (t *TwoFactor) VerifyScratchToken(token string) bool {
|
||||
if len(token) == 0 {
|
||||
return false
|
||||
}
|
||||
tempHash := HashToken(token, t.ScratchSalt)
|
||||
return subtle.ConstantTimeCompare([]byte(t.ScratchHash), []byte(tempHash)) == 1
|
||||
}
|
||||
|
||||
func (t *TwoFactor) getEncryptionKey() []byte {
|
||||
k := md5.Sum([]byte(setting.SecretKey))
|
||||
return k[:]
|
||||
}
|
||||
|
||||
// SetSecret sets the 2FA secret.
|
||||
func (t *TwoFactor) SetSecret(secretString string) error {
|
||||
secretBytes, err := secret.AesEncrypt(t.getEncryptionKey(), []byte(secretString))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.Secret = base64.StdEncoding.EncodeToString(secretBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateTOTP validates the provided passcode.
|
||||
func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) {
|
||||
decodedStoredSecret, err := base64.StdEncoding.DecodeString(t.Secret)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
secretBytes, err := secret.AesDecrypt(t.getEncryptionKey(), decodedStoredSecret)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
secretStr := string(secretBytes)
|
||||
return totp.Validate(passcode, secretStr), nil
|
||||
}
|
||||
|
||||
// NewTwoFactor creates a new two-factor authentication token.
|
||||
func NewTwoFactor(t *TwoFactor) error {
|
||||
_, err := db.GetEngine(db.DefaultContext).Insert(t)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateTwoFactor updates a two-factor authentication token.
|
||||
func UpdateTwoFactor(t *TwoFactor) error {
|
||||
_, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetTwoFactorByUID returns the two-factor authentication token associated with
|
||||
// the user, if any.
|
||||
func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {
|
||||
twofa := &TwoFactor{}
|
||||
has, err := db.GetEngine(db.DefaultContext).Where("uid=?", uid).Get(twofa)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, ErrTwoFactorNotEnrolled{uid}
|
||||
}
|
||||
return twofa, nil
|
||||
}
|
||||
|
||||
// HasTwoFactorByUID returns the two-factor authentication token associated with
|
||||
// the user, if any.
|
||||
func HasTwoFactorByUID(uid int64) (bool, error) {
|
||||
return db.GetEngine(db.DefaultContext).Where("uid=?", uid).Exist(&TwoFactor{})
|
||||
}
|
||||
|
||||
// DeleteTwoFactorByID deletes two-factor authentication token by given ID.
|
||||
func DeleteTwoFactorByID(id, userID int64) error {
|
||||
cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&TwoFactor{
|
||||
UID: userID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
} else if cnt != 1 {
|
||||
return ErrTwoFactorNotEnrolled{userID}
|
||||
}
|
||||
return nil
|
||||
}
|
154
models/auth/u2f.go
Normal file
154
models/auth/u2f.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
// Copyright 2018 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
|
||||
"github.com/tstranex/u2f"
|
||||
)
|
||||
|
||||
// ____ ________________________________ .__ __ __ .__
|
||||
// | | \_____ \_ _____/\______ \ ____ ____ |__| _______/ |_____________ _/ |_|__| ____ ____
|
||||
// | | // ____/| __) | _// __ \ / ___\| |/ ___/\ __\_ __ \__ \\ __\ |/ _ \ / \
|
||||
// | | // \| \ | | \ ___// /_/ > |\___ \ | | | | \// __ \| | | ( <_> ) | \
|
||||
// |______/ \_______ \___ / |____|_ /\___ >___ /|__/____ > |__| |__| (____ /__| |__|\____/|___| /
|
||||
// \/ \/ \/ \/_____/ \/ \/ \/
|
||||
|
||||
// ErrU2FRegistrationNotExist represents a "ErrU2FRegistrationNotExist" kind of error.
|
||||
type ErrU2FRegistrationNotExist struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
func (err ErrU2FRegistrationNotExist) Error() string {
|
||||
return fmt.Sprintf("U2F registration does not exist [id: %d]", err.ID)
|
||||
}
|
||||
|
||||
// IsErrU2FRegistrationNotExist checks if an error is a ErrU2FRegistrationNotExist.
|
||||
func IsErrU2FRegistrationNotExist(err error) bool {
|
||||
_, ok := err.(ErrU2FRegistrationNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
// U2FRegistration represents the registration data and counter of a security key
|
||||
type U2FRegistration struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Name string
|
||||
UserID int64 `xorm:"INDEX"`
|
||||
Raw []byte
|
||||
Counter uint32 `xorm:"BIGINT"`
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(U2FRegistration))
|
||||
}
|
||||
|
||||
// TableName returns a better table name for U2FRegistration
|
||||
func (reg U2FRegistration) TableName() string {
|
||||
return "u2f_registration"
|
||||
}
|
||||
|
||||
// Parse will convert the db entry U2FRegistration to an u2f.Registration struct
|
||||
func (reg *U2FRegistration) Parse() (*u2f.Registration, error) {
|
||||
r := new(u2f.Registration)
|
||||
return r, r.UnmarshalBinary(reg.Raw)
|
||||
}
|
||||
|
||||
func (reg *U2FRegistration) updateCounter(e db.Engine) error {
|
||||
_, err := e.ID(reg.ID).Cols("counter").Update(reg)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateCounter will update the database value of counter
|
||||
func (reg *U2FRegistration) UpdateCounter() error {
|
||||
return reg.updateCounter(db.GetEngine(db.DefaultContext))
|
||||
}
|
||||
|
||||
// U2FRegistrationList is a list of *U2FRegistration
|
||||
type U2FRegistrationList []*U2FRegistration
|
||||
|
||||
// ToRegistrations will convert all U2FRegistrations to u2f.Registrations
|
||||
func (list U2FRegistrationList) ToRegistrations() []u2f.Registration {
|
||||
regs := make([]u2f.Registration, 0, len(list))
|
||||
for _, reg := range list {
|
||||
r, err := reg.Parse()
|
||||
if err != nil {
|
||||
log.Error("parsing u2f registration: %v", err)
|
||||
continue
|
||||
}
|
||||
regs = append(regs, *r)
|
||||
}
|
||||
|
||||
return regs
|
||||
}
|
||||
|
||||
func getU2FRegistrationsByUID(e db.Engine, uid int64) (U2FRegistrationList, error) {
|
||||
regs := make(U2FRegistrationList, 0)
|
||||
return regs, e.Where("user_id = ?", uid).Find(®s)
|
||||
}
|
||||
|
||||
// GetU2FRegistrationByID returns U2F registration by id
|
||||
func GetU2FRegistrationByID(id int64) (*U2FRegistration, error) {
|
||||
return getU2FRegistrationByID(db.GetEngine(db.DefaultContext), id)
|
||||
}
|
||||
|
||||
func getU2FRegistrationByID(e db.Engine, id int64) (*U2FRegistration, error) {
|
||||
reg := new(U2FRegistration)
|
||||
if found, err := e.ID(id).Get(reg); err != nil {
|
||||
return nil, err
|
||||
} else if !found {
|
||||
return nil, ErrU2FRegistrationNotExist{ID: id}
|
||||
}
|
||||
return reg, nil
|
||||
}
|
||||
|
||||
// GetU2FRegistrationsByUID returns all U2F registrations of the given user
|
||||
func GetU2FRegistrationsByUID(uid int64) (U2FRegistrationList, error) {
|
||||
return getU2FRegistrationsByUID(db.GetEngine(db.DefaultContext), uid)
|
||||
}
|
||||
|
||||
// HasU2FRegistrationsByUID returns whether a given user has U2F registrations
|
||||
func HasU2FRegistrationsByUID(uid int64) (bool, error) {
|
||||
return db.GetEngine(db.DefaultContext).Where("user_id = ?", uid).Exist(&U2FRegistration{})
|
||||
}
|
||||
|
||||
func createRegistration(e db.Engine, userID int64, name string, reg *u2f.Registration) (*U2FRegistration, error) {
|
||||
raw, err := reg.MarshalBinary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := &U2FRegistration{
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
Counter: 0,
|
||||
Raw: raw,
|
||||
}
|
||||
_, err = e.InsertOne(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// CreateRegistration will create a new U2FRegistration from the given Registration
|
||||
func CreateRegistration(userID int64, name string, reg *u2f.Registration) (*U2FRegistration, error) {
|
||||
return createRegistration(db.GetEngine(db.DefaultContext), userID, name, reg)
|
||||
}
|
||||
|
||||
// DeleteRegistration will delete U2FRegistration
|
||||
func DeleteRegistration(reg *U2FRegistration) error {
|
||||
return deleteRegistration(db.GetEngine(db.DefaultContext), reg)
|
||||
}
|
||||
|
||||
func deleteRegistration(e db.Engine, reg *U2FRegistration) error {
|
||||
_, err := e.Delete(reg)
|
||||
return err
|
||||
}
|
100
models/auth/u2f_test.go
Normal file
100
models/auth/u2f_test.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tstranex/u2f"
|
||||
)
|
||||
|
||||
func TestGetU2FRegistrationByID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
res, err := GetU2FRegistrationByID(1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "U2F Key", res.Name)
|
||||
|
||||
_, err = GetU2FRegistrationByID(342432)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, IsErrU2FRegistrationNotExist(err))
|
||||
}
|
||||
|
||||
func TestGetU2FRegistrationsByUID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
res, err := GetU2FRegistrationsByUID(32)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, res, 1)
|
||||
assert.Equal(t, "U2F Key", res[0].Name)
|
||||
}
|
||||
|
||||
func TestU2FRegistration_TableName(t *testing.T) {
|
||||
assert.Equal(t, "u2f_registration", U2FRegistration{}.TableName())
|
||||
}
|
||||
|
||||
func TestU2FRegistration_UpdateCounter(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
reg := unittest.AssertExistsAndLoadBean(t, &U2FRegistration{ID: 1}).(*U2FRegistration)
|
||||
reg.Counter = 1
|
||||
assert.NoError(t, reg.UpdateCounter())
|
||||
unittest.AssertExistsIf(t, true, &U2FRegistration{ID: 1, Counter: 1})
|
||||
}
|
||||
|
||||
func TestU2FRegistration_UpdateLargeCounter(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
reg := unittest.AssertExistsAndLoadBean(t, &U2FRegistration{ID: 1}).(*U2FRegistration)
|
||||
reg.Counter = 0xffffffff
|
||||
assert.NoError(t, reg.UpdateCounter())
|
||||
unittest.AssertExistsIf(t, true, &U2FRegistration{ID: 1, Counter: 0xffffffff})
|
||||
}
|
||||
|
||||
func TestCreateRegistration(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
res, err := CreateRegistration(1, "U2F Created Key", &u2f.Registration{Raw: []byte("Test")})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "U2F Created Key", res.Name)
|
||||
assert.Equal(t, []byte("Test"), res.Raw)
|
||||
|
||||
unittest.AssertExistsIf(t, true, &U2FRegistration{Name: "U2F Created Key", UserID: 1})
|
||||
}
|
||||
|
||||
func TestDeleteRegistration(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
reg := unittest.AssertExistsAndLoadBean(t, &U2FRegistration{ID: 1}).(*U2FRegistration)
|
||||
|
||||
assert.NoError(t, DeleteRegistration(reg))
|
||||
unittest.AssertNotExistsBean(t, &U2FRegistration{ID: 1})
|
||||
}
|
||||
|
||||
const validU2FRegistrationResponseHex = "0504b174bc49c7ca254b70d2e5c207cee9cf174820ebd77ea3c65508c26da51b657c1cc6b952f8621697936482da0a6d3d3826a59095daf6cd7c03e2e60385d2f6d9402a552dfdb7477ed65fd84133f86196010b2215b57da75d315b7b9e8fe2e3925a6019551bab61d16591659cbaf00b4950f7abfe6660e2e006f76868b772d70c253082013c3081e4a003020102020a47901280001155957352300a06082a8648ce3d0403023017311530130603550403130c476e756262792050696c6f74301e170d3132303831343138323933325a170d3133303831343138323933325a3031312f302d0603550403132650696c6f74476e756262792d302e342e312d34373930313238303030313135353935373335323059301306072a8648ce3d020106082a8648ce3d030107034200048d617e65c9508e64bcc5673ac82a6799da3c1446682c258c463fffdf58dfd2fa3e6c378b53d795c4a4dffb4199edd7862f23abaf0203b4b8911ba0569994e101300a06082a8648ce3d0403020347003044022060cdb6061e9c22262d1aac1d96d8c70829b2366531dda268832cb836bcd30dfa0220631b1459f09e6330055722c8d89b7f48883b9089b88d60d1d9795902b30410df304502201471899bcc3987e62e8202c9b39c33c19033f7340352dba80fcab017db9230e402210082677d673d891933ade6f617e5dbde2e247e70423fd5ad7804a6d3d3961ef871"
|
||||
|
||||
func TestToRegistrations_SkipInvalidItemsWithoutCrashing(t *testing.T) {
|
||||
regKeyRaw, _ := hex.DecodeString(validU2FRegistrationResponseHex)
|
||||
regs := U2FRegistrationList{
|
||||
&U2FRegistration{ID: 1},
|
||||
&U2FRegistration{ID: 2, Name: "U2F Key", UserID: 2, Counter: 0, Raw: regKeyRaw, CreatedUnix: 946684800, UpdatedUnix: 946684800},
|
||||
}
|
||||
|
||||
actual := regs.ToRegistrations()
|
||||
assert.Len(t, actual, 1)
|
||||
}
|
||||
|
||||
func TestToRegistrations(t *testing.T) {
|
||||
regKeyRaw, _ := hex.DecodeString(validU2FRegistrationResponseHex)
|
||||
regs := U2FRegistrationList{
|
||||
&U2FRegistration{ID: 1, Name: "U2F Key", UserID: 1, Counter: 0, Raw: regKeyRaw, CreatedUnix: 946684800, UpdatedUnix: 946684800},
|
||||
&U2FRegistration{ID: 2, Name: "U2F Key", UserID: 2, Counter: 0, Raw: regKeyRaw, CreatedUnix: 946684800, UpdatedUnix: 946684800},
|
||||
}
|
||||
|
||||
actual := regs.ToRegistrations()
|
||||
assert.Len(t, actual, 2)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue