1
0
Fork 0
forked from forgejo/forgejo

add other session providers (#5963)

This commit is contained in:
techknowlogick 2019-02-05 11:52:51 -05:00 committed by GitHub
parent bf4badad1d
commit 9de871a0f8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
160 changed files with 37644 additions and 66 deletions

47
vendor/github.com/couchbase/goutils/LICENSE.md generated vendored Normal file
View file

@ -0,0 +1,47 @@
COUCHBASE INC. COMMUNITY EDITION LICENSE AGREEMENT
IMPORTANT-READ CAREFULLY: BY CLICKING THE "I ACCEPT" BOX OR INSTALLING,
DOWNLOADING OR OTHERWISE USING THIS SOFTWARE AND ANY ASSOCIATED
DOCUMENTATION, YOU, ON BEHALF OF YOURSELF OR AS AN AUTHORIZED
REPRESENTATIVE ON BEHALF OF AN ENTITY ("LICENSEE") AGREE TO ALL THE
TERMS OF THIS COMMUNITY EDITION LICENSE AGREEMENT (THE "AGREEMENT")
REGARDING YOUR USE OF THE SOFTWARE. YOU REPRESENT AND WARRANT THAT YOU
HAVE FULL LEGAL AUTHORITY TO BIND THE LICENSEE TO THIS AGREEMENT. IF YOU
DO NOT AGREE WITH ALL OF THESE TERMS, DO NOT SELECT THE "I ACCEPT" BOX
AND DO NOT INSTALL, DOWNLOAD OR OTHERWISE USE THE SOFTWARE. THE
EFFECTIVE DATE OF THIS AGREEMENT IS THE DATE ON WHICH YOU CLICK "I
ACCEPT" OR OTHERWISE INSTALL, DOWNLOAD OR USE THE SOFTWARE.
1. License Grant. Couchbase Inc. hereby grants Licensee, free of charge,
the non-exclusive right to use, copy, merge, publish, distribute,
sublicense, and/or sell copies of the Software, and to permit persons to
whom the Software is furnished to do so, subject to Licensee including
the following copyright notice in all copies or substantial portions of
the Software:
Couchbase (r) http://www.Couchbase.com Copyright 2016 Couchbase, Inc.
As used in this Agreement, "Software" means the object code version of
the applicable elastic data management server software provided by
Couchbase Inc.
2. Restrictions. Licensee will not reverse engineer, disassemble, or
decompile the Software (except to the extent such restrictions are
prohibited by law).
3. Support. Couchbase, Inc. will provide Licensee with access to, and
use of, the Couchbase, Inc. support forum available at the following
URL: http://www.couchbase.org/forums/. Couchbase, Inc. may, at its
discretion, modify, suspend or terminate support at any time upon notice
to Licensee.
4. Warranty Disclaimer and Limitation of Liability. THE SOFTWARE IS
PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
COUCHBASE INC. OR THE AUTHORS OR COPYRIGHT HOLDERS IN THE SOFTWARE BE
LIABLE FOR ANY CLAIM, DAMAGES (IINCLUDING, WITHOUT LIMITATION, DIRECT,
INDIRECT OR CONSEQUENTIAL DAMAGES) OR OTHER LIABILITY, WHETHER IN AN
ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

481
vendor/github.com/couchbase/goutils/logging/logger.go generated vendored Normal file
View file

@ -0,0 +1,481 @@
// Copyright (c) 2016 Couchbase, Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
// except in compliance with the License. You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the
// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
// either express or implied. See the License for the specific language governing permissions
// and limitations under the License.
package logging
import (
"os"
"runtime"
"strings"
"sync"
)
type Level int
const (
NONE = Level(iota) // Disable all logging
FATAL // System is in severe error state and has to abort
SEVERE // System is in severe error state and cannot recover reliably
ERROR // System is in error state but can recover and continue reliably
WARN // System approaching error state, or is in a correct but undesirable state
INFO // System-level events and status, in correct states
REQUEST // Request-level events, with request-specific rlevel
TRACE // Trace detailed system execution, e.g. function entry / exit
DEBUG // Debug
)
type LogEntryFormatter int
const (
TEXTFORMATTER = LogEntryFormatter(iota)
JSONFORMATTER
KVFORMATTER
)
func (level Level) String() string {
return _LEVEL_NAMES[level]
}
var _LEVEL_NAMES = []string{
DEBUG: "DEBUG",
TRACE: "TRACE",
REQUEST: "REQUEST",
INFO: "INFO",
WARN: "WARN",
ERROR: "ERROR",
SEVERE: "SEVERE",
FATAL: "FATAL",
NONE: "NONE",
}
var _LEVEL_MAP = map[string]Level{
"debug": DEBUG,
"trace": TRACE,
"request": REQUEST,
"info": INFO,
"warn": WARN,
"error": ERROR,
"severe": SEVERE,
"fatal": FATAL,
"none": NONE,
}
func ParseLevel(name string) (level Level, ok bool) {
level, ok = _LEVEL_MAP[strings.ToLower(name)]
return
}
/*
Pair supports logging of key-value pairs. Keys beginning with _ are
reserved for the logger, e.g. _time, _level, _msg, and _rlevel. The
Pair APIs are designed to avoid heap allocation and garbage
collection.
*/
type Pairs []Pair
type Pair struct {
Name string
Value interface{}
}
/*
Map allows key-value pairs to be specified using map literals or data
structures. For example:
Errorm(msg, Map{...})
Map incurs heap allocation and garbage collection, so the Pair APIs
should be preferred.
*/
type Map map[string]interface{}
// Logger provides a common interface for logging libraries
type Logger interface {
/*
These APIs write all the given pairs in addition to standard logger keys.
*/
Logp(level Level, msg string, kv ...Pair)
Debugp(msg string, kv ...Pair)
Tracep(msg string, kv ...Pair)
Requestp(rlevel Level, msg string, kv ...Pair)
Infop(msg string, kv ...Pair)
Warnp(msg string, kv ...Pair)
Errorp(msg string, kv ...Pair)
Severep(msg string, kv ...Pair)
Fatalp(msg string, kv ...Pair)
/*
These APIs write the fields in the given kv Map in addition to standard logger keys.
*/
Logm(level Level, msg string, kv Map)
Debugm(msg string, kv Map)
Tracem(msg string, kv Map)
Requestm(rlevel Level, msg string, kv Map)
Infom(msg string, kv Map)
Warnm(msg string, kv Map)
Errorm(msg string, kv Map)
Severem(msg string, kv Map)
Fatalm(msg string, kv Map)
/*
These APIs only write _msg, _time, _level, and other logger keys. If
the msg contains other fields, use the Pair or Map APIs instead.
*/
Logf(level Level, fmt string, args ...interface{})
Debugf(fmt string, args ...interface{})
Tracef(fmt string, args ...interface{})
Requestf(rlevel Level, fmt string, args ...interface{})
Infof(fmt string, args ...interface{})
Warnf(fmt string, args ...interface{})
Errorf(fmt string, args ...interface{})
Severef(fmt string, args ...interface{})
Fatalf(fmt string, args ...interface{})
/*
These APIs control the logging level
*/
SetLevel(Level) // Set the logging level
Level() Level // Get the current logging level
}
var logger Logger = nil
var curLevel Level = DEBUG // initially set to never skip
var loggerMutex sync.RWMutex
// All the methods below first acquire the mutex (mostly in exclusive mode)
// and only then check if logging at the current level is enabled.
// This introduces a fair bottleneck for those log entries that should be
// skipped (the majority, at INFO or below levels)
// We try to predict here if we should lock the mutex at all by caching
// the current log level: while dynamically changing logger, there might
// be the odd entry skipped as the new level is cached.
// Since we seem to never change the logger, this is not an issue.
func skipLogging(level Level) bool {
if logger == nil {
return true
}
return level > curLevel
}
func SetLogger(newLogger Logger) {
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger = newLogger
if logger == nil {
curLevel = NONE
} else {
curLevel = newLogger.Level()
}
}
func Logp(level Level, msg string, kv ...Pair) {
if skipLogging(level) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Logp(level, msg, kv...)
}
func Debugp(msg string, kv ...Pair) {
if skipLogging(DEBUG) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Debugp(msg, kv...)
}
func Tracep(msg string, kv ...Pair) {
if skipLogging(TRACE) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Tracep(msg, kv...)
}
func Requestp(rlevel Level, msg string, kv ...Pair) {
if skipLogging(REQUEST) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Requestp(rlevel, msg, kv...)
}
func Infop(msg string, kv ...Pair) {
if skipLogging(INFO) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Infop(msg, kv...)
}
func Warnp(msg string, kv ...Pair) {
if skipLogging(WARN) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Warnp(msg, kv...)
}
func Errorp(msg string, kv ...Pair) {
if skipLogging(ERROR) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Errorp(msg, kv...)
}
func Severep(msg string, kv ...Pair) {
if skipLogging(SEVERE) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Severep(msg, kv...)
}
func Fatalp(msg string, kv ...Pair) {
if skipLogging(FATAL) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Fatalp(msg, kv...)
}
func Logm(level Level, msg string, kv Map) {
if skipLogging(level) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Logm(level, msg, kv)
}
func Debugm(msg string, kv Map) {
if skipLogging(DEBUG) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Debugm(msg, kv)
}
func Tracem(msg string, kv Map) {
if skipLogging(TRACE) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Tracem(msg, kv)
}
func Requestm(rlevel Level, msg string, kv Map) {
if skipLogging(REQUEST) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Requestm(rlevel, msg, kv)
}
func Infom(msg string, kv Map) {
if skipLogging(INFO) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Infom(msg, kv)
}
func Warnm(msg string, kv Map) {
if skipLogging(WARN) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Warnm(msg, kv)
}
func Errorm(msg string, kv Map) {
if skipLogging(ERROR) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Errorm(msg, kv)
}
func Severem(msg string, kv Map) {
if skipLogging(SEVERE) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Severem(msg, kv)
}
func Fatalm(msg string, kv Map) {
if skipLogging(FATAL) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Fatalm(msg, kv)
}
func Logf(level Level, fmt string, args ...interface{}) {
if skipLogging(level) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Logf(level, fmt, args...)
}
func Debugf(fmt string, args ...interface{}) {
if skipLogging(DEBUG) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Debugf(fmt, args...)
}
func Tracef(fmt string, args ...interface{}) {
if skipLogging(TRACE) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Tracef(fmt, args...)
}
func Requestf(rlevel Level, fmt string, args ...interface{}) {
if skipLogging(REQUEST) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Requestf(rlevel, fmt, args...)
}
func Infof(fmt string, args ...interface{}) {
if skipLogging(INFO) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Infof(fmt, args...)
}
func Warnf(fmt string, args ...interface{}) {
if skipLogging(WARN) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Warnf(fmt, args...)
}
func Errorf(fmt string, args ...interface{}) {
if skipLogging(ERROR) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Errorf(fmt, args...)
}
func Severef(fmt string, args ...interface{}) {
if skipLogging(SEVERE) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Severef(fmt, args...)
}
func Fatalf(fmt string, args ...interface{}) {
if skipLogging(FATAL) {
return
}
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Fatalf(fmt, args...)
}
func SetLevel(level Level) {
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.SetLevel(level)
curLevel = level
}
func LogLevel() Level {
loggerMutex.RLock()
defer loggerMutex.RUnlock()
return logger.Level()
}
func Stackf(level Level, fmt string, args ...interface{}) {
if skipLogging(level) {
return
}
buf := make([]byte, 1<<16)
n := runtime.Stack(buf, false)
s := string(buf[0:n])
loggerMutex.Lock()
defer loggerMutex.Unlock()
logger.Logf(level, fmt, args...)
logger.Logf(level, s)
}
func init() {
logger = NewLogger(os.Stderr, INFO, TEXTFORMATTER)
SetLogger(logger)
}

View file

@ -0,0 +1,318 @@
// Copyright (c) 2016 Couchbase, Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
// except in compliance with the License. You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the
// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
// either express or implied. See the License for the specific language governing permissions
// and limitations under the License.
package logging
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"time"
)
type goLogger struct {
logger *log.Logger
level Level
entryFormatter formatter
}
const (
_LEVEL = "_level"
_MSG = "_msg"
_TIME = "_time"
_RLEVEL = "_rlevel"
)
func NewLogger(out io.Writer, lvl Level, fmtLogging LogEntryFormatter) *goLogger {
logger := &goLogger{
logger: log.New(out, "", 0),
level: lvl,
}
if fmtLogging == JSONFORMATTER {
logger.entryFormatter = &jsonFormatter{}
} else if fmtLogging == KVFORMATTER {
logger.entryFormatter = &keyvalueFormatter{}
} else {
logger.entryFormatter = &textFormatter{}
}
return logger
}
func (gl *goLogger) Logp(level Level, msg string, kv ...Pair) {
if gl.logger == nil {
return
}
if level <= gl.level {
e := newLogEntry(msg, level)
copyPairs(e, kv)
gl.log(e)
}
}
func (gl *goLogger) Debugp(msg string, kv ...Pair) {
gl.Logp(DEBUG, msg, kv...)
}
func (gl *goLogger) Tracep(msg string, kv ...Pair) {
gl.Logp(TRACE, msg, kv...)
}
func (gl *goLogger) Requestp(rlevel Level, msg string, kv ...Pair) {
if gl.logger == nil {
return
}
if REQUEST <= gl.level {
e := newLogEntry(msg, REQUEST)
e.Rlevel = rlevel
copyPairs(e, kv)
gl.log(e)
}
}
func (gl *goLogger) Infop(msg string, kv ...Pair) {
gl.Logp(INFO, msg, kv...)
}
func (gl *goLogger) Warnp(msg string, kv ...Pair) {
gl.Logp(WARN, msg, kv...)
}
func (gl *goLogger) Errorp(msg string, kv ...Pair) {
gl.Logp(ERROR, msg, kv...)
}
func (gl *goLogger) Severep(msg string, kv ...Pair) {
gl.Logp(SEVERE, msg, kv...)
}
func (gl *goLogger) Fatalp(msg string, kv ...Pair) {
gl.Logp(FATAL, msg, kv...)
}
func (gl *goLogger) Logm(level Level, msg string, kv Map) {
if gl.logger == nil {
return
}
if level <= gl.level {
e := newLogEntry(msg, level)
e.Data = kv
gl.log(e)
}
}
func (gl *goLogger) Debugm(msg string, kv Map) {
gl.Logm(DEBUG, msg, kv)
}
func (gl *goLogger) Tracem(msg string, kv Map) {
gl.Logm(TRACE, msg, kv)
}
func (gl *goLogger) Requestm(rlevel Level, msg string, kv Map) {
if gl.logger == nil {
return
}
if REQUEST <= gl.level {
e := newLogEntry(msg, REQUEST)
e.Rlevel = rlevel
e.Data = kv
gl.log(e)
}
}
func (gl *goLogger) Infom(msg string, kv Map) {
gl.Logm(INFO, msg, kv)
}
func (gl *goLogger) Warnm(msg string, kv Map) {
gl.Logm(WARN, msg, kv)
}
func (gl *goLogger) Errorm(msg string, kv Map) {
gl.Logm(ERROR, msg, kv)
}
func (gl *goLogger) Severem(msg string, kv Map) {
gl.Logm(SEVERE, msg, kv)
}
func (gl *goLogger) Fatalm(msg string, kv Map) {
gl.Logm(FATAL, msg, kv)
}
func (gl *goLogger) Logf(level Level, format string, args ...interface{}) {
if gl.logger == nil {
return
}
if level <= gl.level {
e := newLogEntry(fmt.Sprintf(format, args...), level)
gl.log(e)
}
}
func (gl *goLogger) Debugf(format string, args ...interface{}) {
gl.Logf(DEBUG, format, args...)
}
func (gl *goLogger) Tracef(format string, args ...interface{}) {
gl.Logf(TRACE, format, args...)
}
func (gl *goLogger) Requestf(rlevel Level, format string, args ...interface{}) {
if gl.logger == nil {
return
}
if REQUEST <= gl.level {
e := newLogEntry(fmt.Sprintf(format, args...), REQUEST)
e.Rlevel = rlevel
gl.log(e)
}
}
func (gl *goLogger) Infof(format string, args ...interface{}) {
gl.Logf(INFO, format, args...)
}
func (gl *goLogger) Warnf(format string, args ...interface{}) {
gl.Logf(WARN, format, args...)
}
func (gl *goLogger) Errorf(format string, args ...interface{}) {
gl.Logf(ERROR, format, args...)
}
func (gl *goLogger) Severef(format string, args ...interface{}) {
gl.Logf(SEVERE, format, args...)
}
func (gl *goLogger) Fatalf(format string, args ...interface{}) {
gl.Logf(FATAL, format, args...)
}
func (gl *goLogger) Level() Level {
return gl.level
}
func (gl *goLogger) SetLevel(level Level) {
gl.level = level
}
func (gl *goLogger) log(newEntry *logEntry) {
s := gl.entryFormatter.format(newEntry)
gl.logger.Print(s)
}
type logEntry struct {
Time string
Level Level
Rlevel Level
Message string
Data Map
}
func newLogEntry(msg string, level Level) *logEntry {
return &logEntry{
Time: time.Now().Format("2006-01-02T15:04:05.000-07:00"), // time.RFC3339 with milliseconds
Level: level,
Rlevel: NONE,
Message: msg,
}
}
func copyPairs(newEntry *logEntry, pairs []Pair) {
newEntry.Data = make(Map, len(pairs))
for _, p := range pairs {
newEntry.Data[p.Name] = p.Value
}
}
type formatter interface {
format(*logEntry) string
}
type textFormatter struct {
}
// ex. 2016-02-10T09:15:25.498-08:00 [INFO] This is a message from test in text format
func (*textFormatter) format(newEntry *logEntry) string {
b := &bytes.Buffer{}
appendValue(b, newEntry.Time)
if newEntry.Rlevel != NONE {
fmt.Fprintf(b, "[%s,%s] ", newEntry.Level.String(), newEntry.Rlevel.String())
} else {
fmt.Fprintf(b, "[%s] ", newEntry.Level.String())
}
appendValue(b, newEntry.Message)
for key, value := range newEntry.Data {
appendKeyValue(b, key, value)
}
b.WriteByte('\n')
s := bytes.NewBuffer(b.Bytes())
return s.String()
}
func appendValue(b *bytes.Buffer, value interface{}) {
if _, ok := value.(string); ok {
fmt.Fprintf(b, "%s ", value)
} else {
fmt.Fprintf(b, "%v ", value)
}
}
type keyvalueFormatter struct {
}
// ex. _time=2016-02-10T09:15:25.498-08:00 _level=INFO _msg=This is a message from test in key-value format
func (*keyvalueFormatter) format(newEntry *logEntry) string {
b := &bytes.Buffer{}
appendKeyValue(b, _TIME, newEntry.Time)
appendKeyValue(b, _LEVEL, newEntry.Level.String())
if newEntry.Rlevel != NONE {
appendKeyValue(b, _RLEVEL, newEntry.Rlevel.String())
}
appendKeyValue(b, _MSG, newEntry.Message)
for key, value := range newEntry.Data {
appendKeyValue(b, key, value)
}
b.WriteByte('\n')
s := bytes.NewBuffer(b.Bytes())
return s.String()
}
func appendKeyValue(b *bytes.Buffer, key, value interface{}) {
if _, ok := value.(string); ok {
fmt.Fprintf(b, "%v=%s ", key, value)
} else {
fmt.Fprintf(b, "%v=%v ", key, value)
}
}
type jsonFormatter struct {
}
// ex. {"_level":"INFO","_msg":"This is a message from test in json format","_time":"2016-02-10T09:12:59.518-08:00"}
func (*jsonFormatter) format(newEntry *logEntry) string {
if newEntry.Data == nil {
newEntry.Data = make(Map, 5)
}
newEntry.Data[_TIME] = newEntry.Time
newEntry.Data[_LEVEL] = newEntry.Level.String()
if newEntry.Rlevel != NONE {
newEntry.Data[_RLEVEL] = newEntry.Rlevel.String()
}
newEntry.Data[_MSG] = newEntry.Message
serialized, _ := json.Marshal(newEntry.Data)
s := bytes.NewBuffer(append(serialized, '\n'))
return s.String()
}

View file

@ -0,0 +1,207 @@
// @author Couchbase <info@couchbase.com>
// @copyright 2018 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package scramsha provides implementation of client side SCRAM-SHA
// according to https://tools.ietf.org/html/rfc5802
package scramsha
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"fmt"
"github.com/pkg/errors"
"golang.org/x/crypto/pbkdf2"
"hash"
"strconv"
"strings"
)
func hmacHash(message []byte, secret []byte, hashFunc func() hash.Hash) []byte {
h := hmac.New(hashFunc, secret)
h.Write(message)
return h.Sum(nil)
}
func shaHash(message []byte, hashFunc func() hash.Hash) []byte {
h := hashFunc()
h.Write(message)
return h.Sum(nil)
}
func generateClientNonce(size int) (string, error) {
randomBytes := make([]byte, size)
_, err := rand.Read(randomBytes)
if err != nil {
return "", errors.Wrap(err, "Unable to generate nonce")
}
return base64.StdEncoding.EncodeToString(randomBytes), nil
}
// ScramSha provides context for SCRAM-SHA handling
type ScramSha struct {
hashSize int
hashFunc func() hash.Hash
clientNonce string
serverNonce string
salt []byte
i int
saltedPassword []byte
authMessage string
}
var knownMethods = []string{"SCRAM-SHA512", "SCRAM-SHA256", "SCRAM-SHA1"}
// BestMethod returns SCRAM-SHA method we consider the best out of suggested
// by server
func BestMethod(methods string) (string, error) {
for _, m := range knownMethods {
if strings.Index(methods, m) != -1 {
return m, nil
}
}
return "", errors.Errorf(
"None of the server suggested methods [%s] are supported",
methods)
}
// NewScramSha creates context for SCRAM-SHA handling
func NewScramSha(method string) (*ScramSha, error) {
s := &ScramSha{}
if method == knownMethods[0] {
s.hashFunc = sha512.New
s.hashSize = 64
} else if method == knownMethods[1] {
s.hashFunc = sha256.New
s.hashSize = 32
} else if method == knownMethods[2] {
s.hashFunc = sha1.New
s.hashSize = 20
} else {
return nil, errors.Errorf("Unsupported method %s", method)
}
return s, nil
}
// GetStartRequest builds start SCRAM-SHA request to be sent to server
func (s *ScramSha) GetStartRequest(user string) (string, error) {
var err error
s.clientNonce, err = generateClientNonce(24)
if err != nil {
return "", errors.Wrapf(err, "Unable to generate SCRAM-SHA "+
"start request for user %s", user)
}
message := fmt.Sprintf("n,,n=%s,r=%s", user, s.clientNonce)
s.authMessage = message[3:]
return message, nil
}
// HandleStartResponse handles server response on start SCRAM-SHA request
func (s *ScramSha) HandleStartResponse(response string) error {
parts := strings.Split(response, ",")
if len(parts) != 3 {
return errors.Errorf("expected 3 fields in first SCRAM-SHA-1 "+
"server message %s", response)
}
if !strings.HasPrefix(parts[0], "r=") || len(parts[0]) < 3 {
return errors.Errorf("Server sent an invalid nonce %s",
parts[0])
}
if !strings.HasPrefix(parts[1], "s=") || len(parts[1]) < 3 {
return errors.Errorf("Server sent an invalid salt %s", parts[1])
}
if !strings.HasPrefix(parts[2], "i=") || len(parts[2]) < 3 {
return errors.Errorf("Server sent an invalid iteration count %s",
parts[2])
}
s.serverNonce = parts[0][2:]
encodedSalt := parts[1][2:]
var err error
s.i, err = strconv.Atoi(parts[2][2:])
if err != nil {
return errors.Errorf("Iteration count %s must be integer.",
parts[2][2:])
}
if s.i < 1 {
return errors.New("Iteration count should be positive")
}
if !strings.HasPrefix(s.serverNonce, s.clientNonce) {
return errors.Errorf("Server nonce %s doesn't contain client"+
" nonce %s", s.serverNonce, s.clientNonce)
}
s.salt, err = base64.StdEncoding.DecodeString(encodedSalt)
if err != nil {
return errors.Wrapf(err, "Unable to decode salt %s",
encodedSalt)
}
s.authMessage = s.authMessage + "," + response
return nil
}
// GetFinalRequest builds final SCRAM-SHA request to be sent to server
func (s *ScramSha) GetFinalRequest(pass string) string {
clientFinalMessageBare := "c=biws,r=" + s.serverNonce
s.authMessage = s.authMessage + "," + clientFinalMessageBare
s.saltedPassword = pbkdf2.Key([]byte(pass), s.salt, s.i,
s.hashSize, s.hashFunc)
clientKey := hmacHash([]byte("Client Key"), s.saltedPassword, s.hashFunc)
storedKey := shaHash(clientKey, s.hashFunc)
clientSignature := hmacHash([]byte(s.authMessage), storedKey, s.hashFunc)
clientProof := make([]byte, len(clientSignature))
for i := 0; i < len(clientSignature); i++ {
clientProof[i] = clientKey[i] ^ clientSignature[i]
}
return clientFinalMessageBare + ",p=" +
base64.StdEncoding.EncodeToString(clientProof)
}
// HandleFinalResponse handles server's response on final SCRAM-SHA request
func (s *ScramSha) HandleFinalResponse(response string) error {
if strings.Contains(response, ",") ||
!strings.HasPrefix(response, "v=") {
return errors.Errorf("Server sent an invalid final message %s",
response)
}
decodedMessage, err := base64.StdEncoding.DecodeString(response[2:])
if err != nil {
return errors.Wrapf(err, "Unable to decode server message %s",
response[2:])
}
serverKey := hmacHash([]byte("Server Key"), s.saltedPassword,
s.hashFunc)
serverSignature := hmacHash([]byte(s.authMessage), serverKey,
s.hashFunc)
if string(decodedMessage) != string(serverSignature) {
return errors.Errorf("Server proof %s doesn't match "+
"the expected: %s",
string(decodedMessage), string(serverSignature))
}
return nil
}

View file

@ -0,0 +1,252 @@
// @author Couchbase <info@couchbase.com>
// @copyright 2018 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package scramsha provides implementation of client side SCRAM-SHA
// via Http according to https://tools.ietf.org/html/rfc7804
package scramsha
import (
"encoding/base64"
"github.com/pkg/errors"
"io"
"io/ioutil"
"net/http"
"strings"
)
// consts used to parse scramsha response from target
const (
WWWAuthenticate = "WWW-Authenticate"
AuthenticationInfo = "Authentication-Info"
Authorization = "Authorization"
DataPrefix = "data="
SidPrefix = "sid="
)
// Request provides implementation of http request that can be retried
type Request struct {
body io.ReadSeeker
// Embed an HTTP request directly. This makes a *Request act exactly
// like an *http.Request so that all meta methods are supported.
*http.Request
}
type lenReader interface {
Len() int
}
// NewRequest creates http request that can be retried
func NewRequest(method, url string, body io.ReadSeeker) (*Request, error) {
// Wrap the body in a noop ReadCloser if non-nil. This prevents the
// reader from being closed by the HTTP client.
var rcBody io.ReadCloser
if body != nil {
rcBody = ioutil.NopCloser(body)
}
// Make the request with the noop-closer for the body.
httpReq, err := http.NewRequest(method, url, rcBody)
if err != nil {
return nil, err
}
// Check if we can set the Content-Length automatically.
if lr, ok := body.(lenReader); ok {
httpReq.ContentLength = int64(lr.Len())
}
return &Request{body, httpReq}, nil
}
func encode(str string) string {
return base64.StdEncoding.EncodeToString([]byte(str))
}
func decode(str string) (string, error) {
bytes, err := base64.StdEncoding.DecodeString(str)
if err != nil {
return "", errors.Errorf("Cannot base64 decode %s",
str)
}
return string(bytes), err
}
func trimPrefix(s, prefix string) (string, error) {
l := len(s)
trimmed := strings.TrimPrefix(s, prefix)
if l == len(trimmed) {
return trimmed, errors.Errorf("Prefix %s not found in %s",
prefix, s)
}
return trimmed, nil
}
func drainBody(resp *http.Response) {
defer resp.Body.Close()
io.Copy(ioutil.Discard, resp.Body)
}
// DoScramSha performs SCRAM-SHA handshake via Http
func DoScramSha(req *Request,
username string,
password string,
client *http.Client) (*http.Response, error) {
method := "SCRAM-SHA-512"
s, err := NewScramSha("SCRAM-SHA512")
if err != nil {
return nil, errors.Wrap(err,
"Unable to initialize SCRAM-SHA handler")
}
message, err := s.GetStartRequest(username)
if err != nil {
return nil, err
}
encodedMessage := method + " " + DataPrefix + encode(message)
req.Header.Set(Authorization, encodedMessage)
res, err := client.Do(req.Request)
if err != nil {
return nil, errors.Wrap(err, "Problem sending SCRAM-SHA start"+
"request")
}
if res.StatusCode != http.StatusUnauthorized {
return res, nil
}
authHeader := res.Header.Get(WWWAuthenticate)
if authHeader == "" {
drainBody(res)
return nil, errors.Errorf("Header %s is not populated in "+
"SCRAM-SHA start response", WWWAuthenticate)
}
authHeader, err = trimPrefix(authHeader, method+" ")
if err != nil {
if strings.HasPrefix(authHeader, "Basic ") {
// user not found
return res, nil
}
drainBody(res)
return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+
"start response %s", authHeader)
}
drainBody(res)
sid, response, err := parseSidAndData(authHeader)
if err != nil {
return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+
"start response %s", authHeader)
}
err = s.HandleStartResponse(response)
if err != nil {
return nil, errors.Wrapf(err, "Error parsing SCRAM-SHA start "+
"response %s", response)
}
message = s.GetFinalRequest(password)
encodedMessage = method + " " + SidPrefix + sid + "," + DataPrefix +
encode(message)
req.Header.Set(Authorization, encodedMessage)
// rewind request body so it can be resent again
if req.body != nil {
if _, err = req.body.Seek(0, 0); err != nil {
return nil, errors.Errorf("Failed to seek body: %v",
err)
}
}
res, err = client.Do(req.Request)
if err != nil {
return nil, errors.Wrap(err, "Problem sending SCRAM-SHA final"+
"request")
}
if res.StatusCode == http.StatusUnauthorized {
// TODO retrieve and return error
return res, nil
}
if res.StatusCode >= http.StatusInternalServerError {
// in this case we cannot expect server to set headers properly
return res, nil
}
authHeader = res.Header.Get(AuthenticationInfo)
if authHeader == "" {
drainBody(res)
return nil, errors.Errorf("Header %s is not populated in "+
"SCRAM-SHA final response", AuthenticationInfo)
}
finalSid, response, err := parseSidAndData(authHeader)
if err != nil {
drainBody(res)
return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+
"final response %s", authHeader)
}
if finalSid != sid {
drainBody(res)
return nil, errors.Errorf("Sid %s returned by server "+
"doesn't match the original sid %s", finalSid, sid)
}
err = s.HandleFinalResponse(response)
if err != nil {
drainBody(res)
return nil, errors.Wrapf(err,
"Error handling SCRAM-SHA final server response %s",
response)
}
return res, nil
}
func parseSidAndData(authHeader string) (string, string, error) {
sidIndex := strings.Index(authHeader, SidPrefix)
if sidIndex < 0 {
return "", "", errors.Errorf("Cannot find %s in %s",
SidPrefix, authHeader)
}
sidEndIndex := strings.Index(authHeader, ",")
if sidEndIndex < 0 {
return "", "", errors.Errorf("Cannot find ',' in %s",
authHeader)
}
sid := authHeader[sidIndex+len(SidPrefix) : sidEndIndex]
dataIndex := strings.Index(authHeader, DataPrefix)
if dataIndex < 0 {
return "", "", errors.Errorf("Cannot find %s in %s",
DataPrefix, authHeader)
}
data, err := decode(authHeader[dataIndex+len(DataPrefix):])
if err != nil {
return "", "", err
}
return sid, data, nil
}