From 86fb258cc0362cd4a98e9cbeb235bc373abcc0d3 Mon Sep 17 00:00:00 2001 From: shawnps Date: Sat, 7 Feb 2015 07:32:38 -0800 Subject: [PATCH] add Godeps, closes #10 --- Godeps/Godeps.json | 15 + Godeps/Readme | 5 + Godeps/_workspace/.gitignore | 2 + Godeps/_workspace/src/gopkg.in/mgo.v2/LICENSE | 25 + .../_workspace/src/gopkg.in/mgo.v2/Makefile | 5 + .../_workspace/src/gopkg.in/mgo.v2/README.md | 4 + Godeps/_workspace/src/gopkg.in/mgo.v2/auth.go | 467 ++ .../src/gopkg.in/mgo.v2/auth_test.go | 1181 +++++ .../src/gopkg.in/mgo.v2/bson/LICENSE | 25 + .../src/gopkg.in/mgo.v2/bson/bson.go | 698 +++ .../src/gopkg.in/mgo.v2/bson/bson_test.go | 1543 +++++++ .../src/gopkg.in/mgo.v2/bson/decode.go | 820 ++++ .../src/gopkg.in/mgo.v2/bson/encode.go | 489 +++ Godeps/_workspace/src/gopkg.in/mgo.v2/bulk.go | 71 + .../src/gopkg.in/mgo.v2/bulk_test.go | 93 + .../_workspace/src/gopkg.in/mgo.v2/cluster.go | 621 +++ .../src/gopkg.in/mgo.v2/cluster_test.go | 1596 +++++++ Godeps/_workspace/src/gopkg.in/mgo.v2/doc.go | 31 + .../src/gopkg.in/mgo.v2/export_test.go | 33 + .../_workspace/src/gopkg.in/mgo.v2/gridfs.go | 754 ++++ .../src/gopkg.in/mgo.v2/gridfs_test.go | 680 +++ .../gopkg.in/mgo.v2/internal/scram/scram.go | 266 ++ .../mgo.v2/internal/scram/scram_test.go | 67 + Godeps/_workspace/src/gopkg.in/mgo.v2/log.go | 133 + .../_workspace/src/gopkg.in/mgo.v2/queue.go | 91 + .../src/gopkg.in/mgo.v2/queue_test.go | 101 + .../_workspace/src/gopkg.in/mgo.v2/raceoff.go | 5 + .../_workspace/src/gopkg.in/mgo.v2/raceon.go | 5 + .../src/gopkg.in/mgo.v2/sasl/sasl.c | 77 + .../src/gopkg.in/mgo.v2/sasl/sasl.go | 138 + .../src/gopkg.in/mgo.v2/sasl/sasl_windows.c | 118 + .../src/gopkg.in/mgo.v2/sasl/sasl_windows.go | 140 + .../src/gopkg.in/mgo.v2/sasl/sasl_windows.h | 7 + .../src/gopkg.in/mgo.v2/sasl/sspi_windows.c | 96 + .../src/gopkg.in/mgo.v2/sasl/sspi_windows.h | 70 + .../src/gopkg.in/mgo.v2/saslimpl.go | 11 + .../src/gopkg.in/mgo.v2/saslstub.go | 11 + .../_workspace/src/gopkg.in/mgo.v2/server.go | 447 ++ .../_workspace/src/gopkg.in/mgo.v2/session.go | 3867 +++++++++++++++++ .../src/gopkg.in/mgo.v2/session_test.go | 3484 +++++++++++++++ .../_workspace/src/gopkg.in/mgo.v2/socket.go | 675 +++ .../_workspace/src/gopkg.in/mgo.v2/stats.go | 147 + .../src/gopkg.in/mgo.v2/suite_test.go | 254 ++ .../src/gopkg.in/mgo.v2/syscall_test.go | 15 + .../gopkg.in/mgo.v2/syscall_windows_test.go | 11 + .../src/gopkg.in/mgo.v2/testdb/client.pem | 44 + .../src/gopkg.in/mgo.v2/testdb/dropall.js | 52 + .../src/gopkg.in/mgo.v2/testdb/init.js | 110 + .../src/gopkg.in/mgo.v2/testdb/server.pem | 33 + .../src/gopkg.in/mgo.v2/testdb/setup.sh | 58 + .../gopkg.in/mgo.v2/testdb/supervisord.conf | 65 + .../src/gopkg.in/mgo.v2/testdb/wait.js | 58 + .../src/gopkg.in/mgo.v2/txn/chaos.go | 68 + .../src/gopkg.in/mgo.v2/txn/debug.go | 109 + .../src/gopkg.in/mgo.v2/txn/dockey_test.go | 205 + .../src/gopkg.in/mgo.v2/txn/flusher.go | 968 +++++ .../src/gopkg.in/mgo.v2/txn/mgo_test.go | 101 + .../src/gopkg.in/mgo.v2/txn/sim_test.go | 389 ++ .../src/gopkg.in/mgo.v2/txn/tarjan.go | 94 + .../src/gopkg.in/mgo.v2/txn/tarjan_test.go | 44 + .../_workspace/src/gopkg.in/mgo.v2/txn/txn.go | 609 +++ .../src/gopkg.in/mgo.v2/txn/txn_test.go | 627 +++ .../src/labix.org/v2/mgo/bson/LICENSE | 25 + .../src/labix.org/v2/mgo/bson/bson.go | 682 +++ .../src/labix.org/v2/mgo/bson/bson_test.go | 1452 +++++++ .../src/labix.org/v2/mgo/bson/decode.go | 795 ++++ .../src/labix.org/v2/mgo/bson/encode.go | 462 ++ 67 files changed, 26444 insertions(+) create mode 100644 Godeps/Godeps.json create mode 100644 Godeps/Readme create mode 100644 Godeps/_workspace/.gitignore create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/LICENSE create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/Makefile create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/README.md create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/auth.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/auth_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/bson/LICENSE create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/bson/bson.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/bson/bson_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/bson/decode.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/bson/encode.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/bulk.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/bulk_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/cluster.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/cluster_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/doc.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/export_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/gridfs.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/gridfs_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/internal/scram/scram.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/internal/scram/scram_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/log.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/queue.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/queue_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/raceoff.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/raceon.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl.c create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl_windows.c create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl_windows.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl_windows.h create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sspi_windows.c create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sspi_windows.h create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/saslimpl.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/saslstub.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/server.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/session.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/session_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/socket.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/stats.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/suite_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/syscall_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/syscall_windows_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/client.pem create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/dropall.js create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/init.js create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/server.pem create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/setup.sh create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/supervisord.conf create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/wait.js create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/txn/chaos.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/txn/debug.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/txn/dockey_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/txn/flusher.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/txn/mgo_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/txn/sim_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/txn/tarjan.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/txn/tarjan_test.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/txn/txn.go create mode 100644 Godeps/_workspace/src/gopkg.in/mgo.v2/txn/txn_test.go create mode 100644 Godeps/_workspace/src/labix.org/v2/mgo/bson/LICENSE create mode 100644 Godeps/_workspace/src/labix.org/v2/mgo/bson/bson.go create mode 100644 Godeps/_workspace/src/labix.org/v2/mgo/bson/bson_test.go create mode 100644 Godeps/_workspace/src/labix.org/v2/mgo/bson/decode.go create mode 100644 Godeps/_workspace/src/labix.org/v2/mgo/bson/encode.go diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json new file mode 100644 index 0000000..51d626a --- /dev/null +++ b/Godeps/Godeps.json @@ -0,0 +1,15 @@ +{ + "ImportPath": "github.com/gojp/goreportcard", + "GoVersion": "go1.4.1", + "Deps": [ + { + "ImportPath": "gopkg.in/mgo.v2", + "Rev": "8c1ecfe7d8a0b5bc49a809f4c15955e514f77a80" + }, + { + "ImportPath": "labix.org/v2/mgo/bson", + "Comment": "287", + "Rev": "gustavo@niemeyer.net-20140501184651-975yyw9ipld92pji" + } + ] +} diff --git a/Godeps/Readme b/Godeps/Readme new file mode 100644 index 0000000..4cdaa53 --- /dev/null +++ b/Godeps/Readme @@ -0,0 +1,5 @@ +This directory tree is generated automatically by godep. + +Please do not edit. + +See https://github.com/tools/godep for more information. diff --git a/Godeps/_workspace/.gitignore b/Godeps/_workspace/.gitignore new file mode 100644 index 0000000..f037d68 --- /dev/null +++ b/Godeps/_workspace/.gitignore @@ -0,0 +1,2 @@ +/pkg +/bin diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/LICENSE b/Godeps/_workspace/src/gopkg.in/mgo.v2/LICENSE new file mode 100644 index 0000000..770c767 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/LICENSE @@ -0,0 +1,25 @@ +mgo - MongoDB driver for Go + +Copyright (c) 2010-2013 - Gustavo Niemeyer + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/Makefile b/Godeps/_workspace/src/gopkg.in/mgo.v2/Makefile new file mode 100644 index 0000000..51bee73 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/Makefile @@ -0,0 +1,5 @@ +startdb: + @testdb/setup.sh start + +stopdb: + @testdb/setup.sh stop diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/README.md b/Godeps/_workspace/src/gopkg.in/mgo.v2/README.md new file mode 100644 index 0000000..f4e452c --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/README.md @@ -0,0 +1,4 @@ +The MongoDB driver for Go +------------------------- + +Please go to [http://labix.org/mgo](http://labix.org/mgo) for all project details. diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/auth.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/auth.go new file mode 100644 index 0000000..1761d0d --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/auth.go @@ -0,0 +1,467 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "crypto/md5" + "crypto/sha1" + "encoding/hex" + "errors" + "fmt" + "sync" + + "gopkg.in/mgo.v2/bson" + "gopkg.in/mgo.v2/internal/scram" +) + +type authCmd struct { + Authenticate int + + Nonce string + User string + Key string +} + +type startSaslCmd struct { + StartSASL int `bson:"startSasl"` +} + +type authResult struct { + ErrMsg string + Ok bool +} + +type getNonceCmd struct { + GetNonce int +} + +type getNonceResult struct { + Nonce string + Err string "$err" + Code int +} + +type logoutCmd struct { + Logout int +} + +type saslCmd struct { + Start int `bson:"saslStart,omitempty"` + Continue int `bson:"saslContinue,omitempty"` + ConversationId int `bson:"conversationId,omitempty"` + Mechanism string `bson:"mechanism,omitempty"` + Payload []byte +} + +type saslResult struct { + Ok bool `bson:"ok"` + NotOk bool `bson:"code"` // Server <= 2.3.2 returns ok=1 & code>0 on errors (WTF?) + Done bool + + ConversationId int `bson:"conversationId"` + Payload []byte + ErrMsg string +} + +type saslStepper interface { + Step(serverData []byte) (clientData []byte, done bool, err error) + Close() +} + +func (socket *mongoSocket) getNonce() (nonce string, err error) { + socket.Lock() + for socket.cachedNonce == "" && socket.dead == nil { + debugf("Socket %p to %s: waiting for nonce", socket, socket.addr) + socket.gotNonce.Wait() + } + if socket.cachedNonce == "mongos" { + socket.Unlock() + return "", errors.New("Can't authenticate with mongos; see http://j.mp/mongos-auth") + } + debugf("Socket %p to %s: got nonce", socket, socket.addr) + nonce, err = socket.cachedNonce, socket.dead + socket.cachedNonce = "" + socket.Unlock() + if err != nil { + nonce = "" + } + return +} + +func (socket *mongoSocket) resetNonce() { + debugf("Socket %p to %s: requesting a new nonce", socket, socket.addr) + op := &queryOp{} + op.query = &getNonceCmd{GetNonce: 1} + op.collection = "admin.$cmd" + op.limit = -1 + op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { + if err != nil { + socket.kill(errors.New("getNonce: "+err.Error()), true) + return + } + result := &getNonceResult{} + err = bson.Unmarshal(docData, &result) + if err != nil { + socket.kill(errors.New("Failed to unmarshal nonce: "+err.Error()), true) + return + } + debugf("Socket %p to %s: nonce unmarshalled: %#v", socket, socket.addr, result) + if result.Code == 13390 { + // mongos doesn't yet support auth (see http://j.mp/mongos-auth) + result.Nonce = "mongos" + } else if result.Nonce == "" { + var msg string + if result.Err != "" { + msg = fmt.Sprintf("Got an empty nonce: %s (%d)", result.Err, result.Code) + } else { + msg = "Got an empty nonce" + } + socket.kill(errors.New(msg), true) + return + } + socket.Lock() + if socket.cachedNonce != "" { + socket.Unlock() + panic("resetNonce: nonce already cached") + } + socket.cachedNonce = result.Nonce + socket.gotNonce.Signal() + socket.Unlock() + } + err := socket.Query(op) + if err != nil { + socket.kill(errors.New("resetNonce: "+err.Error()), true) + } +} + +func (socket *mongoSocket) Login(cred Credential) error { + socket.Lock() + if cred.Mechanism == "" && socket.serverInfo.MaxWireVersion >= 3 { + cred.Mechanism = "SCRAM-SHA-1" + } + for _, sockCred := range socket.creds { + if sockCred == cred { + debugf("Socket %p to %s: login: db=%q user=%q (already logged in)", socket, socket.addr, cred.Source, cred.Username) + socket.Unlock() + return nil + } + } + if socket.dropLogout(cred) { + debugf("Socket %p to %s: login: db=%q user=%q (cached)", socket, socket.addr, cred.Source, cred.Username) + socket.creds = append(socket.creds, cred) + socket.Unlock() + return nil + } + socket.Unlock() + + debugf("Socket %p to %s: login: db=%q user=%q", socket, socket.addr, cred.Source, cred.Username) + + var err error + switch cred.Mechanism { + case "", "MONGODB-CR", "MONGO-CR": // Name changed to MONGODB-CR in SERVER-8501. + err = socket.loginClassic(cred) + case "PLAIN": + err = socket.loginPlain(cred) + case "MONGODB-X509": + err = socket.loginX509(cred) + default: + // Try SASL for everything else, if it is available. + err = socket.loginSASL(cred) + } + + if err != nil { + debugf("Socket %p to %s: login error: %s", socket, socket.addr, err) + } else { + debugf("Socket %p to %s: login successful", socket, socket.addr) + } + return err +} + +func (socket *mongoSocket) loginClassic(cred Credential) error { + // Note that this only works properly because this function is + // synchronous, which means the nonce won't get reset while we're + // using it and any other login requests will block waiting for a + // new nonce provided in the defer call below. + nonce, err := socket.getNonce() + if err != nil { + return err + } + defer socket.resetNonce() + + psum := md5.New() + psum.Write([]byte(cred.Username + ":mongo:" + cred.Password)) + + ksum := md5.New() + ksum.Write([]byte(nonce + cred.Username)) + ksum.Write([]byte(hex.EncodeToString(psum.Sum(nil)))) + + key := hex.EncodeToString(ksum.Sum(nil)) + + cmd := authCmd{Authenticate: 1, User: cred.Username, Nonce: nonce, Key: key} + res := authResult{} + return socket.loginRun(cred.Source, &cmd, &res, func() error { + if !res.Ok { + return errors.New(res.ErrMsg) + } + socket.Lock() + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + socket.Unlock() + return nil + }) +} + +type authX509Cmd struct { + Authenticate int + User string + Mechanism string +} + +func (socket *mongoSocket) loginX509(cred Credential) error { + cmd := authX509Cmd{Authenticate: 1, User: cred.Username, Mechanism: "MONGODB-X509"} + res := authResult{} + return socket.loginRun(cred.Source, &cmd, &res, func() error { + if !res.Ok { + return errors.New(res.ErrMsg) + } + socket.Lock() + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + socket.Unlock() + return nil + }) +} + +func (socket *mongoSocket) loginPlain(cred Credential) error { + cmd := saslCmd{Start: 1, Mechanism: "PLAIN", Payload: []byte("\x00" + cred.Username + "\x00" + cred.Password)} + res := authResult{} + return socket.loginRun(cred.Source, &cmd, &res, func() error { + if !res.Ok { + return errors.New(res.ErrMsg) + } + socket.Lock() + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + socket.Unlock() + return nil + }) +} + +func (socket *mongoSocket) loginSASL(cred Credential) error { + var sasl saslStepper + var err error + if cred.Mechanism == "SCRAM-SHA-1" { + // SCRAM is handled without external libraries. + sasl = saslNewScram(cred) + } else if len(cred.ServiceHost) > 0 { + sasl, err = saslNew(cred, cred.ServiceHost) + } else { + sasl, err = saslNew(cred, socket.Server().Addr) + } + if err != nil { + return err + } + defer sasl.Close() + + // The goal of this logic is to carry a locked socket until the + // local SASL step confirms the auth is valid; the socket needs to be + // locked so that concurrent action doesn't leave the socket in an + // auth state that doesn't reflect the operations that took place. + // As a simple case, imagine inverting login=>logout to logout=>login. + // + // The logic below works because the lock func isn't called concurrently. + locked := false + lock := func(b bool) { + if locked != b { + locked = b + if b { + socket.Lock() + } else { + socket.Unlock() + } + } + } + + lock(true) + defer lock(false) + + start := 1 + cmd := saslCmd{} + res := saslResult{} + for { + payload, done, err := sasl.Step(res.Payload) + if err != nil { + return err + } + if done && res.Done { + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + break + } + lock(false) + + cmd = saslCmd{ + Start: start, + Continue: 1 - start, + ConversationId: res.ConversationId, + Mechanism: cred.Mechanism, + Payload: payload, + } + start = 0 + err = socket.loginRun(cred.Source, &cmd, &res, func() error { + // See the comment on lock for why this is necessary. + lock(true) + if !res.Ok || res.NotOk { + return fmt.Errorf("server returned error on SASL authentication step: %s", res.ErrMsg) + } + return nil + }) + if err != nil { + return err + } + if done && res.Done { + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + break + } + } + + return nil +} + +func saslNewScram(cred Credential) *saslScram { + credsum := md5.New() + credsum.Write([]byte(cred.Username + ":mongo:" + cred.Password)) + client := scram.NewClient(sha1.New, cred.Username, hex.EncodeToString(credsum.Sum(nil))) + return &saslScram{cred: cred, client: client} +} + +type saslScram struct { + cred Credential + client *scram.Client +} + +func (s *saslScram) Close() {} + +func (s *saslScram) Step(serverData []byte) (clientData []byte, done bool, err error) { + more := s.client.Step(serverData) + return s.client.Out(), !more, s.client.Err() +} + +func (socket *mongoSocket) loginRun(db string, query, result interface{}, f func() error) error { + var mutex sync.Mutex + var replyErr error + mutex.Lock() + + op := queryOp{} + op.query = query + op.collection = db + ".$cmd" + op.limit = -1 + op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { + defer mutex.Unlock() + + if err != nil { + replyErr = err + return + } + + err = bson.Unmarshal(docData, result) + if err != nil { + replyErr = err + } else { + // Must handle this within the read loop for the socket, so + // that concurrent login requests are properly ordered. + replyErr = f() + } + } + + err := socket.Query(&op) + if err != nil { + return err + } + mutex.Lock() // Wait. + return replyErr +} + +func (socket *mongoSocket) Logout(db string) { + socket.Lock() + cred, found := socket.dropAuth(db) + if found { + debugf("Socket %p to %s: logout: db=%q (flagged)", socket, socket.addr, db) + socket.logout = append(socket.logout, cred) + } + socket.Unlock() +} + +func (socket *mongoSocket) LogoutAll() { + socket.Lock() + if l := len(socket.creds); l > 0 { + debugf("Socket %p to %s: logout all (flagged %d)", socket, socket.addr, l) + socket.logout = append(socket.logout, socket.creds...) + socket.creds = socket.creds[0:0] + } + socket.Unlock() +} + +func (socket *mongoSocket) flushLogout() (ops []interface{}) { + socket.Lock() + if l := len(socket.logout); l > 0 { + debugf("Socket %p to %s: logout all (flushing %d)", socket, socket.addr, l) + for i := 0; i != l; i++ { + op := queryOp{} + op.query = &logoutCmd{1} + op.collection = socket.logout[i].Source + ".$cmd" + op.limit = -1 + ops = append(ops, &op) + } + socket.logout = socket.logout[0:0] + } + socket.Unlock() + return +} + +func (socket *mongoSocket) dropAuth(db string) (cred Credential, found bool) { + for i, sockCred := range socket.creds { + if sockCred.Source == db { + copy(socket.creds[i:], socket.creds[i+1:]) + socket.creds = socket.creds[:len(socket.creds)-1] + return sockCred, true + } + } + return cred, false +} + +func (socket *mongoSocket) dropLogout(cred Credential) (found bool) { + for i, sockCred := range socket.logout { + if sockCred == cred { + copy(socket.logout[i:], socket.logout[i+1:]) + socket.logout = socket.logout[:len(socket.logout)-1] + return true + } + } + return false +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/auth_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/auth_test.go new file mode 100644 index 0000000..a9c0b27 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/auth_test.go @@ -0,0 +1,1181 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo_test + +import ( + "crypto/tls" + "flag" + "fmt" + "io/ioutil" + "net" + "net/url" + "os" + "runtime" + "sync" + "time" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2" +) + +func (s *S) TestAuthLoginDatabase(c *C) { + // Test both with a normal database and with an authenticated shard. + for _, addr := range []string{"localhost:40002", "localhost:40203"} { + session, err := mgo.Dial(addr) + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") + + admindb := session.DB("admin") + + err = admindb.Login("root", "wrong") + c.Assert(err, ErrorMatches, "auth fail(s|ed)|.*Authentication failed.") + + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + } +} + +func (s *S) TestAuthLoginSession(c *C) { + // Test both with a normal database and with an authenticated shard. + for _, addr := range []string{"localhost:40002", "localhost:40203"} { + session, err := mgo.Dial(addr) + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") + + cred := mgo.Credential{ + Username: "root", + Password: "wrong", + } + err = session.Login(&cred) + c.Assert(err, ErrorMatches, "auth fail(s|ed)|.*Authentication failed.") + + cred.Password = "rapadura" + + err = session.Login(&cred) + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + } +} + +func (s *S) TestAuthLoginLogout(c *C) { + // Test both with a normal database and with an authenticated shard. + for _, addr := range []string{"localhost:40002", "localhost:40203"} { + session, err := mgo.Dial(addr) + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + admindb.Logout() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") + + // Must have dropped auth from the session too. + session = session.Copy() + defer session.Close() + + coll = session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") + } +} + +func (s *S) TestAuthLoginLogoutAll(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + session.LogoutAll() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") + + // Must have dropped auth from the session too. + session = session.Copy() + defer session.Close() + + coll = session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") +} + +func (s *S) TestAuthUpsertUserErrors(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + + err = mydb.UpsertUser(&mgo.User{}) + c.Assert(err, ErrorMatches, "user has no Username") + + err = mydb.UpsertUser(&mgo.User{Username: "user", Password: "pass", UserSource: "source"}) + c.Assert(err, ErrorMatches, "user has both Password/PasswordHash and UserSource set") + + err = mydb.UpsertUser(&mgo.User{Username: "user", Password: "pass", OtherDBRoles: map[string][]mgo.Role{"db": nil}}) + c.Assert(err, ErrorMatches, "user with OtherDBRoles is only supported in the admin or \\$external databases") +} + +func (s *S) TestAuthUpsertUser(c *C) { + if !s.versionAtLeast(2, 4) { + c.Skip("UpsertUser only works on 2.4+") + } + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + + ruser := &mgo.User{ + Username: "myruser", + Password: "mypass", + Roles: []mgo.Role{mgo.RoleRead}, + } + rwuser := &mgo.User{ + Username: "myrwuser", + Password: "mypass", + Roles: []mgo.Role{mgo.RoleReadWrite}, + } + + err = mydb.UpsertUser(ruser) + c.Assert(err, IsNil) + err = mydb.UpsertUser(rwuser) + c.Assert(err, IsNil) + + err = mydb.Login("myruser", "mypass") + c.Assert(err, IsNil) + + admindb.Logout() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + err = mydb.Login("myrwuser", "mypass") + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + + myotherdb := session.DB("myotherdb") + + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + // Test UserSource. + rwuserother := &mgo.User{ + Username: "myrwuser", + UserSource: "mydb", + Roles: []mgo.Role{mgo.RoleRead}, + } + + err = myotherdb.UpsertUser(rwuserother) + if s.versionAtLeast(2, 6) { + c.Assert(err, ErrorMatches, `MongoDB 2.6\+ does not support the UserSource setting`) + return + } + c.Assert(err, IsNil) + + admindb.Logout() + + // Test indirection via UserSource: we can't write to it, because + // the roles for myrwuser are different there. + othercoll := myotherdb.C("myothercoll") + err = othercoll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + // Reading works, though. + err = othercoll.Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) + + // Can't login directly into the database using UserSource, though. + err = myotherdb.Login("myrwuser", "mypass") + c.Assert(err, ErrorMatches, "auth fail(s|ed)|.*Authentication failed.") +} + +func (s *S) TestAuthUpsertUserOtherDBRoles(c *C) { + if !s.versionAtLeast(2, 4) { + c.Skip("UpsertUser only works on 2.4+") + } + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + ruser := &mgo.User{ + Username: "myruser", + Password: "mypass", + OtherDBRoles: map[string][]mgo.Role{"mydb": []mgo.Role{mgo.RoleRead}}, + } + + err = admindb.UpsertUser(ruser) + c.Assert(err, IsNil) + defer admindb.RemoveUser("myruser") + + admindb.Logout() + err = admindb.Login("myruser", "mypass") + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + err = coll.Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestAuthUpsertUserUpdates(c *C) { + if !s.versionAtLeast(2, 4) { + c.Skip("UpsertUser only works on 2.4+") + } + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + + // Insert a user that can read. + user := &mgo.User{ + Username: "myruser", + Password: "mypass", + Roles: []mgo.Role{mgo.RoleRead}, + } + err = mydb.UpsertUser(user) + c.Assert(err, IsNil) + + // Now update the user password. + user = &mgo.User{ + Username: "myruser", + Password: "mynewpass", + } + err = mydb.UpsertUser(user) + c.Assert(err, IsNil) + + // Login with the new user. + usession, err := mgo.Dial("myruser:mynewpass@localhost:40002/mydb") + c.Assert(err, IsNil) + defer usession.Close() + + // Can read, but not write. + err = usession.DB("mydb").C("mycoll").Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) + err = usession.DB("mydb").C("mycoll").Insert(M{"ok": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + // Update the user role. + user = &mgo.User{ + Username: "myruser", + Roles: []mgo.Role{mgo.RoleReadWrite}, + } + err = mydb.UpsertUser(user) + c.Assert(err, IsNil) + + // Dial again to ensure the password hasn't changed. + usession, err = mgo.Dial("myruser:mynewpass@localhost:40002/mydb") + c.Assert(err, IsNil) + defer usession.Close() + + // Now it can write. + err = usession.DB("mydb").C("mycoll").Insert(M{"ok": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthAddUser(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + err = mydb.AddUser("myruser", "mypass", true) + c.Assert(err, IsNil) + err = mydb.AddUser("mywuser", "mypass", false) + c.Assert(err, IsNil) + + err = mydb.Login("myruser", "mypass") + c.Assert(err, IsNil) + + admindb.Logout() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + err = mydb.Login("mywuser", "mypass") + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthAddUserReplaces(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + err = mydb.AddUser("myuser", "myoldpass", false) + c.Assert(err, IsNil) + err = mydb.AddUser("myuser", "mynewpass", true) + c.Assert(err, IsNil) + + admindb.Logout() + + err = mydb.Login("myuser", "myoldpass") + c.Assert(err, ErrorMatches, "auth fail(s|ed)|.*Authentication failed.") + err = mydb.Login("myuser", "mynewpass") + c.Assert(err, IsNil) + + // ReadOnly flag was changed too. + err = mydb.C("mycoll").Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") +} + +func (s *S) TestAuthRemoveUser(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + err = mydb.AddUser("myuser", "mypass", true) + c.Assert(err, IsNil) + err = mydb.RemoveUser("myuser") + c.Assert(err, IsNil) + err = mydb.RemoveUser("myuser") + c.Assert(err, Equals, mgo.ErrNotFound) + + err = mydb.Login("myuser", "mypass") + c.Assert(err, ErrorMatches, "auth fail(s|ed)|.*Authentication failed.") +} + +func (s *S) TestAuthLoginTwiceDoesNothing(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + oldStats := mgo.GetStats() + + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + newStats := mgo.GetStats() + c.Assert(newStats.SentOps, Equals, oldStats.SentOps) +} + +func (s *S) TestAuthLoginLogoutLoginDoesNothing(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + oldStats := mgo.GetStats() + + admindb.Logout() + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + newStats := mgo.GetStats() + c.Assert(newStats.SentOps, Equals, oldStats.SentOps) +} + +func (s *S) TestAuthLoginSwitchUser(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + + err = admindb.Login("reader", "rapadura") + c.Assert(err, IsNil) + + // Can't write. + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + // But can read. + result := struct{ N int }{} + err = coll.Find(nil).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 1) +} + +func (s *S) TestAuthLoginChangePassword(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + err = mydb.AddUser("myuser", "myoldpass", false) + c.Assert(err, IsNil) + + err = mydb.Login("myuser", "myoldpass") + c.Assert(err, IsNil) + + err = mydb.AddUser("myuser", "mynewpass", true) + c.Assert(err, IsNil) + + err = mydb.Login("myuser", "mynewpass") + c.Assert(err, IsNil) + + admindb.Logout() + + // The second login must be in effect, which means read-only. + err = mydb.C("mycoll").Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") +} + +func (s *S) TestAuthLoginCachingWithSessionRefresh(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + session.Refresh() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthLoginCachingWithSessionCopy(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + session = session.Copy() + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthLoginCachingWithSessionClone(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + session = session.Clone() + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthLoginCachingWithNewSession(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + session = session.New() + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized for .*") +} + +func (s *S) TestAuthLoginCachingAcrossPool(c *C) { + // Logins are cached even when the conenction goes back + // into the pool. + + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + // Add another user to test the logout case at the same time. + mydb := session.DB("mydb") + err = mydb.AddUser("myuser", "mypass", false) + c.Assert(err, IsNil) + + err = mydb.Login("myuser", "mypass") + c.Assert(err, IsNil) + + // Logout root explicitly, to test both cases. + admindb.Logout() + + // Give socket back to pool. + session.Refresh() + + // Brand new session, should use socket from the pool. + other := session.New() + defer other.Close() + + oldStats := mgo.GetStats() + + err = other.DB("admin").Login("root", "rapadura") + c.Assert(err, IsNil) + err = other.DB("mydb").Login("myuser", "mypass") + c.Assert(err, IsNil) + + // Both logins were cached, so no ops. + newStats := mgo.GetStats() + c.Assert(newStats.SentOps, Equals, oldStats.SentOps) + + // And they actually worked. + err = other.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, IsNil) + + other.DB("admin").Logout() + + err = other.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthLoginCachingAcrossPoolWithLogout(c *C) { + // Now verify that logouts are properly flushed if they + // are not revalidated after leaving the pool. + + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + // Add another user to test the logout case at the same time. + mydb := session.DB("mydb") + err = mydb.AddUser("myuser", "mypass", true) + c.Assert(err, IsNil) + + err = mydb.Login("myuser", "mypass") + c.Assert(err, IsNil) + + // Just some data to query later. + err = session.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, IsNil) + + // Give socket back to pool. + session.Refresh() + + // Brand new session, should use socket from the pool. + other := session.New() + defer other.Close() + + oldStats := mgo.GetStats() + + err = other.DB("mydb").Login("myuser", "mypass") + c.Assert(err, IsNil) + + // Login was cached, so no ops. + newStats := mgo.GetStats() + c.Assert(newStats.SentOps, Equals, oldStats.SentOps) + + // Can't write, since root has been implicitly logged out + // when the collection went into the pool, and not revalidated. + err = other.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + // But can read due to the revalidated myuser login. + result := struct{ N int }{} + err = other.DB("mydb").C("mycoll").Find(nil).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 1) +} + +func (s *S) TestAuthEventual(c *C) { + // Eventual sessions don't keep sockets around, so they are + // an interesting test case. + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + err = session.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, IsNil) + + var wg sync.WaitGroup + wg.Add(20) + + for i := 0; i != 10; i++ { + go func() { + defer wg.Done() + var result struct{ N int } + err := session.DB("mydb").C("mycoll").Find(nil).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 1) + }() + } + + for i := 0; i != 10; i++ { + go func() { + defer wg.Done() + err := session.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, IsNil) + }() + } + + wg.Wait() +} + +func (s *S) TestAuthURL(c *C) { + session, err := mgo.Dial("mongodb://root:rapadura@localhost:40002/") + c.Assert(err, IsNil) + defer session.Close() + + err = session.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthURLWrongCredentials(c *C) { + session, err := mgo.Dial("mongodb://root:wrong@localhost:40002/") + if session != nil { + session.Close() + } + c.Assert(err, ErrorMatches, "auth fail(s|ed)|.*Authentication failed.") + c.Assert(session, IsNil) +} + +func (s *S) TestAuthURLWithNewSession(c *C) { + // When authentication is in the URL, the new session will + // actually carry it on as well, even if logged out explicitly. + session, err := mgo.Dial("mongodb://root:rapadura@localhost:40002/") + c.Assert(err, IsNil) + defer session.Close() + + session.DB("admin").Logout() + + // Do it twice to ensure it passes the needed data on. + session = session.New() + defer session.Close() + session = session.New() + defer session.Close() + + err = session.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthURLWithDatabase(c *C) { + session, err := mgo.Dial("mongodb://root:rapadura@localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + mydb := session.DB("mydb") + err = mydb.AddUser("myruser", "mypass", true) + c.Assert(err, IsNil) + + // Test once with database, and once with source. + for i := 0; i < 2; i++ { + var url string + if i == 0 { + url = "mongodb://myruser:mypass@localhost:40002/mydb" + } else { + url = "mongodb://myruser:mypass@localhost:40002/admin?authSource=mydb" + } + usession, err := mgo.Dial(url) + c.Assert(err, IsNil) + defer usession.Close() + + ucoll := usession.DB("mydb").C("mycoll") + err = ucoll.FindId(0).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) + err = ucoll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + } +} + +func (s *S) TestDefaultDatabase(c *C) { + tests := []struct{ url, db string }{ + {"mongodb://root:rapadura@localhost:40002", "test"}, + {"mongodb://root:rapadura@localhost:40002/admin", "admin"}, + {"mongodb://localhost:40001", "test"}, + {"mongodb://localhost:40001/", "test"}, + {"mongodb://localhost:40001/mydb", "mydb"}, + } + + for _, test := range tests { + session, err := mgo.Dial(test.url) + c.Assert(err, IsNil) + defer session.Close() + + c.Logf("test: %#v", test) + c.Assert(session.DB("").Name, Equals, test.db) + + scopy := session.Copy() + c.Check(scopy.DB("").Name, Equals, test.db) + scopy.Close() + } +} + +func (s *S) TestAuthDirect(c *C) { + // Direct connections must work to the master and slaves. + for _, port := range []string{"40031", "40032", "40033"} { + url := fmt.Sprintf("mongodb://root:rapadura@localhost:%s/?connect=direct", port) + session, err := mgo.Dial(url) + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, true) + + var result struct{} + err = session.DB("mydb").C("mycoll").Find(nil).One(&result) + c.Assert(err, Equals, mgo.ErrNotFound) + } +} + +func (s *S) TestAuthDirectWithLogin(c *C) { + // Direct connections must work to the master and slaves. + for _, port := range []string{"40031", "40032", "40033"} { + url := fmt.Sprintf("mongodb://localhost:%s/?connect=direct", port) + session, err := mgo.Dial(url) + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, true) + session.SetSyncTimeout(3 * time.Second) + + err = session.DB("admin").Login("root", "rapadura") + c.Assert(err, IsNil) + + var result struct{} + err = session.DB("mydb").C("mycoll").Find(nil).One(&result) + c.Assert(err, Equals, mgo.ErrNotFound) + } +} + +func (s *S) TestAuthScramSha1Cred(c *C) { + if !s.versionAtLeast(2, 7, 7) { + c.Skip("SCRAM-SHA-1 tests depend on 2.7.7") + } + cred := &mgo.Credential{ + Username: "root", + Password: "rapadura", + Mechanism: "SCRAM-SHA-1", + Source: "admin", + } + host := "localhost:40002" + c.Logf("Connecting to %s...", host) + session, err := mgo.Dial(host) + c.Assert(err, IsNil) + defer session.Close() + + mycoll := session.DB("admin").C("mycoll") + + c.Logf("Connected! Testing the need for authentication...") + err = mycoll.Find(nil).One(nil) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + c.Logf("Authenticating...") + err = session.Login(cred) + c.Assert(err, IsNil) + c.Logf("Authenticated!") + + c.Logf("Connected! Testing the need for authentication...") + err = mycoll.Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestAuthScramSha1URL(c *C) { + if !s.versionAtLeast(2, 7, 7) { + c.Skip("SCRAM-SHA-1 tests depend on 2.7.7") + } + host := "localhost:40002" + c.Logf("Connecting to %s...", host) + session, err := mgo.Dial(fmt.Sprintf("root:rapadura@%s?authMechanism=SCRAM-SHA-1", host)) + c.Assert(err, IsNil) + defer session.Close() + + mycoll := session.DB("admin").C("mycoll") + + c.Logf("Connected! Testing the need for authentication...") + err = mycoll.Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestAuthX509Cred(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + binfo, err := session.BuildInfo() + c.Assert(err, IsNil) + if binfo.OpenSSLVersion == "" { + c.Skip("server does not support SSL") + } + + clientCertPEM, err := ioutil.ReadFile("testdb/client.pem") + c.Assert(err, IsNil) + + clientCert, err := tls.X509KeyPair(clientCertPEM, clientCertPEM) + c.Assert(err, IsNil) + + tlsConfig := &tls.Config{ + // Isolating tests to client certs, don't care about server validation. + InsecureSkipVerify: true, + Certificates: []tls.Certificate{clientCert}, + } + + var host = "localhost:40003" + c.Logf("Connecting to %s...", host) + session, err = mgo.DialWithInfo(&mgo.DialInfo{ + Addrs: []string{host}, + DialServer: func(addr *mgo.ServerAddr) (net.Conn, error) { + return tls.Dial("tcp", addr.String(), tlsConfig) + }, + }) + c.Assert(err, IsNil) + defer session.Close() + + err = session.Login(&mgo.Credential{Username: "root", Password: "rapadura"}) + c.Assert(err, IsNil) + + // This needs to be kept in sync with client.pem + x509Subject := "CN=localhost,OU=Client,O=MGO,L=MGO,ST=MGO,C=GO" + + externalDB := session.DB("$external") + var x509User mgo.User = mgo.User{ + Username: x509Subject, + OtherDBRoles: map[string][]mgo.Role{"admin": []mgo.Role{mgo.RoleRoot}}, + } + err = externalDB.UpsertUser(&x509User) + c.Assert(err, IsNil) + + session.LogoutAll() + + c.Logf("Connected! Ensuring authentication is required...") + names, err := session.DatabaseNames() + c.Assert(err, ErrorMatches, "not authorized .*") + + cred := &mgo.Credential{ + Username: x509Subject, + Mechanism: "MONGODB-X509", + Source: "$external", + } + + c.Logf("Authenticating...") + err = session.Login(cred) + c.Assert(err, IsNil) + c.Logf("Authenticated!") + + names, err = session.DatabaseNames() + c.Assert(err, IsNil) + c.Assert(len(names) > 0, Equals, true) +} + +var ( + plainFlag = flag.String("plain", "", "Host to test PLAIN authentication against (depends on custom environment)") + plainUser = "einstein" + plainPass = "password" +) + +func (s *S) TestAuthPlainCred(c *C) { + if *plainFlag == "" { + c.Skip("no -plain") + } + cred := &mgo.Credential{ + Username: plainUser, + Password: plainPass, + Source: "$external", + Mechanism: "PLAIN", + } + c.Logf("Connecting to %s...", *plainFlag) + session, err := mgo.Dial(*plainFlag) + c.Assert(err, IsNil) + defer session.Close() + + records := session.DB("records").C("records") + + c.Logf("Connected! Testing the need for authentication...") + err = records.Find(nil).One(nil) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + c.Logf("Authenticating...") + err = session.Login(cred) + c.Assert(err, IsNil) + c.Logf("Authenticated!") + + c.Logf("Connected! Testing the need for authentication...") + err = records.Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestAuthPlainURL(c *C) { + if *plainFlag == "" { + c.Skip("no -plain") + } + c.Logf("Connecting to %s...", *plainFlag) + session, err := mgo.Dial(fmt.Sprintf("%s:%s@%s?authMechanism=PLAIN", url.QueryEscape(plainUser), url.QueryEscape(plainPass), *plainFlag)) + c.Assert(err, IsNil) + defer session.Close() + + c.Logf("Connected! Testing the need for authentication...") + err = session.DB("records").C("records").Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +var ( + kerberosFlag = flag.Bool("kerberos", false, "Test Kerberos authentication (depends on custom environment)") + kerberosHost = "ldaptest.10gen.cc" + kerberosUser = "drivers@LDAPTEST.10GEN.CC" + + winKerberosPasswordEnv = "MGO_KERBEROS_PASSWORD" +) + + +// Kerberos has its own suite because it talks to a remote server +// that is prepared to authenticate against a kerberos deployment. +type KerberosSuite struct{} + +var _ = Suite(&KerberosSuite{}) + +func (kerberosSuite *KerberosSuite) SetUpSuite(c *C) { + mgo.SetDebug(true) + mgo.SetStats(true) +} + +func (kerberosSuite *KerberosSuite) TearDownSuite(c *C) { + mgo.SetDebug(false) + mgo.SetStats(false) +} + +func (kerberosSuite *KerberosSuite) SetUpTest(c *C) { + mgo.SetLogger((*cLogger)(c)) + mgo.ResetStats() +} + +func (kerberosSuite *KerberosSuite) TearDownTest(c *C) { + mgo.SetLogger(nil) +} + +func (kerberosSuite *KerberosSuite) TestAuthKerberosCred(c *C) { + if !*kerberosFlag { + c.Skip("no -kerberos") + } + cred := &mgo.Credential{ + Username: kerberosUser, + Mechanism: "GSSAPI", + } + windowsAppendPasswordToCredential(cred) + c.Logf("Connecting to %s...", kerberosHost) + session, err := mgo.Dial(kerberosHost) + c.Assert(err, IsNil) + defer session.Close() + + c.Logf("Connected! Testing the need for authentication...") + n, err := session.DB("kerberos").C("test").Find(M{}).Count() + c.Assert(err, ErrorMatches, ".*authorized.*") + + c.Logf("Authenticating...") + err = session.Login(cred) + c.Assert(err, IsNil) + c.Logf("Authenticated!") + + n, err = session.DB("kerberos").C("test").Find(M{}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 1) +} + +func (kerberosSuite *KerberosSuite) TestAuthKerberosURL(c *C) { + if !*kerberosFlag { + c.Skip("no -kerberos") + } + c.Logf("Connecting to %s...", kerberosHost) + connectUri := url.QueryEscape(kerberosUser) + "@" + kerberosHost + "?authMechanism=GSSAPI" + if runtime.GOOS == "windows" { + connectUri = url.QueryEscape(kerberosUser) + ":" + url.QueryEscape(getWindowsKerberosPassword()) + "@" + kerberosHost + "?authMechanism=GSSAPI" + } + session, err := mgo.Dial(connectUri) + c.Assert(err, IsNil) + defer session.Close() + n, err := session.DB("kerberos").C("test").Find(M{}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 1) +} + +func (kerberosSuite *KerberosSuite) TestAuthKerberosServiceName(c *C) { + if !*kerberosFlag { + c.Skip("no -kerberos") + } + + wrongServiceName := "wrong" + rightServiceName := "mongodb" + + cred := &mgo.Credential{ + Username: kerberosUser, + Mechanism: "GSSAPI", + Service: wrongServiceName, + } + windowsAppendPasswordToCredential(cred) + + c.Logf("Connecting to %s...", kerberosHost) + session, err := mgo.Dial(kerberosHost) + c.Assert(err, IsNil) + defer session.Close() + + c.Logf("Authenticating with incorrect service name...") + err = session.Login(cred) + c.Assert(err, ErrorMatches, ".*@LDAPTEST.10GEN.CC not found.*") + + cred.Service = rightServiceName + c.Logf("Authenticating with correct service name...") + err = session.Login(cred) + c.Assert(err, IsNil) + c.Logf("Authenticated!") + + n, err := session.DB("kerberos").C("test").Find(M{}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 1) +} + +func (kerberosSuite *KerberosSuite) TestAuthKerberosServiceHost(c *C) { + if !*kerberosFlag { + c.Skip("no -kerberos") + } + + wrongServiceHost := "eggs.bacon.tk" + rightServiceHost := kerberosHost + + cred := &mgo.Credential{ + Username: kerberosUser, + Mechanism: "GSSAPI", + ServiceHost: wrongServiceHost, + } + windowsAppendPasswordToCredential(cred) + + c.Logf("Connecting to %s...", kerberosHost) + session, err := mgo.Dial(kerberosHost) + c.Assert(err, IsNil) + defer session.Close() + + c.Logf("Authenticating with incorrect service host...") + err = session.Login(cred) + c.Assert(err, ErrorMatches, ".*@LDAPTEST.10GEN.CC not found.*") + + cred.ServiceHost = rightServiceHost + c.Logf("Authenticating with correct service host...") + err = session.Login(cred) + c.Assert(err, IsNil) + c.Logf("Authenticated!") + + n, err := session.DB("kerberos").C("test").Find(M{}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 1) +} + +// No kinit on SSPI-style Kerberos, so we need to provide a password. In order +// to avoid inlining password, require it to be set as an environment variable, +// for instance: `SET MGO_KERBEROS_PASSWORD=this_isnt_the_password` +func getWindowsKerberosPassword() string { + pw := os.Getenv(winKerberosPasswordEnv) + if pw == "" { + panic(fmt.Sprintf("Need to set %v environment variable to run Kerberos tests on Windows", winKerberosPasswordEnv)) + } + return pw +} + +func windowsAppendPasswordToCredential(cred *mgo.Credential) { + if runtime.GOOS == "windows" { + cred.Password = getWindowsKerberosPassword() + } +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/LICENSE b/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/LICENSE new file mode 100644 index 0000000..8903260 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/LICENSE @@ -0,0 +1,25 @@ +BSON library for Go + +Copyright (c) 2010-2012 - Gustavo Niemeyer + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/bson.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/bson.go new file mode 100644 index 0000000..68e932f --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/bson.go @@ -0,0 +1,698 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Package bson is an implementation of the BSON specification for Go: +// +// http://bsonspec.org +// +// It was created as part of the mgo MongoDB driver for Go, but is standalone +// and may be used on its own without the driver. +package bson + +import ( + "crypto/md5" + "crypto/rand" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "os" + "reflect" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" +) + +// -------------------------------------------------------------------------- +// The public API. + +// A value implementing the bson.Getter interface will have its GetBSON +// method called when the given value has to be marshalled, and the result +// of this method will be marshaled in place of the actual object. +// +// If GetBSON returns return a non-nil error, the marshalling procedure +// will stop and error out with the provided value. +type Getter interface { + GetBSON() (interface{}, error) +} + +// A value implementing the bson.Setter interface will receive the BSON +// value via the SetBSON method during unmarshaling, and the object +// itself will not be changed as usual. +// +// If setting the value works, the method should return nil or alternatively +// bson.SetZero to set the respective field to its zero value (nil for +// pointer types). If SetBSON returns a value of type bson.TypeError, the +// BSON value will be omitted from a map or slice being decoded and the +// unmarshalling will continue. If it returns any other non-nil error, the +// unmarshalling procedure will stop and error out with the provided value. +// +// This interface is generally useful in pointer receivers, since the method +// will want to change the receiver. A type field that implements the Setter +// interface doesn't have to be a pointer, though. +// +// Unlike the usual behavior, unmarshalling onto a value that implements a +// Setter interface will NOT reset the value to its zero state. This allows +// the value to decide by itself how to be unmarshalled. +// +// For example: +// +// type MyString string +// +// func (s *MyString) SetBSON(raw bson.Raw) error { +// return raw.Unmarshal(s) +// } +// +type Setter interface { + SetBSON(raw Raw) error +} + +// SetZero may be returned from a SetBSON method to have the value set to +// its respective zero value. When used in pointer values, this will set the +// field to nil rather than to the pre-allocated value. +var SetZero = errors.New("set to zero") + +// M is a convenient alias for a map[string]interface{} map, useful for +// dealing with BSON in a native way. For instance: +// +// bson.M{"a": 1, "b": true} +// +// There's no special handling for this type in addition to what's done anyway +// for an equivalent map type. Elements in the map will be dumped in an +// undefined ordered. See also the bson.D type for an ordered alternative. +type M map[string]interface{} + +// D represents a BSON document containing ordered elements. For example: +// +// bson.D{{"a", 1}, {"b", true}} +// +// In some situations, such as when creating indexes for MongoDB, the order in +// which the elements are defined is important. If the order is not important, +// using a map is generally more comfortable. See bson.M and bson.RawD. +type D []DocElem + +// DocElem is an element of the bson.D document representation. +type DocElem struct { + Name string + Value interface{} +} + +// Map returns a map out of the ordered element name/value pairs in d. +func (d D) Map() (m M) { + m = make(M, len(d)) + for _, item := range d { + m[item.Name] = item.Value + } + return m +} + +// The Raw type represents raw unprocessed BSON documents and elements. +// Kind is the kind of element as defined per the BSON specification, and +// Data is the raw unprocessed data for the respective element. +// Using this type it is possible to unmarshal or marshal values partially. +// +// Relevant documentation: +// +// http://bsonspec.org/#/specification +// +type Raw struct { + Kind byte + Data []byte +} + +// RawD represents a BSON document containing raw unprocessed elements. +// This low-level representation may be useful when lazily processing +// documents of uncertain content, or when manipulating the raw content +// documents in general. +type RawD []RawDocElem + +// See the RawD type. +type RawDocElem struct { + Name string + Value Raw +} + +// ObjectId is a unique ID identifying a BSON value. It must be exactly 12 bytes +// long. MongoDB objects by default have such a property set in their "_id" +// property. +// +// http://www.mongodb.org/display/DOCS/Object+IDs +type ObjectId string + +// ObjectIdHex returns an ObjectId from the provided hex representation. +// Calling this function with an invalid hex representation will +// cause a runtime panic. See the IsObjectIdHex function. +func ObjectIdHex(s string) ObjectId { + d, err := hex.DecodeString(s) + if err != nil || len(d) != 12 { + panic(fmt.Sprintf("Invalid input to ObjectIdHex: %q", s)) + } + return ObjectId(d) +} + +// IsObjectIdHex returns whether s is a valid hex representation of +// an ObjectId. See the ObjectIdHex function. +func IsObjectIdHex(s string) bool { + if len(s) != 24 { + return false + } + _, err := hex.DecodeString(s) + return err == nil +} + +// objectIdCounter is atomically incremented when generating a new ObjectId +// using NewObjectId() function. It's used as a counter part of an id. +var objectIdCounter uint32 = 0 + +// machineId stores machine id generated once and used in subsequent calls +// to NewObjectId function. +var machineId = readMachineId() + +// readMachineId generates machine id and puts it into the machineId global +// variable. If this function fails to get the hostname, it will cause +// a runtime error. +func readMachineId() []byte { + var sum [3]byte + id := sum[:] + hostname, err1 := os.Hostname() + if err1 != nil { + _, err2 := io.ReadFull(rand.Reader, id) + if err2 != nil { + panic(fmt.Errorf("cannot get hostname: %v; %v", err1, err2)) + } + return id + } + hw := md5.New() + hw.Write([]byte(hostname)) + copy(id, hw.Sum(nil)) + return id +} + +// NewObjectId returns a new unique ObjectId. +func NewObjectId() ObjectId { + var b [12]byte + // Timestamp, 4 bytes, big endian + binary.BigEndian.PutUint32(b[:], uint32(time.Now().Unix())) + // Machine, first 3 bytes of md5(hostname) + b[4] = machineId[0] + b[5] = machineId[1] + b[6] = machineId[2] + // Pid, 2 bytes, specs don't specify endianness, but we use big endian. + pid := os.Getpid() + b[7] = byte(pid >> 8) + b[8] = byte(pid) + // Increment, 3 bytes, big endian + i := atomic.AddUint32(&objectIdCounter, 1) + b[9] = byte(i >> 16) + b[10] = byte(i >> 8) + b[11] = byte(i) + return ObjectId(b[:]) +} + +// NewObjectIdWithTime returns a dummy ObjectId with the timestamp part filled +// with the provided number of seconds from epoch UTC, and all other parts +// filled with zeroes. It's not safe to insert a document with an id generated +// by this method, it is useful only for queries to find documents with ids +// generated before or after the specified timestamp. +func NewObjectIdWithTime(t time.Time) ObjectId { + var b [12]byte + binary.BigEndian.PutUint32(b[:4], uint32(t.Unix())) + return ObjectId(string(b[:])) +} + +// String returns a hex string representation of the id. +// Example: ObjectIdHex("4d88e15b60f486e428412dc9"). +func (id ObjectId) String() string { + return fmt.Sprintf(`ObjectIdHex("%x")`, string(id)) +} + +// Hex returns a hex representation of the ObjectId. +func (id ObjectId) Hex() string { + return hex.EncodeToString([]byte(id)) +} + +// MarshalJSON turns a bson.ObjectId into a json.Marshaller. +func (id ObjectId) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`"%x"`, string(id))), nil +} + +// UnmarshalJSON turns *bson.ObjectId into a json.Unmarshaller. +func (id *ObjectId) UnmarshalJSON(data []byte) error { + if len(data) != 26 || data[0] != '"' || data[25] != '"' { + return errors.New(fmt.Sprintf("Invalid ObjectId in JSON: %s", string(data))) + } + var buf [12]byte + _, err := hex.Decode(buf[:], data[1:25]) + if err != nil { + return errors.New(fmt.Sprintf("Invalid ObjectId in JSON: %s (%s)", string(data), err)) + } + *id = ObjectId(string(buf[:])) + return nil +} + +// Valid returns true if id is valid. A valid id must contain exactly 12 bytes. +func (id ObjectId) Valid() bool { + return len(id) == 12 +} + +// byteSlice returns byte slice of id from start to end. +// Calling this function with an invalid id will cause a runtime panic. +func (id ObjectId) byteSlice(start, end int) []byte { + if len(id) != 12 { + panic(fmt.Sprintf("Invalid ObjectId: %q", string(id))) + } + return []byte(string(id)[start:end]) +} + +// Time returns the timestamp part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Time() time.Time { + // First 4 bytes of ObjectId is 32-bit big-endian seconds from epoch. + secs := int64(binary.BigEndian.Uint32(id.byteSlice(0, 4))) + return time.Unix(secs, 0) +} + +// Machine returns the 3-byte machine id part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Machine() []byte { + return id.byteSlice(4, 7) +} + +// Pid returns the process id part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Pid() uint16 { + return binary.BigEndian.Uint16(id.byteSlice(7, 9)) +} + +// Counter returns the incrementing value part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Counter() int32 { + b := id.byteSlice(9, 12) + // Counter is stored as big-endian 3-byte value + return int32(uint32(b[0])<<16 | uint32(b[1])<<8 | uint32(b[2])) +} + +// The Symbol type is similar to a string and is used in languages with a +// distinct symbol type. +type Symbol string + +// Now returns the current time with millisecond precision. MongoDB stores +// timestamps with the same precision, so a Time returned from this method +// will not change after a roundtrip to the database. That's the only reason +// why this function exists. Using the time.Now function also works fine +// otherwise. +func Now() time.Time { + return time.Unix(0, time.Now().UnixNano()/1e6*1e6) +} + +// MongoTimestamp is a special internal type used by MongoDB that for some +// strange reason has its own datatype defined in BSON. +type MongoTimestamp int64 + +type orderKey int64 + +// MaxKey is a special value that compares higher than all other possible BSON +// values in a MongoDB database. +var MaxKey = orderKey(1<<63 - 1) + +// MinKey is a special value that compares lower than all other possible BSON +// values in a MongoDB database. +var MinKey = orderKey(-1 << 63) + +type undefined struct{} + +// Undefined represents the undefined BSON value. +var Undefined undefined + +// Binary is a representation for non-standard binary values. Any kind should +// work, but the following are known as of this writing: +// +// 0x00 - Generic. This is decoded as []byte(data), not Binary{0x00, data}. +// 0x01 - Function (!?) +// 0x02 - Obsolete generic. +// 0x03 - UUID +// 0x05 - MD5 +// 0x80 - User defined. +// +type Binary struct { + Kind byte + Data []byte +} + +// RegEx represents a regular expression. The Options field may contain +// individual characters defining the way in which the pattern should be +// applied, and must be sorted. Valid options as of this writing are 'i' for +// case insensitive matching, 'm' for multi-line matching, 'x' for verbose +// mode, 'l' to make \w, \W, and similar be locale-dependent, 's' for dot-all +// mode (a '.' matches everything), and 'u' to make \w, \W, and similar match +// unicode. The value of the Options parameter is not verified before being +// marshaled into the BSON format. +type RegEx struct { + Pattern string + Options string +} + +// JavaScript is a type that holds JavaScript code. If Scope is non-nil, it +// will be marshaled as a mapping from identifiers to values that may be +// used when evaluating the provided Code. +type JavaScript struct { + Code string + Scope interface{} +} + +// DBPointer refers to a document id in a namespace. +// +// This type is deprecated in the BSON specification and should not be used +// except for backwards compatibility with ancient applications. +type DBPointer struct { + Namespace string + Id ObjectId +} + +const initialBufferSize = 64 + +func handleErr(err *error) { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } else if _, ok := r.(externalPanic); ok { + panic(r) + } else if s, ok := r.(string); ok { + *err = errors.New(s) + } else if e, ok := r.(error); ok { + *err = e + } else { + panic(r) + } + } +} + +// Marshal serializes the in value, which may be a map or a struct value. +// In the case of struct values, only exported fields will be serialized. +// The lowercased field name is used as the key for each exported field, +// but this behavior may be changed using the respective field tag. +// The tag may also contain flags to tweak the marshalling behavior for +// the field. The tag formats accepted are: +// +// "[][,[,]]" +// +// `(...) bson:"[][,[,]]" (...)` +// +// The following flags are currently supported: +// +// omitempty Only include the field if it's not set to the zero +// value for the type or to empty slices or maps. +// +// minsize Marshal an int64 value as an int32, if that's feasible +// while preserving the numeric value. +// +// inline Inline the field, which must be a struct or a map, +// causing all of its fields or keys to be processed as if +// they were part of the outer struct. For maps, keys must +// not conflict with the bson keys of other struct fields. +// +// Some examples: +// +// type T struct { +// A bool +// B int "myb" +// C string "myc,omitempty" +// D string `bson:",omitempty" json:"jsonkey"` +// E int64 ",minsize" +// F int64 "myf,omitempty,minsize" +// } +// +func Marshal(in interface{}) (out []byte, err error) { + defer handleErr(&err) + e := &encoder{make([]byte, 0, initialBufferSize)} + e.addDoc(reflect.ValueOf(in)) + return e.out, nil +} + +// Unmarshal deserializes data from in into the out value. The out value +// must be a map, a pointer to a struct, or a pointer to a bson.D value. +// The lowercased field name is used as the key for each exported field, +// but this behavior may be changed using the respective field tag. +// The tag may also contain flags to tweak the marshalling behavior for +// the field. The tag formats accepted are: +// +// "[][,[,]]" +// +// `(...) bson:"[][,[,]]" (...)` +// +// The following flags are currently supported during unmarshal (see the +// Marshal method for other flags): +// +// inline Inline the field, which must be a struct or a map. +// Inlined structs are handled as if its fields were part +// of the outer struct. An inlined map causes keys that do +// not match any other struct field to be inserted in the +// map rather than being discarded as usual. +// +// The target field or element types of out may not necessarily match +// the BSON values of the provided data. The following conversions are +// made automatically: +// +// - Numeric types are converted if at least the integer part of the +// value would be preserved correctly +// - Bools are converted to numeric types as 1 or 0 +// - Numeric types are converted to bools as true if not 0 or false otherwise +// - Binary and string BSON data is converted to a string, array or byte slice +// +// If the value would not fit the type and cannot be converted, it's +// silently skipped. +// +// Pointer values are initialized when necessary. +func Unmarshal(in []byte, out interface{}) (err error) { + if raw, ok := out.(*Raw); ok { + raw.Kind = 3 + raw.Data = in + return nil + } + defer handleErr(&err) + v := reflect.ValueOf(out) + switch v.Kind() { + case reflect.Ptr: + fallthrough + case reflect.Map: + d := newDecoder(in) + d.readDocTo(v) + case reflect.Struct: + return errors.New("Unmarshal can't deal with struct values. Use a pointer.") + default: + return errors.New("Unmarshal needs a map or a pointer to a struct.") + } + return nil +} + +// Unmarshal deserializes raw into the out value. If the out value type +// is not compatible with raw, a *bson.TypeError is returned. +// +// See the Unmarshal function documentation for more details on the +// unmarshalling process. +func (raw Raw) Unmarshal(out interface{}) (err error) { + defer handleErr(&err) + v := reflect.ValueOf(out) + switch v.Kind() { + case reflect.Ptr: + v = v.Elem() + fallthrough + case reflect.Map: + d := newDecoder(raw.Data) + good := d.readElemTo(v, raw.Kind) + if !good { + return &TypeError{v.Type(), raw.Kind} + } + case reflect.Struct: + return errors.New("Raw Unmarshal can't deal with struct values. Use a pointer.") + default: + return errors.New("Raw Unmarshal needs a map or a valid pointer.") + } + return nil +} + +type TypeError struct { + Type reflect.Type + Kind byte +} + +func (e *TypeError) Error() string { + return fmt.Sprintf("BSON kind 0x%02x isn't compatible with type %s", e.Kind, e.Type.String()) +} + +// -------------------------------------------------------------------------- +// Maintain a mapping of keys to structure field indexes + +type structInfo struct { + FieldsMap map[string]fieldInfo + FieldsList []fieldInfo + InlineMap int + Zero reflect.Value +} + +type fieldInfo struct { + Key string + Num int + OmitEmpty bool + MinSize bool + Inline []int +} + +var structMap = make(map[reflect.Type]*structInfo) +var structMapMutex sync.RWMutex + +type externalPanic string + +func (e externalPanic) String() string { + return string(e) +} + +func getStructInfo(st reflect.Type) (*structInfo, error) { + structMapMutex.RLock() + sinfo, found := structMap[st] + structMapMutex.RUnlock() + if found { + return sinfo, nil + } + n := st.NumField() + fieldsMap := make(map[string]fieldInfo) + fieldsList := make([]fieldInfo, 0, n) + inlineMap := -1 + for i := 0; i != n; i++ { + field := st.Field(i) + if field.PkgPath != "" { + continue // Private field + } + + info := fieldInfo{Num: i} + + tag := field.Tag.Get("bson") + if tag == "" && strings.Index(string(field.Tag), ":") < 0 { + tag = string(field.Tag) + } + if tag == "-" { + continue + } + + // XXX Drop this after a few releases. + if s := strings.Index(tag, "/"); s >= 0 { + recommend := tag[:s] + for _, c := range tag[s+1:] { + switch c { + case 'c': + recommend += ",omitempty" + case 's': + recommend += ",minsize" + default: + msg := fmt.Sprintf("Unsupported flag %q in tag %q of type %s", string([]byte{uint8(c)}), tag, st) + panic(externalPanic(msg)) + } + } + msg := fmt.Sprintf("Replace tag %q in field %s of type %s by %q", tag, field.Name, st, recommend) + panic(externalPanic(msg)) + } + + inline := false + fields := strings.Split(tag, ",") + if len(fields) > 1 { + for _, flag := range fields[1:] { + switch flag { + case "omitempty": + info.OmitEmpty = true + case "minsize": + info.MinSize = true + case "inline": + inline = true + default: + msg := fmt.Sprintf("Unsupported flag %q in tag %q of type %s", flag, tag, st) + panic(externalPanic(msg)) + } + } + tag = fields[0] + } + + if inline { + switch field.Type.Kind() { + case reflect.Map: + if inlineMap >= 0 { + return nil, errors.New("Multiple ,inline maps in struct " + st.String()) + } + if field.Type.Key() != reflect.TypeOf("") { + return nil, errors.New("Option ,inline needs a map with string keys in struct " + st.String()) + } + inlineMap = info.Num + case reflect.Struct: + sinfo, err := getStructInfo(field.Type) + if err != nil { + return nil, err + } + for _, finfo := range sinfo.FieldsList { + if _, found := fieldsMap[finfo.Key]; found { + msg := "Duplicated key '" + finfo.Key + "' in struct " + st.String() + return nil, errors.New(msg) + } + if finfo.Inline == nil { + finfo.Inline = []int{i, finfo.Num} + } else { + finfo.Inline = append([]int{i}, finfo.Inline...) + } + fieldsMap[finfo.Key] = finfo + fieldsList = append(fieldsList, finfo) + } + default: + panic("Option ,inline needs a struct value or map field") + } + continue + } + + if tag != "" { + info.Key = tag + } else { + info.Key = strings.ToLower(field.Name) + } + + if _, found = fieldsMap[info.Key]; found { + msg := "Duplicated key '" + info.Key + "' in struct " + st.String() + return nil, errors.New(msg) + } + + fieldsList = append(fieldsList, info) + fieldsMap[info.Key] = info + } + sinfo = &structInfo{ + fieldsMap, + fieldsList, + inlineMap, + reflect.New(st).Elem(), + } + structMapMutex.Lock() + structMap[st] = sinfo + structMapMutex.Unlock() + return sinfo, nil +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/bson_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/bson_test.go new file mode 100644 index 0000000..0606c49 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/bson_test.go @@ -0,0 +1,1543 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. + +package bson_test + +import ( + "encoding/binary" + "encoding/json" + "errors" + "net/url" + "reflect" + "testing" + "time" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2/bson" +) + +func TestAll(t *testing.T) { + TestingT(t) +} + +type S struct{} + +var _ = Suite(&S{}) + +// Wrap up the document elements contained in data, prepending the int32 +// length of the data, and appending the '\x00' value closing the document. +func wrapInDoc(data string) string { + result := make([]byte, len(data)+5) + binary.LittleEndian.PutUint32(result, uint32(len(result))) + copy(result[4:], []byte(data)) + return string(result) +} + +func makeZeroDoc(value interface{}) (zero interface{}) { + v := reflect.ValueOf(value) + t := v.Type() + switch t.Kind() { + case reflect.Map: + mv := reflect.MakeMap(t) + zero = mv.Interface() + case reflect.Ptr: + pv := reflect.New(v.Type().Elem()) + zero = pv.Interface() + case reflect.Slice, reflect.Int: + zero = reflect.New(t).Interface() + default: + panic("unsupported doc type") + } + return zero +} + +func testUnmarshal(c *C, data string, obj interface{}) { + zero := makeZeroDoc(obj) + err := bson.Unmarshal([]byte(data), zero) + c.Assert(err, IsNil) + c.Assert(zero, DeepEquals, obj) +} + +type testItemType struct { + obj interface{} + data string +} + +// -------------------------------------------------------------------------- +// Samples from bsonspec.org: + +var sampleItems = []testItemType{ + {bson.M{"hello": "world"}, + "\x16\x00\x00\x00\x02hello\x00\x06\x00\x00\x00world\x00\x00"}, + {bson.M{"BSON": []interface{}{"awesome", float64(5.05), 1986}}, + "1\x00\x00\x00\x04BSON\x00&\x00\x00\x00\x020\x00\x08\x00\x00\x00" + + "awesome\x00\x011\x00333333\x14@\x102\x00\xc2\x07\x00\x00\x00\x00"}, +} + +func (s *S) TestMarshalSampleItems(c *C) { + for i, item := range sampleItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, item.data, Commentf("Failed on item %d", i)) + } +} + +func (s *S) TestUnmarshalSampleItems(c *C) { + for i, item := range sampleItems { + value := bson.M{} + err := bson.Unmarshal([]byte(item.data), value) + c.Assert(err, IsNil) + c.Assert(value, DeepEquals, item.obj, Commentf("Failed on item %d", i)) + } +} + +// -------------------------------------------------------------------------- +// Every type, ordered by the type flag. These are not wrapped with the +// length and last \x00 from the document. wrapInDoc() computes them. +// Note that all of them should be supported as two-way conversions. + +var allItems = []testItemType{ + {bson.M{}, + ""}, + {bson.M{"_": float64(5.05)}, + "\x01_\x00333333\x14@"}, + {bson.M{"_": "yo"}, + "\x02_\x00\x03\x00\x00\x00yo\x00"}, + {bson.M{"_": bson.M{"a": true}}, + "\x03_\x00\x09\x00\x00\x00\x08a\x00\x01\x00"}, + {bson.M{"_": []interface{}{true, false}}, + "\x04_\x00\r\x00\x00\x00\x080\x00\x01\x081\x00\x00\x00"}, + {bson.M{"_": []byte("yo")}, + "\x05_\x00\x02\x00\x00\x00\x00yo"}, + {bson.M{"_": bson.Binary{0x80, []byte("udef")}}, + "\x05_\x00\x04\x00\x00\x00\x80udef"}, + {bson.M{"_": bson.Undefined}, // Obsolete, but still seen in the wild. + "\x06_\x00"}, + {bson.M{"_": bson.ObjectId("0123456789ab")}, + "\x07_\x000123456789ab"}, + {bson.M{"_": bson.DBPointer{"testnamespace", bson.ObjectId("0123456789ab")}}, + "\x0C_\x00\x0e\x00\x00\x00testnamespace\x000123456789ab"}, + {bson.M{"_": false}, + "\x08_\x00\x00"}, + {bson.M{"_": true}, + "\x08_\x00\x01"}, + {bson.M{"_": time.Unix(0, 258e6)}, // Note the NS <=> MS conversion. + "\x09_\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"_": nil}, + "\x0A_\x00"}, + {bson.M{"_": bson.RegEx{"ab", "cd"}}, + "\x0B_\x00ab\x00cd\x00"}, + {bson.M{"_": bson.JavaScript{"code", nil}}, + "\x0D_\x00\x05\x00\x00\x00code\x00"}, + {bson.M{"_": bson.Symbol("sym")}, + "\x0E_\x00\x04\x00\x00\x00sym\x00"}, + {bson.M{"_": bson.JavaScript{"code", bson.M{"": nil}}}, + "\x0F_\x00\x14\x00\x00\x00\x05\x00\x00\x00code\x00" + + "\x07\x00\x00\x00\x0A\x00\x00"}, + {bson.M{"_": 258}, + "\x10_\x00\x02\x01\x00\x00"}, + {bson.M{"_": bson.MongoTimestamp(258)}, + "\x11_\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"_": int64(258)}, + "\x12_\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"_": int64(258 << 32)}, + "\x12_\x00\x00\x00\x00\x00\x02\x01\x00\x00"}, + {bson.M{"_": bson.MaxKey}, + "\x7F_\x00"}, + {bson.M{"_": bson.MinKey}, + "\xFF_\x00"}, +} + +func (s *S) TestMarshalAllItems(c *C) { + for i, item := range allItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalAllItems(c *C) { + for i, item := range allItems { + value := bson.M{} + err := bson.Unmarshal([]byte(wrapInDoc(item.data)), value) + c.Assert(err, IsNil) + c.Assert(value, DeepEquals, item.obj, Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalRawAllItems(c *C) { + for i, item := range allItems { + if len(item.data) == 0 { + continue + } + value := item.obj.(bson.M)["_"] + if value == nil { + continue + } + pv := reflect.New(reflect.ValueOf(value).Type()) + raw := bson.Raw{item.data[0], []byte(item.data[3:])} + c.Logf("Unmarshal raw: %#v, %#v", raw, pv.Interface()) + err := raw.Unmarshal(pv.Interface()) + c.Assert(err, IsNil) + c.Assert(pv.Elem().Interface(), DeepEquals, value, Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalRawIncompatible(c *C) { + raw := bson.Raw{0x08, []byte{0x01}} // true + err := raw.Unmarshal(&struct{}{}) + c.Assert(err, ErrorMatches, "BSON kind 0x08 isn't compatible with type struct \\{\\}") +} + +func (s *S) TestUnmarshalZeroesStruct(c *C) { + data, err := bson.Marshal(bson.M{"b": 2}) + c.Assert(err, IsNil) + type T struct{ A, B int } + v := T{A: 1} + err = bson.Unmarshal(data, &v) + c.Assert(err, IsNil) + c.Assert(v.A, Equals, 0) + c.Assert(v.B, Equals, 2) +} + +func (s *S) TestUnmarshalZeroesMap(c *C) { + data, err := bson.Marshal(bson.M{"b": 2}) + c.Assert(err, IsNil) + m := bson.M{"a": 1} + err = bson.Unmarshal(data, &m) + c.Assert(err, IsNil) + c.Assert(m, DeepEquals, bson.M{"b": 2}) +} + +func (s *S) TestUnmarshalNonNilInterface(c *C) { + data, err := bson.Marshal(bson.M{"b": 2}) + c.Assert(err, IsNil) + m := bson.M{"a": 1} + var i interface{} + i = m + err = bson.Unmarshal(data, &i) + c.Assert(err, IsNil) + c.Assert(i, DeepEquals, bson.M{"b": 2}) + c.Assert(m, DeepEquals, bson.M{"a": 1}) +} + +// -------------------------------------------------------------------------- +// Some one way marshaling operations which would unmarshal differently. + +var oneWayMarshalItems = []testItemType{ + // These are being passed as pointers, and will unmarshal as values. + {bson.M{"": &bson.Binary{0x02, []byte("old")}}, + "\x05\x00\x07\x00\x00\x00\x02\x03\x00\x00\x00old"}, + {bson.M{"": &bson.Binary{0x80, []byte("udef")}}, + "\x05\x00\x04\x00\x00\x00\x80udef"}, + {bson.M{"": &bson.RegEx{"ab", "cd"}}, + "\x0B\x00ab\x00cd\x00"}, + {bson.M{"": &bson.JavaScript{"code", nil}}, + "\x0D\x00\x05\x00\x00\x00code\x00"}, + {bson.M{"": &bson.JavaScript{"code", bson.M{"": nil}}}, + "\x0F\x00\x14\x00\x00\x00\x05\x00\x00\x00code\x00" + + "\x07\x00\x00\x00\x0A\x00\x00"}, + + // There's no float32 type in BSON. Will encode as a float64. + {bson.M{"": float32(5.05)}, + "\x01\x00\x00\x00\x00@33\x14@"}, + + // The array will be unmarshaled as a slice instead. + {bson.M{"": [2]bool{true, false}}, + "\x04\x00\r\x00\x00\x00\x080\x00\x01\x081\x00\x00\x00"}, + + // The typed slice will be unmarshaled as []interface{}. + {bson.M{"": []bool{true, false}}, + "\x04\x00\r\x00\x00\x00\x080\x00\x01\x081\x00\x00\x00"}, + + // Will unmarshal as a []byte. + {bson.M{"": bson.Binary{0x00, []byte("yo")}}, + "\x05\x00\x02\x00\x00\x00\x00yo"}, + {bson.M{"": bson.Binary{0x02, []byte("old")}}, + "\x05\x00\x07\x00\x00\x00\x02\x03\x00\x00\x00old"}, + + // No way to preserve the type information here. We might encode as a zero + // value, but this would mean that pointer values in structs wouldn't be + // able to correctly distinguish between unset and set to the zero value. + {bson.M{"": (*byte)(nil)}, + "\x0A\x00"}, + + // No int types smaller than int32 in BSON. Could encode this as a char, + // but it would still be ambiguous, take more, and be awkward in Go when + // loaded without typing information. + {bson.M{"": byte(8)}, + "\x10\x00\x08\x00\x00\x00"}, + + // There are no unsigned types in BSON. Will unmarshal as int32 or int64. + {bson.M{"": uint32(258)}, + "\x10\x00\x02\x01\x00\x00"}, + {bson.M{"": uint64(258)}, + "\x12\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"": uint64(258 << 32)}, + "\x12\x00\x00\x00\x00\x00\x02\x01\x00\x00"}, + + // This will unmarshal as int. + {bson.M{"": int32(258)}, + "\x10\x00\x02\x01\x00\x00"}, + + // That's a special case. The unsigned value is too large for an int32, + // so an int64 is used instead. + {bson.M{"": uint32(1<<32 - 1)}, + "\x12\x00\xFF\xFF\xFF\xFF\x00\x00\x00\x00"}, + {bson.M{"": uint(1<<32 - 1)}, + "\x12\x00\xFF\xFF\xFF\xFF\x00\x00\x00\x00"}, +} + +func (s *S) TestOneWayMarshalItems(c *C) { + for i, item := range oneWayMarshalItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), + Commentf("Failed on item %d", i)) + } +} + +// -------------------------------------------------------------------------- +// Two-way tests for user-defined structures using the samples +// from bsonspec.org. + +type specSample1 struct { + Hello string +} + +type specSample2 struct { + BSON []interface{} "BSON" +} + +var structSampleItems = []testItemType{ + {&specSample1{"world"}, + "\x16\x00\x00\x00\x02hello\x00\x06\x00\x00\x00world\x00\x00"}, + {&specSample2{[]interface{}{"awesome", float64(5.05), 1986}}, + "1\x00\x00\x00\x04BSON\x00&\x00\x00\x00\x020\x00\x08\x00\x00\x00" + + "awesome\x00\x011\x00333333\x14@\x102\x00\xc2\x07\x00\x00\x00\x00"}, +} + +func (s *S) TestMarshalStructSampleItems(c *C) { + for i, item := range structSampleItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, item.data, + Commentf("Failed on item %d", i)) + } +} + +func (s *S) TestUnmarshalStructSampleItems(c *C) { + for _, item := range structSampleItems { + testUnmarshal(c, item.data, item.obj) + } +} + +func (s *S) Test64bitInt(c *C) { + var i int64 = (1 << 31) + if int(i) > 0 { + data, err := bson.Marshal(bson.M{"i": int(i)}) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc("\x12i\x00\x00\x00\x00\x80\x00\x00\x00\x00")) + + var result struct{ I int } + err = bson.Unmarshal(data, &result) + c.Assert(err, IsNil) + c.Assert(int64(result.I), Equals, i) + } +} + +// -------------------------------------------------------------------------- +// Generic two-way struct marshaling tests. + +var bytevar = byte(8) +var byteptr = &bytevar + +var structItems = []testItemType{ + {&struct{ Ptr *byte }{nil}, + "\x0Aptr\x00"}, + {&struct{ Ptr *byte }{&bytevar}, + "\x10ptr\x00\x08\x00\x00\x00"}, + {&struct{ Ptr **byte }{&byteptr}, + "\x10ptr\x00\x08\x00\x00\x00"}, + {&struct{ Byte byte }{8}, + "\x10byte\x00\x08\x00\x00\x00"}, + {&struct{ Byte byte }{0}, + "\x10byte\x00\x00\x00\x00\x00"}, + {&struct { + V byte "Tag" + }{8}, + "\x10Tag\x00\x08\x00\x00\x00"}, + {&struct { + V *struct { + Byte byte + } + }{&struct{ Byte byte }{8}}, + "\x03v\x00" + "\x0f\x00\x00\x00\x10byte\x00\b\x00\x00\x00\x00"}, + {&struct{ priv byte }{}, ""}, + + // The order of the dumped fields should be the same in the struct. + {&struct{ A, C, B, D, F, E *byte }{}, + "\x0Aa\x00\x0Ac\x00\x0Ab\x00\x0Ad\x00\x0Af\x00\x0Ae\x00"}, + + {&struct{ V bson.Raw }{bson.Raw{0x03, []byte("\x0f\x00\x00\x00\x10byte\x00\b\x00\x00\x00\x00")}}, + "\x03v\x00" + "\x0f\x00\x00\x00\x10byte\x00\b\x00\x00\x00\x00"}, + {&struct{ V bson.Raw }{bson.Raw{0x10, []byte("\x00\x00\x00\x00")}}, + "\x10v\x00" + "\x00\x00\x00\x00"}, + + // Byte arrays. + {&struct{ V [2]byte }{[2]byte{'y', 'o'}}, + "\x05v\x00\x02\x00\x00\x00\x00yo"}, +} + +func (s *S) TestMarshalStructItems(c *C) { + for i, item := range structItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), + Commentf("Failed on item %d", i)) + } +} + +func (s *S) TestUnmarshalStructItems(c *C) { + for _, item := range structItems { + testUnmarshal(c, wrapInDoc(item.data), item.obj) + } +} + +func (s *S) TestUnmarshalRawStructItems(c *C) { + for i, item := range structItems { + raw := bson.Raw{0x03, []byte(wrapInDoc(item.data))} + zero := makeZeroDoc(item.obj) + err := raw.Unmarshal(zero) + c.Assert(err, IsNil) + c.Assert(zero, DeepEquals, item.obj, Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalRawNil(c *C) { + // Regression test: shouldn't try to nil out the pointer itself, + // as it's not settable. + raw := bson.Raw{0x0A, []byte{}} + err := raw.Unmarshal(&struct{}{}) + c.Assert(err, IsNil) +} + +// -------------------------------------------------------------------------- +// One-way marshaling tests. + +type dOnIface struct { + D interface{} +} + +type ignoreField struct { + Before string + Ignore string `bson:"-"` + After string +} + +var marshalItems = []testItemType{ + // Ordered document dump. Will unmarshal as a dictionary by default. + {bson.D{{"a", nil}, {"c", nil}, {"b", nil}, {"d", nil}, {"f", nil}, {"e", true}}, + "\x0Aa\x00\x0Ac\x00\x0Ab\x00\x0Ad\x00\x0Af\x00\x08e\x00\x01"}, + {MyD{{"a", nil}, {"c", nil}, {"b", nil}, {"d", nil}, {"f", nil}, {"e", true}}, + "\x0Aa\x00\x0Ac\x00\x0Ab\x00\x0Ad\x00\x0Af\x00\x08e\x00\x01"}, + {&dOnIface{bson.D{{"a", nil}, {"c", nil}, {"b", nil}, {"d", true}}}, + "\x03d\x00" + wrapInDoc("\x0Aa\x00\x0Ac\x00\x0Ab\x00\x08d\x00\x01")}, + + {bson.RawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}, + "\x0Aa\x00" + "\x0Ac\x00" + "\x08b\x00\x01"}, + {MyRawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}, + "\x0Aa\x00" + "\x0Ac\x00" + "\x08b\x00\x01"}, + {&dOnIface{bson.RawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}}, + "\x03d\x00" + wrapInDoc("\x0Aa\x00"+"\x0Ac\x00"+"\x08b\x00\x01")}, + + {&ignoreField{"before", "ignore", "after"}, + "\x02before\x00\a\x00\x00\x00before\x00\x02after\x00\x06\x00\x00\x00after\x00"}, + + // Marshalling a Raw document does nothing. + {bson.Raw{0x03, []byte(wrapInDoc("anything"))}, + "anything"}, + {bson.Raw{Data: []byte(wrapInDoc("anything"))}, + "anything"}, +} + +func (s *S) TestMarshalOneWayItems(c *C) { + for _, item := range marshalItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data)) + } +} + +// -------------------------------------------------------------------------- +// One-way unmarshaling tests. + +var unmarshalItems = []testItemType{ + // Field is private. Should not attempt to unmarshal it. + {&struct{ priv byte }{}, + "\x10priv\x00\x08\x00\x00\x00"}, + + // Wrong casing. Field names are lowercased. + {&struct{ Byte byte }{}, + "\x10Byte\x00\x08\x00\x00\x00"}, + + // Ignore non-existing field. + {&struct{ Byte byte }{9}, + "\x10boot\x00\x08\x00\x00\x00" + "\x10byte\x00\x09\x00\x00\x00"}, + + // Do not unmarshal on ignored field. + {&ignoreField{"before", "", "after"}, + "\x02before\x00\a\x00\x00\x00before\x00" + + "\x02-\x00\a\x00\x00\x00ignore\x00" + + "\x02after\x00\x06\x00\x00\x00after\x00"}, + + // Ignore unsuitable types silently. + {map[string]string{"str": "s"}, + "\x02str\x00\x02\x00\x00\x00s\x00" + "\x10int\x00\x01\x00\x00\x00"}, + {map[string][]int{"array": []int{5, 9}}, + "\x04array\x00" + wrapInDoc("\x100\x00\x05\x00\x00\x00"+"\x021\x00\x02\x00\x00\x00s\x00"+"\x102\x00\x09\x00\x00\x00")}, + + // Wrong type. Shouldn't init pointer. + {&struct{ Str *byte }{}, + "\x02str\x00\x02\x00\x00\x00s\x00"}, + {&struct{ Str *struct{ Str string } }{}, + "\x02str\x00\x02\x00\x00\x00s\x00"}, + + // Ordered document. + {&struct{ bson.D }{bson.D{{"a", nil}, {"c", nil}, {"b", nil}, {"d", true}}}, + "\x03d\x00" + wrapInDoc("\x0Aa\x00\x0Ac\x00\x0Ab\x00\x08d\x00\x01")}, + + // Raw document. + {&bson.Raw{0x03, []byte(wrapInDoc("\x10byte\x00\x08\x00\x00\x00"))}, + "\x10byte\x00\x08\x00\x00\x00"}, + + // RawD document. + {&struct{ bson.RawD }{bson.RawD{{"a", bson.Raw{0x0A, []byte{}}}, {"c", bson.Raw{0x0A, []byte{}}}, {"b", bson.Raw{0x08, []byte{0x01}}}}}, + "\x03rawd\x00" + wrapInDoc("\x0Aa\x00\x0Ac\x00\x08b\x00\x01")}, + + // Decode old binary. + {bson.M{"_": []byte("old")}, + "\x05_\x00\x07\x00\x00\x00\x02\x03\x00\x00\x00old"}, + + // Decode old binary without length. According to the spec, this shouldn't happen. + {bson.M{"_": []byte("old")}, + "\x05_\x00\x03\x00\x00\x00\x02old"}, +} + +func (s *S) TestUnmarshalOneWayItems(c *C) { + for _, item := range unmarshalItems { + testUnmarshal(c, wrapInDoc(item.data), item.obj) + } +} + +func (s *S) TestUnmarshalNilInStruct(c *C) { + // Nil is the default value, so we need to ensure it's indeed being set. + b := byte(1) + v := &struct{ Ptr *byte }{&b} + err := bson.Unmarshal([]byte(wrapInDoc("\x0Aptr\x00")), v) + c.Assert(err, IsNil) + c.Assert(v, DeepEquals, &struct{ Ptr *byte }{nil}) +} + +// -------------------------------------------------------------------------- +// Marshalling error cases. + +type structWithDupKeys struct { + Name byte + Other byte "name" // Tag should precede. +} + +var marshalErrorItems = []testItemType{ + {bson.M{"": uint64(1 << 63)}, + "BSON has no uint64 type, and value is too large to fit correctly in an int64"}, + {bson.M{"": bson.ObjectId("tooshort")}, + "ObjectIDs must be exactly 12 bytes long \\(got 8\\)"}, + {int64(123), + "Can't marshal int64 as a BSON document"}, + {bson.M{"": 1i}, + "Can't marshal complex128 in a BSON document"}, + {&structWithDupKeys{}, + "Duplicated key 'name' in struct bson_test.structWithDupKeys"}, + {bson.Raw{0x0A, []byte{}}, + "Attempted to unmarshal Raw kind 10 as a document"}, + {&inlineCantPtr{&struct{ A, B int }{1, 2}}, + "Option ,inline needs a struct value or map field"}, + {&inlineDupName{1, struct{ A, B int }{2, 3}}, + "Duplicated key 'a' in struct bson_test.inlineDupName"}, + {&inlineDupMap{}, + "Multiple ,inline maps in struct bson_test.inlineDupMap"}, + {&inlineBadKeyMap{}, + "Option ,inline needs a map with string keys in struct bson_test.inlineBadKeyMap"}, + {&inlineMap{A: 1, M: map[string]interface{}{"a": 1}}, + `Can't have key "a" in inlined map; conflicts with struct field`}, +} + +func (s *S) TestMarshalErrorItems(c *C) { + for _, item := range marshalErrorItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, ErrorMatches, item.data) + c.Assert(data, IsNil) + } +} + +// -------------------------------------------------------------------------- +// Unmarshalling error cases. + +type unmarshalErrorType struct { + obj interface{} + data string + error string +} + +var unmarshalErrorItems = []unmarshalErrorType{ + // Tag name conflicts with existing parameter. + {&structWithDupKeys{}, + "\x10name\x00\x08\x00\x00\x00", + "Duplicated key 'name' in struct bson_test.structWithDupKeys"}, + + // Non-string map key. + {map[int]interface{}{}, + "\x10name\x00\x08\x00\x00\x00", + "BSON map must have string keys. Got: map\\[int\\]interface \\{\\}"}, + + {nil, + "\xEEname\x00", + "Unknown element kind \\(0xEE\\)"}, + + {struct{ Name bool }{}, + "\x10name\x00\x08\x00\x00\x00", + "Unmarshal can't deal with struct values. Use a pointer."}, + + {123, + "\x10name\x00\x08\x00\x00\x00", + "Unmarshal needs a map or a pointer to a struct."}, +} + +func (s *S) TestUnmarshalErrorItems(c *C) { + for _, item := range unmarshalErrorItems { + data := []byte(wrapInDoc(item.data)) + var value interface{} + switch reflect.ValueOf(item.obj).Kind() { + case reflect.Map, reflect.Ptr: + value = makeZeroDoc(item.obj) + case reflect.Invalid: + value = bson.M{} + default: + value = item.obj + } + err := bson.Unmarshal(data, value) + c.Assert(err, ErrorMatches, item.error) + } +} + +type unmarshalRawErrorType struct { + obj interface{} + raw bson.Raw + error string +} + +var unmarshalRawErrorItems = []unmarshalRawErrorType{ + // Tag name conflicts with existing parameter. + {&structWithDupKeys{}, + bson.Raw{0x03, []byte("\x10byte\x00\x08\x00\x00\x00")}, + "Duplicated key 'name' in struct bson_test.structWithDupKeys"}, + + {&struct{}{}, + bson.Raw{0xEE, []byte{}}, + "Unknown element kind \\(0xEE\\)"}, + + {struct{ Name bool }{}, + bson.Raw{0x10, []byte("\x08\x00\x00\x00")}, + "Raw Unmarshal can't deal with struct values. Use a pointer."}, + + {123, + bson.Raw{0x10, []byte("\x08\x00\x00\x00")}, + "Raw Unmarshal needs a map or a valid pointer."}, +} + +func (s *S) TestUnmarshalRawErrorItems(c *C) { + for i, item := range unmarshalRawErrorItems { + err := item.raw.Unmarshal(item.obj) + c.Assert(err, ErrorMatches, item.error, Commentf("Failed on item %d: %#v\n", i, item)) + } +} + +var corruptedData = []string{ + "\x04\x00\x00\x00\x00", // Shorter than minimum + "\x06\x00\x00\x00\x00", // Not enough data + "\x05\x00\x00", // Broken length + "\x05\x00\x00\x00\xff", // Corrupted termination + "\x0A\x00\x00\x00\x0Aooop\x00", // Unfinished C string + + // Array end past end of string (s[2]=0x07 is correct) + wrapInDoc("\x04\x00\x09\x00\x00\x00\x0A\x00\x00"), + + // Array end within string, but past acceptable. + wrapInDoc("\x04\x00\x08\x00\x00\x00\x0A\x00\x00"), + + // Document end within string, but past acceptable. + wrapInDoc("\x03\x00\x08\x00\x00\x00\x0A\x00\x00"), + + // String with corrupted end. + wrapInDoc("\x02\x00\x03\x00\x00\x00yo\xFF"), +} + +func (s *S) TestUnmarshalMapDocumentTooShort(c *C) { + for _, data := range corruptedData { + err := bson.Unmarshal([]byte(data), bson.M{}) + c.Assert(err, ErrorMatches, "Document is corrupted") + + err = bson.Unmarshal([]byte(data), &struct{}{}) + c.Assert(err, ErrorMatches, "Document is corrupted") + } +} + +// -------------------------------------------------------------------------- +// Setter test cases. + +var setterResult = map[string]error{} + +type setterType struct { + received interface{} +} + +func (o *setterType) SetBSON(raw bson.Raw) error { + err := raw.Unmarshal(&o.received) + if err != nil { + panic("The panic:" + err.Error()) + } + if s, ok := o.received.(string); ok { + if result, ok := setterResult[s]; ok { + return result + } + } + return nil +} + +type ptrSetterDoc struct { + Field *setterType "_" +} + +type valSetterDoc struct { + Field setterType "_" +} + +func (s *S) TestUnmarshalAllItemsWithPtrSetter(c *C) { + for _, item := range allItems { + for i := 0; i != 2; i++ { + var field *setterType + if i == 0 { + obj := &ptrSetterDoc{} + err := bson.Unmarshal([]byte(wrapInDoc(item.data)), obj) + c.Assert(err, IsNil) + field = obj.Field + } else { + obj := &valSetterDoc{} + err := bson.Unmarshal([]byte(wrapInDoc(item.data)), obj) + c.Assert(err, IsNil) + field = &obj.Field + } + if item.data == "" { + // Nothing to unmarshal. Should be untouched. + if i == 0 { + c.Assert(field, IsNil) + } else { + c.Assert(field.received, IsNil) + } + } else { + expected := item.obj.(bson.M)["_"] + c.Assert(field, NotNil, Commentf("Pointer not initialized (%#v)", expected)) + c.Assert(field.received, DeepEquals, expected) + } + } + } +} + +func (s *S) TestUnmarshalWholeDocumentWithSetter(c *C) { + obj := &setterType{} + err := bson.Unmarshal([]byte(sampleItems[0].data), obj) + c.Assert(err, IsNil) + c.Assert(obj.received, DeepEquals, bson.M{"hello": "world"}) +} + +func (s *S) TestUnmarshalSetterOmits(c *C) { + setterResult["2"] = &bson.TypeError{} + setterResult["4"] = &bson.TypeError{} + defer func() { + delete(setterResult, "2") + delete(setterResult, "4") + }() + + m := map[string]*setterType{} + data := wrapInDoc("\x02abc\x00\x02\x00\x00\x001\x00" + + "\x02def\x00\x02\x00\x00\x002\x00" + + "\x02ghi\x00\x02\x00\x00\x003\x00" + + "\x02jkl\x00\x02\x00\x00\x004\x00") + err := bson.Unmarshal([]byte(data), m) + c.Assert(err, IsNil) + c.Assert(m["abc"], NotNil) + c.Assert(m["def"], IsNil) + c.Assert(m["ghi"], NotNil) + c.Assert(m["jkl"], IsNil) + + c.Assert(m["abc"].received, Equals, "1") + c.Assert(m["ghi"].received, Equals, "3") +} + +func (s *S) TestUnmarshalSetterErrors(c *C) { + boom := errors.New("BOOM") + setterResult["2"] = boom + defer delete(setterResult, "2") + + m := map[string]*setterType{} + data := wrapInDoc("\x02abc\x00\x02\x00\x00\x001\x00" + + "\x02def\x00\x02\x00\x00\x002\x00" + + "\x02ghi\x00\x02\x00\x00\x003\x00") + err := bson.Unmarshal([]byte(data), m) + c.Assert(err, Equals, boom) + c.Assert(m["abc"], NotNil) + c.Assert(m["def"], IsNil) + c.Assert(m["ghi"], IsNil) + + c.Assert(m["abc"].received, Equals, "1") +} + +func (s *S) TestDMap(c *C) { + d := bson.D{{"a", 1}, {"b", 2}} + c.Assert(d.Map(), DeepEquals, bson.M{"a": 1, "b": 2}) +} + +func (s *S) TestUnmarshalSetterSetZero(c *C) { + setterResult["foo"] = bson.SetZero + defer delete(setterResult, "field") + + data, err := bson.Marshal(bson.M{"field": "foo"}) + c.Assert(err, IsNil) + + m := map[string]*setterType{} + err = bson.Unmarshal([]byte(data), m) + c.Assert(err, IsNil) + + value, ok := m["field"] + c.Assert(ok, Equals, true) + c.Assert(value, IsNil) +} + +// -------------------------------------------------------------------------- +// Getter test cases. + +type typeWithGetter struct { + result interface{} + err error +} + +func (t *typeWithGetter) GetBSON() (interface{}, error) { + if t == nil { + return "", nil + } + return t.result, t.err +} + +type docWithGetterField struct { + Field *typeWithGetter "_" +} + +func (s *S) TestMarshalAllItemsWithGetter(c *C) { + for i, item := range allItems { + if item.data == "" { + continue + } + obj := &docWithGetterField{} + obj.Field = &typeWithGetter{result: item.obj.(bson.M)["_"]} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), + Commentf("Failed on item #%d", i)) + } +} + +func (s *S) TestMarshalWholeDocumentWithGetter(c *C) { + obj := &typeWithGetter{result: sampleItems[0].obj} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, sampleItems[0].data) +} + +func (s *S) TestGetterErrors(c *C) { + e := errors.New("oops") + + obj1 := &docWithGetterField{} + obj1.Field = &typeWithGetter{sampleItems[0].obj, e} + data, err := bson.Marshal(obj1) + c.Assert(err, ErrorMatches, "oops") + c.Assert(data, IsNil) + + obj2 := &typeWithGetter{sampleItems[0].obj, e} + data, err = bson.Marshal(obj2) + c.Assert(err, ErrorMatches, "oops") + c.Assert(data, IsNil) +} + +type intGetter int64 + +func (t intGetter) GetBSON() (interface{}, error) { + return int64(t), nil +} + +type typeWithIntGetter struct { + V intGetter ",minsize" +} + +func (s *S) TestMarshalShortWithGetter(c *C) { + obj := typeWithIntGetter{42} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + m := bson.M{} + err = bson.Unmarshal(data, m) + c.Assert(err, IsNil) + c.Assert(m["v"], Equals, 42) +} + +func (s *S) TestMarshalWithGetterNil(c *C) { + obj := docWithGetterField{} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + m := bson.M{} + err = bson.Unmarshal(data, m) + c.Assert(err, IsNil) + c.Assert(m, DeepEquals, bson.M{"_": ""}) +} + +// -------------------------------------------------------------------------- +// Cross-type conversion tests. + +type crossTypeItem struct { + obj1 interface{} + obj2 interface{} +} + +type condStr struct { + V string ",omitempty" +} +type condStrNS struct { + V string `a:"A" bson:",omitempty" b:"B"` +} +type condBool struct { + V bool ",omitempty" +} +type condInt struct { + V int ",omitempty" +} +type condUInt struct { + V uint ",omitempty" +} +type condFloat struct { + V float64 ",omitempty" +} +type condIface struct { + V interface{} ",omitempty" +} +type condPtr struct { + V *bool ",omitempty" +} +type condSlice struct { + V []string ",omitempty" +} +type condMap struct { + V map[string]int ",omitempty" +} +type namedCondStr struct { + V string "myv,omitempty" +} +type condTime struct { + V time.Time ",omitempty" +} +type condStruct struct { + V struct{ A []int } ",omitempty" +} + +type shortInt struct { + V int64 ",minsize" +} +type shortUint struct { + V uint64 ",minsize" +} +type shortIface struct { + V interface{} ",minsize" +} +type shortPtr struct { + V *int64 ",minsize" +} +type shortNonEmptyInt struct { + V int64 ",minsize,omitempty" +} + +type inlineInt struct { + V struct{ A, B int } ",inline" +} +type inlineCantPtr struct { + V *struct{ A, B int } ",inline" +} +type inlineDupName struct { + A int + V struct{ A, B int } ",inline" +} +type inlineMap struct { + A int + M map[string]interface{} ",inline" +} +type inlineMapInt struct { + A int + M map[string]int ",inline" +} +type inlineMapMyM struct { + A int + M MyM ",inline" +} +type inlineDupMap struct { + M1 map[string]interface{} ",inline" + M2 map[string]interface{} ",inline" +} +type inlineBadKeyMap struct { + M map[int]int ",inline" +} + +type getterSetterD bson.D + +func (s getterSetterD) GetBSON() (interface{}, error) { + if len(s) == 0 { + return bson.D{}, nil + } + return bson.D(s[:len(s)-1]), nil +} + +func (s *getterSetterD) SetBSON(raw bson.Raw) error { + var doc bson.D + err := raw.Unmarshal(&doc) + doc = append(doc, bson.DocElem{"suffix", true}) + *s = getterSetterD(doc) + return err +} + +type getterSetterInt int + +func (i getterSetterInt) GetBSON() (interface{}, error) { + return bson.D{{"a", int(i)}}, nil +} + +func (i *getterSetterInt) SetBSON(raw bson.Raw) error { + var doc struct{ A int } + err := raw.Unmarshal(&doc) + *i = getterSetterInt(doc.A) + return err +} + +type ( + MyString string + MyBytes []byte + MyBool bool + MyD []bson.DocElem + MyRawD []bson.RawDocElem + MyM map[string]interface{} +) + +var ( + truevar = true + falsevar = false + + int64var = int64(42) + int64ptr = &int64var + intvar = int(42) + intptr = &intvar + + gsintvar = getterSetterInt(42) +) + +func parseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return u +} + +// That's a pretty fun test. It will dump the first item, generate a zero +// value equivalent to the second one, load the dumped data onto it, and then +// verify that the resulting value is deep-equal to the untouched second value. +// Then, it will do the same in the *opposite* direction! +var twoWayCrossItems = []crossTypeItem{ + // int<=>int + {&struct{ I int }{42}, &struct{ I int8 }{42}}, + {&struct{ I int }{42}, &struct{ I int32 }{42}}, + {&struct{ I int }{42}, &struct{ I int64 }{42}}, + {&struct{ I int8 }{42}, &struct{ I int32 }{42}}, + {&struct{ I int8 }{42}, &struct{ I int64 }{42}}, + {&struct{ I int32 }{42}, &struct{ I int64 }{42}}, + + // uint<=>uint + {&struct{ I uint }{42}, &struct{ I uint8 }{42}}, + {&struct{ I uint }{42}, &struct{ I uint32 }{42}}, + {&struct{ I uint }{42}, &struct{ I uint64 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I uint32 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I uint64 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I uint64 }{42}}, + + // float32<=>float64 + {&struct{ I float32 }{42}, &struct{ I float64 }{42}}, + + // int<=>uint + {&struct{ I uint }{42}, &struct{ I int }{42}}, + {&struct{ I uint }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint }{42}, &struct{ I int64 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int64 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int64 }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int64 }{42}}, + + // int <=> float + {&struct{ I int }{42}, &struct{ I float64 }{42}}, + + // int <=> bool + {&struct{ I int }{1}, &struct{ I bool }{true}}, + {&struct{ I int }{0}, &struct{ I bool }{false}}, + + // uint <=> float64 + {&struct{ I uint }{42}, &struct{ I float64 }{42}}, + + // uint <=> bool + {&struct{ I uint }{1}, &struct{ I bool }{true}}, + {&struct{ I uint }{0}, &struct{ I bool }{false}}, + + // float64 <=> bool + {&struct{ I float64 }{1}, &struct{ I bool }{true}}, + {&struct{ I float64 }{0}, &struct{ I bool }{false}}, + + // string <=> string and string <=> []byte + {&struct{ S []byte }{[]byte("abc")}, &struct{ S string }{"abc"}}, + {&struct{ S []byte }{[]byte("def")}, &struct{ S bson.Symbol }{"def"}}, + {&struct{ S string }{"ghi"}, &struct{ S bson.Symbol }{"ghi"}}, + + // map <=> struct + {&struct { + A struct { + B, C int + } + }{struct{ B, C int }{1, 2}}, + map[string]map[string]int{"a": map[string]int{"b": 1, "c": 2}}}, + + {&struct{ A bson.Symbol }{"abc"}, map[string]string{"a": "abc"}}, + {&struct{ A bson.Symbol }{"abc"}, map[string][]byte{"a": []byte("abc")}}, + {&struct{ A []byte }{[]byte("abc")}, map[string]string{"a": "abc"}}, + {&struct{ A uint }{42}, map[string]int{"a": 42}}, + {&struct{ A uint }{42}, map[string]float64{"a": 42}}, + {&struct{ A uint }{1}, map[string]bool{"a": true}}, + {&struct{ A int }{42}, map[string]uint{"a": 42}}, + {&struct{ A int }{42}, map[string]float64{"a": 42}}, + {&struct{ A int }{1}, map[string]bool{"a": true}}, + {&struct{ A float64 }{42}, map[string]float32{"a": 42}}, + {&struct{ A float64 }{42}, map[string]int{"a": 42}}, + {&struct{ A float64 }{42}, map[string]uint{"a": 42}}, + {&struct{ A float64 }{1}, map[string]bool{"a": true}}, + {&struct{ A bool }{true}, map[string]int{"a": 1}}, + {&struct{ A bool }{true}, map[string]uint{"a": 1}}, + {&struct{ A bool }{true}, map[string]float64{"a": 1}}, + {&struct{ A **byte }{&byteptr}, map[string]byte{"a": 8}}, + + // url.URL <=> string + {&struct{ URL *url.URL }{parseURL("h://e.c/p")}, map[string]string{"url": "h://e.c/p"}}, + {&struct{ URL url.URL }{*parseURL("h://e.c/p")}, map[string]string{"url": "h://e.c/p"}}, + + // Slices + {&struct{ S []int }{[]int{1, 2, 3}}, map[string][]int{"s": []int{1, 2, 3}}}, + {&struct{ S *[]int }{&[]int{1, 2, 3}}, map[string][]int{"s": []int{1, 2, 3}}}, + + // Conditionals + {&condBool{true}, map[string]bool{"v": true}}, + {&condBool{}, map[string]bool{}}, + {&condInt{1}, map[string]int{"v": 1}}, + {&condInt{}, map[string]int{}}, + {&condUInt{1}, map[string]uint{"v": 1}}, + {&condUInt{}, map[string]uint{}}, + {&condFloat{}, map[string]int{}}, + {&condStr{"yo"}, map[string]string{"v": "yo"}}, + {&condStr{}, map[string]string{}}, + {&condStrNS{"yo"}, map[string]string{"v": "yo"}}, + {&condStrNS{}, map[string]string{}}, + {&condSlice{[]string{"yo"}}, map[string][]string{"v": []string{"yo"}}}, + {&condSlice{}, map[string][]string{}}, + {&condMap{map[string]int{"k": 1}}, bson.M{"v": bson.M{"k": 1}}}, + {&condMap{}, map[string][]string{}}, + {&condIface{"yo"}, map[string]string{"v": "yo"}}, + {&condIface{""}, map[string]string{"v": ""}}, + {&condIface{}, map[string]string{}}, + {&condPtr{&truevar}, map[string]bool{"v": true}}, + {&condPtr{&falsevar}, map[string]bool{"v": false}}, + {&condPtr{}, map[string]string{}}, + + {&condTime{time.Unix(123456789, 123e6)}, map[string]time.Time{"v": time.Unix(123456789, 123e6)}}, + {&condTime{}, map[string]string{}}, + + {&condStruct{struct{ A []int }{[]int{1}}}, bson.M{"v": bson.M{"a": []interface{}{1}}}}, + {&condStruct{struct{ A []int }{}}, bson.M{}}, + + {&namedCondStr{"yo"}, map[string]string{"myv": "yo"}}, + {&namedCondStr{}, map[string]string{}}, + + {&shortInt{1}, map[string]interface{}{"v": 1}}, + {&shortInt{1 << 30}, map[string]interface{}{"v": 1 << 30}}, + {&shortInt{1 << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortUint{1 << 30}, map[string]interface{}{"v": 1 << 30}}, + {&shortUint{1 << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortIface{int64(1) << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortPtr{int64ptr}, map[string]interface{}{"v": intvar}}, + + {&shortNonEmptyInt{1}, map[string]interface{}{"v": 1}}, + {&shortNonEmptyInt{1 << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortNonEmptyInt{}, map[string]interface{}{}}, + + {&inlineInt{struct{ A, B int }{1, 2}}, map[string]interface{}{"a": 1, "b": 2}}, + {&inlineMap{A: 1, M: map[string]interface{}{"b": 2}}, map[string]interface{}{"a": 1, "b": 2}}, + {&inlineMap{A: 1, M: nil}, map[string]interface{}{"a": 1}}, + {&inlineMapInt{A: 1, M: map[string]int{"b": 2}}, map[string]int{"a": 1, "b": 2}}, + {&inlineMapInt{A: 1, M: nil}, map[string]int{"a": 1}}, + {&inlineMapMyM{A: 1, M: MyM{"b": MyM{"c": 3}}}, map[string]interface{}{"a": 1, "b": map[string]interface{}{"c": 3}}}, + + // []byte <=> MyBytes + {&struct{ B MyBytes }{[]byte("abc")}, map[string]string{"b": "abc"}}, + {&struct{ B MyBytes }{[]byte{}}, map[string]string{"b": ""}}, + {&struct{ B MyBytes }{}, map[string]bool{}}, + {&struct{ B []byte }{[]byte("abc")}, map[string]MyBytes{"b": []byte("abc")}}, + + // bool <=> MyBool + {&struct{ B MyBool }{true}, map[string]bool{"b": true}}, + {&struct{ B MyBool }{}, map[string]bool{"b": false}}, + {&struct{ B MyBool }{}, map[string]string{}}, + {&struct{ B bool }{}, map[string]MyBool{"b": false}}, + + // arrays + {&struct{ V [2]int }{[...]int{1, 2}}, map[string][2]int{"v": [2]int{1, 2}}}, + + // zero time + {&struct{ V time.Time }{}, map[string]interface{}{"v": time.Time{}}}, + + // zero time + 1 second + 1 millisecond; overflows int64 as nanoseconds + {&struct{ V time.Time }{time.Unix(-62135596799, 1e6).Local()}, + map[string]interface{}{"v": time.Unix(-62135596799, 1e6).Local()}}, + + // bson.D <=> []DocElem + {&bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}, &bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}}, + {&bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}, &MyD{{"a", MyD{{"b", 1}, {"c", 2}}}}}, + {&struct{ V MyD }{MyD{{"a", 1}}}, &bson.D{{"v", bson.D{{"a", 1}}}}}, + + // bson.RawD <=> []RawDocElem + {&bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}, &bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}}, + {&bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}, &MyRawD{{"a", bson.Raw{0x08, []byte{0x01}}}}}, + + // bson.M <=> map + {bson.M{"a": bson.M{"b": 1, "c": 2}}, MyM{"a": MyM{"b": 1, "c": 2}}}, + {bson.M{"a": bson.M{"b": 1, "c": 2}}, map[string]interface{}{"a": map[string]interface{}{"b": 1, "c": 2}}}, + + // bson.M <=> map[MyString] + {bson.M{"a": bson.M{"b": 1, "c": 2}}, map[MyString]interface{}{"a": map[MyString]interface{}{"b": 1, "c": 2}}}, + + // json.Number <=> int64, float64 + {&struct{ N json.Number }{"5"}, map[string]interface{}{"n": int64(5)}}, + {&struct{ N json.Number }{"5.05"}, map[string]interface{}{"n": 5.05}}, + {&struct{ N json.Number }{"9223372036854776000"}, map[string]interface{}{"n": float64(1 << 63)}}, + + // bson.D <=> non-struct getter/setter + {&bson.D{{"a", 1}}, &getterSetterD{{"a", 1}, {"suffix", true}}}, + {&bson.D{{"a", 42}}, &gsintvar}, +} + +// Same thing, but only one way (obj1 => obj2). +var oneWayCrossItems = []crossTypeItem{ + // map <=> struct + {map[string]interface{}{"a": 1, "b": "2", "c": 3}, map[string]int{"a": 1, "c": 3}}, + + // inline map elides badly typed values + {map[string]interface{}{"a": 1, "b": "2", "c": 3}, &inlineMapInt{A: 1, M: map[string]int{"c": 3}}}, + + // Can't decode int into struct. + {bson.M{"a": bson.M{"b": 2}}, &struct{ A bool }{}}, + + // Would get decoded into a int32 too in the opposite direction. + {&shortIface{int64(1) << 30}, map[string]interface{}{"v": 1 << 30}}, +} + +func testCrossPair(c *C, dump interface{}, load interface{}) { + c.Logf("Dump: %#v", dump) + c.Logf("Load: %#v", load) + zero := makeZeroDoc(load) + data, err := bson.Marshal(dump) + c.Assert(err, IsNil) + c.Logf("Dumped: %#v", string(data)) + err = bson.Unmarshal(data, zero) + c.Assert(err, IsNil) + c.Logf("Loaded: %#v", zero) + c.Assert(zero, DeepEquals, load) +} + +func (s *S) TestTwoWayCrossPairs(c *C) { + for _, item := range twoWayCrossItems { + testCrossPair(c, item.obj1, item.obj2) + testCrossPair(c, item.obj2, item.obj1) + } +} + +func (s *S) TestOneWayCrossPairs(c *C) { + for _, item := range oneWayCrossItems { + testCrossPair(c, item.obj1, item.obj2) + } +} + +// -------------------------------------------------------------------------- +// ObjectId hex representation test. + +func (s *S) TestObjectIdHex(c *C) { + id := bson.ObjectIdHex("4d88e15b60f486e428412dc9") + c.Assert(id.String(), Equals, `ObjectIdHex("4d88e15b60f486e428412dc9")`) + c.Assert(id.Hex(), Equals, "4d88e15b60f486e428412dc9") +} + +func (s *S) TestIsObjectIdHex(c *C) { + test := []struct { + id string + valid bool + }{ + {"4d88e15b60f486e428412dc9", true}, + {"4d88e15b60f486e428412dc", false}, + {"4d88e15b60f486e428412dc9e", false}, + {"4d88e15b60f486e428412dcx", false}, + } + for _, t := range test { + c.Assert(bson.IsObjectIdHex(t.id), Equals, t.valid) + } +} + +// -------------------------------------------------------------------------- +// ObjectId parts extraction tests. + +type objectIdParts struct { + id bson.ObjectId + timestamp int64 + machine []byte + pid uint16 + counter int32 +} + +var objectIds = []objectIdParts{ + objectIdParts{ + bson.ObjectIdHex("4d88e15b60f486e428412dc9"), + 1300816219, + []byte{0x60, 0xf4, 0x86}, + 0xe428, + 4271561, + }, + objectIdParts{ + bson.ObjectIdHex("000000000000000000000000"), + 0, + []byte{0x00, 0x00, 0x00}, + 0x0000, + 0, + }, + objectIdParts{ + bson.ObjectIdHex("00000000aabbccddee000001"), + 0, + []byte{0xaa, 0xbb, 0xcc}, + 0xddee, + 1, + }, +} + +func (s *S) TestObjectIdPartsExtraction(c *C) { + for i, v := range objectIds { + t := time.Unix(v.timestamp, 0) + c.Assert(v.id.Time(), Equals, t, Commentf("#%d Wrong timestamp value", i)) + c.Assert(v.id.Machine(), DeepEquals, v.machine, Commentf("#%d Wrong machine id value", i)) + c.Assert(v.id.Pid(), Equals, v.pid, Commentf("#%d Wrong pid value", i)) + c.Assert(v.id.Counter(), Equals, v.counter, Commentf("#%d Wrong counter value", i)) + } +} + +func (s *S) TestNow(c *C) { + before := time.Now() + time.Sleep(1e6) + now := bson.Now() + time.Sleep(1e6) + after := time.Now() + c.Assert(now.After(before) && now.Before(after), Equals, true, Commentf("now=%s, before=%s, after=%s", now, before, after)) +} + +// -------------------------------------------------------------------------- +// ObjectId generation tests. + +func (s *S) TestNewObjectId(c *C) { + // Generate 10 ids + ids := make([]bson.ObjectId, 10) + for i := 0; i < 10; i++ { + ids[i] = bson.NewObjectId() + } + for i := 1; i < 10; i++ { + prevId := ids[i-1] + id := ids[i] + // Test for uniqueness among all other 9 generated ids + for j, tid := range ids { + if j != i { + c.Assert(id, Not(Equals), tid, Commentf("Generated ObjectId is not unique")) + } + } + // Check that timestamp was incremented and is within 30 seconds of the previous one + secs := id.Time().Sub(prevId.Time()).Seconds() + c.Assert((secs >= 0 && secs <= 30), Equals, true, Commentf("Wrong timestamp in generated ObjectId")) + // Check that machine ids are the same + c.Assert(id.Machine(), DeepEquals, prevId.Machine()) + // Check that pids are the same + c.Assert(id.Pid(), Equals, prevId.Pid()) + // Test for proper increment + delta := int(id.Counter() - prevId.Counter()) + c.Assert(delta, Equals, 1, Commentf("Wrong increment in generated ObjectId")) + } +} + +func (s *S) TestNewObjectIdWithTime(c *C) { + t := time.Unix(12345678, 0) + id := bson.NewObjectIdWithTime(t) + c.Assert(id.Time(), Equals, t) + c.Assert(id.Machine(), DeepEquals, []byte{0x00, 0x00, 0x00}) + c.Assert(int(id.Pid()), Equals, 0) + c.Assert(int(id.Counter()), Equals, 0) +} + +// -------------------------------------------------------------------------- +// ObjectId JSON marshalling. + +type jsonType struct { + Id *bson.ObjectId +} + +func (s *S) TestObjectIdJSONMarshaling(c *C) { + id := bson.ObjectIdHex("4d88e15b60f486e428412dc9") + v := jsonType{Id: &id} + data, err := json.Marshal(&v) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, `{"Id":"4d88e15b60f486e428412dc9"}`) +} + +func (s *S) TestObjectIdJSONUnmarshaling(c *C) { + data := []byte(`{"Id":"4d88e15b60f486e428412dc9"}`) + v := jsonType{} + err := json.Unmarshal(data, &v) + c.Assert(err, IsNil) + c.Assert(*v.Id, Equals, bson.ObjectIdHex("4d88e15b60f486e428412dc9")) +} + +func (s *S) TestObjectIdJSONUnmarshalingError(c *C) { + v := jsonType{} + err := json.Unmarshal([]byte(`{"Id":"4d88e15b60f486e428412dc9A"}`), &v) + c.Assert(err, ErrorMatches, `Invalid ObjectId in JSON: "4d88e15b60f486e428412dc9A"`) + err = json.Unmarshal([]byte(`{"Id":"4d88e15b60f486e428412dcZ"}`), &v) + c.Assert(err, ErrorMatches, `Invalid ObjectId in JSON: "4d88e15b60f486e428412dcZ" .*`) +} + +// -------------------------------------------------------------------------- +// Some simple benchmarks. + +type BenchT struct { + A, B, C, D, E, F string +} + +type BenchRawT struct { + A string + B int + C bson.M + D []float64 +} + +func (s *S) BenchmarkUnmarhsalStruct(c *C) { + v := BenchT{A: "A", D: "D", E: "E"} + data, err := bson.Marshal(&v) + if err != nil { + panic(err) + } + c.ResetTimer() + for i := 0; i < c.N; i++ { + err = bson.Unmarshal(data, &v) + } + if err != nil { + panic(err) + } +} + +func (s *S) BenchmarkUnmarhsalMap(c *C) { + m := bson.M{"a": "a", "d": "d", "e": "e"} + data, err := bson.Marshal(&m) + if err != nil { + panic(err) + } + c.ResetTimer() + for i := 0; i < c.N; i++ { + err = bson.Unmarshal(data, &m) + } + if err != nil { + panic(err) + } +} + +func (s *S) BenchmarkUnmarshalRaw(c *C) { + var err error + m := BenchRawT{ + A: "test_string", + B: 123, + C: bson.M{ + "subdoc_int": 12312, + "subdoc_doc": bson.M{"1": 1}, + }, + D: []float64{0.0, 1.3333, -99.9997, 3.1415}, + } + data, err := bson.Marshal(&m) + if err != nil { + panic(err) + } + raw := bson.Raw{} + c.ResetTimer() + for i := 0; i < c.N; i++ { + err = bson.Unmarshal(data, &raw) + } + if err != nil { + panic(err) + } +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/decode.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/decode.go new file mode 100644 index 0000000..782e933 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/decode.go @@ -0,0 +1,820 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. + +package bson + +import ( + "fmt" + "math" + "net/url" + "reflect" + "strconv" + "sync" + "time" +) + +type decoder struct { + in []byte + i int + docType reflect.Type +} + +var typeM = reflect.TypeOf(M{}) + +func newDecoder(in []byte) *decoder { + return &decoder{in, 0, typeM} +} + +// -------------------------------------------------------------------------- +// Some helper functions. + +func corrupted() { + panic("Document is corrupted") +} + +func settableValueOf(i interface{}) reflect.Value { + v := reflect.ValueOf(i) + sv := reflect.New(v.Type()).Elem() + sv.Set(v) + return sv +} + +// -------------------------------------------------------------------------- +// Unmarshaling of documents. + +const ( + setterUnknown = iota + setterNone + setterType + setterAddr +) + +var setterStyles map[reflect.Type]int +var setterIface reflect.Type +var setterMutex sync.RWMutex + +func init() { + var iface Setter + setterIface = reflect.TypeOf(&iface).Elem() + setterStyles = make(map[reflect.Type]int) +} + +func setterStyle(outt reflect.Type) int { + setterMutex.RLock() + style := setterStyles[outt] + setterMutex.RUnlock() + if style == setterUnknown { + setterMutex.Lock() + defer setterMutex.Unlock() + if outt.Implements(setterIface) { + setterStyles[outt] = setterType + } else if reflect.PtrTo(outt).Implements(setterIface) { + setterStyles[outt] = setterAddr + } else { + setterStyles[outt] = setterNone + } + style = setterStyles[outt] + } + return style +} + +func getSetter(outt reflect.Type, out reflect.Value) Setter { + style := setterStyle(outt) + if style == setterNone { + return nil + } + if style == setterAddr { + if !out.CanAddr() { + return nil + } + out = out.Addr() + } else if outt.Kind() == reflect.Ptr && out.IsNil() { + out.Set(reflect.New(outt.Elem())) + } + return out.Interface().(Setter) +} + +func clearMap(m reflect.Value) { + var none reflect.Value + for _, k := range m.MapKeys() { + m.SetMapIndex(k, none) + } +} + +func (d *decoder) readDocTo(out reflect.Value) { + var elemType reflect.Type + outt := out.Type() + outk := outt.Kind() + + for { + if outk == reflect.Ptr && out.IsNil() { + out.Set(reflect.New(outt.Elem())) + } + if setter := getSetter(outt, out); setter != nil { + var raw Raw + d.readDocTo(reflect.ValueOf(&raw)) + err := setter.SetBSON(raw) + if _, ok := err.(*TypeError); err != nil && !ok { + panic(err) + } + return + } + if outk == reflect.Ptr { + out = out.Elem() + outt = out.Type() + outk = out.Kind() + continue + } + break + } + + var fieldsMap map[string]fieldInfo + var inlineMap reflect.Value + start := d.i + + origout := out + if outk == reflect.Interface { + if d.docType.Kind() == reflect.Map { + mv := reflect.MakeMap(d.docType) + out.Set(mv) + out = mv + } else { + dv := reflect.New(d.docType).Elem() + out.Set(dv) + out = dv + } + outt = out.Type() + outk = outt.Kind() + } + + docType := d.docType + keyType := typeString + convertKey := false + switch outk { + case reflect.Map: + keyType = outt.Key() + if keyType.Kind() != reflect.String { + panic("BSON map must have string keys. Got: " + outt.String()) + } + if keyType != typeString { + convertKey = true + } + elemType = outt.Elem() + if elemType == typeIface { + d.docType = outt + } + if out.IsNil() { + out.Set(reflect.MakeMap(out.Type())) + } else if out.Len() > 0 { + clearMap(out) + } + case reflect.Struct: + if outt != typeRaw { + sinfo, err := getStructInfo(out.Type()) + if err != nil { + panic(err) + } + fieldsMap = sinfo.FieldsMap + out.Set(sinfo.Zero) + if sinfo.InlineMap != -1 { + inlineMap = out.Field(sinfo.InlineMap) + if !inlineMap.IsNil() && inlineMap.Len() > 0 { + clearMap(inlineMap) + } + elemType = inlineMap.Type().Elem() + if elemType == typeIface { + d.docType = inlineMap.Type() + } + } + } + case reflect.Slice: + switch outt.Elem() { + case typeDocElem: + origout.Set(d.readDocElems(outt)) + return + case typeRawDocElem: + origout.Set(d.readRawDocElems(outt)) + return + } + fallthrough + default: + panic("Unsupported document type for unmarshalling: " + out.Type().String()) + } + + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + name := d.readCStr() + if d.i >= end { + corrupted() + } + + switch outk { + case reflect.Map: + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + k := reflect.ValueOf(name) + if convertKey { + k = k.Convert(keyType) + } + out.SetMapIndex(k, e) + } + case reflect.Struct: + if outt == typeRaw { + d.dropElem(kind) + } else { + if info, ok := fieldsMap[name]; ok { + if info.Inline == nil { + d.readElemTo(out.Field(info.Num), kind) + } else { + d.readElemTo(out.FieldByIndex(info.Inline), kind) + } + } else if inlineMap.IsValid() { + if inlineMap.IsNil() { + inlineMap.Set(reflect.MakeMap(inlineMap.Type())) + } + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + inlineMap.SetMapIndex(reflect.ValueOf(name), e) + } + } else { + d.dropElem(kind) + } + } + case reflect.Slice: + } + + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } + d.docType = docType + + if outt == typeRaw { + out.Set(reflect.ValueOf(Raw{0x03, d.in[start:d.i]})) + } +} + +func (d *decoder) readArrayDocTo(out reflect.Value) { + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + i := 0 + l := out.Len() + for d.in[d.i] != '\x00' { + if i >= l { + panic("Length mismatch on array field") + } + kind := d.readByte() + for d.i < end && d.in[d.i] != '\x00' { + d.i++ + } + if d.i >= end { + corrupted() + } + d.i++ + d.readElemTo(out.Index(i), kind) + if d.i >= end { + corrupted() + } + i++ + } + if i != l { + panic("Length mismatch on array field") + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } +} + +func (d *decoder) readSliceDoc(t reflect.Type) interface{} { + tmp := make([]reflect.Value, 0, 8) + elemType := t.Elem() + + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + for d.i < end && d.in[d.i] != '\x00' { + d.i++ + } + if d.i >= end { + corrupted() + } + d.i++ + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + tmp = append(tmp, e) + } + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } + + n := len(tmp) + slice := reflect.MakeSlice(t, n, n) + for i := 0; i != n; i++ { + slice.Index(i).Set(tmp[i]) + } + return slice.Interface() +} + +var typeSlice = reflect.TypeOf([]interface{}{}) +var typeIface = typeSlice.Elem() + +func (d *decoder) readDocElems(typ reflect.Type) reflect.Value { + docType := d.docType + d.docType = typ + slice := make([]DocElem, 0, 8) + d.readDocWith(func(kind byte, name string) { + e := DocElem{Name: name} + v := reflect.ValueOf(&e.Value) + if d.readElemTo(v.Elem(), kind) { + slice = append(slice, e) + } + }) + slicev := reflect.New(typ).Elem() + slicev.Set(reflect.ValueOf(slice)) + d.docType = docType + return slicev +} + +func (d *decoder) readRawDocElems(typ reflect.Type) reflect.Value { + docType := d.docType + d.docType = typ + slice := make([]RawDocElem, 0, 8) + d.readDocWith(func(kind byte, name string) { + e := RawDocElem{Name: name} + v := reflect.ValueOf(&e.Value) + if d.readElemTo(v.Elem(), kind) { + slice = append(slice, e) + } + }) + slicev := reflect.New(typ).Elem() + slicev.Set(reflect.ValueOf(slice)) + d.docType = docType + return slicev +} + +func (d *decoder) readDocWith(f func(kind byte, name string)) { + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + name := d.readCStr() + if d.i >= end { + corrupted() + } + f(kind, name) + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } +} + +// -------------------------------------------------------------------------- +// Unmarshaling of individual elements within a document. + +var blackHole = settableValueOf(struct{}{}) + +func (d *decoder) dropElem(kind byte) { + d.readElemTo(blackHole, kind) +} + +// Attempt to decode an element from the document and put it into out. +// If the types are not compatible, the returned ok value will be +// false and out will be unchanged. +func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { + + start := d.i + + if kind == '\x03' { + // Delegate unmarshaling of documents. + outt := out.Type() + outk := out.Kind() + switch outk { + case reflect.Interface, reflect.Ptr, reflect.Struct, reflect.Map: + d.readDocTo(out) + return true + } + if setterStyle(outt) != setterNone { + d.readDocTo(out) + return true + } + if outk == reflect.Slice { + switch outt.Elem() { + case typeDocElem: + out.Set(d.readDocElems(outt)) + case typeRawDocElem: + out.Set(d.readRawDocElems(outt)) + } + return true + } + d.readDocTo(blackHole) + return true + } + + var in interface{} + + switch kind { + case 0x01: // Float64 + in = d.readFloat64() + case 0x02: // UTF-8 string + in = d.readStr() + case 0x03: // Document + panic("Can't happen. Handled above.") + case 0x04: // Array + outt := out.Type() + for outt.Kind() == reflect.Ptr { + outt = outt.Elem() + } + switch outt.Kind() { + case reflect.Array: + d.readArrayDocTo(out) + return true + case reflect.Slice: + in = d.readSliceDoc(outt) + default: + in = d.readSliceDoc(typeSlice) + } + case 0x05: // Binary + b := d.readBinary() + if b.Kind == 0x00 || b.Kind == 0x02 { + in = b.Data + } else { + in = b + } + case 0x06: // Undefined (obsolete, but still seen in the wild) + in = Undefined + case 0x07: // ObjectId + in = ObjectId(d.readBytes(12)) + case 0x08: // Bool + in = d.readBool() + case 0x09: // Timestamp + // MongoDB handles timestamps as milliseconds. + i := d.readInt64() + if i == -62135596800000 { + in = time.Time{} // In UTC for convenience. + } else { + in = time.Unix(i/1e3, i%1e3*1e6) + } + case 0x0A: // Nil + in = nil + case 0x0B: // RegEx + in = d.readRegEx() + case 0x0C: + in = DBPointer{Namespace: d.readStr(), Id: ObjectId(d.readBytes(12))} + case 0x0D: // JavaScript without scope + in = JavaScript{Code: d.readStr()} + case 0x0E: // Symbol + in = Symbol(d.readStr()) + case 0x0F: // JavaScript with scope + d.i += 4 // Skip length + js := JavaScript{d.readStr(), make(M)} + d.readDocTo(reflect.ValueOf(js.Scope)) + in = js + case 0x10: // Int32 + in = int(d.readInt32()) + case 0x11: // Mongo-specific timestamp + in = MongoTimestamp(d.readInt64()) + case 0x12: // Int64 + in = d.readInt64() + case 0x7F: // Max key + in = MaxKey + case 0xFF: // Min key + in = MinKey + default: + panic(fmt.Sprintf("Unknown element kind (0x%02X)", kind)) + } + + outt := out.Type() + + if outt == typeRaw { + out.Set(reflect.ValueOf(Raw{kind, d.in[start:d.i]})) + return true + } + + if setter := getSetter(outt, out); setter != nil { + err := setter.SetBSON(Raw{kind, d.in[start:d.i]}) + if err == SetZero { + out.Set(reflect.Zero(outt)) + return true + } + if err == nil { + return true + } + if _, ok := err.(*TypeError); !ok { + panic(err) + } + return false + } + + if in == nil { + out.Set(reflect.Zero(outt)) + return true + } + + outk := outt.Kind() + + // Dereference and initialize pointer if necessary. + first := true + for outk == reflect.Ptr { + if !out.IsNil() { + out = out.Elem() + } else { + elem := reflect.New(outt.Elem()) + if first { + // Only set if value is compatible. + first = false + defer func(out, elem reflect.Value) { + if good { + out.Set(elem) + } + }(out, elem) + } else { + out.Set(elem) + } + out = elem + } + outt = out.Type() + outk = outt.Kind() + } + + inv := reflect.ValueOf(in) + if outt == inv.Type() { + out.Set(inv) + return true + } + + switch outk { + case reflect.Interface: + out.Set(inv) + return true + case reflect.String: + switch inv.Kind() { + case reflect.String: + out.SetString(inv.String()) + return true + case reflect.Slice: + if b, ok := in.([]byte); ok { + out.SetString(string(b)) + return true + } + case reflect.Int, reflect.Int64: + if outt == typeJSONNumber { + out.SetString(strconv.FormatInt(inv.Int(), 10)) + return true + } + case reflect.Float64: + if outt == typeJSONNumber { + out.SetString(strconv.FormatFloat(inv.Float(), 'f', -1, 64)) + return true + } + } + case reflect.Slice, reflect.Array: + // Remember, array (0x04) slices are built with the correct + // element type. If we are here, must be a cross BSON kind + // conversion (e.g. 0x05 unmarshalling on string). + if outt.Elem().Kind() != reflect.Uint8 { + break + } + switch inv.Kind() { + case reflect.String: + slice := []byte(inv.String()) + out.Set(reflect.ValueOf(slice)) + return true + case reflect.Slice: + switch outt.Kind() { + case reflect.Array: + reflect.Copy(out, inv) + case reflect.Slice: + out.SetBytes(inv.Bytes()) + } + return true + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch inv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetInt(inv.Int()) + return true + case reflect.Float32, reflect.Float64: + out.SetInt(int64(inv.Float())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetInt(1) + } else { + out.SetInt(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("can't happen: no uint types in BSON (!?)") + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + switch inv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetUint(uint64(inv.Int())) + return true + case reflect.Float32, reflect.Float64: + out.SetUint(uint64(inv.Float())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetUint(1) + } else { + out.SetUint(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON.") + } + case reflect.Float32, reflect.Float64: + switch inv.Kind() { + case reflect.Float32, reflect.Float64: + out.SetFloat(inv.Float()) + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetFloat(float64(inv.Int())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetFloat(1) + } else { + out.SetFloat(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON?") + } + case reflect.Bool: + switch inv.Kind() { + case reflect.Bool: + out.SetBool(inv.Bool()) + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetBool(inv.Int() != 0) + return true + case reflect.Float32, reflect.Float64: + out.SetBool(inv.Float() != 0) + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON?") + } + case reflect.Struct: + if outt == typeURL && inv.Kind() == reflect.String { + u, err := url.Parse(inv.String()) + if err != nil { + panic(err) + } + out.Set(reflect.ValueOf(u).Elem()) + return true + } + } + + return false +} + +// -------------------------------------------------------------------------- +// Parsers of basic types. + +func (d *decoder) readRegEx() RegEx { + re := RegEx{} + re.Pattern = d.readCStr() + re.Options = d.readCStr() + return re +} + +func (d *decoder) readBinary() Binary { + l := d.readInt32() + b := Binary{} + b.Kind = d.readByte() + b.Data = d.readBytes(l) + if b.Kind == 0x02 && len(b.Data) >= 4 { + // Weird obsolete format with redundant length. + b.Data = b.Data[4:] + } + return b +} + +func (d *decoder) readStr() string { + l := d.readInt32() + b := d.readBytes(l - 1) + if d.readByte() != '\x00' { + corrupted() + } + return string(b) +} + +func (d *decoder) readCStr() string { + start := d.i + end := start + l := len(d.in) + for ; end != l; end++ { + if d.in[end] == '\x00' { + break + } + } + d.i = end + 1 + if d.i > l { + corrupted() + } + return string(d.in[start:end]) +} + +func (d *decoder) readBool() bool { + if d.readByte() == 1 { + return true + } + return false +} + +func (d *decoder) readFloat64() float64 { + return math.Float64frombits(uint64(d.readInt64())) +} + +func (d *decoder) readInt32() int32 { + b := d.readBytes(4) + return int32((uint32(b[0]) << 0) | + (uint32(b[1]) << 8) | + (uint32(b[2]) << 16) | + (uint32(b[3]) << 24)) +} + +func (d *decoder) readInt64() int64 { + b := d.readBytes(8) + return int64((uint64(b[0]) << 0) | + (uint64(b[1]) << 8) | + (uint64(b[2]) << 16) | + (uint64(b[3]) << 24) | + (uint64(b[4]) << 32) | + (uint64(b[5]) << 40) | + (uint64(b[6]) << 48) | + (uint64(b[7]) << 56)) +} + +func (d *decoder) readByte() byte { + i := d.i + d.i++ + if d.i > len(d.in) { + corrupted() + } + return d.in[i] +} + +func (d *decoder) readBytes(length int32) []byte { + start := d.i + d.i += int(length) + if d.i > len(d.in) { + corrupted() + } + return d.in[start : start+int(length)] +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/encode.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/encode.go new file mode 100644 index 0000000..03a1548 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/bson/encode.go @@ -0,0 +1,489 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. + +package bson + +import ( + "encoding/json" + "fmt" + "math" + "net/url" + "reflect" + "strconv" + "time" +) + +// -------------------------------------------------------------------------- +// Some internal infrastructure. + +var ( + typeBinary = reflect.TypeOf(Binary{}) + typeObjectId = reflect.TypeOf(ObjectId("")) + typeDBPointer = reflect.TypeOf(DBPointer{"", ObjectId("")}) + typeSymbol = reflect.TypeOf(Symbol("")) + typeMongoTimestamp = reflect.TypeOf(MongoTimestamp(0)) + typeOrderKey = reflect.TypeOf(MinKey) + typeDocElem = reflect.TypeOf(DocElem{}) + typeRawDocElem = reflect.TypeOf(RawDocElem{}) + typeRaw = reflect.TypeOf(Raw{}) + typeURL = reflect.TypeOf(url.URL{}) + typeTime = reflect.TypeOf(time.Time{}) + typeString = reflect.TypeOf("") + typeJSONNumber = reflect.TypeOf(json.Number("")) +) + +const itoaCacheSize = 32 + +var itoaCache []string + +func init() { + itoaCache = make([]string, itoaCacheSize) + for i := 0; i != itoaCacheSize; i++ { + itoaCache[i] = strconv.Itoa(i) + } +} + +func itoa(i int) string { + if i < itoaCacheSize { + return itoaCache[i] + } + return strconv.Itoa(i) +} + +// -------------------------------------------------------------------------- +// Marshaling of the document value itself. + +type encoder struct { + out []byte +} + +func (e *encoder) addDoc(v reflect.Value) { + for { + if vi, ok := v.Interface().(Getter); ok { + getv, err := vi.GetBSON() + if err != nil { + panic(err) + } + v = reflect.ValueOf(getv) + continue + } + if v.Kind() == reflect.Ptr { + v = v.Elem() + continue + } + break + } + + if v.Type() == typeRaw { + raw := v.Interface().(Raw) + if raw.Kind != 0x03 && raw.Kind != 0x00 { + panic("Attempted to unmarshal Raw kind " + strconv.Itoa(int(raw.Kind)) + " as a document") + } + e.addBytes(raw.Data...) + return + } + + start := e.reserveInt32() + + switch v.Kind() { + case reflect.Map: + e.addMap(v) + case reflect.Struct: + e.addStruct(v) + case reflect.Array, reflect.Slice: + e.addSlice(v) + default: + panic("Can't marshal " + v.Type().String() + " as a BSON document") + } + + e.addBytes(0) + e.setInt32(start, int32(len(e.out)-start)) +} + +func (e *encoder) addMap(v reflect.Value) { + for _, k := range v.MapKeys() { + e.addElem(k.String(), v.MapIndex(k), false) + } +} + +func (e *encoder) addStruct(v reflect.Value) { + sinfo, err := getStructInfo(v.Type()) + if err != nil { + panic(err) + } + var value reflect.Value + if sinfo.InlineMap >= 0 { + m := v.Field(sinfo.InlineMap) + if m.Len() > 0 { + for _, k := range m.MapKeys() { + ks := k.String() + if _, found := sinfo.FieldsMap[ks]; found { + panic(fmt.Sprintf("Can't have key %q in inlined map; conflicts with struct field", ks)) + } + e.addElem(ks, m.MapIndex(k), false) + } + } + } + for _, info := range sinfo.FieldsList { + if info.Inline == nil { + value = v.Field(info.Num) + } else { + value = v.FieldByIndex(info.Inline) + } + if info.OmitEmpty && isZero(value) { + continue + } + e.addElem(info.Key, value, info.MinSize) + } +} + +func isZero(v reflect.Value) bool { + switch v.Kind() { + case reflect.String: + return len(v.String()) == 0 + case reflect.Ptr, reflect.Interface: + return v.IsNil() + case reflect.Slice: + return v.Len() == 0 + case reflect.Map: + return v.Len() == 0 + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Struct: + if v.Type() == typeTime { + return v.Interface().(time.Time).IsZero() + } + for i := v.NumField()-1; i >= 0; i-- { + if !isZero(v.Field(i)) { + return false + } + } + return true + } + return false +} + +func (e *encoder) addSlice(v reflect.Value) { + vi := v.Interface() + if d, ok := vi.(D); ok { + for _, elem := range d { + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + if d, ok := vi.(RawD); ok { + for _, elem := range d { + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + l := v.Len() + et := v.Type().Elem() + if et == typeDocElem { + for i := 0; i < l; i++ { + elem := v.Index(i).Interface().(DocElem) + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + if et == typeRawDocElem { + for i := 0; i < l; i++ { + elem := v.Index(i).Interface().(RawDocElem) + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + for i := 0; i < l; i++ { + e.addElem(itoa(i), v.Index(i), false) + } +} + +// -------------------------------------------------------------------------- +// Marshaling of elements in a document. + +func (e *encoder) addElemName(kind byte, name string) { + e.addBytes(kind) + e.addBytes([]byte(name)...) + e.addBytes(0) +} + +func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { + + if !v.IsValid() { + e.addElemName('\x0A', name) + return + } + + if getter, ok := v.Interface().(Getter); ok { + getv, err := getter.GetBSON() + if err != nil { + panic(err) + } + e.addElem(name, reflect.ValueOf(getv), minSize) + return + } + + switch v.Kind() { + + case reflect.Interface: + e.addElem(name, v.Elem(), minSize) + + case reflect.Ptr: + e.addElem(name, v.Elem(), minSize) + + case reflect.String: + s := v.String() + switch v.Type() { + case typeObjectId: + if len(s) != 12 { + panic("ObjectIDs must be exactly 12 bytes long (got " + + strconv.Itoa(len(s)) + ")") + } + e.addElemName('\x07', name) + e.addBytes([]byte(s)...) + case typeSymbol: + e.addElemName('\x0E', name) + e.addStr(s) + case typeJSONNumber: + n := v.Interface().(json.Number) + if i, err := n.Int64(); err == nil { + e.addElemName('\x12', name) + e.addInt64(i) + } else if f, err := n.Float64(); err == nil { + e.addElemName('\x01', name) + e.addFloat64(f) + } else { + panic("failed to convert json.Number to a number: " + s) + } + default: + e.addElemName('\x02', name) + e.addStr(s) + } + + case reflect.Float32, reflect.Float64: + e.addElemName('\x01', name) + e.addFloat64(v.Float()) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + u := v.Uint() + if int64(u) < 0 { + panic("BSON has no uint64 type, and value is too large to fit correctly in an int64") + } else if u <= math.MaxInt32 && (minSize || v.Kind() <= reflect.Uint32) { + e.addElemName('\x10', name) + e.addInt32(int32(u)) + } else { + e.addElemName('\x12', name) + e.addInt64(int64(u)) + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch v.Type() { + case typeMongoTimestamp: + e.addElemName('\x11', name) + e.addInt64(v.Int()) + + case typeOrderKey: + if v.Int() == int64(MaxKey) { + e.addElemName('\x7F', name) + } else { + e.addElemName('\xFF', name) + } + + default: + i := v.Int() + if (minSize || v.Type().Kind() != reflect.Int64) && i >= math.MinInt32 && i <= math.MaxInt32 { + // It fits into an int32, encode as such. + e.addElemName('\x10', name) + e.addInt32(int32(i)) + } else { + e.addElemName('\x12', name) + e.addInt64(i) + } + } + + case reflect.Bool: + e.addElemName('\x08', name) + if v.Bool() { + e.addBytes(1) + } else { + e.addBytes(0) + } + + case reflect.Map: + e.addElemName('\x03', name) + e.addDoc(v) + + case reflect.Slice: + vt := v.Type() + et := vt.Elem() + if et.Kind() == reflect.Uint8 { + e.addElemName('\x05', name) + e.addBinary('\x00', v.Bytes()) + } else if et == typeDocElem || et == typeRawDocElem { + e.addElemName('\x03', name) + e.addDoc(v) + } else { + e.addElemName('\x04', name) + e.addDoc(v) + } + + case reflect.Array: + et := v.Type().Elem() + if et.Kind() == reflect.Uint8 { + e.addElemName('\x05', name) + e.addBinary('\x00', v.Slice(0, v.Len()).Interface().([]byte)) + } else { + e.addElemName('\x04', name) + e.addDoc(v) + } + + case reflect.Struct: + switch s := v.Interface().(type) { + + case Raw: + kind := s.Kind + if kind == 0x00 { + kind = 0x03 + } + e.addElemName(kind, name) + e.addBytes(s.Data...) + + case Binary: + e.addElemName('\x05', name) + e.addBinary(s.Kind, s.Data) + + case DBPointer: + e.addElemName('\x0C', name) + e.addStr(s.Namespace) + if len(s.Id) != 12 { + panic("ObjectIDs must be exactly 12 bytes long (got " + + strconv.Itoa(len(s.Id)) + ")") + } + e.addBytes([]byte(s.Id)...) + + case RegEx: + e.addElemName('\x0B', name) + e.addCStr(s.Pattern) + e.addCStr(s.Options) + + case JavaScript: + if s.Scope == nil { + e.addElemName('\x0D', name) + e.addStr(s.Code) + } else { + e.addElemName('\x0F', name) + start := e.reserveInt32() + e.addStr(s.Code) + e.addDoc(reflect.ValueOf(s.Scope)) + e.setInt32(start, int32(len(e.out)-start)) + } + + case time.Time: + // MongoDB handles timestamps as milliseconds. + e.addElemName('\x09', name) + e.addInt64(s.Unix() * 1000 + int64(s.Nanosecond() / 1e6)) + + case url.URL: + e.addElemName('\x02', name) + e.addStr(s.String()) + + case undefined: + e.addElemName('\x06', name) + + default: + e.addElemName('\x03', name) + e.addDoc(v) + } + + default: + panic("Can't marshal " + v.Type().String() + " in a BSON document") + } +} + +// -------------------------------------------------------------------------- +// Marshaling of base types. + +func (e *encoder) addBinary(subtype byte, v []byte) { + if subtype == 0x02 { + // Wonder how that brilliant idea came to life. Obsolete, luckily. + e.addInt32(int32(len(v) + 4)) + e.addBytes(subtype) + e.addInt32(int32(len(v))) + } else { + e.addInt32(int32(len(v))) + e.addBytes(subtype) + } + e.addBytes(v...) +} + +func (e *encoder) addStr(v string) { + e.addInt32(int32(len(v) + 1)) + e.addCStr(v) +} + +func (e *encoder) addCStr(v string) { + e.addBytes([]byte(v)...) + e.addBytes(0) +} + +func (e *encoder) reserveInt32() (pos int) { + pos = len(e.out) + e.addBytes(0, 0, 0, 0) + return pos +} + +func (e *encoder) setInt32(pos int, v int32) { + e.out[pos+0] = byte(v) + e.out[pos+1] = byte(v >> 8) + e.out[pos+2] = byte(v >> 16) + e.out[pos+3] = byte(v >> 24) +} + +func (e *encoder) addInt32(v int32) { + u := uint32(v) + e.addBytes(byte(u), byte(u>>8), byte(u>>16), byte(u>>24)) +} + +func (e *encoder) addInt64(v int64) { + u := uint64(v) + e.addBytes(byte(u), byte(u>>8), byte(u>>16), byte(u>>24), + byte(u>>32), byte(u>>40), byte(u>>48), byte(u>>56)) +} + +func (e *encoder) addFloat64(v float64) { + e.addInt64(int64(math.Float64bits(v))) +} + +func (e *encoder) addBytes(v ...byte) { + e.out = append(e.out, v...) +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/bulk.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/bulk.go new file mode 100644 index 0000000..23f4508 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/bulk.go @@ -0,0 +1,71 @@ +package mgo + +// Bulk represents an operation that can be prepared with several +// orthogonal changes before being delivered to the server. +// +// WARNING: This API is still experimental. +// +// Relevant documentation: +// +// http://blog.mongodb.org/post/84922794768/mongodbs-new-bulk-api +// +type Bulk struct { + c *Collection + ordered bool + inserts []interface{} +} + +// BulkError holds an error returned from running a Bulk operation. +// +// TODO: This is private for the moment, until we understand exactly how +// to report these multi-errors in a useful and convenient way. +type bulkError struct { + err error +} + +// BulkResult holds the results for a bulk operation. +type BulkResult struct { + // Be conservative while we understand exactly how to report these + // results in a useful and convenient way, and also how to emulate + // them with prior servers. + private bool +} + +func (e *bulkError) Error() string { + return e.err.Error() +} + +// Bulk returns a value to prepare the execution of a bulk operation. +// +// WARNING: This API is still experimental. +// +func (c *Collection) Bulk() *Bulk { + return &Bulk{c: c, ordered: true} +} + +// Unordered puts the bulk operation in unordered mode. +// +// In unordered mode the indvidual operations may be sent +// out of order, which means latter operations may proceed +// even if prior ones have failed. +func (b *Bulk) Unordered() { + b.ordered = false +} + +// Insert queues up the provided documents for insertion. +func (b *Bulk) Insert(docs ...interface{}) { + b.inserts = append(b.inserts, docs...) +} + +// Run runs all the operations queued up. +func (b *Bulk) Run() (*BulkResult, error) { + op := &insertOp{b.c.FullName, b.inserts, 0} + if !b.ordered { + op.flags = 1 // ContinueOnError + } + _, err := b.c.writeQuery(op) + if err != nil { + return nil, &bulkError{err} + } + return &BulkResult{}, nil +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/bulk_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/bulk_test.go new file mode 100644 index 0000000..24af1b1 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/bulk_test.go @@ -0,0 +1,93 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2014 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo_test + +import ( + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2" +) + +func (s *S) TestBulkInsert(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + bulk := coll.Bulk() + bulk.Insert(M{"n": 1}) + bulk.Insert(M{"n": 2}, M{"n": 3}) + r, err := bulk.Run() + c.Assert(err, IsNil) + c.Assert(r, FitsTypeOf, &mgo.BulkResult{}) + + type doc struct{ N int } + var res []doc + err = coll.Find(nil).Sort("n").All(&res) + c.Assert(err, IsNil) + c.Assert(res, DeepEquals, []doc{{1}, {2}, {3}}) +} + +func (s *S) TestBulkInsertError(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + bulk := coll.Bulk() + bulk.Insert(M{"_id": 1}, M{"_id": 2}, M{"_id": 2}, M{"n": 3}) + _, err = bulk.Run() + c.Assert(err, ErrorMatches, ".*duplicate key.*") + + type doc struct { + N int `_id` + } + var res []doc + err = coll.Find(nil).Sort("_id").All(&res) + c.Assert(err, IsNil) + c.Assert(res, DeepEquals, []doc{{1}, {2}}) +} + +func (s *S) TestBulkInsertErrorUnordered(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + bulk := coll.Bulk() + bulk.Unordered() + bulk.Insert(M{"_id": 1}, M{"_id": 2}, M{"_id": 2}, M{"_id": 3}) + _, err = bulk.Run() + c.Assert(err, ErrorMatches, ".*duplicate key.*") + + type doc struct { + N int `_id` + } + var res []doc + err = coll.Find(nil).Sort("_id").All(&res) + c.Assert(err, IsNil) + c.Assert(res, DeepEquals, []doc{{1}, {2}, {3}}) +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/cluster.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/cluster.go new file mode 100644 index 0000000..104dd39 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/cluster.go @@ -0,0 +1,621 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "errors" + "net" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +// --------------------------------------------------------------------------- +// Mongo cluster encapsulation. +// +// A cluster enables the communication with one or more servers participating +// in a mongo cluster. This works with individual servers, a replica set, +// a replica pair, one or multiple mongos routers, etc. + +type mongoCluster struct { + sync.RWMutex + serverSynced sync.Cond + userSeeds []string + dynaSeeds []string + servers mongoServers + masters mongoServers + references int + syncing bool + direct bool + failFast bool + syncCount uint + cachedIndex map[string]bool + sync chan bool + dial dialer +} + +func newCluster(userSeeds []string, direct, failFast bool, dial dialer) *mongoCluster { + cluster := &mongoCluster{ + userSeeds: userSeeds, + references: 1, + direct: direct, + failFast: failFast, + dial: dial, + } + cluster.serverSynced.L = cluster.RWMutex.RLocker() + cluster.sync = make(chan bool, 1) + stats.cluster(+1) + go cluster.syncServersLoop() + return cluster +} + +// Acquire increases the reference count for the cluster. +func (cluster *mongoCluster) Acquire() { + cluster.Lock() + cluster.references++ + debugf("Cluster %p acquired (refs=%d)", cluster, cluster.references) + cluster.Unlock() +} + +// Release decreases the reference count for the cluster. Once +// it reaches zero, all servers will be closed. +func (cluster *mongoCluster) Release() { + cluster.Lock() + if cluster.references == 0 { + panic("cluster.Release() with references == 0") + } + cluster.references-- + debugf("Cluster %p released (refs=%d)", cluster, cluster.references) + if cluster.references == 0 { + for _, server := range cluster.servers.Slice() { + server.Close() + } + // Wake up the sync loop so it can die. + cluster.syncServers() + stats.cluster(-1) + } + cluster.Unlock() +} + +func (cluster *mongoCluster) LiveServers() (servers []string) { + cluster.RLock() + for _, serv := range cluster.servers.Slice() { + servers = append(servers, serv.Addr) + } + cluster.RUnlock() + return servers +} + +func (cluster *mongoCluster) removeServer(server *mongoServer) { + cluster.Lock() + cluster.masters.Remove(server) + other := cluster.servers.Remove(server) + cluster.Unlock() + if other != nil { + other.Close() + log("Removed server ", server.Addr, " from cluster.") + } + server.Close() +} + +type isMasterResult struct { + IsMaster bool + Secondary bool + Primary string + Hosts []string + Passives []string + Tags bson.D + Msg string + MaxWireVersion int `bson:"maxWireVersion"` +} + +func (cluster *mongoCluster) isMaster(socket *mongoSocket, result *isMasterResult) error { + // Monotonic let's it talk to a slave and still hold the socket. + session := newSession(Monotonic, cluster, 10*time.Second) + session.setSocket(socket) + err := session.Run("ismaster", result) + session.Close() + return err +} + +type possibleTimeout interface { + Timeout() bool +} + +var syncSocketTimeout = 5 * time.Second + +func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerInfo, hosts []string, err error) { + var syncTimeout time.Duration + if raceDetector { + // This variable is only ever touched by tests. + globalMutex.Lock() + syncTimeout = syncSocketTimeout + globalMutex.Unlock() + } else { + syncTimeout = syncSocketTimeout + } + + addr := server.Addr + log("SYNC Processing ", addr, "...") + + // Retry a few times to avoid knocking a server down for a hiccup. + var result isMasterResult + var tryerr error + for retry := 0; ; retry++ { + if retry == 3 || retry == 1 && cluster.failFast { + return nil, nil, tryerr + } + if retry > 0 { + // Don't abuse the server needlessly if there's something actually wrong. + if err, ok := tryerr.(possibleTimeout); ok && err.Timeout() { + // Give a chance for waiters to timeout as well. + cluster.serverSynced.Broadcast() + } + time.Sleep(syncShortDelay) + } + + // It's not clear what would be a good timeout here. Is it + // better to wait longer or to retry? + socket, _, err := server.AcquireSocket(0, syncTimeout) + if err != nil { + tryerr = err + logf("SYNC Failed to get socket to %s: %v", addr, err) + continue + } + err = cluster.isMaster(socket, &result) + socket.Release() + if err != nil { + tryerr = err + logf("SYNC Command 'ismaster' to %s failed: %v", addr, err) + continue + } + debugf("SYNC Result of 'ismaster' from %s: %#v", addr, result) + break + } + + if result.IsMaster { + debugf("SYNC %s is a master.", addr) + // Made an incorrect assumption above, so fix stats. + stats.conn(-1, false) + stats.conn(+1, true) + } else if result.Secondary { + debugf("SYNC %s is a slave.", addr) + } else if cluster.direct { + logf("SYNC %s in unknown state. Pretending it's a slave due to direct connection.", addr) + } else { + logf("SYNC %s is neither a master nor a slave.", addr) + // Made an incorrect assumption above, so fix stats. + stats.conn(-1, false) + return nil, nil, errors.New(addr + " is not a master nor slave") + } + + info = &mongoServerInfo{ + Master: result.IsMaster, + Mongos: result.Msg == "isdbgrid", + Tags: result.Tags, + MaxWireVersion: result.MaxWireVersion, + } + + hosts = make([]string, 0, 1+len(result.Hosts)+len(result.Passives)) + if result.Primary != "" { + // First in the list to speed up master discovery. + hosts = append(hosts, result.Primary) + } + hosts = append(hosts, result.Hosts...) + hosts = append(hosts, result.Passives...) + + debugf("SYNC %s knows about the following peers: %#v", addr, hosts) + return info, hosts, nil +} + +type syncKind bool + +const ( + completeSync syncKind = true + partialSync syncKind = false +) + +func (cluster *mongoCluster) addServer(server *mongoServer, info *mongoServerInfo, syncKind syncKind) { + cluster.Lock() + current := cluster.servers.Search(server.ResolvedAddr) + if current == nil { + if syncKind == partialSync { + cluster.Unlock() + server.Close() + log("SYNC Discarding unknown server ", server.Addr, " due to partial sync.") + return + } + cluster.servers.Add(server) + if info.Master { + cluster.masters.Add(server) + log("SYNC Adding ", server.Addr, " to cluster as a master.") + } else { + log("SYNC Adding ", server.Addr, " to cluster as a slave.") + } + } else { + if server != current { + panic("addServer attempting to add duplicated server") + } + if server.Info().Master != info.Master { + if info.Master { + log("SYNC Server ", server.Addr, " is now a master.") + cluster.masters.Add(server) + } else { + log("SYNC Server ", server.Addr, " is now a slave.") + cluster.masters.Remove(server) + } + } + } + server.SetInfo(info) + debugf("SYNC Broadcasting availability of server %s", server.Addr) + cluster.serverSynced.Broadcast() + cluster.Unlock() +} + +func (cluster *mongoCluster) getKnownAddrs() []string { + cluster.RLock() + max := len(cluster.userSeeds) + len(cluster.dynaSeeds) + cluster.servers.Len() + seen := make(map[string]bool, max) + known := make([]string, 0, max) + + add := func(addr string) { + if _, found := seen[addr]; !found { + seen[addr] = true + known = append(known, addr) + } + } + + for _, addr := range cluster.userSeeds { + add(addr) + } + for _, addr := range cluster.dynaSeeds { + add(addr) + } + for _, serv := range cluster.servers.Slice() { + add(serv.Addr) + } + cluster.RUnlock() + + return known +} + +// syncServers injects a value into the cluster.sync channel to force +// an iteration of the syncServersLoop function. +func (cluster *mongoCluster) syncServers() { + select { + case cluster.sync <- true: + default: + } +} + +// How long to wait for a checkup of the cluster topology if nothing +// else kicks a synchronization before that. +const syncServersDelay = 30 * time.Second +const syncShortDelay = 500 * time.Millisecond + +// syncServersLoop loops while the cluster is alive to keep its idea of +// the server topology up-to-date. It must be called just once from +// newCluster. The loop iterates once syncServersDelay has passed, or +// if somebody injects a value into the cluster.sync channel to force a +// synchronization. A loop iteration will contact all servers in +// parallel, ask them about known peers and their own role within the +// cluster, and then attempt to do the same with all the peers +// retrieved. +func (cluster *mongoCluster) syncServersLoop() { + for { + debugf("SYNC Cluster %p is starting a sync loop iteration.", cluster) + + cluster.Lock() + if cluster.references == 0 { + cluster.Unlock() + break + } + cluster.references++ // Keep alive while syncing. + direct := cluster.direct + cluster.Unlock() + + cluster.syncServersIteration(direct) + + // We just synchronized, so consume any outstanding requests. + select { + case <-cluster.sync: + default: + } + + cluster.Release() + + // Hold off before allowing another sync. No point in + // burning CPU looking for down servers. + if !cluster.failFast { + time.Sleep(syncShortDelay) + } + + cluster.Lock() + if cluster.references == 0 { + cluster.Unlock() + break + } + cluster.syncCount++ + // Poke all waiters so they have a chance to timeout or + // restart syncing if they wish to. + cluster.serverSynced.Broadcast() + // Check if we have to restart immediately either way. + restart := !direct && cluster.masters.Empty() || cluster.servers.Empty() + cluster.Unlock() + + if restart { + log("SYNC No masters found. Will synchronize again.") + time.Sleep(syncShortDelay) + continue + } + + debugf("SYNC Cluster %p waiting for next requested or scheduled sync.", cluster) + + // Hold off until somebody explicitly requests a synchronization + // or it's time to check for a cluster topology change again. + select { + case <-cluster.sync: + case <-time.After(syncServersDelay): + } + } + debugf("SYNC Cluster %p is stopping its sync loop.", cluster) +} + +func (cluster *mongoCluster) server(addr string, tcpaddr *net.TCPAddr) *mongoServer { + cluster.RLock() + server := cluster.servers.Search(tcpaddr.String()) + cluster.RUnlock() + if server != nil { + return server + } + return newServer(addr, tcpaddr, cluster.sync, cluster.dial) +} + +func resolveAddr(addr string) (*net.TCPAddr, error) { + // This hack allows having a timeout on resolution. + conn, err := net.DialTimeout("udp", addr, 10*time.Second) + if err != nil { + log("SYNC Failed to resolve server address: ", addr) + return nil, errors.New("failed to resolve server address: " + addr) + } + tcpaddr := (*net.TCPAddr)(conn.RemoteAddr().(*net.UDPAddr)) + conn.Close() + if tcpaddr.String() != addr { + debug("SYNC Address ", addr, " resolved as ", tcpaddr.String()) + } + return tcpaddr, nil +} + +type pendingAdd struct { + server *mongoServer + info *mongoServerInfo +} + +func (cluster *mongoCluster) syncServersIteration(direct bool) { + log("SYNC Starting full topology synchronization...") + + var wg sync.WaitGroup + var m sync.Mutex + notYetAdded := make(map[string]pendingAdd) + addIfFound := make(map[string]bool) + seen := make(map[string]bool) + syncKind := partialSync + + var spawnSync func(addr string, byMaster bool) + spawnSync = func(addr string, byMaster bool) { + wg.Add(1) + go func() { + defer wg.Done() + + tcpaddr, err := resolveAddr(addr) + if err != nil { + log("SYNC Failed to start sync of ", addr, ": ", err.Error()) + return + } + resolvedAddr := tcpaddr.String() + + m.Lock() + if byMaster { + if pending, ok := notYetAdded[resolvedAddr]; ok { + delete(notYetAdded, resolvedAddr) + m.Unlock() + cluster.addServer(pending.server, pending.info, completeSync) + return + } + addIfFound[resolvedAddr] = true + } + if seen[resolvedAddr] { + m.Unlock() + return + } + seen[resolvedAddr] = true + m.Unlock() + + server := cluster.server(addr, tcpaddr) + info, hosts, err := cluster.syncServer(server) + if err != nil { + cluster.removeServer(server) + return + } + + m.Lock() + add := direct || info.Master || addIfFound[resolvedAddr] + if add { + syncKind = completeSync + } else { + notYetAdded[resolvedAddr] = pendingAdd{server, info} + } + m.Unlock() + if add { + cluster.addServer(server, info, completeSync) + } + if !direct { + for _, addr := range hosts { + spawnSync(addr, info.Master) + } + } + }() + } + + knownAddrs := cluster.getKnownAddrs() + for _, addr := range knownAddrs { + spawnSync(addr, false) + } + wg.Wait() + + if syncKind == completeSync { + logf("SYNC Synchronization was complete (got data from primary).") + for _, pending := range notYetAdded { + cluster.removeServer(pending.server) + } + } else { + logf("SYNC Synchronization was partial (cannot talk to primary).") + for _, pending := range notYetAdded { + cluster.addServer(pending.server, pending.info, partialSync) + } + } + + cluster.Lock() + ml := cluster.masters.Len() + logf("SYNC Synchronization completed: %d master(s) and %d slave(s) alive.", ml, cluster.servers.Len()-ml) + + // Update dynamic seeds, but only if we have any good servers. Otherwise, + // leave them alone for better chances of a successful sync in the future. + if syncKind == completeSync { + dynaSeeds := make([]string, cluster.servers.Len()) + for i, server := range cluster.servers.Slice() { + dynaSeeds[i] = server.Addr + } + cluster.dynaSeeds = dynaSeeds + debugf("SYNC New dynamic seeds: %#v\n", dynaSeeds) + } + cluster.Unlock() +} + +// AcquireSocket returns a socket to a server in the cluster. If slaveOk is +// true, it will attempt to return a socket to a slave server. If it is +// false, the socket will necessarily be to a master server. +func (cluster *mongoCluster) AcquireSocket(slaveOk bool, syncTimeout time.Duration, socketTimeout time.Duration, serverTags []bson.D, poolLimit int) (s *mongoSocket, err error) { + var started time.Time + var syncCount uint + warnedLimit := false + for { + cluster.RLock() + for { + ml := cluster.masters.Len() + sl := cluster.servers.Len() + debugf("Cluster has %d known masters and %d known slaves.", ml, sl-ml) + if ml > 0 || slaveOk && sl > 0 { + break + } + if started.IsZero() { + // Initialize after fast path above. + started = time.Now() + syncCount = cluster.syncCount + } else if syncTimeout != 0 && started.Before(time.Now().Add(-syncTimeout)) || cluster.failFast && cluster.syncCount != syncCount { + cluster.RUnlock() + return nil, errors.New("no reachable servers") + } + log("Waiting for servers to synchronize...") + cluster.syncServers() + + // Remember: this will release and reacquire the lock. + cluster.serverSynced.Wait() + } + + var server *mongoServer + if slaveOk { + server = cluster.servers.BestFit(serverTags) + } else { + server = cluster.masters.BestFit(nil) + } + cluster.RUnlock() + + if server == nil { + // Must have failed the requested tags. Sleep to avoid spinning. + time.Sleep(1e8) + continue + } + + s, abended, err := server.AcquireSocket(poolLimit, socketTimeout) + if err == errPoolLimit { + if !warnedLimit { + warnedLimit = true + log("WARNING: Per-server connection limit reached.") + } + time.Sleep(100 * time.Millisecond) + continue + } + if err != nil { + cluster.removeServer(server) + cluster.syncServers() + continue + } + if abended && !slaveOk { + var result isMasterResult + err := cluster.isMaster(s, &result) + if err != nil || !result.IsMaster { + logf("Cannot confirm server %s as master (%v)", server.Addr, err) + s.Release() + cluster.syncServers() + time.Sleep(100 * time.Millisecond) + continue + } + } + return s, nil + } + panic("unreached") +} + +func (cluster *mongoCluster) CacheIndex(cacheKey string, exists bool) { + cluster.Lock() + if cluster.cachedIndex == nil { + cluster.cachedIndex = make(map[string]bool) + } + if exists { + cluster.cachedIndex[cacheKey] = true + } else { + delete(cluster.cachedIndex, cacheKey) + } + cluster.Unlock() +} + +func (cluster *mongoCluster) HasCachedIndex(cacheKey string) (result bool) { + cluster.RLock() + if cluster.cachedIndex != nil { + result = cluster.cachedIndex[cacheKey] + } + cluster.RUnlock() + return +} + +func (cluster *mongoCluster) ResetIndexCache() { + cluster.Lock() + cluster.cachedIndex = make(map[string]bool) + cluster.Unlock() +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/cluster_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/cluster_test.go new file mode 100644 index 0000000..a29bd6f --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/cluster_test.go @@ -0,0 +1,1596 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo_test + +import ( + "fmt" + "io" + "net" + "strings" + "sync" + "time" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +func (s *S) TestNewSession(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + // Do a dummy operation to wait for connection. + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + // Tweak safety and query settings to ensure other has copied those. + session.SetSafe(nil) + session.SetBatch(-1) + other := session.New() + defer other.Close() + session.SetSafe(&mgo.Safe{}) + + // Clone was copied while session was unsafe, so no errors. + otherColl := other.DB("mydb").C("mycoll") + err = otherColl.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + // Original session was made safe again. + err = coll.Insert(M{"_id": 1}) + c.Assert(err, NotNil) + + // With New(), each session has its own socket now. + stats := mgo.GetStats() + c.Assert(stats.MasterConns, Equals, 2) + c.Assert(stats.SocketsInUse, Equals, 2) + + // Ensure query parameters were cloned. + err = otherColl.Insert(M{"_id": 2}) + c.Assert(err, IsNil) + + // Ping the database to ensure the nonce has been received already. + c.Assert(other.Ping(), IsNil) + + mgo.ResetStats() + + iter := otherColl.Find(M{}).Iter() + c.Assert(err, IsNil) + + m := M{} + ok := iter.Next(m) + c.Assert(ok, Equals, true) + err = iter.Close() + c.Assert(err, IsNil) + + // If Batch(-1) is in effect, a single document must have been received. + stats = mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 1) +} + +func (s *S) TestCloneSession(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + // Do a dummy operation to wait for connection. + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + // Tweak safety and query settings to ensure clone is copying those. + session.SetSafe(nil) + session.SetBatch(-1) + clone := session.Clone() + defer clone.Close() + session.SetSafe(&mgo.Safe{}) + + // Clone was copied while session was unsafe, so no errors. + cloneColl := clone.DB("mydb").C("mycoll") + err = cloneColl.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + // Original session was made safe again. + err = coll.Insert(M{"_id": 1}) + c.Assert(err, NotNil) + + // With Clone(), same socket is shared between sessions now. + stats := mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 1) + c.Assert(stats.SocketRefs, Equals, 2) + + // Refreshing one of them should let the original socket go, + // while preserving the safety settings. + clone.Refresh() + err = cloneColl.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + // Must have used another connection now. + stats = mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 2) + c.Assert(stats.SocketRefs, Equals, 2) + + // Ensure query parameters were cloned. + err = cloneColl.Insert(M{"_id": 2}) + c.Assert(err, IsNil) + + // Ping the database to ensure the nonce has been received already. + c.Assert(clone.Ping(), IsNil) + + mgo.ResetStats() + + iter := cloneColl.Find(M{}).Iter() + c.Assert(err, IsNil) + + m := M{} + ok := iter.Next(m) + c.Assert(ok, Equals, true) + err = iter.Close() + c.Assert(err, IsNil) + + // If Batch(-1) is in effect, a single document must have been received. + stats = mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 1) +} + +func (s *S) TestSetModeStrong(c *C) { + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, false) + session.SetMode(mgo.Strong, false) + + c.Assert(session.Mode(), Equals, mgo.Strong) + + result := M{} + cmd := session.DB("admin").C("$cmd") + err = cmd.Find(M{"ismaster": 1}).One(&result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, true) + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + // Wait since the sync also uses sockets. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for cluster sync to finish...") + time.Sleep(5e8) + } + + stats := mgo.GetStats() + c.Assert(stats.MasterConns, Equals, 1) + c.Assert(stats.SlaveConns, Equals, 2) + c.Assert(stats.SocketsInUse, Equals, 1) + + session.SetMode(mgo.Strong, true) + + stats = mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestSetModeMonotonic(c *C) { + // Must necessarily connect to a slave, otherwise the + // master connection will be available first. + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, false) + + c.Assert(session.Mode(), Equals, mgo.Monotonic) + + result := M{} + cmd := session.DB("admin").C("$cmd") + err = cmd.Find(M{"ismaster": 1}).One(&result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, false) + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + result = M{} + err = cmd.Find(M{"ismaster": 1}).One(&result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, true) + + // Wait since the sync also uses sockets. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for cluster sync to finish...") + time.Sleep(5e8) + } + + stats := mgo.GetStats() + c.Assert(stats.MasterConns, Equals, 1) + c.Assert(stats.SlaveConns, Equals, 2) + c.Assert(stats.SocketsInUse, Equals, 2) + + session.SetMode(mgo.Monotonic, true) + + stats = mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestSetModeMonotonicAfterStrong(c *C) { + // Test that a strong session shifting to a monotonic + // one preserves the socket untouched. + + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + // Insert something to force a connection to the master. + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + session.SetMode(mgo.Monotonic, false) + + // Wait since the sync also uses sockets. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for cluster sync to finish...") + time.Sleep(5e8) + } + + // Master socket should still be reserved. + stats := mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 1) + + // Confirm it's the master even though it's Monotonic by now. + result := M{} + cmd := session.DB("admin").C("$cmd") + err = cmd.Find(M{"ismaster": 1}).One(&result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, true) +} + +func (s *S) TestSetModeStrongAfterMonotonic(c *C) { + // Test that shifting from Monotonic to Strong while + // using a slave socket will keep the socket reserved + // until the master socket is necessary, so that no + // switch over occurs unless it's actually necessary. + + // Must necessarily connect to a slave, otherwise the + // master connection will be available first. + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, false) + + // Ensure we're talking to a slave, and reserve the socket. + result := M{} + err = session.Run("ismaster", &result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, false) + + // Switch to a Strong session. + session.SetMode(mgo.Strong, false) + + // Wait since the sync also uses sockets. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for cluster sync to finish...") + time.Sleep(5e8) + } + + // Slave socket should still be reserved. + stats := mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 1) + + // But any operation will switch it to the master. + result = M{} + err = session.Run("ismaster", &result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, true) +} + +func (s *S) TestSetModeMonotonicWriteOnIteration(c *C) { + // Must necessarily connect to a slave, otherwise the + // master connection will be available first. + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, false) + + c.Assert(session.Mode(), Equals, mgo.Monotonic) + + coll1 := session.DB("mydb").C("mycoll1") + coll2 := session.DB("mydb").C("mycoll2") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll1.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + // Release master so we can grab a slave again. + session.Refresh() + + // Wait until synchronization is done. + for { + n, err := coll1.Count() + c.Assert(err, IsNil) + if n == len(ns) { + break + } + } + + iter := coll1.Find(nil).Batch(2).Iter() + i := 0 + m := M{} + for iter.Next(&m) { + i++ + if i > 3 { + err := coll2.Insert(M{"n": 47 + i}) + c.Assert(err, IsNil) + } + } + c.Assert(i, Equals, len(ns)) +} + +func (s *S) TestSetModeEventual(c *C) { + // Must necessarily connect to a slave, otherwise the + // master connection will be available first. + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Eventual, false) + + c.Assert(session.Mode(), Equals, mgo.Eventual) + + result := M{} + err = session.Run("ismaster", &result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, false) + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + result = M{} + err = session.Run("ismaster", &result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, false) + + // Wait since the sync also uses sockets. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for cluster sync to finish...") + time.Sleep(5e8) + } + + stats := mgo.GetStats() + c.Assert(stats.MasterConns, Equals, 1) + c.Assert(stats.SlaveConns, Equals, 2) + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestSetModeEventualAfterStrong(c *C) { + // Test that a strong session shifting to an eventual + // one preserves the socket untouched. + + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + // Insert something to force a connection to the master. + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + session.SetMode(mgo.Eventual, false) + + // Wait since the sync also uses sockets. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for cluster sync to finish...") + time.Sleep(5e8) + } + + // Master socket should still be reserved. + stats := mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 1) + + // Confirm it's the master even though it's Eventual by now. + result := M{} + cmd := session.DB("admin").C("$cmd") + err = cmd.Find(M{"ismaster": 1}).One(&result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, true) + + session.SetMode(mgo.Eventual, true) + + stats = mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestPrimaryShutdownStrong(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + // With strong consistency, this will open a socket to the master. + result := &struct{ Host string }{} + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + + // Kill the master. + host := result.Host + s.Stop(host) + + // This must fail, since the connection was broken. + err = session.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + // With strong consistency, it fails again until reset. + err = session.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + session.Refresh() + + // Now we should be able to talk to the new master. + // Increase the timeout since this may take quite a while. + session.SetSyncTimeout(3 * time.Minute) + + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(result.Host, Not(Equals), host) + + // Insert some data to confirm it's indeed a master. + err = session.DB("mydb").C("mycoll").Insert(M{"n": 42}) + c.Assert(err, IsNil) +} + +func (s *S) TestPrimaryHiccup(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + // With strong consistency, this will open a socket to the master. + result := &struct{ Host string }{} + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + + // Establish a few extra sessions to create spare sockets to + // the master. This increases a bit the chances of getting an + // incorrect cached socket. + var sessions []*mgo.Session + for i := 0; i < 20; i++ { + sessions = append(sessions, session.Copy()) + err = sessions[len(sessions)-1].Run("serverStatus", result) + c.Assert(err, IsNil) + } + for i := range sessions { + sessions[i].Close() + } + + // Kill the master, but bring it back immediatelly. + host := result.Host + s.Stop(host) + s.StartAll() + + // This must fail, since the connection was broken. + err = session.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + // With strong consistency, it fails again until reset. + err = session.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + session.Refresh() + + // Now we should be able to talk to the new master. + // Increase the timeout since this may take quite a while. + session.SetSyncTimeout(3 * time.Minute) + + // Insert some data to confirm it's indeed a master. + err = session.DB("mydb").C("mycoll").Insert(M{"n": 42}) + c.Assert(err, IsNil) +} + +func (s *S) TestPrimaryShutdownMonotonic(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, true) + + // Insert something to force a switch to the master. + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + // Wait a bit for this to be synchronized to slaves. + time.Sleep(3 * time.Second) + + result := &struct{ Host string }{} + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + + // Kill the master. + host := result.Host + s.Stop(host) + + // This must fail, since the connection was broken. + err = session.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + // With monotonic consistency, it fails again until reset. + err = session.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + session.Refresh() + + // Now we should be able to talk to the new master. + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(result.Host, Not(Equals), host) +} + +func (s *S) TestPrimaryShutdownMonotonicWithSlave(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + ssresult := &struct{ Host string }{} + imresult := &struct{ IsMaster bool }{} + + // Figure the master while still using the strong session. + err = session.Run("serverStatus", ssresult) + c.Assert(err, IsNil) + err = session.Run("isMaster", imresult) + c.Assert(err, IsNil) + master := ssresult.Host + c.Assert(imresult.IsMaster, Equals, true, Commentf("%s is not the master", master)) + + // Create new monotonic session with an explicit address to ensure + // a slave is synchronized before the master, otherwise a connection + // with the master may be used below for lack of other options. + var addr string + switch { + case strings.HasSuffix(ssresult.Host, ":40021"): + addr = "localhost:40022" + case strings.HasSuffix(ssresult.Host, ":40022"): + addr = "localhost:40021" + case strings.HasSuffix(ssresult.Host, ":40023"): + addr = "localhost:40021" + default: + c.Fatal("Unknown host: ", ssresult.Host) + } + + session, err = mgo.Dial(addr) + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, true) + + // Check the address of the socket associated with the monotonic session. + c.Log("Running serverStatus and isMaster with monotonic session") + err = session.Run("serverStatus", ssresult) + c.Assert(err, IsNil) + err = session.Run("isMaster", imresult) + c.Assert(err, IsNil) + slave := ssresult.Host + c.Assert(imresult.IsMaster, Equals, false, Commentf("%s is not a slave", slave)) + + c.Assert(master, Not(Equals), slave) + + // Kill the master. + s.Stop(master) + + // Session must still be good, since we were talking to a slave. + err = session.Run("serverStatus", ssresult) + c.Assert(err, IsNil) + + c.Assert(ssresult.Host, Equals, slave, + Commentf("Monotonic session moved from %s to %s", slave, ssresult.Host)) + + // If we try to insert something, it'll have to hold until the new + // master is available to move the connection, and work correctly. + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + // Must now be talking to the new master. + err = session.Run("serverStatus", ssresult) + c.Assert(err, IsNil) + err = session.Run("isMaster", imresult) + c.Assert(err, IsNil) + c.Assert(imresult.IsMaster, Equals, true, Commentf("%s is not the master", master)) + + // ... which is not the old one, since it's still dead. + c.Assert(ssresult.Host, Not(Equals), master) +} + +func (s *S) TestPrimaryShutdownEventual(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + result := &struct{ Host string }{} + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + master := result.Host + + session.SetMode(mgo.Eventual, true) + + // Should connect to the master when needed. + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + // Wait a bit for this to be synchronized to slaves. + time.Sleep(3 * time.Second) + + // Kill the master. + s.Stop(master) + + // Should still work, with the new master now. + coll = session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(result.Host, Not(Equals), master) +} + +func (s *S) TestPreserveSocketCountOnSync(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + stats := mgo.GetStats() + for stats.MasterConns+stats.SlaveConns != 3 { + stats = mgo.GetStats() + c.Log("Waiting for all connections to be established...") + time.Sleep(5e8) + } + + c.Assert(stats.SocketsAlive, Equals, 3) + + // Kill the master (with rs1, 'a' is always the master). + s.Stop("localhost:40011") + + // Wait for the logic to run for a bit and bring it back. + startedAll := make(chan bool) + go func() { + time.Sleep(5e9) + s.StartAll() + startedAll <- true + }() + + // Do not allow the test to return before the goroutine above is done. + defer func() { + <-startedAll + }() + + // Do an action to kick the resync logic in, and also to + // wait until the cluster recognizes the server is back. + result := struct{ Ok bool }{} + err = session.Run("getLastError", &result) + c.Assert(err, IsNil) + c.Assert(result.Ok, Equals, true) + + for i := 0; i != 20; i++ { + stats = mgo.GetStats() + if stats.SocketsAlive == 3 { + break + } + c.Logf("Waiting for 3 sockets alive, have %d", stats.SocketsAlive) + time.Sleep(5e8) + } + + // Ensure the number of sockets is preserved after syncing. + stats = mgo.GetStats() + c.Assert(stats.SocketsAlive, Equals, 3) + c.Assert(stats.SocketsInUse, Equals, 1) + c.Assert(stats.SocketRefs, Equals, 1) +} + +// Connect to the master of a deployment with a single server, +// run an insert, and then ensure the insert worked and that a +// single connection was established. +func (s *S) TestTopologySyncWithSingleMaster(c *C) { + // Use hostname here rather than IP, to make things trickier. + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1, "b": 2}) + c.Assert(err, IsNil) + + // One connection used for discovery. Master socket recycled for + // insert. Socket is reserved after insert. + stats := mgo.GetStats() + c.Assert(stats.MasterConns, Equals, 1) + c.Assert(stats.SlaveConns, Equals, 0) + c.Assert(stats.SocketsInUse, Equals, 1) + + // Refresh session and socket must be released. + session.Refresh() + stats = mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestTopologySyncWithSlaveSeed(c *C) { + // That's supposed to be a slave. Must run discovery + // and find out master to insert successfully. + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"a": 1, "b": 2}) + + result := struct{ Ok bool }{} + err = session.Run("getLastError", &result) + c.Assert(err, IsNil) + c.Assert(result.Ok, Equals, true) + + // One connection to each during discovery. Master + // socket recycled for insert. + stats := mgo.GetStats() + c.Assert(stats.MasterConns, Equals, 1) + c.Assert(stats.SlaveConns, Equals, 2) + + // Only one socket reference alive, in the master socket owned + // by the above session. + c.Assert(stats.SocketsInUse, Equals, 1) + + // Refresh it, and it must be gone. + session.Refresh() + stats = mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestSyncTimeout(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + s.Stop("localhost:40001") + + timeout := 3 * time.Second + session.SetSyncTimeout(timeout) + started := time.Now() + + // Do something. + result := struct{ Ok bool }{} + err = session.Run("getLastError", &result) + c.Assert(err, ErrorMatches, "no reachable servers") + c.Assert(started.Before(time.Now().Add(-timeout)), Equals, true) + c.Assert(started.After(time.Now().Add(-timeout*2)), Equals, true) +} + +func (s *S) TestDialWithTimeout(c *C) { + if *fast { + c.Skip("-fast") + } + + timeout := 2 * time.Second + started := time.Now() + + // 40009 isn't used by the test servers. + session, err := mgo.DialWithTimeout("localhost:40009", timeout) + if session != nil { + session.Close() + } + c.Assert(err, ErrorMatches, "no reachable servers") + c.Assert(session, IsNil) + c.Assert(started.Before(time.Now().Add(-timeout)), Equals, true) + c.Assert(started.After(time.Now().Add(-timeout*2)), Equals, true) +} + +func (s *S) TestSocketTimeout(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + s.Freeze("localhost:40001") + + timeout := 3 * time.Second + session.SetSocketTimeout(timeout) + started := time.Now() + + // Do something. + result := struct{ Ok bool }{} + err = session.Run("getLastError", &result) + c.Assert(err, ErrorMatches, ".*: i/o timeout") + c.Assert(started.Before(time.Now().Add(-timeout)), Equals, true) + c.Assert(started.After(time.Now().Add(-timeout*2)), Equals, true) +} + +func (s *S) TestSocketTimeoutOnDial(c *C) { + if *fast { + c.Skip("-fast") + } + + timeout := 1 * time.Second + + defer mgo.HackSyncSocketTimeout(timeout)() + + s.Freeze("localhost:40001") + + started := time.Now() + + session, err := mgo.DialWithTimeout("localhost:40001", timeout) + c.Assert(err, ErrorMatches, "no reachable servers") + c.Assert(session, IsNil) + + c.Assert(started.Before(time.Now().Add(-timeout)), Equals, true) + c.Assert(started.After(time.Now().Add(-20*time.Second)), Equals, true) +} + +func (s *S) TestSocketTimeoutOnInactiveSocket(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + timeout := 2 * time.Second + session.SetSocketTimeout(timeout) + + // Do something that relies on the timeout and works. + c.Assert(session.Ping(), IsNil) + + // Freeze and wait for the timeout to go by. + s.Freeze("localhost:40001") + time.Sleep(timeout + 500*time.Millisecond) + s.Thaw("localhost:40001") + + // Do something again. The timeout above should not have killed + // the socket as there was nothing to be done. + c.Assert(session.Ping(), IsNil) +} + +func (s *S) TestDirect(c *C) { + session, err := mgo.Dial("localhost:40012?connect=direct") + c.Assert(err, IsNil) + defer session.Close() + + // We know that server is a slave. + session.SetMode(mgo.Monotonic, true) + + result := &struct{ Host string }{} + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(strings.HasSuffix(result.Host, ":40012"), Equals, true) + + stats := mgo.GetStats() + c.Assert(stats.SocketsAlive, Equals, 1) + c.Assert(stats.SocketsInUse, Equals, 1) + c.Assert(stats.SocketRefs, Equals, 1) + + // We've got no master, so it'll timeout. + session.SetSyncTimeout(5e8 * time.Nanosecond) + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"test": 1}) + c.Assert(err, ErrorMatches, "no reachable servers") + + // Writing to the local database is okay. + coll = session.DB("local").C("mycoll") + defer coll.RemoveAll(nil) + id := bson.NewObjectId() + err = coll.Insert(M{"_id": id}) + c.Assert(err, IsNil) + + // Data was stored in the right server. + n, err := coll.Find(M{"_id": id}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 1) + + // Server hasn't changed. + result.Host = "" + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(strings.HasSuffix(result.Host, ":40012"), Equals, true) +} + +func (s *S) TestDirectToUnknownStateMember(c *C) { + session, err := mgo.Dial("localhost:40041?connect=direct") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, true) + + result := &struct{ Host string }{} + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(strings.HasSuffix(result.Host, ":40041"), Equals, true) + + // We've got no master, so it'll timeout. + session.SetSyncTimeout(5e8 * time.Nanosecond) + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"test": 1}) + c.Assert(err, ErrorMatches, "no reachable servers") + + // Slave is still reachable. + result.Host = "" + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(strings.HasSuffix(result.Host, ":40041"), Equals, true) +} + +func (s *S) TestFailFast(c *C) { + info := mgo.DialInfo{ + Addrs: []string{"localhost:99999"}, + Timeout: 5 * time.Second, + FailFast: true, + } + + started := time.Now() + + _, err := mgo.DialWithInfo(&info) + c.Assert(err, ErrorMatches, "no reachable servers") + + c.Assert(started.After(time.Now().Add(-time.Second)), Equals, true) +} + +type OpCounters struct { + Insert int + Query int + Update int + Delete int + GetMore int + Command int +} + +func getOpCounters(server string) (c *OpCounters, err error) { + session, err := mgo.Dial(server + "?connect=direct") + if err != nil { + return nil, err + } + defer session.Close() + session.SetMode(mgo.Monotonic, true) + result := struct{ OpCounters }{} + err = session.Run("serverStatus", &result) + return &result.OpCounters, err +} + +func (s *S) TestMonotonicSlaveOkFlagWithMongos(c *C) { + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + ssresult := &struct{ Host string }{} + imresult := &struct{ IsMaster bool }{} + + // Figure the master while still using the strong session. + err = session.Run("serverStatus", ssresult) + c.Assert(err, IsNil) + err = session.Run("isMaster", imresult) + c.Assert(err, IsNil) + master := ssresult.Host + c.Assert(imresult.IsMaster, Equals, true, Commentf("%s is not the master", master)) + + // Collect op counters for everyone. + opc21a, err := getOpCounters("localhost:40021") + c.Assert(err, IsNil) + opc22a, err := getOpCounters("localhost:40022") + c.Assert(err, IsNil) + opc23a, err := getOpCounters("localhost:40023") + c.Assert(err, IsNil) + + // Do a SlaveOk query through MongoS + + mongos, err := mgo.Dial("localhost:40202") + c.Assert(err, IsNil) + defer mongos.Close() + + mongos.SetMode(mgo.Monotonic, true) + + coll := mongos.DB("mydb").C("mycoll") + result := &struct{}{} + for i := 0; i != 5; i++ { + err := coll.Find(nil).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) + } + + // Collect op counters for everyone again. + opc21b, err := getOpCounters("localhost:40021") + c.Assert(err, IsNil) + opc22b, err := getOpCounters("localhost:40022") + c.Assert(err, IsNil) + opc23b, err := getOpCounters("localhost:40023") + c.Assert(err, IsNil) + + masterPort := master[strings.Index(master, ":")+1:] + + var masterDelta, slaveDelta int + switch masterPort { + case "40021": + masterDelta = opc21b.Query - opc21a.Query + slaveDelta = (opc22b.Query - opc22a.Query) + (opc23b.Query - opc23a.Query) + case "40022": + masterDelta = opc22b.Query - opc22a.Query + slaveDelta = (opc21b.Query - opc21a.Query) + (opc23b.Query - opc23a.Query) + case "40023": + masterDelta = opc23b.Query - opc23a.Query + slaveDelta = (opc21b.Query - opc21a.Query) + (opc22b.Query - opc22a.Query) + default: + c.Fatal("Uh?") + } + + c.Check(masterDelta, Equals, 0) // Just the counting itself. + c.Check(slaveDelta, Equals, 5) // The counting for both, plus 5 queries above. +} + +func (s *S) TestRemovalOfClusterMember(c *C) { + if *fast { + c.Skip("-fast") + } + + master, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer master.Close() + + // Wait for cluster to fully sync up. + for i := 0; i < 10; i++ { + if len(master.LiveServers()) == 3 { + break + } + time.Sleep(5e8) + } + if len(master.LiveServers()) != 3 { + c.Fatalf("Test started with bad cluster state: %v", master.LiveServers()) + } + + result := &struct { + IsMaster bool + Me string + }{} + slave := master.Copy() + slave.SetMode(mgo.Monotonic, true) // Monotonic can hold a non-master socket persistently. + err = slave.Run("isMaster", result) + c.Assert(err, IsNil) + c.Assert(result.IsMaster, Equals, false) + slaveAddr := result.Me + + defer func() { + master.Refresh() + master.Run(bson.D{{"$eval", `rs.add("` + slaveAddr + `")`}}, nil) + master.Close() + slave.Close() + }() + + c.Logf("========== Removing slave: %s ==========", slaveAddr) + + master.Run(bson.D{{"$eval", `rs.remove("` + slaveAddr + `")`}}, nil) + err = master.Ping() + c.Assert(err, Equals, io.EOF) + + master.Refresh() + + // Give the cluster a moment to catch up by doing a roundtrip to the master. + err = master.Ping() + c.Assert(err, IsNil) + + time.Sleep(3e9) + + // This must fail since the slave has been taken off the cluster. + err = slave.Ping() + c.Assert(err, NotNil) + + for i := 0; i < 15; i++ { + if len(master.LiveServers()) == 2 { + break + } + time.Sleep(time.Second) + } + live := master.LiveServers() + if len(live) != 2 { + c.Errorf("Removed server still considered live: %#s", live) + } + + c.Log("========== Test succeeded. ==========") +} + +func (s *S) TestPoolLimitSimple(c *C) { + for test := 0; test < 2; test++ { + var session *mgo.Session + var err error + if test == 0 { + session, err = mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + session.SetPoolLimit(1) + } else { + session, err = mgo.Dial("localhost:40001?maxPoolSize=1") + c.Assert(err, IsNil) + } + defer session.Close() + + // Put one socket in use. + c.Assert(session.Ping(), IsNil) + + done := make(chan time.Duration) + + // Now block trying to get another one due to the pool limit. + go func() { + copy := session.Copy() + defer copy.Close() + started := time.Now() + c.Check(copy.Ping(), IsNil) + done <- time.Now().Sub(started) + }() + + time.Sleep(300 * time.Millisecond) + + // Put the one socket back in the pool, freeing it for the copy. + session.Refresh() + delay := <-done + c.Assert(delay > 300*time.Millisecond, Equals, true, Commentf("Delay: %s", delay)) + } +} + +func (s *S) TestPoolLimitMany(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + stats := mgo.GetStats() + for stats.MasterConns+stats.SlaveConns != 3 { + stats = mgo.GetStats() + c.Log("Waiting for all connections to be established...") + time.Sleep(500 * time.Millisecond) + } + c.Assert(stats.SocketsAlive, Equals, 3) + + const poolLimit = 64 + session.SetPoolLimit(poolLimit) + + // Consume the whole limit for the master. + var master []*mgo.Session + for i := 0; i < poolLimit; i++ { + s := session.Copy() + defer s.Close() + c.Assert(s.Ping(), IsNil) + master = append(master, s) + } + + before := time.Now() + go func() { + time.Sleep(3e9) + master[0].Refresh() + }() + + // Then, a single ping must block, since it would need another + // connection to the master, over the limit. Once the goroutine + // above releases its socket, it should move on. + session.Ping() + delay := time.Now().Sub(before) + c.Assert(delay > 3e9, Equals, true) + c.Assert(delay < 6e9, Equals, true) +} + +func (s *S) TestSetModeEventualIterBug(c *C) { + session1, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session1.Close() + + session1.SetMode(mgo.Eventual, false) + + coll1 := session1.DB("mydb").C("mycoll") + + const N = 100 + for i := 0; i < N; i++ { + err = coll1.Insert(M{"_id": i}) + c.Assert(err, IsNil) + } + + c.Logf("Waiting until secondary syncs") + for { + n, err := coll1.Count() + c.Assert(err, IsNil) + if n == N { + c.Logf("Found all") + break + } + } + + session2, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session2.Close() + + session2.SetMode(mgo.Eventual, false) + + coll2 := session2.DB("mydb").C("mycoll") + + i := 0 + iter := coll2.Find(nil).Batch(10).Iter() + var result struct{} + for iter.Next(&result) { + i++ + } + c.Assert(iter.Close(), Equals, nil) + c.Assert(i, Equals, N) +} + +func (s *S) TestCustomDialOld(c *C) { + dials := make(chan bool, 16) + dial := func(addr net.Addr) (net.Conn, error) { + tcpaddr, ok := addr.(*net.TCPAddr) + if !ok { + return nil, fmt.Errorf("unexpected address type: %T", addr) + } + dials <- true + return net.DialTCP("tcp", nil, tcpaddr) + } + info := mgo.DialInfo{ + Addrs: []string{"localhost:40012"}, + Dial: dial, + } + + // Use hostname here rather than IP, to make things trickier. + session, err := mgo.DialWithInfo(&info) + c.Assert(err, IsNil) + defer session.Close() + + const N = 3 + for i := 0; i < N; i++ { + select { + case <-dials: + case <-time.After(5 * time.Second): + c.Fatalf("expected %d dials, got %d", N, i) + } + } + select { + case <-dials: + c.Fatalf("got more dials than expected") + case <-time.After(100 * time.Millisecond): + } +} + +func (s *S) TestCustomDialNew(c *C) { + dials := make(chan bool, 16) + dial := func(addr *mgo.ServerAddr) (net.Conn, error) { + dials <- true + if addr.TCPAddr().Port == 40012 { + c.Check(addr.String(), Equals, "localhost:40012") + } + return net.DialTCP("tcp", nil, addr.TCPAddr()) + } + info := mgo.DialInfo{ + Addrs: []string{"localhost:40012"}, + DialServer: dial, + } + + // Use hostname here rather than IP, to make things trickier. + session, err := mgo.DialWithInfo(&info) + c.Assert(err, IsNil) + defer session.Close() + + const N = 3 + for i := 0; i < N; i++ { + select { + case <-dials: + case <-time.After(5 * time.Second): + c.Fatalf("expected %d dials, got %d", N, i) + } + } + select { + case <-dials: + c.Fatalf("got more dials than expected") + case <-time.After(100 * time.Millisecond): + } +} + +func (s *S) TestPrimaryShutdownOnAuthShard(c *C) { + if *fast { + c.Skip("-fast") + } + + // Dial the shard. + session, err := mgo.Dial("localhost:40203") + c.Assert(err, IsNil) + defer session.Close() + + // Login and insert something to make it more realistic. + session.DB("admin").Login("root", "rapadura") + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(bson.M{"n": 1}) + c.Assert(err, IsNil) + + // Dial the replica set to figure the master out. + rs, err := mgo.Dial("root:rapadura@localhost:40031") + c.Assert(err, IsNil) + defer rs.Close() + + // With strong consistency, this will open a socket to the master. + result := &struct{ Host string }{} + err = rs.Run("serverStatus", result) + c.Assert(err, IsNil) + + // Kill the master. + host := result.Host + s.Stop(host) + + // This must fail, since the connection was broken. + err = rs.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + // This won't work because the master just died. + err = coll.Insert(bson.M{"n": 2}) + c.Assert(err, NotNil) + + // Refresh session and wait for re-election. + session.Refresh() + for i := 0; i < 60; i++ { + err = coll.Insert(bson.M{"n": 3}) + if err == nil { + break + } + c.Logf("Waiting for replica set to elect a new master. Last error: %v", err) + time.Sleep(500 * time.Millisecond) + } + c.Assert(err, IsNil) + + count, err := coll.Count() + c.Assert(count > 1, Equals, true) +} + +func (s *S) TestNearestSecondary(c *C) { + defer mgo.HackPingDelay(3 * time.Second)() + + rs1a := "127.0.0.1:40011" + rs1b := "127.0.0.1:40012" + rs1c := "127.0.0.1:40013" + s.Freeze(rs1b) + + session, err := mgo.Dial(rs1a) + c.Assert(err, IsNil) + defer session.Close() + + // Wait for the sync up to run through the first couple of servers. + for len(session.LiveServers()) != 2 { + c.Log("Waiting for two servers to be alive...") + time.Sleep(100 * time.Millisecond) + } + + // Extra delay to ensure the third server gets penalized. + time.Sleep(500 * time.Millisecond) + + // Release third server. + s.Thaw(rs1b) + + // Wait for it to come up. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for all servers to be alive...") + time.Sleep(100 * time.Millisecond) + } + + session.SetMode(mgo.Monotonic, true) + var result struct{ Host string } + + // See which slave picks the line, several times to avoid chance. + for i := 0; i < 10; i++ { + session.Refresh() + err = session.Run("serverStatus", &result) + c.Assert(err, IsNil) + c.Assert(hostPort(result.Host), Equals, hostPort(rs1c)) + } + + if *fast { + // Don't hold back for several seconds. + return + } + + // Now hold the other server for long enough to penalize it. + s.Freeze(rs1c) + time.Sleep(5 * time.Second) + s.Thaw(rs1c) + + // Wait for the ping to be processed. + time.Sleep(500 * time.Millisecond) + + // Repeating the test should now pick the former server consistently. + for i := 0; i < 10; i++ { + session.Refresh() + err = session.Run("serverStatus", &result) + c.Assert(err, IsNil) + c.Assert(hostPort(result.Host), Equals, hostPort(rs1b)) + } +} + +func (s *S) TestConnectCloseConcurrency(c *C) { + restore := mgo.HackPingDelay(500 * time.Millisecond) + defer restore() + var wg sync.WaitGroup + const n = 500 + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + session, err := mgo.Dial("localhost:40001") + if err != nil { + c.Fatal(err) + } + time.Sleep(1) + session.Close() + }() + } + wg.Wait() +} + +func (s *S) TestSelectServers(c *C) { + if !s.versionAtLeast(2, 2) { + c.Skip("read preferences introduced in 2.2") + } + + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Eventual, true) + + var result struct{ Host string } + + session.Refresh() + session.SelectServers(bson.D{{"rs1", "b"}}) + err = session.Run("serverStatus", &result) + c.Assert(err, IsNil) + c.Assert(hostPort(result.Host), Equals, "40012") + + session.Refresh() + session.SelectServers(bson.D{{"rs1", "c"}}) + err = session.Run("serverStatus", &result) + c.Assert(err, IsNil) + c.Assert(hostPort(result.Host), Equals, "40013") +} + +func (s *S) TestSelectServersWithMongos(c *C) { + if !s.versionAtLeast(2, 2) { + c.Skip("read preferences introduced in 2.2") + } + + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + ssresult := &struct{ Host string }{} + imresult := &struct{ IsMaster bool }{} + + // Figure the master while still using the strong session. + err = session.Run("serverStatus", ssresult) + c.Assert(err, IsNil) + err = session.Run("isMaster", imresult) + c.Assert(err, IsNil) + master := ssresult.Host + c.Assert(imresult.IsMaster, Equals, true, Commentf("%s is not the master", master)) + + var slave1, slave2 string + switch hostPort(master) { + case "40021": + slave1, slave2 = "b", "c" + case "40022": + slave1, slave2 = "a", "c" + case "40023": + slave1, slave2 = "a", "b" + } + + // Collect op counters for everyone. + opc21a, err := getOpCounters("localhost:40021") + c.Assert(err, IsNil) + opc22a, err := getOpCounters("localhost:40022") + c.Assert(err, IsNil) + opc23a, err := getOpCounters("localhost:40023") + c.Assert(err, IsNil) + + // Do a SlaveOk query through MongoS + mongos, err := mgo.Dial("localhost:40202") + c.Assert(err, IsNil) + defer mongos.Close() + + mongos.SetMode(mgo.Monotonic, true) + + mongos.Refresh() + mongos.SelectServers(bson.D{{"rs2", slave1}}) + coll := mongos.DB("mydb").C("mycoll") + result := &struct{}{} + for i := 0; i != 5; i++ { + err := coll.Find(nil).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) + } + + mongos.Refresh() + mongos.SelectServers(bson.D{{"rs2", slave2}}) + coll = mongos.DB("mydb").C("mycoll") + for i := 0; i != 7; i++ { + err := coll.Find(nil).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) + } + + // Collect op counters for everyone again. + opc21b, err := getOpCounters("localhost:40021") + c.Assert(err, IsNil) + opc22b, err := getOpCounters("localhost:40022") + c.Assert(err, IsNil) + opc23b, err := getOpCounters("localhost:40023") + c.Assert(err, IsNil) + + switch hostPort(master) { + case "40021": + c.Check(opc21b.Query-opc21a.Query, Equals, 0) + c.Check(opc22b.Query-opc22a.Query, Equals, 5) + c.Check(opc23b.Query-opc23a.Query, Equals, 7) + case "40022": + c.Check(opc21b.Query-opc21a.Query, Equals, 5) + c.Check(opc22b.Query-opc22a.Query, Equals, 0) + c.Check(opc23b.Query-opc23a.Query, Equals, 7) + case "40023": + c.Check(opc21b.Query-opc21a.Query, Equals, 5) + c.Check(opc22b.Query-opc22a.Query, Equals, 7) + c.Check(opc23b.Query-opc23a.Query, Equals, 0) + default: + c.Fatal("Uh?") + } +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/doc.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/doc.go new file mode 100644 index 0000000..9316c55 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/doc.go @@ -0,0 +1,31 @@ +// Package mgo offers a rich MongoDB driver for Go. +// +// Details about the mgo project (pronounced as "mango") are found +// in its web page: +// +// http://labix.org/mgo +// +// Usage of the driver revolves around the concept of sessions. To +// get started, obtain a session using the Dial function: +// +// session, err := mgo.Dial(url) +// +// This will establish one or more connections with the cluster of +// servers defined by the url parameter. From then on, the cluster +// may be queried with multiple consistency rules (see SetMode) and +// documents retrieved with statements such as: +// +// c := session.DB(database).C(collection) +// err := c.Find(query).One(&result) +// +// New sessions are typically created by calling session.Copy on the +// initial session obtained at dial time. These new sessions will share +// the same cluster information and connection cache, and may be easily +// handed into other methods and functions for organizing logic. +// Every session created must have its Close method called at the end +// of its life time, so its resources may be put back in the pool or +// collected, depending on the case. +// +// For more details, see the documentation for the types and methods. +// +package mgo diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/export_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/export_test.go new file mode 100644 index 0000000..690f84d --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/export_test.go @@ -0,0 +1,33 @@ +package mgo + +import ( + "time" +) + +func HackPingDelay(newDelay time.Duration) (restore func()) { + globalMutex.Lock() + defer globalMutex.Unlock() + + oldDelay := pingDelay + restore = func() { + globalMutex.Lock() + pingDelay = oldDelay + globalMutex.Unlock() + } + pingDelay = newDelay + return +} + +func HackSyncSocketTimeout(newTimeout time.Duration) (restore func()) { + globalMutex.Lock() + defer globalMutex.Unlock() + + oldTimeout := syncSocketTimeout + restore = func() { + globalMutex.Lock() + syncSocketTimeout = oldTimeout + globalMutex.Unlock() + } + syncSocketTimeout = newTimeout + return +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/gridfs.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/gridfs.go new file mode 100644 index 0000000..3439462 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/gridfs.go @@ -0,0 +1,754 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "crypto/md5" + "encoding/hex" + "errors" + "hash" + "io" + "os" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +type GridFS struct { + Files *Collection + Chunks *Collection +} + +type gfsFileMode int + +const ( + gfsClosed gfsFileMode = 0 + gfsReading gfsFileMode = 1 + gfsWriting gfsFileMode = 2 +) + +type GridFile struct { + m sync.Mutex + c sync.Cond + gfs *GridFS + mode gfsFileMode + err error + + chunk int + offset int64 + + wpending int + wbuf []byte + wsum hash.Hash + + rbuf []byte + rcache *gfsCachedChunk + + doc gfsFile +} + +type gfsFile struct { + Id interface{} "_id" + ChunkSize int "chunkSize" + UploadDate time.Time "uploadDate" + Length int64 ",minsize" + MD5 string + Filename string ",omitempty" + ContentType string "contentType,omitempty" + Metadata *bson.Raw ",omitempty" +} + +type gfsChunk struct { + Id interface{} "_id" + FilesId interface{} "files_id" + N int + Data []byte +} + +type gfsCachedChunk struct { + wait sync.Mutex + n int + data []byte + err error +} + +func newGridFS(db *Database, prefix string) *GridFS { + return &GridFS{db.C(prefix + ".files"), db.C(prefix + ".chunks")} +} + +func (gfs *GridFS) newFile() *GridFile { + file := &GridFile{gfs: gfs} + file.c.L = &file.m + //runtime.SetFinalizer(file, finalizeFile) + return file +} + +func finalizeFile(file *GridFile) { + file.Close() +} + +// Create creates a new file with the provided name in the GridFS. If the file +// name already exists, a new version will be inserted with an up-to-date +// uploadDate that will cause it to be atomically visible to the Open and +// OpenId methods. If the file name is not important, an empty name may be +// provided and the file Id used instead. +// +// It's important to Close files whether they are being written to +// or read from, and to check the err result to ensure the operation +// completed successfully. +// +// A simple example inserting a new file: +// +// func check(err error) { +// if err != nil { +// panic(err.String()) +// } +// } +// file, err := db.GridFS("fs").Create("myfile.txt") +// check(err) +// n, err := file.Write([]byte("Hello world!") +// check(err) +// err = file.Close() +// check(err) +// fmt.Printf("%d bytes written\n", n) +// +// The io.Writer interface is implemented by *GridFile and may be used to +// help on the file creation. For example: +// +// file, err := db.GridFS("fs").Create("myfile.txt") +// check(err) +// messages, err := os.Open("/var/log/messages") +// check(err) +// defer messages.Close() +// err = io.Copy(file, messages) +// check(err) +// err = file.Close() +// check(err) +// +func (gfs *GridFS) Create(name string) (file *GridFile, err error) { + file = gfs.newFile() + file.mode = gfsWriting + file.wsum = md5.New() + file.doc = gfsFile{Id: bson.NewObjectId(), ChunkSize: 255 * 1024, Filename: name} + return +} + +// OpenId returns the file with the provided id, for reading. +// If the file isn't found, err will be set to mgo.ErrNotFound. +// +// It's important to Close files whether they are being written to +// or read from, and to check the err result to ensure the operation +// completed successfully. +// +// The following example will print the first 8192 bytes from the file: +// +// func check(err error) { +// if err != nil { +// panic(err.String()) +// } +// } +// file, err := db.GridFS("fs").OpenId(objid) +// check(err) +// b := make([]byte, 8192) +// n, err := file.Read(b) +// check(err) +// fmt.Println(string(b)) +// check(err) +// err = file.Close() +// check(err) +// fmt.Printf("%d bytes read\n", n) +// +// The io.Reader interface is implemented by *GridFile and may be used to +// deal with it. As an example, the following snippet will dump the whole +// file into the standard output: +// +// file, err := db.GridFS("fs").OpenId(objid) +// check(err) +// err = io.Copy(os.Stdout, file) +// check(err) +// err = file.Close() +// check(err) +// +func (gfs *GridFS) OpenId(id interface{}) (file *GridFile, err error) { + var doc gfsFile + err = gfs.Files.Find(bson.M{"_id": id}).One(&doc) + if err != nil { + return + } + file = gfs.newFile() + file.mode = gfsReading + file.doc = doc + return +} + +// Open returns the most recently uploaded file with the provided +// name, for reading. If the file isn't found, err will be set +// to mgo.ErrNotFound. +// +// It's important to Close files whether they are being written to +// or read from, and to check the err result to ensure the operation +// completed successfully. +// +// The following example will print the first 8192 bytes from the file: +// +// file, err := db.GridFS("fs").Open("myfile.txt") +// check(err) +// b := make([]byte, 8192) +// n, err := file.Read(b) +// check(err) +// fmt.Println(string(b)) +// check(err) +// err = file.Close() +// check(err) +// fmt.Printf("%d bytes read\n", n) +// +// The io.Reader interface is implemented by *GridFile and may be used to +// deal with it. As an example, the following snippet will dump the whole +// file into the standard output: +// +// file, err := db.GridFS("fs").Open("myfile.txt") +// check(err) +// err = io.Copy(os.Stdout, file) +// check(err) +// err = file.Close() +// check(err) +// +func (gfs *GridFS) Open(name string) (file *GridFile, err error) { + var doc gfsFile + err = gfs.Files.Find(bson.M{"filename": name}).Sort("-uploadDate").One(&doc) + if err != nil { + return + } + file = gfs.newFile() + file.mode = gfsReading + file.doc = doc + return +} + +// OpenNext opens the next file from iter for reading, sets *file to it, +// and returns true on the success case. If no more documents are available +// on iter or an error occurred, *file is set to nil and the result is false. +// Errors will be available via iter.Err(). +// +// The iter parameter must be an iterator on the GridFS files collection. +// Using the GridFS.Find method is an easy way to obtain such an iterator, +// but any iterator on the collection will work. +// +// If the provided *file is non-nil, OpenNext will close it before attempting +// to iterate to the next element. This means that in a loop one only +// has to worry about closing files when breaking out of the loop early +// (break, return, or panic). +// +// For example: +// +// gfs := db.GridFS("fs") +// query := gfs.Find(nil).Sort("filename") +// iter := query.Iter() +// var f *mgo.GridFile +// for gfs.OpenNext(iter, &f) { +// fmt.Printf("Filename: %s\n", f.Name()) +// } +// if iter.Close() != nil { +// panic(iter.Close()) +// } +// +func (gfs *GridFS) OpenNext(iter *Iter, file **GridFile) bool { + if *file != nil { + // Ignoring the error here shouldn't be a big deal + // as we're reading the file and the loop iteration + // for this file is finished. + _ = (*file).Close() + } + var doc gfsFile + if !iter.Next(&doc) { + *file = nil + return false + } + f := gfs.newFile() + f.mode = gfsReading + f.doc = doc + *file = f + return true +} + +// Find runs query on GridFS's files collection and returns +// the resulting Query. +// +// This logic: +// +// gfs := db.GridFS("fs") +// iter := gfs.Find(nil).Iter() +// +// Is equivalent to: +// +// files := db.C("fs" + ".files") +// iter := files.Find(nil).Iter() +// +func (gfs *GridFS) Find(query interface{}) *Query { + return gfs.Files.Find(query) +} + +// RemoveId deletes the file with the provided id from the GridFS. +func (gfs *GridFS) RemoveId(id interface{}) error { + err := gfs.Files.Remove(bson.M{"_id": id}) + if err != nil { + return err + } + _, err = gfs.Chunks.RemoveAll(bson.D{{"files_id", id}}) + return err +} + +type gfsDocId struct { + Id interface{} "_id" +} + +// Remove deletes all files with the provided name from the GridFS. +func (gfs *GridFS) Remove(name string) (err error) { + iter := gfs.Files.Find(bson.M{"filename": name}).Select(bson.M{"_id": 1}).Iter() + var doc gfsDocId + for iter.Next(&doc) { + if e := gfs.RemoveId(doc.Id); e != nil { + err = e + } + } + if err == nil { + err = iter.Close() + } + return err +} + +func (file *GridFile) assertMode(mode gfsFileMode) { + switch file.mode { + case mode: + return + case gfsWriting: + panic("GridFile is open for writing") + case gfsReading: + panic("GridFile is open for reading") + case gfsClosed: + panic("GridFile is closed") + default: + panic("internal error: missing GridFile mode") + } +} + +// SetChunkSize sets size of saved chunks. Once the file is written to, it +// will be split in blocks of that size and each block saved into an +// independent chunk document. The default chunk size is 256kb. +// +// It is a runtime error to call this function once the file has started +// being written to. +func (file *GridFile) SetChunkSize(bytes int) { + file.assertMode(gfsWriting) + debugf("GridFile %p: setting chunk size to %d", file, bytes) + file.m.Lock() + file.doc.ChunkSize = bytes + file.m.Unlock() +} + +// Id returns the current file Id. +func (file *GridFile) Id() interface{} { + return file.doc.Id +} + +// SetId changes the current file Id. +// +// It is a runtime error to call this function once the file has started +// being written to, or when the file is not open for writing. +func (file *GridFile) SetId(id interface{}) { + file.assertMode(gfsWriting) + file.m.Lock() + file.doc.Id = id + file.m.Unlock() +} + +// Name returns the optional file name. An empty string will be returned +// in case it is unset. +func (file *GridFile) Name() string { + return file.doc.Filename +} + +// SetName changes the optional file name. An empty string may be used to +// unset it. +// +// It is a runtime error to call this function when the file is not open +// for writing. +func (file *GridFile) SetName(name string) { + file.assertMode(gfsWriting) + file.m.Lock() + file.doc.Filename = name + file.m.Unlock() +} + +// ContentType returns the optional file content type. An empty string will be +// returned in case it is unset. +func (file *GridFile) ContentType() string { + return file.doc.ContentType +} + +// ContentType changes the optional file content type. An empty string may be +// used to unset it. +// +// It is a runtime error to call this function when the file is not open +// for writing. +func (file *GridFile) SetContentType(ctype string) { + file.assertMode(gfsWriting) + file.m.Lock() + file.doc.ContentType = ctype + file.m.Unlock() +} + +// GetMeta unmarshals the optional "metadata" field associated with the +// file into the result parameter. The meaning of keys under that field +// is user-defined. For example: +// +// result := struct{ INode int }{} +// err = file.GetMeta(&result) +// if err != nil { +// panic(err.String()) +// } +// fmt.Printf("inode: %d\n", result.INode) +// +func (file *GridFile) GetMeta(result interface{}) (err error) { + file.m.Lock() + if file.doc.Metadata != nil { + err = bson.Unmarshal(file.doc.Metadata.Data, result) + } + file.m.Unlock() + return +} + +// SetMeta changes the optional "metadata" field associated with the +// file. The meaning of keys under that field is user-defined. +// For example: +// +// file.SetMeta(bson.M{"inode": inode}) +// +// It is a runtime error to call this function when the file is not open +// for writing. +func (file *GridFile) SetMeta(metadata interface{}) { + file.assertMode(gfsWriting) + data, err := bson.Marshal(metadata) + file.m.Lock() + if err != nil && file.err == nil { + file.err = err + } else { + file.doc.Metadata = &bson.Raw{Data: data} + } + file.m.Unlock() +} + +// Size returns the file size in bytes. +func (file *GridFile) Size() (bytes int64) { + file.m.Lock() + bytes = file.doc.Length + file.m.Unlock() + return +} + +// MD5 returns the file MD5 as a hex-encoded string. +func (file *GridFile) MD5() (md5 string) { + return file.doc.MD5 +} + +// UploadDate returns the file upload time. +func (file *GridFile) UploadDate() time.Time { + return file.doc.UploadDate +} + +// SetUploadDate changes the file upload time. +// +// It is a runtime error to call this function when the file is not open +// for writing. +func (file *GridFile) SetUploadDate(t time.Time) { + file.assertMode(gfsWriting) + file.m.Lock() + file.doc.UploadDate = t + file.m.Unlock() +} + +// Close flushes any pending changes in case the file is being written +// to, waits for any background operations to finish, and closes the file. +// +// It's important to Close files whether they are being written to +// or read from, and to check the err result to ensure the operation +// completed successfully. +func (file *GridFile) Close() (err error) { + file.m.Lock() + defer file.m.Unlock() + if file.mode == gfsWriting { + if len(file.wbuf) > 0 && file.err == nil { + file.insertChunk(file.wbuf) + file.wbuf = file.wbuf[0:0] + } + file.completeWrite() + } else if file.mode == gfsReading && file.rcache != nil { + file.rcache.wait.Lock() + file.rcache = nil + } + file.mode = gfsClosed + debugf("GridFile %p: closed", file) + return file.err +} + +func (file *GridFile) completeWrite() { + for file.wpending > 0 { + debugf("GridFile %p: waiting for %d pending chunks to complete file write", file, file.wpending) + file.c.Wait() + } + if file.err != nil { + file.gfs.Chunks.RemoveAll(bson.D{{"files_id", file.doc.Id}}) + return + } + hexsum := hex.EncodeToString(file.wsum.Sum(nil)) + if file.doc.UploadDate.IsZero() { + file.doc.UploadDate = bson.Now() + } + file.doc.MD5 = hexsum + file.err = file.gfs.Files.Insert(file.doc) + file.gfs.Chunks.EnsureIndexKey("files_id", "n") +} + +// Abort cancels an in-progress write, preventing the file from being +// automically created and ensuring previously written chunks are +// removed when the file is closed. +// +// It is a runtime error to call Abort when the file was not opened +// for writing. +func (file *GridFile) Abort() { + if file.mode != gfsWriting { + panic("file.Abort must be called on file opened for writing") + } + file.err = errors.New("write aborted") +} + +// Write writes the provided data to the file and returns the +// number of bytes written and an error in case something +// wrong happened. +// +// The file will internally cache the data so that all but the last +// chunk sent to the database have the size defined by SetChunkSize. +// This also means that errors may be deferred until a future call +// to Write or Close. +// +// The parameters and behavior of this function turn the file +// into an io.Writer. +func (file *GridFile) Write(data []byte) (n int, err error) { + file.assertMode(gfsWriting) + file.m.Lock() + debugf("GridFile %p: writing %d bytes", file, len(data)) + defer file.m.Unlock() + + if file.err != nil { + return 0, file.err + } + + n = len(data) + file.doc.Length += int64(n) + chunkSize := file.doc.ChunkSize + + if len(file.wbuf)+len(data) < chunkSize { + file.wbuf = append(file.wbuf, data...) + return + } + + // First, flush file.wbuf complementing with data. + if len(file.wbuf) > 0 { + missing := chunkSize - len(file.wbuf) + if missing > len(data) { + missing = len(data) + } + file.wbuf = append(file.wbuf, data[:missing]...) + data = data[missing:] + file.insertChunk(file.wbuf) + file.wbuf = file.wbuf[0:0] + } + + // Then, flush all chunks from data without copying. + for len(data) > chunkSize { + size := chunkSize + if size > len(data) { + size = len(data) + } + file.insertChunk(data[:size]) + data = data[size:] + } + + // And append the rest for a future call. + file.wbuf = append(file.wbuf, data...) + + return n, file.err +} + +func (file *GridFile) insertChunk(data []byte) { + n := file.chunk + file.chunk++ + debugf("GridFile %p: adding to checksum: %q", file, string(data)) + file.wsum.Write(data) + + for file.doc.ChunkSize*file.wpending >= 1024*1024 { + // Hold on.. we got a MB pending. + file.c.Wait() + if file.err != nil { + return + } + } + + file.wpending++ + + debugf("GridFile %p: inserting chunk %d with %d bytes", file, n, len(data)) + + // We may not own the memory of data, so rather than + // simply copying it, we'll marshal the document ahead of time. + data, err := bson.Marshal(gfsChunk{bson.NewObjectId(), file.doc.Id, n, data}) + if err != nil { + file.err = err + return + } + + go func() { + err := file.gfs.Chunks.Insert(bson.Raw{Data: data}) + file.m.Lock() + file.wpending-- + if err != nil && file.err == nil { + file.err = err + } + file.c.Broadcast() + file.m.Unlock() + }() +} + +// Seek sets the offset for the next Read or Write on file to +// offset, interpreted according to whence: 0 means relative to +// the origin of the file, 1 means relative to the current offset, +// and 2 means relative to the end. It returns the new offset and +// an error, if any. +func (file *GridFile) Seek(offset int64, whence int) (pos int64, err error) { + file.m.Lock() + debugf("GridFile %p: seeking for %s (whence=%d)", file, offset, whence) + defer file.m.Unlock() + switch whence { + case os.SEEK_SET: + case os.SEEK_CUR: + offset += file.offset + case os.SEEK_END: + offset += file.doc.Length + default: + panic("unsupported whence value") + } + if offset > file.doc.Length { + return file.offset, errors.New("seek past end of file") + } + if offset == file.doc.Length { + // If we're seeking to the end of the file, + // no need to read anything. This enables + // a client to find the size of the file using only the + // io.ReadSeeker interface with low overhead. + file.offset = offset + return file.offset, nil + } + chunk := int(offset / int64(file.doc.ChunkSize)) + if chunk+1 == file.chunk && offset >= file.offset { + file.rbuf = file.rbuf[int(offset-file.offset):] + file.offset = offset + return file.offset, nil + } + file.offset = offset + file.chunk = chunk + file.rbuf = nil + file.rbuf, err = file.getChunk() + if err == nil { + file.rbuf = file.rbuf[int(file.offset-int64(chunk)*int64(file.doc.ChunkSize)):] + } + return file.offset, err +} + +// Read reads into b the next available data from the file and +// returns the number of bytes written and an error in case +// something wrong happened. At the end of the file, n will +// be zero and err will be set to os.EOF. +// +// The parameters and behavior of this function turn the file +// into an io.Reader. +func (file *GridFile) Read(b []byte) (n int, err error) { + file.assertMode(gfsReading) + file.m.Lock() + debugf("GridFile %p: reading at offset %d into buffer of length %d", file, file.offset, len(b)) + defer file.m.Unlock() + if file.offset == file.doc.Length { + return 0, io.EOF + } + for err == nil { + i := copy(b, file.rbuf) + n += i + file.offset += int64(i) + file.rbuf = file.rbuf[i:] + if i == len(b) || file.offset == file.doc.Length { + break + } + b = b[i:] + file.rbuf, err = file.getChunk() + } + return n, err +} + +func (file *GridFile) getChunk() (data []byte, err error) { + cache := file.rcache + file.rcache = nil + if cache != nil && cache.n == file.chunk { + debugf("GridFile %p: Getting chunk %d from cache", file, file.chunk) + cache.wait.Lock() + data, err = cache.data, cache.err + } else { + debugf("GridFile %p: Fetching chunk %d", file, file.chunk) + var doc gfsChunk + err = file.gfs.Chunks.Find(bson.D{{"files_id", file.doc.Id}, {"n", file.chunk}}).One(&doc) + data = doc.Data + } + file.chunk++ + if int64(file.chunk)*int64(file.doc.ChunkSize) < file.doc.Length { + // Read the next one in background. + cache = &gfsCachedChunk{n: file.chunk} + cache.wait.Lock() + debugf("GridFile %p: Scheduling chunk %d for background caching", file, file.chunk) + // Clone the session to avoid having it closed in between. + chunks := file.gfs.Chunks + session := chunks.Database.Session.Clone() + go func(id interface{}, n int) { + defer session.Close() + chunks = chunks.With(session) + var doc gfsChunk + cache.err = chunks.Find(bson.D{{"files_id", id}, {"n", n}}).One(&doc) + cache.data = doc.Data + cache.wait.Unlock() + }(file.doc.Id, file.chunk) + file.rcache = cache + } + debugf("Returning err: %#v", err) + return +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/gridfs_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/gridfs_test.go new file mode 100644 index 0000000..9afd245 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/gridfs_test.go @@ -0,0 +1,680 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo_test + +import ( + "io" + "os" + "time" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +func (s *S) TestGridFSCreate(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + before := bson.Now() + + gfs := db.GridFS("fs") + file, err := gfs.Create("") + c.Assert(err, IsNil) + + n, err := file.Write([]byte("some data")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 9) + + err = file.Close() + c.Assert(err, IsNil) + + after := bson.Now() + + // Check the file information. + result := M{} + err = db.C("fs.files").Find(nil).One(result) + c.Assert(err, IsNil) + + fileId, ok := result["_id"].(bson.ObjectId) + c.Assert(ok, Equals, true) + c.Assert(fileId.Valid(), Equals, true) + result["_id"] = "" + + ud, ok := result["uploadDate"].(time.Time) + c.Assert(ok, Equals, true) + c.Assert(ud.After(before) && ud.Before(after), Equals, true) + result["uploadDate"] = "" + + expected := M{ + "_id": "", + "length": 9, + "chunkSize": 255 * 1024, + "uploadDate": "", + "md5": "1e50210a0202497fb79bc38b6ade6c34", + } + c.Assert(result, DeepEquals, expected) + + // Check the chunk. + result = M{} + err = db.C("fs.chunks").Find(nil).One(result) + c.Assert(err, IsNil) + + chunkId, ok := result["_id"].(bson.ObjectId) + c.Assert(ok, Equals, true) + c.Assert(chunkId.Valid(), Equals, true) + result["_id"] = "" + + expected = M{ + "_id": "", + "files_id": fileId, + "n": 0, + "data": []byte("some data"), + } + c.Assert(result, DeepEquals, expected) + + // Check that an index was created. + indexes, err := db.C("fs.chunks").Indexes() + c.Assert(err, IsNil) + c.Assert(len(indexes), Equals, 2) + c.Assert(indexes[1].Key, DeepEquals, []string{"files_id", "n"}) +} + +func (s *S) TestGridFSFileDetails(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + + file, err := gfs.Create("myfile1.txt") + c.Assert(err, IsNil) + + n, err := file.Write([]byte("some")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 4) + + c.Assert(file.Size(), Equals, int64(4)) + + n, err = file.Write([]byte(" data")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 5) + + c.Assert(file.Size(), Equals, int64(9)) + + id, _ := file.Id().(bson.ObjectId) + c.Assert(id.Valid(), Equals, true) + c.Assert(file.Name(), Equals, "myfile1.txt") + c.Assert(file.ContentType(), Equals, "") + + var info interface{} + err = file.GetMeta(&info) + c.Assert(err, IsNil) + c.Assert(info, IsNil) + + file.SetId("myid") + file.SetName("myfile2.txt") + file.SetContentType("text/plain") + file.SetMeta(M{"any": "thing"}) + + c.Assert(file.Id(), Equals, "myid") + c.Assert(file.Name(), Equals, "myfile2.txt") + c.Assert(file.ContentType(), Equals, "text/plain") + + err = file.GetMeta(&info) + c.Assert(err, IsNil) + c.Assert(info, DeepEquals, bson.M{"any": "thing"}) + + err = file.Close() + c.Assert(err, IsNil) + + c.Assert(file.MD5(), Equals, "1e50210a0202497fb79bc38b6ade6c34") + + ud := file.UploadDate() + now := time.Now() + c.Assert(ud.Before(now), Equals, true) + c.Assert(ud.After(now.Add(-3*time.Second)), Equals, true) + + result := M{} + err = db.C("fs.files").Find(nil).One(result) + c.Assert(err, IsNil) + + result["uploadDate"] = "" + + expected := M{ + "_id": "myid", + "length": 9, + "chunkSize": 255 * 1024, + "uploadDate": "", + "md5": "1e50210a0202497fb79bc38b6ade6c34", + "filename": "myfile2.txt", + "contentType": "text/plain", + "metadata": M{"any": "thing"}, + } + c.Assert(result, DeepEquals, expected) +} + +func (s *S) TestGridFSSetUploadDate(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + file, err := gfs.Create("") + c.Assert(err, IsNil) + + t := time.Date(2014, 1, 1, 1, 1, 1, 0, time.Local) + file.SetUploadDate(t) + + err = file.Close() + c.Assert(err, IsNil) + + // Check the file information. + result := M{} + err = db.C("fs.files").Find(nil).One(result) + c.Assert(err, IsNil) + + ud := result["uploadDate"].(time.Time) + if !ud.Equal(t) { + c.Fatalf("want upload date %s, got %s", t, ud) + } +} + +func (s *S) TestGridFSCreateWithChunking(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + + file, err := gfs.Create("") + c.Assert(err, IsNil) + + file.SetChunkSize(5) + + // Smaller than the chunk size. + n, err := file.Write([]byte("abc")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) + + // Boundary in the middle. + n, err = file.Write([]byte("defg")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 4) + + // Boundary at the end. + n, err = file.Write([]byte("hij")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) + + // Larger than the chunk size, with 3 chunks. + n, err = file.Write([]byte("klmnopqrstuv")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 12) + + err = file.Close() + c.Assert(err, IsNil) + + // Check the file information. + result := M{} + err = db.C("fs.files").Find(nil).One(result) + c.Assert(err, IsNil) + + fileId, _ := result["_id"].(bson.ObjectId) + c.Assert(fileId.Valid(), Equals, true) + result["_id"] = "" + result["uploadDate"] = "" + + expected := M{ + "_id": "", + "length": 22, + "chunkSize": 5, + "uploadDate": "", + "md5": "44a66044834cbe55040089cabfc102d5", + } + c.Assert(result, DeepEquals, expected) + + // Check the chunks. + iter := db.C("fs.chunks").Find(nil).Sort("n").Iter() + dataChunks := []string{"abcde", "fghij", "klmno", "pqrst", "uv"} + for i := 0; ; i++ { + result = M{} + if !iter.Next(result) { + if i != 5 { + c.Fatalf("Expected 5 chunks, got %d", i) + } + break + } + c.Assert(iter.Close(), IsNil) + + result["_id"] = "" + + expected = M{ + "_id": "", + "files_id": fileId, + "n": i, + "data": []byte(dataChunks[i]), + } + c.Assert(result, DeepEquals, expected) + } +} + +func (s *S) TestGridFSAbort(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + file, err := gfs.Create("") + c.Assert(err, IsNil) + + file.SetChunkSize(5) + + n, err := file.Write([]byte("some data")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 9) + + var count int + for i := 0; i < 10; i++ { + count, err = db.C("fs.chunks").Count() + if count > 0 || err != nil { + break + } + } + c.Assert(err, IsNil) + c.Assert(count, Equals, 1) + + file.Abort() + + err = file.Close() + c.Assert(err, ErrorMatches, "write aborted") + + count, err = db.C("fs.chunks").Count() + c.Assert(err, IsNil) + c.Assert(count, Equals, 0) +} + +func (s *S) TestGridFSOpenNotFound(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + file, err := gfs.OpenId("non-existent") + c.Assert(err == mgo.ErrNotFound, Equals, true) + c.Assert(file, IsNil) + + file, err = gfs.Open("non-existent") + c.Assert(err == mgo.ErrNotFound, Equals, true) + c.Assert(file, IsNil) +} + +func (s *S) TestGridFSReadAll(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + file, err := gfs.Create("") + c.Assert(err, IsNil) + id := file.Id() + + file.SetChunkSize(5) + + n, err := file.Write([]byte("abcdefghijklmnopqrstuv")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 22) + + err = file.Close() + c.Assert(err, IsNil) + + file, err = gfs.OpenId(id) + c.Assert(err, IsNil) + + b := make([]byte, 30) + n, err = file.Read(b) + c.Assert(n, Equals, 22) + c.Assert(err, IsNil) + + n, err = file.Read(b) + c.Assert(n, Equals, 0) + c.Assert(err == io.EOF, Equals, true) + + err = file.Close() + c.Assert(err, IsNil) +} + +func (s *S) TestGridFSReadChunking(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + + file, err := gfs.Create("") + c.Assert(err, IsNil) + + id := file.Id() + + file.SetChunkSize(5) + + n, err := file.Write([]byte("abcdefghijklmnopqrstuv")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 22) + + err = file.Close() + c.Assert(err, IsNil) + + file, err = gfs.OpenId(id) + c.Assert(err, IsNil) + + b := make([]byte, 30) + + // Smaller than the chunk size. + n, err = file.Read(b[:3]) + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) + c.Assert(b[:3], DeepEquals, []byte("abc")) + + // Boundary in the middle. + n, err = file.Read(b[:4]) + c.Assert(err, IsNil) + c.Assert(n, Equals, 4) + c.Assert(b[:4], DeepEquals, []byte("defg")) + + // Boundary at the end. + n, err = file.Read(b[:3]) + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) + c.Assert(b[:3], DeepEquals, []byte("hij")) + + // Larger than the chunk size, with 3 chunks. + n, err = file.Read(b) + c.Assert(err, IsNil) + c.Assert(n, Equals, 12) + c.Assert(b[:12], DeepEquals, []byte("klmnopqrstuv")) + + n, err = file.Read(b) + c.Assert(n, Equals, 0) + c.Assert(err == io.EOF, Equals, true) + + err = file.Close() + c.Assert(err, IsNil) +} + +func (s *S) TestGridFSOpen(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + + file, err := gfs.Create("myfile.txt") + c.Assert(err, IsNil) + file.Write([]byte{'1'}) + file.Close() + + file, err = gfs.Create("myfile.txt") + c.Assert(err, IsNil) + file.Write([]byte{'2'}) + file.Close() + + file, err = gfs.Open("myfile.txt") + c.Assert(err, IsNil) + defer file.Close() + + var b [1]byte + + _, err = file.Read(b[:]) + c.Assert(err, IsNil) + c.Assert(string(b[:]), Equals, "2") +} + +func (s *S) TestGridFSSeek(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + file, err := gfs.Create("") + c.Assert(err, IsNil) + id := file.Id() + + file.SetChunkSize(5) + + n, err := file.Write([]byte("abcdefghijklmnopqrstuv")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 22) + + err = file.Close() + c.Assert(err, IsNil) + + b := make([]byte, 5) + + file, err = gfs.OpenId(id) + c.Assert(err, IsNil) + + o, err := file.Seek(3, os.SEEK_SET) + c.Assert(err, IsNil) + c.Assert(o, Equals, int64(3)) + _, err = file.Read(b) + c.Assert(err, IsNil) + c.Assert(b, DeepEquals, []byte("defgh")) + + o, err = file.Seek(5, os.SEEK_CUR) + c.Assert(err, IsNil) + c.Assert(o, Equals, int64(13)) + _, err = file.Read(b) + c.Assert(err, IsNil) + c.Assert(b, DeepEquals, []byte("nopqr")) + + o, err = file.Seek(0, os.SEEK_END) + c.Assert(err, IsNil) + c.Assert(o, Equals, int64(22)) + n, err = file.Read(b) + c.Assert(err, Equals, io.EOF) + c.Assert(n, Equals, 0) + + o, err = file.Seek(-10, os.SEEK_END) + c.Assert(err, IsNil) + c.Assert(o, Equals, int64(12)) + _, err = file.Read(b) + c.Assert(err, IsNil) + c.Assert(b, DeepEquals, []byte("mnopq")) + + o, err = file.Seek(8, os.SEEK_SET) + c.Assert(err, IsNil) + c.Assert(o, Equals, int64(8)) + _, err = file.Read(b) + c.Assert(err, IsNil) + c.Assert(b, DeepEquals, []byte("ijklm")) + + // Trivial seek forward within same chunk. Already + // got the data, shouldn't touch the database. + sent := mgo.GetStats().SentOps + o, err = file.Seek(1, os.SEEK_CUR) + c.Assert(err, IsNil) + c.Assert(o, Equals, int64(14)) + c.Assert(mgo.GetStats().SentOps, Equals, sent) + _, err = file.Read(b) + c.Assert(err, IsNil) + c.Assert(b, DeepEquals, []byte("opqrs")) + + // Try seeking past end of file. + file.Seek(3, os.SEEK_SET) + o, err = file.Seek(23, os.SEEK_SET) + c.Assert(err, ErrorMatches, "seek past end of file") + c.Assert(o, Equals, int64(3)) +} + +func (s *S) TestGridFSRemoveId(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + + file, err := gfs.Create("myfile.txt") + c.Assert(err, IsNil) + file.Write([]byte{'1'}) + file.Close() + + file, err = gfs.Create("myfile.txt") + c.Assert(err, IsNil) + file.Write([]byte{'2'}) + id := file.Id() + file.Close() + + err = gfs.RemoveId(id) + c.Assert(err, IsNil) + + file, err = gfs.Open("myfile.txt") + c.Assert(err, IsNil) + defer file.Close() + + var b [1]byte + + _, err = file.Read(b[:]) + c.Assert(err, IsNil) + c.Assert(string(b[:]), Equals, "1") + + n, err := db.C("fs.chunks").Find(M{"files_id": id}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 0) +} + +func (s *S) TestGridFSRemove(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + + file, err := gfs.Create("myfile.txt") + c.Assert(err, IsNil) + file.Write([]byte{'1'}) + file.Close() + + file, err = gfs.Create("myfile.txt") + c.Assert(err, IsNil) + file.Write([]byte{'2'}) + file.Close() + + err = gfs.Remove("myfile.txt") + c.Assert(err, IsNil) + + _, err = gfs.Open("myfile.txt") + c.Assert(err == mgo.ErrNotFound, Equals, true) + + n, err := db.C("fs.chunks").Find(nil).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 0) +} + +func (s *S) TestGridFSOpenNext(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + + file, err := gfs.Create("myfile1.txt") + c.Assert(err, IsNil) + file.Write([]byte{'1'}) + file.Close() + + file, err = gfs.Create("myfile2.txt") + c.Assert(err, IsNil) + file.Write([]byte{'2'}) + file.Close() + + var f *mgo.GridFile + var b [1]byte + + iter := gfs.Find(nil).Sort("-filename").Iter() + + ok := gfs.OpenNext(iter, &f) + c.Assert(ok, Equals, true) + c.Check(f.Name(), Equals, "myfile2.txt") + + _, err = f.Read(b[:]) + c.Assert(err, IsNil) + c.Assert(string(b[:]), Equals, "2") + + ok = gfs.OpenNext(iter, &f) + c.Assert(ok, Equals, true) + c.Check(f.Name(), Equals, "myfile1.txt") + + _, err = f.Read(b[:]) + c.Assert(err, IsNil) + c.Assert(string(b[:]), Equals, "1") + + ok = gfs.OpenNext(iter, &f) + c.Assert(ok, Equals, false) + c.Assert(iter.Close(), IsNil) + c.Assert(f, IsNil) + + // Do it again with a more restrictive query to make sure + // it's actually taken into account. + iter = gfs.Find(bson.M{"filename": "myfile1.txt"}).Iter() + + ok = gfs.OpenNext(iter, &f) + c.Assert(ok, Equals, true) + c.Check(f.Name(), Equals, "myfile1.txt") + + ok = gfs.OpenNext(iter, &f) + c.Assert(ok, Equals, false) + c.Assert(iter.Close(), IsNil) + c.Assert(f, IsNil) +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/internal/scram/scram.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/internal/scram/scram.go new file mode 100644 index 0000000..80cda91 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/internal/scram/scram.go @@ -0,0 +1,266 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2014 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Pacakage scram implements a SCRAM-{SHA-1,etc} client per RFC5802. +// +// http://tools.ietf.org/html/rfc5802 +// +package scram + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "encoding/base64" + "fmt" + "hash" + "strconv" + "strings" +) + +// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc). +// +// A Client may be used within a SASL conversation with logic resembling: +// +// var in []byte +// var client = scram.NewClient(sha1.New, user, pass) +// for client.Step(in) { +// out := client.Out() +// // send out to server +// in := serverOut +// } +// if client.Err() != nil { +// // auth failed +// } +// +type Client struct { + newHash func() hash.Hash + + user string + pass string + step int + out bytes.Buffer + err error + + clientNonce []byte + serverNonce []byte + saltedPass []byte + authMsg bytes.Buffer +} + +// NewClient returns a new SCRAM-* client with the provided hash algorithm. +// +// For SCRAM-SHA-1, for example, use: +// +// client := scram.NewClient(sha1.New, user, pass) +// +func NewClient(newHash func() hash.Hash, user, pass string) *Client { + c := &Client{ + newHash: newHash, + user: user, + pass: pass, + } + c.out.Grow(256) + c.authMsg.Grow(256) + return c +} + +// Out returns the data to be sent to the server in the current step. +func (c *Client) Out() []byte { + if c.out.Len() == 0 { + return nil + } + return c.out.Bytes() +} + +// Err returns the error that ocurred, or nil if there were no errors. +func (c *Client) Err() error { + return c.err +} + +// SetNonce sets the client nonce to the provided value. +// If not set, the nonce is generated automatically out of crypto/rand on the first step. +func (c *Client) SetNonce(nonce []byte) { + c.clientNonce = nonce +} + +var escaper = strings.NewReplacer("=", "=3D", ",", "=2C") + +// Step processes the incoming data from the server and makes the +// next round of data for the server available via Client.Out. +// Step returns false if there are no errors and more data is +// still expected. +func (c *Client) Step(in []byte) bool { + c.out.Reset() + if c.step > 2 || c.err != nil { + return false + } + c.step++ + switch c.step { + case 1: + c.err = c.step1(in) + case 2: + c.err = c.step2(in) + case 3: + c.err = c.step3(in) + } + return c.step > 2 || c.err != nil +} + +func (c *Client) step1(in []byte) error { + if len(c.clientNonce) == 0 { + const nonceLen = 6 + buf := make([]byte, nonceLen + b64.EncodedLen(nonceLen)) + if _, err := rand.Read(buf[:nonceLen]); err != nil { + return fmt.Errorf("cannot read random SCRAM-SHA-1 nonce from operating system: %v", err) + } + c.clientNonce = buf[nonceLen:] + b64.Encode(c.clientNonce, buf[:nonceLen]) + } + c.authMsg.WriteString("n=") + escaper.WriteString(&c.authMsg, c.user) + c.authMsg.WriteString(",r=") + c.authMsg.Write(c.clientNonce) + + c.out.WriteString("n,,") + c.out.Write(c.authMsg.Bytes()) + return nil +} + +var b64 = base64.StdEncoding + +func (c *Client) step2(in []byte) error { + c.authMsg.WriteByte(',') + c.authMsg.Write(in) + + fields := bytes.Split(in, []byte(",")) + if len(fields) != 3 { + return fmt.Errorf("expected 3 fields in first SCRAM-SHA-1 server message, got %d: %q", len(fields), in) + } + if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-1 nonce: %q", fields[0]) + } + if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-1 salt: %q", fields[1]) + } + if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-1 iteration count: %q", fields[2]) + } + + c.serverNonce = fields[0][2:] + if !bytes.HasPrefix(c.serverNonce, c.clientNonce) { + return fmt.Errorf("server SCRAM-SHA-1 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce) + } + + salt := make([]byte, b64.DecodedLen(len(fields[1][2:]))) + n, err := b64.Decode(salt, fields[1][2:]) + if err != nil { + return fmt.Errorf("cannot decode SCRAM-SHA-1 salt sent by server: %q", fields[1]) + } + salt = salt[:n] + iterCount, err := strconv.Atoi(string(fields[2][2:])) + if err != nil { + return fmt.Errorf("server sent an invalid SCRAM-SHA-1 iteration count: %q", fields[2]) + } + c.saltPassword(salt, iterCount) + + c.authMsg.WriteString(",c=biws,r=") + c.authMsg.Write(c.serverNonce) + + c.out.WriteString("c=biws,r=") + c.out.Write(c.serverNonce) + c.out.WriteString(",p=") + c.out.Write(c.clientProof()) + return nil +} + +func (c *Client) step3(in []byte) error { + var isv, ise bool + var fields = bytes.Split(in, []byte(",")) + if len(fields) == 1 { + isv = bytes.HasPrefix(fields[0], []byte("v=")) + ise = bytes.HasPrefix(fields[0], []byte("e=")) + } + if ise { + return fmt.Errorf("SCRAM-SHA-1 authentication error: %s", fields[0][2:]) + } else if !isv { + return fmt.Errorf("unsupported SCRAM-SHA-1 final message from server: %q", in) + } + if !bytes.Equal(c.serverSignature(), fields[0][2:]) { + return fmt.Errorf("cannot authenticate SCRAM-SHA-1 server signature: %q", fields[0][2:]) + } + return nil +} + +func (c *Client) saltPassword(salt []byte, iterCount int) { + mac := hmac.New(c.newHash, []byte(c.pass)) + mac.Write(salt) + mac.Write([]byte{0, 0, 0, 1}) + ui := mac.Sum(nil) + hi := make([]byte, len(ui)) + copy(hi, ui) + for i := 1; i < iterCount; i++ { + mac.Reset() + mac.Write(ui) + mac.Sum(ui[:0]) + for j, b := range ui { + hi[j] ^= b + } + } + c.saltedPass = hi +} + +func (c *Client) clientProof() []byte { + mac := hmac.New(c.newHash, c.saltedPass) + mac.Write([]byte("Client Key")) + clientKey := mac.Sum(nil) + hash := c.newHash() + hash.Write(clientKey) + storedKey := hash.Sum(nil) + mac = hmac.New(c.newHash, storedKey) + mac.Write(c.authMsg.Bytes()) + clientProof := mac.Sum(nil) + for i, b := range clientKey { + clientProof[i] ^= b + } + clientProof64 := make([]byte, b64.EncodedLen(len(clientProof))) + b64.Encode(clientProof64, clientProof) + return clientProof64 +} + +func (c *Client) serverSignature() []byte { + mac := hmac.New(c.newHash, c.saltedPass) + mac.Write([]byte("Server Key")) + serverKey := mac.Sum(nil) + + mac = hmac.New(c.newHash, serverKey) + mac.Write(c.authMsg.Bytes()) + serverSignature := mac.Sum(nil) + + encoded := make([]byte, b64.EncodedLen(len(serverSignature))) + b64.Encode(encoded, serverSignature) + return encoded +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/internal/scram/scram_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/internal/scram/scram_test.go new file mode 100644 index 0000000..9c20fdf --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/internal/scram/scram_test.go @@ -0,0 +1,67 @@ +package scram_test + +import ( + "crypto/sha1" + "testing" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2/internal/scram" + "strings" +) + +var _ = Suite(&S{}) + +func Test(t *testing.T) { TestingT(t) } + +type S struct{} + +var tests = [][]string{{ + "U: user pencil", + "N: fyko+d2lbbFgONRv9qkxdawL", + "C: n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL", + "S: r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096", + "C: c=biws,r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,p=v0X8v3Bz2T0CJGbJQyF0X+HI4Ts=", + "S: v=rmF9pqV8S7suAoZWja4dJRkFsKQ=", +}, { + "U: root fe8c89e308ec08763df36333cbf5d3a2", + "N: OTcxNDk5NjM2MzE5", + "C: n,,n=root,r=OTcxNDk5NjM2MzE5", + "S: r=OTcxNDk5NjM2MzE581Ra3provgG0iDsMkDiIAlrh4532dDLp,s=XRDkVrFC9JuL7/F4tG0acQ==,i=10000", + "C: c=biws,r=OTcxNDk5NjM2MzE581Ra3provgG0iDsMkDiIAlrh4532dDLp,p=6y1jp9R7ETyouTXS9fW9k5UHdBc=", + "S: v=LBnd9dUJRxdqZiEq91NKP3z/bHA=", +}} + +func (s *S) TestExamples(c *C) { + for _, steps := range tests { + if len(steps) < 2 || len(steps[0]) < 3 || !strings.HasPrefix(steps[0], "U: ") { + c.Fatalf("Invalid test: %#v", steps) + } + auth := strings.Fields(steps[0][3:]) + client := scram.NewClient(sha1.New, auth[0], auth[1]) + first, done := true, false + c.Logf("-----") + c.Logf("%s", steps[0]) + for _, step := range steps[1:] { + c.Logf("%s", step) + switch step[:3] { + case "N: ": + client.SetNonce([]byte(step[3:])) + case "C: ": + if first { + first = false + done = client.Step(nil) + } + c.Assert(done, Equals, false) + c.Assert(client.Err(), IsNil) + c.Assert(string(client.Out()), Equals, step[3:]) + case "S: ": + first = false + done = client.Step([]byte(step[3:])) + default: + panic("invalid test line: " + step) + } + } + c.Assert(done, Equals, true) + c.Assert(client.Err(), IsNil) + } +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/log.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/log.go new file mode 100644 index 0000000..53eb423 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/log.go @@ -0,0 +1,133 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "fmt" + "sync" +) + +// --------------------------------------------------------------------------- +// Logging integration. + +// Avoid importing the log type information unnecessarily. There's a small cost +// associated with using an interface rather than the type. Depending on how +// often the logger is plugged in, it would be worth using the type instead. +type log_Logger interface { + Output(calldepth int, s string) error +} + +var ( + globalLogger log_Logger + globalDebug bool + globalMutex sync.Mutex +) + +// RACE WARNING: There are known data races when logging, which are manually +// silenced when the race detector is in use. These data races won't be +// observed in typical use, because logging is supposed to be set up once when +// the application starts. Having raceDetector as a constant, the compiler +// should elide the locks altogether in actual use. + +// Specify the *log.Logger object where log messages should be sent to. +func SetLogger(logger log_Logger) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + globalLogger = logger +} + +// Enable the delivery of debug messages to the logger. Only meaningful +// if a logger is also set. +func SetDebug(debug bool) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + globalDebug = debug +} + +func log(v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalLogger != nil { + globalLogger.Output(2, fmt.Sprint(v...)) + } +} + +func logln(v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalLogger != nil { + globalLogger.Output(2, fmt.Sprintln(v...)) + } +} + +func logf(format string, v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalLogger != nil { + globalLogger.Output(2, fmt.Sprintf(format, v...)) + } +} + +func debug(v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalDebug && globalLogger != nil { + globalLogger.Output(2, fmt.Sprint(v...)) + } +} + +func debugln(v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalDebug && globalLogger != nil { + globalLogger.Output(2, fmt.Sprintln(v...)) + } +} + +func debugf(format string, v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalDebug && globalLogger != nil { + globalLogger.Output(2, fmt.Sprintf(format, v...)) + } +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/queue.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/queue.go new file mode 100644 index 0000000..e9245de --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/queue.go @@ -0,0 +1,91 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +type queue struct { + elems []interface{} + nelems, popi, pushi int +} + +func (q *queue) Len() int { + return q.nelems +} + +func (q *queue) Push(elem interface{}) { + //debugf("Pushing(pushi=%d popi=%d cap=%d): %#v\n", + // q.pushi, q.popi, len(q.elems), elem) + if q.nelems == len(q.elems) { + q.expand() + } + q.elems[q.pushi] = elem + q.nelems++ + q.pushi = (q.pushi + 1) % len(q.elems) + //debugf(" Pushed(pushi=%d popi=%d cap=%d): %#v\n", + // q.pushi, q.popi, len(q.elems), elem) +} + +func (q *queue) Pop() (elem interface{}) { + //debugf("Popping(pushi=%d popi=%d cap=%d)\n", + // q.pushi, q.popi, len(q.elems)) + if q.nelems == 0 { + return nil + } + elem = q.elems[q.popi] + q.elems[q.popi] = nil // Help GC. + q.nelems-- + q.popi = (q.popi + 1) % len(q.elems) + //debugf(" Popped(pushi=%d popi=%d cap=%d): %#v\n", + // q.pushi, q.popi, len(q.elems), elem) + return elem +} + +func (q *queue) expand() { + curcap := len(q.elems) + var newcap int + if curcap == 0 { + newcap = 8 + } else if curcap < 1024 { + newcap = curcap * 2 + } else { + newcap = curcap + (curcap / 4) + } + elems := make([]interface{}, newcap) + + if q.popi == 0 { + copy(elems, q.elems) + q.pushi = curcap + } else { + newpopi := newcap - (curcap - q.popi) + copy(elems, q.elems[:q.popi]) + copy(elems[newpopi:], q.elems[q.popi:]) + q.popi = newpopi + } + for i := range q.elems { + q.elems[i] = nil // Help GC. + } + q.elems = elems +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/queue_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/queue_test.go new file mode 100644 index 0000000..bd0ab55 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/queue_test.go @@ -0,0 +1,101 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + . "gopkg.in/check.v1" +) + +type QS struct{} + +var _ = Suite(&QS{}) + +func (s *QS) TestSequentialGrowth(c *C) { + q := queue{} + n := 2048 + for i := 0; i != n; i++ { + q.Push(i) + } + for i := 0; i != n; i++ { + c.Assert(q.Pop(), Equals, i) + } +} + +var queueTestLists = [][]int{ + // {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + + // {8, 9, 10, 11, ... 2, 3, 4, 5, 6, 7} + {0, 1, 2, 3, 4, 5, 6, 7, -1, -1, 8, 9, 10, 11}, + + // {8, 9, 10, 11, ... 2, 3, 4, 5, 6, 7} + {0, 1, 2, 3, -1, -1, 4, 5, 6, 7, 8, 9, 10, 11}, + + // {0, 1, 2, 3, 4, 5, 6, 7, 8} + {0, 1, 2, 3, 4, 5, 6, 7, 8, + -1, -1, -1, -1, -1, -1, -1, -1, -1, + 0, 1, 2, 3, 4, 5, 6, 7, 8}, +} + +func (s *QS) TestQueueTestLists(c *C) { + test := []int{} + testi := 0 + reset := func() { + test = test[0:0] + testi = 0 + } + push := func(i int) { + test = append(test, i) + } + pop := func() (i int) { + if testi == len(test) { + return -1 + } + i = test[testi] + testi++ + return + } + + for _, list := range queueTestLists { + reset() + q := queue{} + for _, n := range list { + if n == -1 { + c.Assert(q.Pop(), Equals, pop(), Commentf("With list %#v", list)) + } else { + q.Push(n) + push(n) + } + } + + for n := pop(); n != -1; n = pop() { + c.Assert(q.Pop(), Equals, n, Commentf("With list %#v", list)) + } + + c.Assert(q.Pop(), Equals, nil, Commentf("With list %#v", list)) + } +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/raceoff.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/raceoff.go new file mode 100644 index 0000000..e60b141 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/raceoff.go @@ -0,0 +1,5 @@ +// +build !race + +package mgo + +const raceDetector = false diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/raceon.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/raceon.go new file mode 100644 index 0000000..737b08e --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/raceon.go @@ -0,0 +1,5 @@ +// +build race + +package mgo + +const raceDetector = true diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl.c b/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl.c new file mode 100644 index 0000000..8be0bc4 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl.c @@ -0,0 +1,77 @@ +// +build !windows + +#include +#include +#include +#include + +static int mgo_sasl_simple(void *context, int id, const char **result, unsigned int *len) +{ + if (!result) { + return SASL_BADPARAM; + } + switch (id) { + case SASL_CB_USER: + *result = (char *)context; + break; + case SASL_CB_AUTHNAME: + *result = (char *)context; + break; + case SASL_CB_LANGUAGE: + *result = NULL; + break; + default: + return SASL_BADPARAM; + } + if (len) { + *len = *result ? strlen(*result) : 0; + } + return SASL_OK; +} + +typedef int (*callback)(void); + +static int mgo_sasl_secret(sasl_conn_t *conn, void *context, int id, sasl_secret_t **result) +{ + if (!conn || !result || id != SASL_CB_PASS) { + return SASL_BADPARAM; + } + *result = (sasl_secret_t *)context; + return SASL_OK; +} + +sasl_callback_t *mgo_sasl_callbacks(const char *username, const char *password) +{ + sasl_callback_t *cb = malloc(4 * sizeof(sasl_callback_t)); + int n = 0; + + size_t len = strlen(password); + sasl_secret_t *secret = (sasl_secret_t*)malloc(sizeof(sasl_secret_t) + len); + if (!secret) { + free(cb); + return NULL; + } + strcpy((char *)secret->data, password); + secret->len = len; + + cb[n].id = SASL_CB_PASS; + cb[n].proc = (callback)&mgo_sasl_secret; + cb[n].context = secret; + n++; + + cb[n].id = SASL_CB_USER; + cb[n].proc = (callback)&mgo_sasl_simple; + cb[n].context = (char*)username; + n++; + + cb[n].id = SASL_CB_AUTHNAME; + cb[n].proc = (callback)&mgo_sasl_simple; + cb[n].context = (char*)username; + n++; + + cb[n].id = SASL_CB_LIST_END; + cb[n].proc = NULL; + cb[n].context = NULL; + + return cb; +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl.go new file mode 100644 index 0000000..8375ddd --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl.go @@ -0,0 +1,138 @@ +// Package sasl is an implementation detail of the mgo package. +// +// This package is not meant to be used by itself. +// + +// +build !windows + +package sasl + +// #cgo LDFLAGS: -lsasl2 +// +// struct sasl_conn {}; +// +// #include +// #include +// +// sasl_callback_t *mgo_sasl_callbacks(const char *username, const char *password); +// +import "C" + +import ( + "fmt" + "strings" + "sync" + "unsafe" +) + +type saslStepper interface { + Step(serverData []byte) (clientData []byte, done bool, err error) + Close() +} + +type saslSession struct { + conn *C.sasl_conn_t + step int + mech string + + cstrings []*C.char + callbacks *C.sasl_callback_t +} + +var initError error +var initOnce sync.Once + +func initSASL() { + rc := C.sasl_client_init(nil) + if rc != C.SASL_OK { + initError = saslError(rc, nil, "cannot initialize SASL library") + } +} + +func New(username, password, mechanism, service, host string) (saslStepper, error) { + initOnce.Do(initSASL) + if initError != nil { + return nil, initError + } + + ss := &saslSession{mech: mechanism} + if service == "" { + service = "mongodb" + } + if i := strings.Index(host, ":"); i >= 0 { + host = host[:i] + } + ss.callbacks = C.mgo_sasl_callbacks(ss.cstr(username), ss.cstr(password)) + rc := C.sasl_client_new(ss.cstr(service), ss.cstr(host), nil, nil, ss.callbacks, 0, &ss.conn) + if rc != C.SASL_OK { + ss.Close() + return nil, saslError(rc, nil, "cannot create new SASL client") + } + return ss, nil +} + +func (ss *saslSession) cstr(s string) *C.char { + cstr := C.CString(s) + ss.cstrings = append(ss.cstrings, cstr) + return cstr +} + +func (ss *saslSession) Close() { + for _, cstr := range ss.cstrings { + C.free(unsafe.Pointer(cstr)) + } + ss.cstrings = nil + + if ss.callbacks != nil { + C.free(unsafe.Pointer(ss.callbacks)) + } + + // The documentation of SASL dispose makes it clear that this should only + // be done when the connection is done, not when the authentication phase + // is done, because an encryption layer may have been negotiated. + // Even then, we'll do this for now, because it's simpler and prevents + // keeping track of this state for every socket. If it breaks, we'll fix it. + C.sasl_dispose(&ss.conn) +} + +func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) { + ss.step++ + if ss.step > 10 { + return nil, false, fmt.Errorf("too many SASL steps without authentication") + } + var cclientData *C.char + var cclientDataLen C.uint + var rc C.int + if ss.step == 1 { + var mechanism *C.char // ignored - must match cred + rc = C.sasl_client_start(ss.conn, ss.cstr(ss.mech), nil, &cclientData, &cclientDataLen, &mechanism) + } else { + var cserverData *C.char + var cserverDataLen C.uint + if len(serverData) > 0 { + cserverData = (*C.char)(unsafe.Pointer(&serverData[0])) + cserverDataLen = C.uint(len(serverData)) + } + rc = C.sasl_client_step(ss.conn, cserverData, cserverDataLen, nil, &cclientData, &cclientDataLen) + } + if cclientData != nil && cclientDataLen > 0 { + clientData = C.GoBytes(unsafe.Pointer(cclientData), C.int(cclientDataLen)) + } + if rc == C.SASL_OK { + return clientData, true, nil + } + if rc == C.SASL_CONTINUE { + return clientData, false, nil + } + return nil, false, saslError(rc, ss.conn, "cannot establish SASL session") +} + +func saslError(rc C.int, conn *C.sasl_conn_t, msg string) error { + var detail string + if conn == nil { + detail = C.GoString(C.sasl_errstring(rc, nil, nil)) + } else { + detail = C.GoString(C.sasl_errdetail(conn)) + } + return fmt.Errorf(msg + ": " + detail) +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl_windows.c b/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl_windows.c new file mode 100644 index 0000000..dd6a88a --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl_windows.c @@ -0,0 +1,118 @@ +#include "sasl_windows.h" + +static const LPSTR SSPI_PACKAGE_NAME = "kerberos"; + +SECURITY_STATUS SEC_ENTRY sspi_acquire_credentials_handle(CredHandle *cred_handle, char *username, char *password, char *domain) +{ + SEC_WINNT_AUTH_IDENTITY auth_identity; + SECURITY_INTEGER ignored; + + auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_ANSI; + auth_identity.User = (LPSTR) username; + auth_identity.UserLength = strlen(username); + auth_identity.Password = (LPSTR) password; + auth_identity.PasswordLength = strlen(password); + auth_identity.Domain = (LPSTR) domain; + auth_identity.DomainLength = strlen(domain); + return call_sspi_acquire_credentials_handle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, &auth_identity, NULL, NULL, cred_handle, &ignored); +} + +int sspi_step(CredHandle *cred_handle, int has_context, CtxtHandle *context, PVOID *buffer, ULONG *buffer_length, char *target) +{ + SecBufferDesc inbuf; + SecBuffer in_bufs[1]; + SecBufferDesc outbuf; + SecBuffer out_bufs[1]; + + if (has_context > 0) { + // If we already have a context, we now have data to send. + // Put this data in an inbuf. + inbuf.ulVersion = SECBUFFER_VERSION; + inbuf.cBuffers = 1; + inbuf.pBuffers = in_bufs; + in_bufs[0].pvBuffer = *buffer; + in_bufs[0].cbBuffer = *buffer_length; + in_bufs[0].BufferType = SECBUFFER_TOKEN; + } + + outbuf.ulVersion = SECBUFFER_VERSION; + outbuf.cBuffers = 1; + outbuf.pBuffers = out_bufs; + out_bufs[0].pvBuffer = NULL; + out_bufs[0].cbBuffer = 0; + out_bufs[0].BufferType = SECBUFFER_TOKEN; + + ULONG context_attr = 0; + + int ret = call_sspi_initialize_security_context(cred_handle, + has_context > 0 ? context : NULL, + (LPSTR) target, + ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_MUTUAL_AUTH, + 0, + SECURITY_NETWORK_DREP, + has_context > 0 ? &inbuf : NULL, + 0, + context, + &outbuf, + &context_attr, + NULL); + + *buffer = malloc(out_bufs[0].cbBuffer); + *buffer_length = out_bufs[0].cbBuffer; + memcpy(*buffer, out_bufs[0].pvBuffer, *buffer_length); + + return ret; +} + +int sspi_send_client_authz_id(CtxtHandle *context, PVOID *buffer, ULONG *buffer_length, char *user_plus_realm) +{ + SecPkgContext_Sizes sizes; + SECURITY_STATUS status = call_sspi_query_context_attributes(context, SECPKG_ATTR_SIZES, &sizes); + + if (status != SEC_E_OK) { + return status; + } + + size_t user_plus_realm_length = strlen(user_plus_realm); + int msgSize = 4 + user_plus_realm_length; + char *msg = malloc((sizes.cbSecurityTrailer + msgSize + sizes.cbBlockSize) * sizeof(char)); + msg[sizes.cbSecurityTrailer + 0] = 1; + msg[sizes.cbSecurityTrailer + 1] = 0; + msg[sizes.cbSecurityTrailer + 2] = 0; + msg[sizes.cbSecurityTrailer + 3] = 0; + memcpy(&msg[sizes.cbSecurityTrailer + 4], user_plus_realm, user_plus_realm_length); + + SecBuffer wrapBufs[3]; + SecBufferDesc wrapBufDesc; + wrapBufDesc.cBuffers = 3; + wrapBufDesc.pBuffers = wrapBufs; + wrapBufDesc.ulVersion = SECBUFFER_VERSION; + + wrapBufs[0].cbBuffer = sizes.cbSecurityTrailer; + wrapBufs[0].BufferType = SECBUFFER_TOKEN; + wrapBufs[0].pvBuffer = msg; + + wrapBufs[1].cbBuffer = msgSize; + wrapBufs[1].BufferType = SECBUFFER_DATA; + wrapBufs[1].pvBuffer = msg + sizes.cbSecurityTrailer; + + wrapBufs[2].cbBuffer = sizes.cbBlockSize; + wrapBufs[2].BufferType = SECBUFFER_PADDING; + wrapBufs[2].pvBuffer = msg + sizes.cbSecurityTrailer + msgSize; + + status = call_sspi_encrypt_message(context, SECQOP_WRAP_NO_ENCRYPT, &wrapBufDesc, 0); + if (status != SEC_E_OK) { + free(msg); + return status; + } + + *buffer_length = wrapBufs[0].cbBuffer + wrapBufs[1].cbBuffer + wrapBufs[2].cbBuffer; + *buffer = malloc(*buffer_length); + + memcpy(*buffer, wrapBufs[0].pvBuffer, wrapBufs[0].cbBuffer); + memcpy(*buffer + wrapBufs[0].cbBuffer, wrapBufs[1].pvBuffer, wrapBufs[1].cbBuffer); + memcpy(*buffer + wrapBufs[0].cbBuffer + wrapBufs[1].cbBuffer, wrapBufs[2].pvBuffer, wrapBufs[2].cbBuffer); + + free(msg); + return SEC_E_OK; +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl_windows.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl_windows.go new file mode 100644 index 0000000..3302cfe --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl_windows.go @@ -0,0 +1,140 @@ +package sasl + +// #include "sasl_windows.h" +import "C" + +import ( + "fmt" + "strings" + "sync" + "unsafe" +) + +type saslStepper interface { + Step(serverData []byte) (clientData []byte, done bool, err error) + Close() +} + +type saslSession struct { + // Credentials + mech string + service string + host string + userPlusRealm string + target string + domain string + + // Internal state + authComplete bool + errored bool + step int + + // C internal state + credHandle C.CredHandle + context C.CtxtHandle + hasContext C.int + + // Keep track of pointers we need to explicitly free + stringsToFree []*C.char +} + +var initError error +var initOnce sync.Once + +func initSSPI() { + rc := C.load_secur32_dll() + if rc != 0 { + initError = fmt.Errorf("Error loading libraries: %v", rc) + } +} + +func New(username, password, mechanism, service, host string) (saslStepper, error) { + initOnce.Do(initSSPI) + ss := &saslSession{mech: mechanism, hasContext: 0, userPlusRealm: username} + if service == "" { + service = "mongodb" + } + if i := strings.Index(host, ":"); i >= 0 { + host = host[:i] + } + ss.service = service + ss.host = host + + usernameComponents := strings.Split(username, "@") + if len(usernameComponents) < 2 { + return nil, fmt.Errorf("Username '%v' doesn't contain a realm!", username) + } + user := usernameComponents[0] + ss.domain = usernameComponents[1] + ss.target = fmt.Sprintf("%s/%s", ss.service, ss.host) + + var status C.SECURITY_STATUS + // Step 0: call AcquireCredentialsHandle to get a nice SSPI CredHandle + if len(password) > 0 { + status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), ss.cstr(password), ss.cstr(ss.domain)) + } else { + status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), nil, ss.cstr(ss.domain)) + } + if status != C.SEC_E_OK { + ss.errored = true + return nil, fmt.Errorf("Couldn't create new SSPI client, error code %v", status) + } + return ss, nil +} + +func (ss *saslSession) cstr(s string) *C.char { + cstr := C.CString(s) + ss.stringsToFree = append(ss.stringsToFree, cstr) + return cstr +} + +func (ss *saslSession) Close() { + for _, cstr := range ss.stringsToFree { + C.free(unsafe.Pointer(cstr)) + } +} + +func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) { + ss.step++ + if ss.step > 10 { + return nil, false, fmt.Errorf("too many SSPI steps without authentication") + } + var buffer C.PVOID + var bufferLength C.ULONG + if len(serverData) > 0 { + buffer = (C.PVOID)(unsafe.Pointer(&serverData[0])) + bufferLength = C.ULONG(len(serverData)) + } + var status C.int + if ss.authComplete { + // Step 3: last bit of magic to use the correct server credentials + status = C.sspi_send_client_authz_id(&ss.context, &buffer, &bufferLength, ss.cstr(ss.userPlusRealm)) + } else { + // Step 1 + Step 2: set up security context with the server and TGT + status = C.sspi_step(&ss.credHandle, ss.hasContext, &ss.context, &buffer, &bufferLength, ss.cstr(ss.target)) + } + if buffer != C.PVOID(nil) { + defer C.free(unsafe.Pointer(buffer)) + } + if status != C.SEC_E_OK && status != C.SEC_I_CONTINUE_NEEDED { + ss.errored = true + return nil, false, ss.handleSSPIErrorCode(status) + } + + clientData = C.GoBytes(unsafe.Pointer(buffer), C.int(bufferLength)) + if status == C.SEC_E_OK { + ss.authComplete = true + return clientData, true, nil + } else { + ss.hasContext = 1 + return clientData, false, nil + } +} + +func (ss *saslSession) handleSSPIErrorCode(code C.int) error { + switch { + case code == C.SEC_E_TARGET_UNKNOWN: + return fmt.Errorf("Target %v@%v not found", ss.target, ss.domain) + } + return fmt.Errorf("Unknown error doing step %v, error code %v", ss.step, code) +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl_windows.h b/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl_windows.h new file mode 100644 index 0000000..94321b2 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sasl_windows.h @@ -0,0 +1,7 @@ +#include + +#include "sspi_windows.h" + +SECURITY_STATUS SEC_ENTRY sspi_acquire_credentials_handle(CredHandle* cred_handle, char* username, char* password, char* domain); +int sspi_step(CredHandle* cred_handle, int has_context, CtxtHandle* context, PVOID* buffer, ULONG* buffer_length, char* target); +int sspi_send_client_authz_id(CtxtHandle* context, PVOID* buffer, ULONG* buffer_length, char* user_plus_realm); diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sspi_windows.c b/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sspi_windows.c new file mode 100644 index 0000000..63f9a6f --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sspi_windows.c @@ -0,0 +1,96 @@ +// Code adapted from the NodeJS kerberos library: +// +// https://github.com/christkv/kerberos/tree/master/lib/win32/kerberos_sspi.c +// +// Under the terms of the Apache License, Version 2.0: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +#include + +#include "sspi_windows.h" + +static HINSTANCE sspi_secur32_dll = NULL; + +int load_secur32_dll() +{ + sspi_secur32_dll = LoadLibrary("secur32.dll"); + if (sspi_secur32_dll == NULL) { + return GetLastError(); + } + return 0; +} + +SECURITY_STATUS SEC_ENTRY call_sspi_encrypt_message(PCtxtHandle phContext, unsigned long fQOP, PSecBufferDesc pMessage, unsigned long MessageSeqNo) +{ + if (sspi_secur32_dll == NULL) { + return -1; + } + encryptMessage_fn pfn_encryptMessage = (encryptMessage_fn) GetProcAddress(sspi_secur32_dll, "EncryptMessage"); + if (!pfn_encryptMessage) { + return -2; + } + return (*pfn_encryptMessage)(phContext, fQOP, pMessage, MessageSeqNo); +} + +SECURITY_STATUS SEC_ENTRY call_sspi_acquire_credentials_handle( + LPSTR pszPrincipal, LPSTR pszPackage, unsigned long fCredentialUse, + void *pvLogonId, void *pAuthData, SEC_GET_KEY_FN pGetKeyFn, void *pvGetKeyArgument, + PCredHandle phCredential, PTimeStamp ptsExpiry) +{ + if (sspi_secur32_dll == NULL) { + return -1; + } + acquireCredentialsHandle_fn pfn_acquireCredentialsHandle; +#ifdef _UNICODE + pfn_acquireCredentialsHandle = (acquireCredentialsHandle_fn) GetProcAddress(sspi_secur32_dll, "AcquireCredentialsHandleW"); +#else + pfn_acquireCredentialsHandle = (acquireCredentialsHandle_fn) GetProcAddress(sspi_secur32_dll, "AcquireCredentialsHandleA"); +#endif + if (!pfn_acquireCredentialsHandle) { + return -2; + } + return (*pfn_acquireCredentialsHandle)( + pszPrincipal, pszPackage, fCredentialUse, pvLogonId, pAuthData, + pGetKeyFn, pvGetKeyArgument, phCredential, ptsExpiry); +} + +SECURITY_STATUS SEC_ENTRY call_sspi_initialize_security_context( + PCredHandle phCredential, PCtxtHandle phContext, LPSTR pszTargetName, + unsigned long fContextReq, unsigned long Reserved1, unsigned long TargetDataRep, + PSecBufferDesc pInput, unsigned long Reserved2, PCtxtHandle phNewContext, + PSecBufferDesc pOutput, unsigned long *pfContextAttr, PTimeStamp ptsExpiry) +{ + if (sspi_secur32_dll == NULL) { + return -1; + } + initializeSecurityContext_fn pfn_initializeSecurityContext; +#ifdef _UNICODE + pfn_initializeSecurityContext = (initializeSecurityContext_fn) GetProcAddress(sspi_secur32_dll, "InitializeSecurityContextW"); +#else + pfn_initializeSecurityContext = (initializeSecurityContext_fn) GetProcAddress(sspi_secur32_dll, "InitializeSecurityContextA"); +#endif + if (!pfn_initializeSecurityContext) { + return -2; + } + return (*pfn_initializeSecurityContext)( + phCredential, phContext, pszTargetName, fContextReq, Reserved1, TargetDataRep, + pInput, Reserved2, phNewContext, pOutput, pfContextAttr, ptsExpiry); +} + +SECURITY_STATUS SEC_ENTRY call_sspi_query_context_attributes(PCtxtHandle phContext, unsigned long ulAttribute, void *pBuffer) +{ + if (sspi_secur32_dll == NULL) { + return -1; + } + queryContextAttributes_fn pfn_queryContextAttributes; +#ifdef _UNICODE + pfn_queryContextAttributes = (queryContextAttributes_fn) GetProcAddress(sspi_secur32_dll, "QueryContextAttributesW"); +#else + pfn_queryContextAttributes = (queryContextAttributes_fn) GetProcAddress(sspi_secur32_dll, "QueryContextAttributesA"); +#endif + if (!pfn_queryContextAttributes) { + return -2; + } + return (*pfn_queryContextAttributes)(phContext, ulAttribute, pBuffer); +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sspi_windows.h b/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sspi_windows.h new file mode 100644 index 0000000..d283270 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/sasl/sspi_windows.h @@ -0,0 +1,70 @@ +// Code adapted from the NodeJS kerberos library: +// +// https://github.com/christkv/kerberos/tree/master/lib/win32/kerberos_sspi.h +// +// Under the terms of the Apache License, Version 2.0: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +#ifndef SSPI_WINDOWS_H +#define SSPI_WINDOWS_H + +#define SECURITY_WIN32 1 + +#include +#include + +int load_secur32_dll(); + +SECURITY_STATUS SEC_ENTRY call_sspi_encrypt_message(PCtxtHandle phContext, unsigned long fQOP, PSecBufferDesc pMessage, unsigned long MessageSeqNo); + +typedef DWORD (WINAPI *encryptMessage_fn)(PCtxtHandle phContext, ULONG fQOP, PSecBufferDesc pMessage, ULONG MessageSeqNo); + +SECURITY_STATUS SEC_ENTRY call_sspi_acquire_credentials_handle( + LPSTR pszPrincipal, // Name of principal + LPSTR pszPackage, // Name of package + unsigned long fCredentialUse, // Flags indicating use + void *pvLogonId, // Pointer to logon ID + void *pAuthData, // Package specific data + SEC_GET_KEY_FN pGetKeyFn, // Pointer to GetKey() func + void *pvGetKeyArgument, // Value to pass to GetKey() + PCredHandle phCredential, // (out) Cred Handle + PTimeStamp ptsExpiry // (out) Lifetime (optional) +); + +typedef DWORD (WINAPI *acquireCredentialsHandle_fn)( + LPSTR pszPrincipal, LPSTR pszPackage, unsigned long fCredentialUse, + void *pvLogonId, void *pAuthData, SEC_GET_KEY_FN pGetKeyFn, void *pvGetKeyArgument, + PCredHandle phCredential, PTimeStamp ptsExpiry +); + +SECURITY_STATUS SEC_ENTRY call_sspi_initialize_security_context( + PCredHandle phCredential, // Cred to base context + PCtxtHandle phContext, // Existing context (OPT) + LPSTR pszTargetName, // Name of target + unsigned long fContextReq, // Context Requirements + unsigned long Reserved1, // Reserved, MBZ + unsigned long TargetDataRep, // Data rep of target + PSecBufferDesc pInput, // Input Buffers + unsigned long Reserved2, // Reserved, MBZ + PCtxtHandle phNewContext, // (out) New Context handle + PSecBufferDesc pOutput, // (inout) Output Buffers + unsigned long *pfContextAttr, // (out) Context attrs + PTimeStamp ptsExpiry // (out) Life span (OPT) +); + +typedef DWORD (WINAPI *initializeSecurityContext_fn)( + PCredHandle phCredential, PCtxtHandle phContext, LPSTR pszTargetName, unsigned long fContextReq, + unsigned long Reserved1, unsigned long TargetDataRep, PSecBufferDesc pInput, unsigned long Reserved2, + PCtxtHandle phNewContext, PSecBufferDesc pOutput, unsigned long *pfContextAttr, PTimeStamp ptsExpiry); + +SECURITY_STATUS SEC_ENTRY call_sspi_query_context_attributes( + PCtxtHandle phContext, // Context to query + unsigned long ulAttribute, // Attribute to query + void *pBuffer // Buffer for attributes +); + +typedef DWORD (WINAPI *queryContextAttributes_fn)( + PCtxtHandle phContext, unsigned long ulAttribute, void *pBuffer); + +#endif // SSPI_WINDOWS_H diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/saslimpl.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/saslimpl.go new file mode 100644 index 0000000..58c0891 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/saslimpl.go @@ -0,0 +1,11 @@ +//+build sasl + +package mgo + +import ( + "gopkg.in/mgo.v2/sasl" +) + +func saslNew(cred Credential, host string) (saslStepper, error) { + return sasl.New(cred.Username, cred.Password, cred.Mechanism, cred.Service, host) +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/saslstub.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/saslstub.go new file mode 100644 index 0000000..6e9e309 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/saslstub.go @@ -0,0 +1,11 @@ +//+build !sasl + +package mgo + +import ( + "fmt" +) + +func saslNew(cred Credential, host string) (saslStepper, error) { + return nil, fmt.Errorf("SASL support not enabled during build (-tags sasl)") +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/server.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/server.go new file mode 100644 index 0000000..8c130be --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/server.go @@ -0,0 +1,447 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "errors" + "net" + "sort" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +// --------------------------------------------------------------------------- +// Mongo server encapsulation. + +type mongoServer struct { + sync.RWMutex + Addr string + ResolvedAddr string + tcpaddr *net.TCPAddr + unusedSockets []*mongoSocket + liveSockets []*mongoSocket + closed bool + abended bool + sync chan bool + dial dialer + pingValue time.Duration + pingIndex int + pingCount uint32 + pingWindow [6]time.Duration + info *mongoServerInfo +} + +type dialer struct { + old func(addr net.Addr) (net.Conn, error) + new func(addr *ServerAddr) (net.Conn, error) +} + +func (dial dialer) isSet() bool { + return dial.old != nil || dial.new != nil +} + +type mongoServerInfo struct { + Master bool + Mongos bool + Tags bson.D + MaxWireVersion int +} + +var defaultServerInfo mongoServerInfo + +func newServer(addr string, tcpaddr *net.TCPAddr, sync chan bool, dial dialer) *mongoServer { + server := &mongoServer{ + Addr: addr, + ResolvedAddr: tcpaddr.String(), + tcpaddr: tcpaddr, + sync: sync, + dial: dial, + info: &defaultServerInfo, + } + // Once so the server gets a ping value, then loop in background. + server.pinger(false) + go server.pinger(true) + return server +} + +var errPoolLimit = errors.New("per-server connection limit reached") +var errServerClosed = errors.New("server was closed") + +// AcquireSocket returns a socket for communicating with the server. +// This will attempt to reuse an old connection, if one is available. Otherwise, +// it will establish a new one. The returned socket is owned by the call site, +// and will return to the cache when the socket has its Release method called +// the same number of times as AcquireSocket + Acquire were called for it. +// If the poolLimit argument is greater than zero and the number of sockets in +// use in this server is greater than the provided limit, errPoolLimit is +// returned. +func (server *mongoServer) AcquireSocket(poolLimit int, timeout time.Duration) (socket *mongoSocket, abended bool, err error) { + for { + server.Lock() + abended = server.abended + if server.closed { + server.Unlock() + return nil, abended, errServerClosed + } + n := len(server.unusedSockets) + if poolLimit > 0 && len(server.liveSockets)-n >= poolLimit { + server.Unlock() + return nil, false, errPoolLimit + } + if n > 0 { + socket = server.unusedSockets[n-1] + server.unusedSockets[n-1] = nil // Help GC. + server.unusedSockets = server.unusedSockets[:n-1] + info := server.info + server.Unlock() + err = socket.InitialAcquire(info, timeout) + if err != nil { + continue + } + } else { + server.Unlock() + socket, err = server.Connect(timeout) + if err == nil { + server.Lock() + // We've waited for the Connect, see if we got + // closed in the meantime + if server.closed { + server.Unlock() + socket.Release() + socket.Close() + return nil, abended, errServerClosed + } + server.liveSockets = append(server.liveSockets, socket) + server.Unlock() + } + } + return + } + panic("unreachable") +} + +// Connect establishes a new connection to the server. This should +// generally be done through server.AcquireSocket(). +func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) { + server.RLock() + master := server.info.Master + dial := server.dial + server.RUnlock() + + logf("Establishing new connection to %s (timeout=%s)...", server.Addr, timeout) + var conn net.Conn + var err error + switch { + case !dial.isSet(): + // Cannot do this because it lacks timeout support. :-( + //conn, err = net.DialTCP("tcp", nil, server.tcpaddr) + conn, err = net.DialTimeout("tcp", server.ResolvedAddr, timeout) + case dial.old != nil: + conn, err = dial.old(server.tcpaddr) + case dial.new != nil: + conn, err = dial.new(&ServerAddr{server.Addr, server.tcpaddr}) + default: + panic("dialer is set, but both dial.old and dial.new are nil") + } + if err != nil { + logf("Connection to %s failed: %v", server.Addr, err.Error()) + return nil, err + } + logf("Connection to %s established.", server.Addr) + + stats.conn(+1, master) + return newSocket(server, conn, timeout), nil +} + +// Close forces closing all sockets that are alive, whether +// they're currently in use or not. +func (server *mongoServer) Close() { + server.Lock() + server.closed = true + liveSockets := server.liveSockets + unusedSockets := server.unusedSockets + server.liveSockets = nil + server.unusedSockets = nil + server.Unlock() + logf("Connections to %s closing (%d live sockets).", server.Addr, len(liveSockets)) + for i, s := range liveSockets { + s.Close() + liveSockets[i] = nil + } + for i := range unusedSockets { + unusedSockets[i] = nil + } +} + +// RecycleSocket puts socket back into the unused cache. +func (server *mongoServer) RecycleSocket(socket *mongoSocket) { + server.Lock() + if !server.closed { + server.unusedSockets = append(server.unusedSockets, socket) + } + server.Unlock() +} + +func removeSocket(sockets []*mongoSocket, socket *mongoSocket) []*mongoSocket { + for i, s := range sockets { + if s == socket { + copy(sockets[i:], sockets[i+1:]) + n := len(sockets) - 1 + sockets[n] = nil + sockets = sockets[:n] + break + } + } + return sockets +} + +// AbendSocket notifies the server that the given socket has terminated +// abnormally, and thus should be discarded rather than cached. +func (server *mongoServer) AbendSocket(socket *mongoSocket) { + server.Lock() + server.abended = true + if server.closed { + server.Unlock() + return + } + server.liveSockets = removeSocket(server.liveSockets, socket) + server.unusedSockets = removeSocket(server.unusedSockets, socket) + server.Unlock() + // Maybe just a timeout, but suggest a cluster sync up just in case. + select { + case server.sync <- true: + default: + } +} + +func (server *mongoServer) SetInfo(info *mongoServerInfo) { + server.Lock() + server.info = info + server.Unlock() +} + +func (server *mongoServer) Info() *mongoServerInfo { + server.Lock() + info := server.info + server.Unlock() + return info +} + +func (server *mongoServer) hasTags(serverTags []bson.D) bool { +NextTagSet: + for _, tags := range serverTags { + NextReqTag: + for _, req := range tags { + for _, has := range server.info.Tags { + if req.Name == has.Name { + if req.Value == has.Value { + continue NextReqTag + } + continue NextTagSet + } + } + continue NextTagSet + } + return true + } + return false +} + +var pingDelay = 5 * time.Second + +func (server *mongoServer) pinger(loop bool) { + var delay time.Duration + if raceDetector { + // This variable is only ever touched by tests. + globalMutex.Lock() + delay = pingDelay + globalMutex.Unlock() + } else { + delay = pingDelay + } + op := queryOp{ + collection: "admin.$cmd", + query: bson.D{{"ping", 1}}, + flags: flagSlaveOk, + limit: -1, + } + for { + if loop { + time.Sleep(delay) + } + op := op + socket, _, err := server.AcquireSocket(0, 3*delay) + if err == nil { + start := time.Now() + _, _ = socket.SimpleQuery(&op) + delay := time.Now().Sub(start) + + server.pingWindow[server.pingIndex] = delay + server.pingIndex = (server.pingIndex + 1) % len(server.pingWindow) + server.pingCount++ + var max time.Duration + for i := 0; i < len(server.pingWindow) && uint32(i) < server.pingCount; i++ { + if server.pingWindow[i] > max { + max = server.pingWindow[i] + } + } + socket.Release() + server.Lock() + if server.closed { + loop = false + } + server.pingValue = max + server.Unlock() + logf("Ping for %s is %d ms", server.Addr, max/time.Millisecond) + } else if err == errServerClosed { + return + } + if !loop { + return + } + } +} + +type mongoServerSlice []*mongoServer + +func (s mongoServerSlice) Len() int { + return len(s) +} + +func (s mongoServerSlice) Less(i, j int) bool { + return s[i].ResolvedAddr < s[j].ResolvedAddr +} + +func (s mongoServerSlice) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func (s mongoServerSlice) Sort() { + sort.Sort(s) +} + +func (s mongoServerSlice) Search(resolvedAddr string) (i int, ok bool) { + n := len(s) + i = sort.Search(n, func(i int) bool { + return s[i].ResolvedAddr >= resolvedAddr + }) + return i, i != n && s[i].ResolvedAddr == resolvedAddr +} + +type mongoServers struct { + slice mongoServerSlice +} + +func (servers *mongoServers) Search(resolvedAddr string) (server *mongoServer) { + if i, ok := servers.slice.Search(resolvedAddr); ok { + return servers.slice[i] + } + return nil +} + +func (servers *mongoServers) Add(server *mongoServer) { + servers.slice = append(servers.slice, server) + servers.slice.Sort() +} + +func (servers *mongoServers) Remove(other *mongoServer) (server *mongoServer) { + if i, found := servers.slice.Search(other.ResolvedAddr); found { + server = servers.slice[i] + copy(servers.slice[i:], servers.slice[i+1:]) + n := len(servers.slice) - 1 + servers.slice[n] = nil // Help GC. + servers.slice = servers.slice[:n] + } + return +} + +func (servers *mongoServers) Slice() []*mongoServer { + return ([]*mongoServer)(servers.slice) +} + +func (servers *mongoServers) Get(i int) *mongoServer { + return servers.slice[i] +} + +func (servers *mongoServers) Len() int { + return len(servers.slice) +} + +func (servers *mongoServers) Empty() bool { + return len(servers.slice) == 0 +} + +// BestFit returns the best guess of what would be the most interesting +// server to perform operations on at this point in time. +func (servers *mongoServers) BestFit(serverTags []bson.D) *mongoServer { + var best *mongoServer + for _, next := range servers.slice { + if best == nil { + best = next + best.RLock() + if serverTags != nil && !next.info.Mongos && !best.hasTags(serverTags) { + best.RUnlock() + best = nil + } + continue + } + next.RLock() + swap := false + switch { + case serverTags != nil && !next.info.Mongos && !next.hasTags(serverTags): + // Must have requested tags. + case next.info.Master != best.info.Master: + // Prefer slaves. + swap = best.info.Master + case absDuration(next.pingValue-best.pingValue) > 15*time.Millisecond: + // Prefer nearest server. + swap = next.pingValue < best.pingValue + case len(next.liveSockets)-len(next.unusedSockets) < len(best.liveSockets)-len(best.unusedSockets): + // Prefer servers with less connections. + swap = true + } + if swap { + best.RUnlock() + best = next + } else { + next.RUnlock() + } + } + if best != nil { + best.RUnlock() + } + return best +} + +func absDuration(d time.Duration) time.Duration { + if d < 0 { + return -d + } + return d +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/session.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/session.go new file mode 100644 index 0000000..73305b2 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/session.go @@ -0,0 +1,3867 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "crypto/md5" + "encoding/hex" + "errors" + "fmt" + "math" + "net" + "net/url" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +type mode int + +const ( + Eventual mode = 0 + Monotonic mode = 1 + Strong mode = 2 +) + +// When changing the Session type, check if newSession and copySession +// need to be updated too. + +type Session struct { + m sync.RWMutex + cluster_ *mongoCluster + slaveSocket *mongoSocket + masterSocket *mongoSocket + slaveOk bool + consistency mode + queryConfig query + safeOp *queryOp + syncTimeout time.Duration + sockTimeout time.Duration + defaultdb string + sourcedb string + dialCred *Credential + creds []Credential + poolLimit int +} + +type Database struct { + Session *Session + Name string +} + +type Collection struct { + Database *Database + Name string // "collection" + FullName string // "db.collection" +} + +type Query struct { + m sync.Mutex + session *Session + query // Enables default settings in session. +} + +type query struct { + op queryOp + prefetch float64 + limit int32 +} + +type getLastError struct { + CmdName int "getLastError" + W interface{} "w,omitempty" + WTimeout int "wtimeout,omitempty" + FSync bool "fsync,omitempty" + J bool "j,omitempty" +} + +type Iter struct { + m sync.Mutex + gotReply sync.Cond + session *Session + server *mongoServer + docData queue + err error + op getMoreOp + prefetch float64 + limit int32 + docsToReceive int + docsBeforeMore int + timeout time.Duration + timedout bool +} + +var ErrNotFound = errors.New("not found") + +const defaultPrefetch = 0.25 + +// Dial establishes a new session to the cluster identified by the given seed +// server(s). The session will enable communication with all of the servers in +// the cluster, so the seed servers are used only to find out about the cluster +// topology. +// +// Dial will timeout after 10 seconds if a server isn't reached. The returned +// session will timeout operations after one minute by default if servers +// aren't available. To customize the timeout, see DialWithTimeout, +// SetSyncTimeout, and SetSocketTimeout. +// +// This method is generally called just once for a given cluster. Further +// sessions to the same cluster are then established using the New or Copy +// methods on the obtained session. This will make them share the underlying +// cluster, and manage the pool of connections appropriately. +// +// Once the session is not useful anymore, Close must be called to release the +// resources appropriately. +// +// The seed servers must be provided in the following format: +// +// [mongodb://][user:pass@]host1[:port1][,host2[:port2],...][/database][?options] +// +// For example, it may be as simple as: +// +// localhost +// +// Or more involved like: +// +// mongodb://myuser:mypass@localhost:40001,otherhost:40001/mydb +// +// If the port number is not provided for a server, it defaults to 27017. +// +// The username and password provided in the URL will be used to authenticate +// into the database named after the slash at the end of the host names, or +// into the "admin" database if none is provided. The authentication information +// will persist in sessions obtained through the New method as well. +// +// The following connection options are supported after the question mark: +// +// connect=direct +// +// Disables the automatic replica set server discovery logic, and +// forces the use of servers provided only (even if secondaries). +// Note that to talk to a secondary the consistency requirements +// must be relaxed to Monotonic or Eventual via SetMode. +// +// +// authSource= +// +// Informs the database used to establish credentials and privileges +// with a MongoDB server. Defaults to the database name provided via +// the URL path, and "admin" if that's unset. +// +// +// authMechanism= +// +// Defines the protocol for credential negotiation. Defaults to "MONGODB-CR", +// which is the default username/password challenge-response mechanism. +// +// +// gssapiServiceName= +// +// Defines the service name to use when authenticating with the GSSAPI +// mechanism. Defaults to "mongodb". +// +// maxPoolSize= +// +// Defines the per-server socket pool limit. Defaults to 4096. +// See Session.SetPoolLimit for details. +// +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/connection-string/ +// +func Dial(url string) (*Session, error) { + session, err := DialWithTimeout(url, 10*time.Second) + if err == nil { + session.SetSyncTimeout(1 * time.Minute) + session.SetSocketTimeout(1 * time.Minute) + } + return session, err +} + +// DialWithTimeout works like Dial, but uses timeout as the amount of time to +// wait for a server to respond when first connecting and also on follow up +// operations in the session. If timeout is zero, the call may block +// forever waiting for a connection to be made. +// +// See SetSyncTimeout for customizing the timeout for the session. +func DialWithTimeout(url string, timeout time.Duration) (*Session, error) { + uinfo, err := parseURL(url) + if err != nil { + return nil, err + } + direct := false + mechanism := "" + service := "" + source := "" + poolLimit := 0 + for k, v := range uinfo.options { + switch k { + case "authSource": + source = v + case "authMechanism": + mechanism = v + case "gssapiServiceName": + service = v + case "maxPoolSize": + poolLimit, err = strconv.Atoi(v) + if err != nil { + return nil, errors.New("bad value for maxPoolSize: " + v) + } + case "connect": + if v == "direct" { + direct = true + break + } + if v == "replicaSet" { + break + } + fallthrough + default: + return nil, errors.New("unsupported connection URL option: " + k + "=" + v) + } + } + info := DialInfo{ + Addrs: uinfo.addrs, + Direct: direct, + Timeout: timeout, + Database: uinfo.db, + Username: uinfo.user, + Password: uinfo.pass, + Mechanism: mechanism, + Service: service, + Source: source, + PoolLimit: poolLimit, + } + return DialWithInfo(&info) +} + +// DialInfo holds options for establishing a session with a MongoDB cluster. +// To use a URL, see the Dial function. +type DialInfo struct { + // Addrs holds the addresses for the seed servers. + Addrs []string + + // Direct informs whether to establish connections only with the + // specified seed servers, or to obtain information for the whole + // cluster and establish connections with further servers too. + Direct bool + + // Timeout is the amount of time to wait for a server to respond when + // first connecting and on follow up operations in the session. If + // timeout is zero, the call may block forever waiting for a connection + // to be established. + Timeout time.Duration + + // FailFast will cause connection and query attempts to fail faster when + // the server is unavailable, instead of retrying until the configured + // timeout period. Note that an unavailable server may silently drop + // packets instead of rejecting them, in which case it's impossible to + // distinguish it from a slow server, so the timeout stays relevant. + FailFast bool + + // Database is the default database name used when the Session.DB method + // is called with an empty name, and is also used during the intial + // authenticatoin if Source is unset. + Database string + + // Source is the database used to establish credentials and privileges + // with a MongoDB server. Defaults to the value of Database, if that is + // set, or "admin" otherwise. + Source string + + // Service defines the service name to use when authenticating with the GSSAPI + // mechanism. Defaults to "mongodb". + Service string + + // ServiceHost defines which hostname to use when authenticating + // with the GSSAPI mechanism. If not specified, defaults to the MongoDB + // server's address. + ServiceHost string + + // Mechanism defines the protocol for credential negotiation. + // Defaults to "MONGODB-CR". + Mechanism string + + // Username and Password inform the credentials for the initial authentication + // done on the database defined by the Source field. See Session.Login. + Username string + Password string + + // PoolLimit defines the per-server socket pool limit. Defaults to 4096. + // See Session.SetPoolLimit for details. + PoolLimit int + + // DialServer optionally specifies the dial function for establishing + // connections with the MongoDB servers. + DialServer func(addr *ServerAddr) (net.Conn, error) + + // WARNING: This field is obsolete. See DialServer above. + Dial func(addr net.Addr) (net.Conn, error) +} + +// ServerAddr represents the address for establishing a connection to an +// individual MongoDB server. +type ServerAddr struct { + str string + tcp *net.TCPAddr +} + +// String returns the address that was provided for the server before resolution. +func (addr *ServerAddr) String() string { + return addr.str +} + +// TCPAddr returns the resolved TCP address for the server. +func (addr *ServerAddr) TCPAddr() *net.TCPAddr { + return addr.tcp +} + +// DialWithInfo establishes a new session to the cluster identified by info. +func DialWithInfo(info *DialInfo) (*Session, error) { + addrs := make([]string, len(info.Addrs)) + for i, addr := range info.Addrs { + p := strings.LastIndexAny(addr, "]:") + if p == -1 || addr[p] != ':' { + // XXX This is untested. The test suite doesn't use the standard port. + addr += ":27017" + } + addrs[i] = addr + } + cluster := newCluster(addrs, info.Direct, info.FailFast, dialer{info.Dial, info.DialServer}) + session := newSession(Eventual, cluster, info.Timeout) + session.defaultdb = info.Database + if session.defaultdb == "" { + session.defaultdb = "test" + } + session.sourcedb = info.Source + if session.sourcedb == "" { + session.sourcedb = info.Database + if session.sourcedb == "" { + session.sourcedb = "admin" + } + } + if info.Username != "" { + source := session.sourcedb + if info.Source == "" && + (info.Mechanism == "GSSAPI" || info.Mechanism == "PLAIN" || info.Mechanism == "MONGODB-X509") { + source = "$external" + } + session.dialCred = &Credential{ + Username: info.Username, + Password: info.Password, + Mechanism: info.Mechanism, + Service: info.Service, + ServiceHost: info.ServiceHost, + Source: source, + } + session.creds = []Credential{*session.dialCred} + } + if info.PoolLimit > 0 { + session.poolLimit = info.PoolLimit + } + cluster.Release() + + // People get confused when we return a session that is not actually + // established to any servers yet (e.g. what if url was wrong). So, + // ping the server to ensure there's someone there, and abort if it + // fails. + if err := session.Ping(); err != nil { + session.Close() + return nil, err + } + session.SetMode(Strong, true) + return session, nil +} + +func isOptSep(c rune) bool { + return c == ';' || c == '&' +} + +type urlInfo struct { + addrs []string + user string + pass string + db string + options map[string]string +} + +func parseURL(s string) (*urlInfo, error) { + if strings.HasPrefix(s, "mongodb://") { + s = s[10:] + } + info := &urlInfo{options: make(map[string]string)} + if c := strings.Index(s, "?"); c != -1 { + for _, pair := range strings.FieldsFunc(s[c+1:], isOptSep) { + l := strings.SplitN(pair, "=", 2) + if len(l) != 2 || l[0] == "" || l[1] == "" { + return nil, errors.New("connection option must be key=value: " + pair) + } + info.options[l[0]] = l[1] + } + s = s[:c] + } + if c := strings.Index(s, "@"); c != -1 { + pair := strings.SplitN(s[:c], ":", 2) + if len(pair) > 2 || pair[0] == "" { + return nil, errors.New("credentials must be provided as user:pass@host") + } + var err error + info.user, err = url.QueryUnescape(pair[0]) + if err != nil { + return nil, fmt.Errorf("cannot unescape username in URL: %q", pair[0]) + } + if len(pair) > 1 { + info.pass, err = url.QueryUnescape(pair[1]) + if err != nil { + return nil, fmt.Errorf("cannot unescape password in URL") + } + } + s = s[c+1:] + } + if c := strings.Index(s, "/"); c != -1 { + info.db = s[c+1:] + s = s[:c] + } + info.addrs = strings.Split(s, ",") + return info, nil +} + +func newSession(consistency mode, cluster *mongoCluster, timeout time.Duration) (session *Session) { + cluster.Acquire() + session = &Session{ + cluster_: cluster, + syncTimeout: timeout, + sockTimeout: timeout, + poolLimit: 4096, + } + debugf("New session %p on cluster %p", session, cluster) + session.SetMode(consistency, true) + session.SetSafe(&Safe{}) + session.queryConfig.prefetch = defaultPrefetch + return session +} + +func copySession(session *Session, keepCreds bool) (s *Session) { + cluster := session.cluster() + cluster.Acquire() + if session.masterSocket != nil { + session.masterSocket.Acquire() + } + if session.slaveSocket != nil { + session.slaveSocket.Acquire() + } + var creds []Credential + if keepCreds { + creds = make([]Credential, len(session.creds)) + copy(creds, session.creds) + } else if session.dialCred != nil { + creds = []Credential{*session.dialCred} + } + scopy := *session + scopy.m = sync.RWMutex{} + scopy.creds = creds + s = &scopy + debugf("New session %p on cluster %p (copy from %p)", s, cluster, session) + return s +} + +// LiveServers returns a list of server addresses which are +// currently known to be alive. +func (s *Session) LiveServers() (addrs []string) { + s.m.RLock() + addrs = s.cluster().LiveServers() + s.m.RUnlock() + return addrs +} + +// DB returns a value representing the named database. If name +// is empty, the database name provided in the dialed URL is +// used instead. If that is also empty, "test" is used as a +// fallback in a way equivalent to the mongo shell. +// +// Creating this value is a very lightweight operation, and +// involves no network communication. +func (s *Session) DB(name string) *Database { + if name == "" { + name = s.defaultdb + } + return &Database{s, name} +} + +// C returns a value representing the named collection. +// +// Creating this value is a very lightweight operation, and +// involves no network communication. +func (db *Database) C(name string) *Collection { + return &Collection{db, name, db.Name + "." + name} +} + +// With returns a copy of db that uses session s. +func (db *Database) With(s *Session) *Database { + newdb := *db + newdb.Session = s + return &newdb +} + +// With returns a copy of c that uses session s. +func (c *Collection) With(s *Session) *Collection { + newdb := *c.Database + newdb.Session = s + newc := *c + newc.Database = &newdb + return &newc +} + +// GridFS returns a GridFS value representing collections in db that +// follow the standard GridFS specification. +// The provided prefix (sometimes known as root) will determine which +// collections to use, and is usually set to "fs" when there is a +// single GridFS in the database. +// +// See the GridFS Create, Open, and OpenId methods for more details. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/GridFS +// http://www.mongodb.org/display/DOCS/GridFS+Tools +// http://www.mongodb.org/display/DOCS/GridFS+Specification +// +func (db *Database) GridFS(prefix string) *GridFS { + return newGridFS(db, prefix) +} + +// Run issues the provided command on the db database and unmarshals +// its result in the respective argument. The cmd argument may be either +// a string with the command name itself, in which case an empty document of +// the form bson.M{cmd: 1} will be used, or it may be a full command document. +// +// Note that MongoDB considers the first marshalled key as the command +// name, so when providing a command with options, it's important to +// use an ordering-preserving document, such as a struct value or an +// instance of bson.D. For instance: +// +// db.Run(bson.D{{"create", "mycollection"}, {"size", 1024}}) +// +// For privilleged commands typically run on the "admin" database, see +// the Run method in the Session type. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Commands +// http://www.mongodb.org/display/DOCS/List+of+Database+CommandSkips +// +func (db *Database) Run(cmd interface{}, result interface{}) error { + if name, ok := cmd.(string); ok { + cmd = bson.D{{name, 1}} + } + return db.C("$cmd").Find(cmd).One(result) +} + +// Credential holds details to authenticate with a MongoDB server. +type Credential struct { + // Username and Password hold the basic details for authentication. + // Password is optional with some authentication mechanisms. + Username string + Password string + + // Source is the database used to establish credentials and privileges + // with a MongoDB server. Defaults to the default database provided + // during dial, or "admin" if that was unset. + Source string + + // Service defines the service name to use when authenticating with the GSSAPI + // mechanism. Defaults to "mongodb". + Service string + + // ServiceHost defines which hostname to use when authenticating + // with the GSSAPI mechanism. If not specified, defaults to the MongoDB + // server's address. + ServiceHost string + + // Mechanism defines the protocol for credential negotiation. + // Defaults to "MONGODB-CR". + Mechanism string +} + +// Login authenticates with MongoDB using the provided credential. The +// authentication is valid for the whole session and will stay valid until +// Logout is explicitly called for the same database, or the session is +// closed. +func (db *Database) Login(user, pass string) error { + return db.Session.Login(&Credential{Username: user, Password: pass, Source: db.Name}) +} + +// Login authenticates with MongoDB using the provided credential. The +// authentication is valid for the whole session and will stay valid until +// Logout is explicitly called for the same database, or the session is +// closed. +func (s *Session) Login(cred *Credential) error { + socket, err := s.acquireSocket(true) + if err != nil { + return err + } + defer socket.Release() + + credCopy := *cred + if cred.Source == "" { + if cred.Mechanism == "GSSAPI" { + credCopy.Source = "$external" + } else { + credCopy.Source = s.sourcedb + } + } + err = socket.Login(credCopy) + if err != nil { + return err + } + + s.m.Lock() + s.creds = append(s.creds, credCopy) + s.m.Unlock() + return nil +} + +func (s *Session) socketLogin(socket *mongoSocket) error { + for _, cred := range s.creds { + if err := socket.Login(cred); err != nil { + return err + } + } + return nil +} + +// Logout removes any established authentication credentials for the database. +func (db *Database) Logout() { + session := db.Session + dbname := db.Name + session.m.Lock() + found := false + for i, cred := range session.creds { + if cred.Source == dbname { + copy(session.creds[i:], session.creds[i+1:]) + session.creds = session.creds[:len(session.creds)-1] + found = true + break + } + } + if found { + if session.masterSocket != nil { + session.masterSocket.Logout(dbname) + } + if session.slaveSocket != nil { + session.slaveSocket.Logout(dbname) + } + } + session.m.Unlock() +} + +// LogoutAll removes all established authentication credentials for the session. +func (s *Session) LogoutAll() { + s.m.Lock() + for _, cred := range s.creds { + if s.masterSocket != nil { + s.masterSocket.Logout(cred.Source) + } + if s.slaveSocket != nil { + s.slaveSocket.Logout(cred.Source) + } + } + s.creds = s.creds[0:0] + s.m.Unlock() +} + +// User represents a MongoDB user. +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/privilege-documents/ +// http://docs.mongodb.org/manual/reference/user-privileges/ +// +type User struct { + // Username is how the user identifies itself to the system. + Username string `bson:"user"` + + // Password is the plaintext password for the user. If set, + // the UpsertUser method will hash it into PasswordHash and + // unset it before the user is added to the database. + Password string `bson:",omitempty"` + + // PasswordHash is the MD5 hash of Username+":mongo:"+Password. + PasswordHash string `bson:"pwd,omitempty"` + + // CustomData holds arbitrary data admins decide to associate + // with this user, such as the full name or employee id. + CustomData interface{} `bson:"customData,omitempty"` + + // Roles indicates the set of roles the user will be provided. + // See the Role constants. + Roles []Role `bson:"roles"` + + // OtherDBRoles allows assigning roles in other databases from + // user documents inserted in the admin database. This field + // only works in the admin database. + OtherDBRoles map[string][]Role `bson:"otherDBRoles,omitempty"` + + // UserSource indicates where to look for this user's credentials. + // It may be set to a database name, or to "$external" for + // consulting an external resource such as Kerberos. UserSource + // must not be set if Password or PasswordHash are present. + // + // WARNING: This setting was only ever supported in MongoDB 2.4, + // and is now obsolete. + UserSource string `bson:"userSource,omitempty"` +} + +type Role string + +const ( + // Relevant documentation: + // + // http://docs.mongodb.org/manual/reference/user-privileges/ + // + RoleRoot Role = "root" + RoleRead Role = "read" + RoleReadAny Role = "readAnyDatabase" + RoleReadWrite Role = "readWrite" + RoleReadWriteAny Role = "readWriteAnyDatabase" + RoleDBAdmin Role = "dbAdmin" + RoleDBAdminAny Role = "dbAdminAnyDatabase" + RoleUserAdmin Role = "userAdmin" + RoleUserAdminAny Role = "userAdminAnyDatabase" + RoleClusterAdmin Role = "clusterAdmin" +) + +// UpsertUser updates the authentication credentials and the roles for +// a MongoDB user within the db database. If the named user doesn't exist +// it will be created. +// +// This method should only be used from MongoDB 2.4 and on. For older +// MongoDB releases, use the obsolete AddUser method instead. +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/user-privileges/ +// http://docs.mongodb.org/manual/reference/privilege-documents/ +// +func (db *Database) UpsertUser(user *User) error { + if user.Username == "" { + return fmt.Errorf("user has no Username") + } + if (user.Password != "" || user.PasswordHash != "") && user.UserSource != "" { + return fmt.Errorf("user has both Password/PasswordHash and UserSource set") + } + if len(user.OtherDBRoles) > 0 && db.Name != "admin" && db.Name != "$external" { + return fmt.Errorf("user with OtherDBRoles is only supported in the admin or $external databases") + } + + // Attempt to run this using 2.6+ commands. + rundb := db + if user.UserSource != "" { + // Compatibility logic for the userSource field of MongoDB <= 2.4.X + rundb = db.Session.DB(user.UserSource) + } + err := rundb.runUserCmd("updateUser", user) + // retry with createUser when isAuthError in order to enable the "localhost exception" + if isNotFound(err) || isAuthError(err) { + return rundb.runUserCmd("createUser", user) + } + if !isNoCmd(err) { + return err + } + + // Command does not exist. Fallback to pre-2.6 behavior. + var set, unset bson.D + if user.Password != "" { + psum := md5.New() + psum.Write([]byte(user.Username + ":mongo:" + user.Password)) + set = append(set, bson.DocElem{"pwd", hex.EncodeToString(psum.Sum(nil))}) + unset = append(unset, bson.DocElem{"userSource", 1}) + } else if user.PasswordHash != "" { + set = append(set, bson.DocElem{"pwd", user.PasswordHash}) + unset = append(unset, bson.DocElem{"userSource", 1}) + } + if user.UserSource != "" { + set = append(set, bson.DocElem{"userSource", user.UserSource}) + unset = append(unset, bson.DocElem{"pwd", 1}) + } + if user.Roles != nil || user.OtherDBRoles != nil { + set = append(set, bson.DocElem{"roles", user.Roles}) + if len(user.OtherDBRoles) > 0 { + set = append(set, bson.DocElem{"otherDBRoles", user.OtherDBRoles}) + } else { + unset = append(unset, bson.DocElem{"otherDBRoles", 1}) + } + } + users := db.C("system.users") + err = users.Update(bson.D{{"user", user.Username}}, bson.D{{"$unset", unset}, {"$set", set}}) + if err == ErrNotFound { + set = append(set, bson.DocElem{"user", user.Username}) + if user.Roles == nil && user.OtherDBRoles == nil { + // Roles must be sent, as it's the way MongoDB distinguishes + // old-style documents from new-style documents in pre-2.6. + set = append(set, bson.DocElem{"roles", user.Roles}) + } + err = users.Insert(set) + } + return err +} + +func isNoCmd(err error) bool { + e, ok := err.(*QueryError) + return ok && strings.HasPrefix(e.Message, "no such cmd:") +} + +func isNotFound(err error) bool { + e, ok := err.(*QueryError) + return ok && e.Code == 11 +} + +func isAuthError(err error) bool { + e, ok := err.(*QueryError) + return ok && e.Code == 13 +} + +func (db *Database) runUserCmd(cmdName string, user *User) error { + cmd := make(bson.D, 0, 16) + cmd = append(cmd, bson.DocElem{cmdName, user.Username}) + if user.Password != "" { + cmd = append(cmd, bson.DocElem{"pwd", user.Password}) + } + var roles []interface{} + for _, role := range user.Roles { + roles = append(roles, role) + } + for db, dbroles := range user.OtherDBRoles { + for _, role := range dbroles { + roles = append(roles, bson.D{{"role", role}, {"db", db}}) + } + } + if roles != nil || user.Roles != nil || cmdName == "createUser" { + cmd = append(cmd, bson.DocElem{"roles", roles}) + } + err := db.Run(cmd, nil) + if !isNoCmd(err) && user.UserSource != "" && (user.UserSource != "$external" || db.Name != "$external") { + return fmt.Errorf("MongoDB 2.6+ does not support the UserSource setting") + } + return err +} + +// AddUser creates or updates the authentication credentials of user within +// the db database. +// +// WARNING: This method is obsolete and should only be used with MongoDB 2.2 +// or earlier. For MongoDB 2.4 and on, use UpsertUser instead. +func (db *Database) AddUser(username, password string, readOnly bool) error { + // Try to emulate the old behavior on 2.6+ + user := &User{Username: username, Password: password} + if db.Name == "admin" { + if readOnly { + user.Roles = []Role{RoleReadAny} + } else { + user.Roles = []Role{RoleReadWriteAny} + } + } else { + if readOnly { + user.Roles = []Role{RoleRead} + } else { + user.Roles = []Role{RoleReadWrite} + } + } + err := db.runUserCmd("updateUser", user) + if isNotFound(err) { + return db.runUserCmd("createUser", user) + } + if !isNoCmd(err) { + return err + } + + // Command doesn't exist. Fallback to pre-2.6 behavior. + psum := md5.New() + psum.Write([]byte(username + ":mongo:" + password)) + digest := hex.EncodeToString(psum.Sum(nil)) + c := db.C("system.users") + _, err = c.Upsert(bson.M{"user": username}, bson.M{"$set": bson.M{"user": username, "pwd": digest, "readOnly": readOnly}}) + return err +} + +// RemoveUser removes the authentication credentials of user from the database. +func (db *Database) RemoveUser(user string) error { + err := db.Run(bson.D{{"dropUser", user}}, nil) + if isNoCmd(err) { + users := db.C("system.users") + return users.Remove(bson.M{"user": user}) + } + if isNotFound(err) { + return ErrNotFound + } + return err +} + +type indexSpec struct { + Name, NS string + Key bson.D + Unique bool ",omitempty" + DropDups bool "dropDups,omitempty" + Background bool ",omitempty" + Sparse bool ",omitempty" + Bits, Min, Max int ",omitempty" + ExpireAfter int "expireAfterSeconds,omitempty" + Weights bson.D ",omitempty" + DefaultLanguage string "default_language,omitempty" + LanguageOverride string "language_override,omitempty" +} + +type Index struct { + Key []string // Index key fields; prefix name with dash (-) for descending order + Unique bool // Prevent two documents from having the same index key + DropDups bool // Drop documents with the same index key as a previously indexed one + Background bool // Build index in background and return immediately + Sparse bool // Only index documents containing the Key fields + + // If ExpireAfter is defined the server will periodically delete + // documents with indexed time.Time older than the provided delta. + ExpireAfter time.Duration + + // Index name computed by EnsureIndex during creation. + Name string + + // Properties for spatial indexes. + Bits, Min, Max int + + // Properties for text indexes. + DefaultLanguage string + LanguageOverride string +} + +type indexKeyInfo struct { + name string + key bson.D + weights bson.D +} + +func parseIndexKey(key []string) (*indexKeyInfo, error) { + var keyInfo indexKeyInfo + isText := false + var order interface{} + for _, field := range key { + raw := field + if keyInfo.name != "" { + keyInfo.name += "_" + } + var kind string + if field != "" { + if field[0] == '$' { + if c := strings.Index(field, ":"); c > 1 && c < len(field)-1 { + kind = field[1:c] + field = field[c+1:] + keyInfo.name += field + "_" + kind + } + } + switch field[0] { + case '$': + // Logic above failed. Reset and error. + field = "" + case '@': + order = "2d" + field = field[1:] + // The shell used to render this field as key_ instead of key_2d, + // and mgo followed suit. This has been fixed in recent server + // releases, and mgo followed as well. + keyInfo.name += field + "_2d" + case '-': + order = -1 + field = field[1:] + keyInfo.name += field + "_-1" + case '+': + field = field[1:] + fallthrough + default: + if kind == "" { + order = 1 + keyInfo.name += field + "_1" + } else { + order = kind + } + } + } + if field == "" || kind != "" && order != kind { + return nil, fmt.Errorf(`invalid index key: want "[$:][-]", got %q`, raw) + } + if kind == "text" { + if !isText { + keyInfo.key = append(keyInfo.key, bson.DocElem{"_fts", "text"}, bson.DocElem{"_ftsx", 1}) + isText = true + } + keyInfo.weights = append(keyInfo.weights, bson.DocElem{field, 1}) + } else { + keyInfo.key = append(keyInfo.key, bson.DocElem{field, order}) + } + } + if keyInfo.name == "" { + return nil, errors.New("invalid index key: no fields provided") + } + return &keyInfo, nil +} + +// EnsureIndexKey ensures an index with the given key exists, creating it +// if necessary. +// +// This example: +// +// err := collection.EnsureIndexKey("a", "b") +// +// Is equivalent to: +// +// err := collection.EnsureIndex(mgo.Index{Key: []string{"a", "b"}}) +// +// See the EnsureIndex method for more details. +func (c *Collection) EnsureIndexKey(key ...string) error { + return c.EnsureIndex(Index{Key: key}) +} + +// EnsureIndex ensures an index with the given key exists, creating it with +// the provided parameters if necessary. +// +// Once EnsureIndex returns successfully, following requests for the same index +// will not contact the server unless Collection.DropIndex is used to drop the +// same index, or Session.ResetIndexCache is called. +// +// For example: +// +// index := Index{ +// Key: []string{"lastname", "firstname"}, +// Unique: true, +// DropDups: true, +// Background: true, // See notes. +// Sparse: true, +// } +// err := collection.EnsureIndex(index) +// +// The Key value determines which fields compose the index. The index ordering +// will be ascending by default. To obtain an index with a descending order, +// the field name should be prefixed by a dash (e.g. []string{"-time"}). +// +// If Unique is true, the index must necessarily contain only a single +// document per Key. With DropDups set to true, documents with the same key +// as a previously indexed one will be dropped rather than an error returned. +// +// If Background is true, other connections will be allowed to proceed using +// the collection without the index while it's being built. Note that the +// session executing EnsureIndex will be blocked for as long as it takes for +// the index to be built. +// +// If Sparse is true, only documents containing the provided Key fields will be +// included in the index. When using a sparse index for sorting, only indexed +// documents will be returned. +// +// If ExpireAfter is non-zero, the server will periodically scan the collection +// and remove documents containing an indexed time.Time field with a value +// older than ExpireAfter. See the documentation for details: +// +// http://docs.mongodb.org/manual/tutorial/expire-data +// +// Other kinds of indexes are also supported through that API. Here is an example: +// +// index := Index{ +// Key: []string{"$2d:loc"}, +// Bits: 26, +// } +// err := collection.EnsureIndex(index) +// +// The example above requests the creation of a "2d" index for the "loc" field. +// +// The 2D index bounds may be changed using the Min and Max attributes of the +// Index value. The default bound setting of (-180, 180) is suitable for +// latitude/longitude pairs. +// +// The Bits parameter sets the precision of the 2D geohash values. If not +// provided, 26 bits are used, which is roughly equivalent to 1 foot of +// precision for the default (-180, 180) index bounds. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Indexes +// http://www.mongodb.org/display/DOCS/Indexing+Advice+and+FAQ +// http://www.mongodb.org/display/DOCS/Indexing+as+a+Background+Operation +// http://www.mongodb.org/display/DOCS/Geospatial+Indexing +// http://www.mongodb.org/display/DOCS/Multikeys +// +func (c *Collection) EnsureIndex(index Index) error { + keyInfo, err := parseIndexKey(index.Key) + if err != nil { + return err + } + + session := c.Database.Session + cacheKey := c.FullName + "\x00" + keyInfo.name + if session.cluster().HasCachedIndex(cacheKey) { + return nil + } + + spec := indexSpec{ + Name: keyInfo.name, + NS: c.FullName, + Key: keyInfo.key, + Unique: index.Unique, + DropDups: index.DropDups, + Background: index.Background, + Sparse: index.Sparse, + Bits: index.Bits, + Min: index.Min, + Max: index.Max, + ExpireAfter: int(index.ExpireAfter / time.Second), + Weights: keyInfo.weights, + DefaultLanguage: index.DefaultLanguage, + LanguageOverride: index.LanguageOverride, + } + + session = session.Clone() + defer session.Close() + session.SetMode(Strong, false) + session.EnsureSafe(&Safe{}) + + db := c.Database.With(session) + err = db.C("system.indexes").Insert(&spec) + if err == nil { + session.cluster().CacheIndex(cacheKey, true) + } + session.Close() + return err +} + +// DropIndex removes the index with key from the collection. +// +// The key value determines which fields compose the index. The index ordering +// will be ascending by default. To obtain an index with a descending order, +// the field name should be prefixed by a dash (e.g. []string{"-time"}). +// +// For example: +// +// err := collection.DropIndex("lastname", "firstname") +// +// See the EnsureIndex method for more details on indexes. +func (c *Collection) DropIndex(key ...string) error { + keyInfo, err := parseIndexKey(key) + if err != nil { + return err + } + + session := c.Database.Session + cacheKey := c.FullName + "\x00" + keyInfo.name + session.cluster().CacheIndex(cacheKey, false) + + session = session.Clone() + defer session.Close() + session.SetMode(Strong, false) + + db := c.Database.With(session) + result := struct { + ErrMsg string + Ok bool + }{} + err = db.Run(bson.D{{"dropIndexes", c.Name}, {"index", keyInfo.name}}, &result) + if err != nil { + return err + } + if !result.Ok { + return errors.New(result.ErrMsg) + } + return nil +} + +// Indexes returns a list of all indexes for the collection. +// +// For example, this snippet would drop all available indexes: +// +// indexes, err := collection.Indexes() +// if err != nil { +// return err +// } +// for _, index := range indexes { +// err = collection.DropIndex(index.Key...) +// if err != nil { +// return err +// } +// } +// +// See the EnsureIndex method for more details on indexes. +func (c *Collection) Indexes() (indexes []Index, err error) { + // Try with a command. + var cmdResult struct { + Indexes []indexSpec + } + err = c.Database.Run(bson.D{{"listIndexes", c.Name}}, &cmdResult) + if err == nil { + for _, spec := range cmdResult.Indexes { + indexes = append(indexes, indexFromSpec(spec)) + } + sort.Sort(indexSlice(indexes)) + return indexes, nil + } + if err != nil && !isNoCmd(err) { + return nil, err + } + + // Command not yet supported. Query the database instead. + query := c.Database.C("system.indexes").Find(bson.M{"ns": c.FullName}) + iter := query.Sort("name").Iter() + for { + var spec indexSpec + if !iter.Next(&spec) { + break + } + indexes = append(indexes, indexFromSpec(spec)) + } + err = iter.Close() + sort.Sort(indexSlice(indexes)) + return indexes, nil +} + +func indexFromSpec(spec indexSpec) Index { + return Index{ + Name: spec.Name, + Key: simpleIndexKey(spec.Key), + Unique: spec.Unique, + DropDups: spec.DropDups, + Background: spec.Background, + Sparse: spec.Sparse, + ExpireAfter: time.Duration(spec.ExpireAfter) * time.Second, + } +} + +type indexSlice []Index + +func (idxs indexSlice) Len() int { return len(idxs) } +func (idxs indexSlice) Less(i, j int) bool { return idxs[i].Name < idxs[j].Name } +func (idxs indexSlice) Swap(i, j int) { idxs[i], idxs[j] = idxs[j], idxs[i] } + +func simpleIndexKey(realKey bson.D) (key []string) { + for i := range realKey { + field := realKey[i].Name + vi, ok := realKey[i].Value.(int) + if !ok { + vf, _ := realKey[i].Value.(float64) + vi = int(vf) + } + if vi == 1 { + key = append(key, field) + continue + } + if vi == -1 { + key = append(key, "-"+field) + continue + } + if vs, ok := realKey[i].Value.(string); ok { + key = append(key, "$"+vs+":"+field) + continue + } + panic("Got unknown index key type for field " + field) + } + return +} + +// ResetIndexCache() clears the cache of previously ensured indexes. +// Following requests to EnsureIndex will contact the server. +func (s *Session) ResetIndexCache() { + s.cluster().ResetIndexCache() +} + +// New creates a new session with the same parameters as the original +// session, including consistency, batch size, prefetching, safety mode, +// etc. The returned session will use sockets from the pool, so there's +// a chance that writes just performed in another session may not yet +// be visible. +// +// Login information from the original session will not be copied over +// into the new session unless it was provided through the initial URL +// for the Dial function. +// +// See the Copy and Clone methods. +// +func (s *Session) New() *Session { + s.m.Lock() + scopy := copySession(s, false) + s.m.Unlock() + scopy.Refresh() + return scopy +} + +// Copy works just like New, but preserves the exact authentication +// information from the original session. +func (s *Session) Copy() *Session { + s.m.Lock() + scopy := copySession(s, true) + s.m.Unlock() + scopy.Refresh() + return scopy +} + +// Clone works just like Copy, but also reuses the same socket as the original +// session, in case it had already reserved one due to its consistency +// guarantees. This behavior ensures that writes performed in the old session +// are necessarily observed when using the new session, as long as it was a +// strong or monotonic session. That said, it also means that long operations +// may cause other goroutines using the original session to wait. +func (s *Session) Clone() *Session { + s.m.Lock() + scopy := copySession(s, true) + s.m.Unlock() + return scopy +} + +// Close terminates the session. It's a runtime error to use a session +// after it has been closed. +func (s *Session) Close() { + s.m.Lock() + if s.cluster_ != nil { + debugf("Closing session %p", s) + s.unsetSocket() + s.cluster_.Release() + s.cluster_ = nil + } + s.m.Unlock() +} + +func (s *Session) cluster() *mongoCluster { + if s.cluster_ == nil { + panic("Session already closed") + } + return s.cluster_ +} + +// Refresh puts back any reserved sockets in use and restarts the consistency +// guarantees according to the current consistency setting for the session. +func (s *Session) Refresh() { + s.m.Lock() + s.slaveOk = s.consistency != Strong + s.unsetSocket() + s.m.Unlock() +} + +// SetMode changes the consistency mode for the session. +// +// In the Strong consistency mode reads and writes will always be made to +// the primary server using a unique connection so that reads and writes are +// fully consistent, ordered, and observing the most up-to-date data. +// This offers the least benefits in terms of distributing load, but the +// most guarantees. See also Monotonic and Eventual. +// +// In the Monotonic consistency mode reads may not be entirely up-to-date, +// but they will always see the history of changes moving forward, the data +// read will be consistent across sequential queries in the same session, +// and modifications made within the session will be observed in following +// queries (read-your-writes). +// +// In practice, the Monotonic mode is obtained by performing initial reads +// on a unique connection to an arbitrary secondary, if one is available, +// and once the first write happens, the session connection is switched over +// to the primary server. This manages to distribute some of the reading +// load with secondaries, while maintaining some useful guarantees. +// +// In the Eventual consistency mode reads will be made to any secondary in the +// cluster, if one is available, and sequential reads will not necessarily +// be made with the same connection. This means that data may be observed +// out of order. Writes will of course be issued to the primary, but +// independent writes in the same Eventual session may also be made with +// independent connections, so there are also no guarantees in terms of +// write ordering (no read-your-writes guarantees either). +// +// The Eventual mode is the fastest and most resource-friendly, but is +// also the one offering the least guarantees about ordering of the data +// read and written. +// +// If refresh is true, in addition to ensuring the session is in the given +// consistency mode, the consistency guarantees will also be reset (e.g. +// a Monotonic session will be allowed to read from secondaries again). +// This is equivalent to calling the Refresh function. +// +// Shifting between Monotonic and Strong modes will keep a previously +// reserved connection for the session unless refresh is true or the +// connection is unsuitable (to a secondary server in a Strong session). +func (s *Session) SetMode(consistency mode, refresh bool) { + s.m.Lock() + debugf("Session %p: setting mode %d with refresh=%v (master=%p, slave=%p)", s, consistency, refresh, s.masterSocket, s.slaveSocket) + s.consistency = consistency + if refresh { + s.slaveOk = s.consistency != Strong + s.unsetSocket() + } else if s.consistency == Strong { + s.slaveOk = false + } else if s.masterSocket == nil { + s.slaveOk = true + } + s.m.Unlock() +} + +// Mode returns the current consistency mode for the session. +func (s *Session) Mode() mode { + s.m.RLock() + mode := s.consistency + s.m.RUnlock() + return mode +} + +// SetSyncTimeout sets the amount of time an operation with this session +// will wait before returning an error in case a connection to a usable +// server can't be established. Set it to zero to wait forever. The +// default value is 7 seconds. +func (s *Session) SetSyncTimeout(d time.Duration) { + s.m.Lock() + s.syncTimeout = d + s.m.Unlock() +} + +// SetSocketTimeout sets the amount of time to wait for a non-responding +// socket to the database before it is forcefully closed. +func (s *Session) SetSocketTimeout(d time.Duration) { + s.m.Lock() + s.sockTimeout = d + if s.masterSocket != nil { + s.masterSocket.SetTimeout(d) + } + if s.slaveSocket != nil { + s.slaveSocket.SetTimeout(d) + } + s.m.Unlock() +} + +// SetCursorTimeout changes the standard timeout period that the server +// enforces on created cursors. The only supported value right now is +// 0, which disables the timeout. The standard server timeout is 10 minutes. +func (s *Session) SetCursorTimeout(d time.Duration) { + s.m.Lock() + if d == 0 { + s.queryConfig.op.flags |= flagNoCursorTimeout + } else { + panic("SetCursorTimeout: only 0 (disable timeout) supported for now") + } + s.m.Unlock() +} + +// SetPoolLimit sets the maximum number of sockets in use in a single server +// before this session will block waiting for a socket to be available. +// The default limit is 4096. +// +// This limit must be set to cover more than any expected workload of the +// application. It is a bad practice and an unsupported use case to use the +// database driver to define the concurrency limit of an application. Prevent +// such concurrency "at the door" instead, by properly restricting the amount +// of used resources and number of goroutines before they are created. +func (s *Session) SetPoolLimit(limit int) { + s.m.Lock() + s.poolLimit = limit + s.m.Unlock() +} + +// SetBatch sets the default batch size used when fetching documents from the +// database. It's possible to change this setting on a per-query basis as +// well, using the Query.Batch method. +// +// The default batch size is defined by the database itself. As of this +// writing, MongoDB will use an initial size of min(100 docs, 4MB) on the +// first batch, and 4MB on remaining ones. +func (s *Session) SetBatch(n int) { + if n == 1 { + // Server interprets 1 as -1 and closes the cursor (!?) + n = 2 + } + s.m.Lock() + s.queryConfig.op.limit = int32(n) + s.m.Unlock() +} + +// SetPrefetch sets the default point at which the next batch of results will be +// requested. When there are p*batch_size remaining documents cached in an +// Iter, the next batch will be requested in background. For instance, when +// using this: +// +// session.SetBatch(200) +// session.SetPrefetch(0.25) +// +// and there are only 50 documents cached in the Iter to be processed, the +// next batch of 200 will be requested. It's possible to change this setting on +// a per-query basis as well, using the Prefetch method of Query. +// +// The default prefetch value is 0.25. +func (s *Session) SetPrefetch(p float64) { + s.m.Lock() + s.queryConfig.prefetch = p + s.m.Unlock() +} + +// See SetSafe for details on the Safe type. +type Safe struct { + W int // Min # of servers to ack before success + WMode string // Write mode for MongoDB 2.0+ (e.g. "majority") + WTimeout int // Milliseconds to wait for W before timing out + FSync bool // Should servers sync to disk before returning success + J bool // Wait for next group commit if journaling; no effect otherwise +} + +// Safe returns the current safety mode for the session. +func (s *Session) Safe() (safe *Safe) { + s.m.Lock() + defer s.m.Unlock() + if s.safeOp != nil { + cmd := s.safeOp.query.(*getLastError) + safe = &Safe{WTimeout: cmd.WTimeout, FSync: cmd.FSync, J: cmd.J} + switch w := cmd.W.(type) { + case string: + safe.WMode = w + case int: + safe.W = w + } + } + return +} + +// SetSafe changes the session safety mode. +// +// If the safe parameter is nil, the session is put in unsafe mode, and writes +// become fire-and-forget, without error checking. The unsafe mode is faster +// since operations won't hold on waiting for a confirmation. +// +// If the safe parameter is not nil, any changing query (insert, update, ...) +// will be followed by a getLastError command with the specified parameters, +// to ensure the request was correctly processed. +// +// The safe.W parameter determines how many servers should confirm a write +// before the operation is considered successful. If set to 0 or 1, the +// command will return as soon as the primary is done with the request. +// If safe.WTimeout is greater than zero, it determines how many milliseconds +// to wait for the safe.W servers to respond before returning an error. +// +// Starting with MongoDB 2.0.0 the safe.WMode parameter can be used instead +// of W to request for richer semantics. If set to "majority" the server will +// wait for a majority of members from the replica set to respond before +// returning. Custom modes may also be defined within the server to create +// very detailed placement schemas. See the data awareness documentation in +// the links below for more details (note that MongoDB internally reuses the +// "w" field name for WMode). +// +// If safe.FSync is true and journaling is disabled, the servers will be +// forced to sync all files to disk immediately before returning. If the +// same option is true but journaling is enabled, the server will instead +// await for the next group commit before returning. +// +// Since MongoDB 2.0.0, the safe.J option can also be used instead of FSync +// to force the server to wait for a group commit in case journaling is +// enabled. The option has no effect if the server has journaling disabled. +// +// For example, the following statement will make the session check for +// errors, without imposing further constraints: +// +// session.SetSafe(&mgo.Safe{}) +// +// The following statement will force the server to wait for a majority of +// members of a replica set to return (MongoDB 2.0+ only): +// +// session.SetSafe(&mgo.Safe{WMode: "majority"}) +// +// The following statement, on the other hand, ensures that at least two +// servers have flushed the change to disk before confirming the success +// of operations: +// +// session.EnsureSafe(&mgo.Safe{W: 2, FSync: true}) +// +// The following statement, on the other hand, disables the verification +// of errors entirely: +// +// session.SetSafe(nil) +// +// See also the EnsureSafe method. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/getLastError+Command +// http://www.mongodb.org/display/DOCS/Verifying+Propagation+of+Writes+with+getLastError +// http://www.mongodb.org/display/DOCS/Data+Center+Awareness +// +func (s *Session) SetSafe(safe *Safe) { + s.m.Lock() + s.safeOp = nil + s.ensureSafe(safe) + s.m.Unlock() +} + +// EnsureSafe compares the provided safety parameters with the ones +// currently in use by the session and picks the most conservative +// choice for each setting. +// +// That is: +// +// - safe.WMode is always used if set. +// - safe.W is used if larger than the current W and WMode is empty. +// - safe.FSync is always used if true. +// - safe.J is used if FSync is false. +// - safe.WTimeout is used if set and smaller than the current WTimeout. +// +// For example, the following statement will ensure the session is +// at least checking for errors, without enforcing further constraints. +// If a more conservative SetSafe or EnsureSafe call was previously done, +// the following call will be ignored. +// +// session.EnsureSafe(&mgo.Safe{}) +// +// See also the SetSafe method for details on what each option means. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/getLastError+Command +// http://www.mongodb.org/display/DOCS/Verifying+Propagation+of+Writes+with+getLastError +// http://www.mongodb.org/display/DOCS/Data+Center+Awareness +// +func (s *Session) EnsureSafe(safe *Safe) { + s.m.Lock() + s.ensureSafe(safe) + s.m.Unlock() +} + +func (s *Session) ensureSafe(safe *Safe) { + if safe == nil { + return + } + + var w interface{} + if safe.WMode != "" { + w = safe.WMode + } else if safe.W > 0 { + w = safe.W + } + + var cmd getLastError + if s.safeOp == nil { + cmd = getLastError{1, w, safe.WTimeout, safe.FSync, safe.J} + } else { + // Copy. We don't want to mutate the existing query. + cmd = *(s.safeOp.query.(*getLastError)) + if cmd.W == nil { + cmd.W = w + } else if safe.WMode != "" { + cmd.W = safe.WMode + } else if i, ok := cmd.W.(int); ok && safe.W > i { + cmd.W = safe.W + } + if safe.WTimeout > 0 && safe.WTimeout < cmd.WTimeout { + cmd.WTimeout = safe.WTimeout + } + if safe.FSync { + cmd.FSync = true + cmd.J = false + } else if safe.J && !cmd.FSync { + cmd.J = true + } + } + s.safeOp = &queryOp{ + query: &cmd, + collection: "admin.$cmd", + limit: -1, + } +} + +// Run issues the provided command on the "admin" database and +// and unmarshals its result in the respective argument. The cmd +// argument may be either a string with the command name itself, in +// which case an empty document of the form bson.M{cmd: 1} will be used, +// or it may be a full command document. +// +// Note that MongoDB considers the first marshalled key as the command +// name, so when providing a command with options, it's important to +// use an ordering-preserving document, such as a struct value or an +// instance of bson.D. For instance: +// +// db.Run(bson.D{{"create", "mycollection"}, {"size", 1024}}) +// +// For commands on arbitrary databases, see the Run method in +// the Database type. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Commands +// http://www.mongodb.org/display/DOCS/List+of+Database+CommandSkips +// +func (s *Session) Run(cmd interface{}, result interface{}) error { + return s.DB("admin").Run(cmd, result) +} + +// SelectServers restricts communication to servers configured with the +// given tags. For example, the following statement restricts servers +// used for reading operations to those with both tag "disk" set to +// "ssd" and tag "rack" set to 1: +// +// session.SelectSlaves(bson.D{{"disk", "ssd"}, {"rack", 1}}) +// +// Multiple sets of tags may be provided, in which case the used server +// must match all tags within any one set. +// +// If a connection was previously assigned to the session due to the +// current session mode (see Session.SetMode), the tag selection will +// only be enforced after the session is refreshed. +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/tutorial/configure-replica-set-tag-sets +// +func (s *Session) SelectServers(tags ...bson.D) { + s.m.Lock() + s.queryConfig.op.serverTags = tags + s.m.Unlock() +} + +// Ping runs a trivial ping command just to get in touch with the server. +func (s *Session) Ping() error { + return s.Run("ping", nil) +} + +// Fsync flushes in-memory writes to disk on the server the session +// is established with. If async is true, the call returns immediately, +// otherwise it returns after the flush has been made. +func (s *Session) Fsync(async bool) error { + return s.Run(bson.D{{"fsync", 1}, {"async", async}}, nil) +} + +// FsyncLock locks all writes in the specific server the session is +// established with and returns. Any writes attempted to the server +// after it is successfully locked will block until FsyncUnlock is +// called for the same server. +// +// This method works on secondaries as well, preventing the oplog from +// being flushed while the server is locked, but since only the server +// connected to is locked, for locking specific secondaries it may be +// necessary to establish a connection directly to the secondary (see +// Dial's connect=direct option). +// +// As an important caveat, note that once a write is attempted and +// blocks, follow up reads will block as well due to the way the +// lock is internally implemented in the server. More details at: +// +// https://jira.mongodb.org/browse/SERVER-4243 +// +// FsyncLock is often used for performing consistent backups of +// the database files on disk. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/fsync+Command +// http://www.mongodb.org/display/DOCS/Backups +// +func (s *Session) FsyncLock() error { + return s.Run(bson.D{{"fsync", 1}, {"lock", true}}, nil) +} + +// FsyncUnlock releases the server for writes. See FsyncLock for details. +func (s *Session) FsyncUnlock() error { + return s.DB("admin").C("$cmd.sys.unlock").Find(nil).One(nil) // WTF? +} + +// Find prepares a query using the provided document. The document may be a +// map or a struct value capable of being marshalled with bson. The map +// may be a generic one using interface{} for its key and/or values, such as +// bson.M, or it may be a properly typed map. Providing nil as the document +// is equivalent to providing an empty document such as bson.M{}. +// +// Further details of the query may be tweaked using the resulting Query value, +// and then executed to retrieve results using methods such as One, For, +// Iter, or Tail. +// +// In case the resulting document includes a field named $err or errmsg, which +// are standard ways for MongoDB to return query errors, the returned err will +// be set to a *QueryError value including the Err message and the Code. In +// those cases, the result argument is still unmarshalled into with the +// received document so that any other custom values may be obtained if +// desired. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Querying +// http://www.mongodb.org/display/DOCS/Advanced+Queries +// +func (c *Collection) Find(query interface{}) *Query { + session := c.Database.Session + session.m.RLock() + q := &Query{session: session, query: session.queryConfig} + session.m.RUnlock() + q.op.query = query + q.op.collection = c.FullName + return q +} + +type repairCmd struct { + RepairCursor string `bson:"repairCursor"` + Cursor *repairCmdCursor ",omitempty" +} + +type repairCmdCursor struct { + BatchSize int `bson:"batchSize,omitempty"` +} + +// Repair returns an iterator that goes over all recovered documents in the +// collection, in a best-effort manner. This is most useful when there are +// damaged data files. Multiple copies of the same document may be returned +// by the iterator. +// +// Repair is supported in MongoDB 2.7.8 and later. +func (c *Collection) Repair() *Iter { + // Clone session and set it to strong mode so that the server + // used for the query may be safely obtained afterwards, if + // necessary for iteration when a cursor is received. + session := c.Database.Session + session.m.Lock() + batchSize := int(session.queryConfig.op.limit) + session.m.Unlock() + cloned := session.Clone() + cloned.SetMode(Strong, false) + defer cloned.Close() + c = c.With(cloned) + + iter := &Iter{ + session: session, + timeout: -1, + } + iter.gotReply.L = &iter.m + + var result struct { + Cursor struct { + FirstBatch []bson.Raw "firstBatch" + Id int64 + } + } + + cmd := repairCmd{ + RepairCursor: c.Name, + Cursor: &repairCmdCursor{batchSize}, + } + iter.err = c.Database.Run(cmd, &result) + if iter.err != nil { + return iter + } + docs := result.Cursor.FirstBatch + for i := range docs { + iter.docData.Push(docs[i].Data) + } + if result.Cursor.Id != 0 { + socket, err := cloned.acquireSocket(true) + if err != nil { + // Cloned session is in strong mode, and the query + // above succeeded. Should have a reserved socket. + panic("internal error: " + err.Error()) + } + iter.server = socket.Server() + socket.Release() + iter.op.cursorId = result.Cursor.Id + iter.op.collection = c.FullName + iter.op.replyFunc = iter.replyFunc() + } + return iter +} + +// FindId is a convenience helper equivalent to: +// +// query := collection.Find(bson.M{"_id": id}) +// +// See the Find method for more details. +func (c *Collection) FindId(id interface{}) *Query { + return c.Find(bson.D{{"_id", id}}) +} + +type Pipe struct { + session *Session + collection *Collection + pipeline interface{} + allowDisk bool + batchSize int +} + +type pipeCmd struct { + Aggregate string + Pipeline interface{} + Cursor *pipeCmdCursor ",omitempty" + Explain bool ",omitempty" + AllowDisk bool "allowDiskUse,omitempty" +} + +type pipeCmdCursor struct { + BatchSize int `bson:"batchSize,omitempty"` +} + +// Pipe prepares a pipeline to aggregate. The pipeline document +// must be a slice built in terms of the aggregation framework language. +// +// For example: +// +// pipe := collection.Pipe([]bson.M{{"$match": bson.M{"name": "Otavio"}}}) +// iter := pipe.Iter() +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/aggregation +// http://docs.mongodb.org/manual/applications/aggregation +// http://docs.mongodb.org/manual/tutorial/aggregation-examples +// +func (c *Collection) Pipe(pipeline interface{}) *Pipe { + session := c.Database.Session + session.m.Lock() + batchSize := int(session.queryConfig.op.limit) + session.m.Unlock() + return &Pipe{ + session: session, + collection: c, + pipeline: pipeline, + batchSize: batchSize, + } +} + +// Iter executes the pipeline and returns an iterator capable of going +// over all the generated results. +func (p *Pipe) Iter() *Iter { + + // Clone session and set it to strong mode so that the server + // used for the query may be safely obtained afterwards, if + // necessary for iteration when a cursor is received. + cloned := p.session.Clone() + cloned.SetMode(Strong, false) + defer cloned.Close() + c := p.collection.With(cloned) + + iter := &Iter{ + session: p.session, + timeout: -1, + } + iter.gotReply.L = &iter.m + + var result struct { + // 2.4, no cursors. + Result []bson.Raw + + // 2.6+, with cursors. + Cursor struct { + FirstBatch []bson.Raw "firstBatch" + Id int64 + } + } + + cmd := pipeCmd{ + Aggregate: c.Name, + Pipeline: p.pipeline, + AllowDisk: p.allowDisk, + Cursor: &pipeCmdCursor{p.batchSize}, + } + iter.err = c.Database.Run(cmd, &result) + if e, ok := iter.err.(*QueryError); ok && e.Message == `unrecognized field "cursor` { + cmd.Cursor = nil + cmd.AllowDisk = false + iter.err = c.Database.Run(cmd, &result) + } + if iter.err != nil { + return iter + } + docs := result.Result + if docs == nil { + docs = result.Cursor.FirstBatch + } + for i := range docs { + iter.docData.Push(docs[i].Data) + } + if result.Cursor.Id != 0 { + socket, err := cloned.acquireSocket(true) + if err != nil { + // Cloned session is in strong mode, and the query + // above succeeded. Should have a reserved socket. + panic("internal error: " + err.Error()) + } + iter.server = socket.Server() + socket.Release() + iter.op.cursorId = result.Cursor.Id + iter.op.collection = c.FullName + iter.op.replyFunc = iter.replyFunc() + } + return iter +} + +// All works like Iter.All. +func (p *Pipe) All(result interface{}) error { + return p.Iter().All(result) +} + +// One executes the pipeline and unmarshals the first item from the +// result set into the result parameter. +// It returns ErrNotFound if no items are generated by the pipeline. +func (p *Pipe) One(result interface{}) error { + iter := p.Iter() + if iter.Next(result) { + return nil + } + if err := iter.Err(); err != nil { + return err + } + return ErrNotFound +} + +// Explain returns a number of details about how the MongoDB server would +// execute the requested pipeline, such as the number of objects examined, +// the number of times the read lock was yielded to allow writes to go in, +// and so on. +// +// For example: +// +// var m bson.M +// err := collection.Pipe(pipeline).Explain(&m) +// if err == nil { +// fmt.Printf("Explain: %#v\n", m) +// } +// +func (p *Pipe) Explain(result interface{}) error { + c := p.collection + cmd := pipeCmd{ + Aggregate: c.Name, + Pipeline: p.pipeline, + AllowDisk: p.allowDisk, + Explain: true, + } + return c.Database.Run(cmd, result) +} + +// AllowDiskUse enables writing to the "/_tmp" server directory so +// that aggregation pipelines do not have to be held entirely in memory. +func (p *Pipe) AllowDiskUse() *Pipe { + p.allowDisk = true + return p +} + +// Batch sets the batch size used when fetching documents from the database. +// It's possible to change this setting on a per-session basis as well, using +// the Batch method of Session. +// +// The default batch size is defined by the database server. +func (p *Pipe) Batch(n int) *Pipe { + p.batchSize = n + return p +} + +type LastError struct { + Err string + Code, N, Waited int + FSyncFiles int `bson:"fsyncFiles"` + WTimeout bool + UpdatedExisting bool `bson:"updatedExisting"` + UpsertedId interface{} `bson:"upserted"` +} + +func (err *LastError) Error() string { + return err.Err +} + +type queryError struct { + Err string "$err" + ErrMsg string + Assertion string + Code int + AssertionCode int "assertionCode" + LastError *LastError "lastErrorObject" +} + +type QueryError struct { + Code int + Message string + Assertion bool +} + +func (err *QueryError) Error() string { + return err.Message +} + +// IsDup returns whether err informs of a duplicate key error because +// a primary key index or a secondary unique index already has an entry +// with the given value. +func IsDup(err error) bool { + // Besides being handy, helps with MongoDB bugs SERVER-7164 and SERVER-11493. + // What follows makes me sad. Hopefully conventions will be more clear over time. + switch e := err.(type) { + case *LastError: + return e.Code == 11000 || e.Code == 11001 || e.Code == 12582 || e.Code == 16460 && strings.Contains(e.Err, " E11000 ") + case *QueryError: + return e.Code == 11000 || e.Code == 11001 || e.Code == 12582 + } + return false +} + +// Insert inserts one or more documents in the respective collection. In +// case the session is in safe mode (see the SetSafe method) and an error +// happens while inserting the provided documents, the returned error will +// be of type *LastError. +func (c *Collection) Insert(docs ...interface{}) error { + _, err := c.writeQuery(&insertOp{c.FullName, docs, 0}) + return err +} + +// Update finds a single document matching the provided selector document +// and modifies it according to the update document. +// If the session is in safe mode (see SetSafe) a ErrNotFound error is +// returned if a document isn't found, or a value of type *LastError +// when some other error is detected. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Updating +// http://www.mongodb.org/display/DOCS/Atomic+Operations +// +func (c *Collection) Update(selector interface{}, update interface{}) error { + lerr, err := c.writeQuery(&updateOp{c.FullName, selector, update, 0}) + if err == nil && lerr != nil && !lerr.UpdatedExisting { + return ErrNotFound + } + return err +} + +// UpdateId is a convenience helper equivalent to: +// +// err := collection.Update(bson.M{"_id": id}, update) +// +// See the Update method for more details. +func (c *Collection) UpdateId(id interface{}, update interface{}) error { + return c.Update(bson.D{{"_id", id}}, update) +} + +// ChangeInfo holds details about the outcome of an update operation. +type ChangeInfo struct { + Updated int // Number of existing documents updated + Removed int // Number of documents removed + UpsertedId interface{} // Upserted _id field, when not explicitly provided +} + +// UpdateAll finds all documents matching the provided selector document +// and modifies them according to the update document. +// If the session is in safe mode (see SetSafe) details of the executed +// operation are returned in info or an error of type *LastError when +// some problem is detected. It is not an error for the update to not be +// applied on any documents because the selector doesn't match. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Updating +// http://www.mongodb.org/display/DOCS/Atomic+Operations +// +func (c *Collection) UpdateAll(selector interface{}, update interface{}) (info *ChangeInfo, err error) { + lerr, err := c.writeQuery(&updateOp{c.FullName, selector, update, 2}) + if err == nil && lerr != nil { + info = &ChangeInfo{Updated: lerr.N} + } + return info, err +} + +// Upsert finds a single document matching the provided selector document +// and modifies it according to the update document. If no document matching +// the selector is found, the update document is applied to the selector +// document and the result is inserted in the collection. +// If the session is in safe mode (see SetSafe) details of the executed +// operation are returned in info, or an error of type *LastError when +// some problem is detected. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Updating +// http://www.mongodb.org/display/DOCS/Atomic+Operations +// +func (c *Collection) Upsert(selector interface{}, update interface{}) (info *ChangeInfo, err error) { + lerr, err := c.writeQuery(&updateOp{c.FullName, selector, update, 1}) + if err == nil && lerr != nil { + info = &ChangeInfo{} + if lerr.UpdatedExisting { + info.Updated = lerr.N + } else { + info.UpsertedId = lerr.UpsertedId + } + } + return info, err +} + +// UpsertId is a convenience helper equivalent to: +// +// info, err := collection.Upsert(bson.M{"_id": id}, update) +// +// See the Upsert method for more details. +func (c *Collection) UpsertId(id interface{}, update interface{}) (info *ChangeInfo, err error) { + return c.Upsert(bson.D{{"_id", id}}, update) +} + +// Remove finds a single document matching the provided selector document +// and removes it from the database. +// If the session is in safe mode (see SetSafe) a ErrNotFound error is +// returned if a document isn't found, or a value of type *LastError +// when some other error is detected. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Removing +// +func (c *Collection) Remove(selector interface{}) error { + lerr, err := c.writeQuery(&deleteOp{c.FullName, selector, 1}) + if err == nil && lerr != nil && lerr.N == 0 { + return ErrNotFound + } + return err +} + +// RemoveId is a convenience helper equivalent to: +// +// err := collection.Remove(bson.M{"_id": id}) +// +// See the Remove method for more details. +func (c *Collection) RemoveId(id interface{}) error { + return c.Remove(bson.D{{"_id", id}}) +} + +// RemoveAll finds all documents matching the provided selector document +// and removes them from the database. In case the session is in safe mode +// (see the SetSafe method) and an error happens when attempting the change, +// the returned error will be of type *LastError. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Removing +// +func (c *Collection) RemoveAll(selector interface{}) (info *ChangeInfo, err error) { + lerr, err := c.writeQuery(&deleteOp{c.FullName, selector, 0}) + if err == nil && lerr != nil { + info = &ChangeInfo{Removed: lerr.N} + } + return info, err +} + +// DropDatabase removes the entire database including all of its collections. +func (db *Database) DropDatabase() error { + return db.Run(bson.D{{"dropDatabase", 1}}, nil) +} + +// DropCollection removes the entire collection including all of its documents. +func (c *Collection) DropCollection() error { + return c.Database.Run(bson.D{{"drop", c.Name}}, nil) +} + +// The CollectionInfo type holds metadata about a collection. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/createCollection+Command +// http://www.mongodb.org/display/DOCS/Capped+Collections +// +type CollectionInfo struct { + // DisableIdIndex prevents the automatic creation of the index + // on the _id field for the collection. + DisableIdIndex bool + + // ForceIdIndex enforces the automatic creation of the index + // on the _id field for the collection. Capped collections, + // for example, do not have such an index by default. + ForceIdIndex bool + + // If Capped is true new documents will replace old ones when + // the collection is full. MaxBytes must necessarily be set + // to define the size when the collection wraps around. + // MaxDocs optionally defines the number of documents when it + // wraps, but MaxBytes still needs to be set. + Capped bool + MaxBytes int + MaxDocs int +} + +// Create explicitly creates the c collection with details of info. +// MongoDB creates collections automatically on use, so this method +// is only necessary when creating collection with non-default +// characteristics, such as capped collections. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/createCollection+Command +// http://www.mongodb.org/display/DOCS/Capped+Collections +// +func (c *Collection) Create(info *CollectionInfo) error { + cmd := make(bson.D, 0, 4) + cmd = append(cmd, bson.DocElem{"create", c.Name}) + if info.Capped { + if info.MaxBytes < 1 { + return fmt.Errorf("Collection.Create: with Capped, MaxBytes must also be set") + } + cmd = append(cmd, bson.DocElem{"capped", true}) + cmd = append(cmd, bson.DocElem{"size", info.MaxBytes}) + if info.MaxDocs > 0 { + cmd = append(cmd, bson.DocElem{"max", info.MaxDocs}) + } + } + if info.DisableIdIndex { + cmd = append(cmd, bson.DocElem{"autoIndexId", false}) + } + if info.ForceIdIndex { + cmd = append(cmd, bson.DocElem{"autoIndexId", true}) + } + return c.Database.Run(cmd, nil) +} + +// Batch sets the batch size used when fetching documents from the database. +// It's possible to change this setting on a per-session basis as well, using +// the Batch method of Session. +// +// The default batch size is defined by the database itself. As of this +// writing, MongoDB will use an initial size of min(100 docs, 4MB) on the +// first batch, and 4MB on remaining ones. +func (q *Query) Batch(n int) *Query { + if n == 1 { + // Server interprets 1 as -1 and closes the cursor (!?) + n = 2 + } + q.m.Lock() + q.op.limit = int32(n) + q.m.Unlock() + return q +} + +// Prefetch sets the point at which the next batch of results will be requested. +// When there are p*batch_size remaining documents cached in an Iter, the next +// batch will be requested in background. For instance, when using this: +// +// query.Batch(200).Prefetch(0.25) +// +// and there are only 50 documents cached in the Iter to be processed, the +// next batch of 200 will be requested. It's possible to change this setting on +// a per-session basis as well, using the SetPrefetch method of Session. +// +// The default prefetch value is 0.25. +func (q *Query) Prefetch(p float64) *Query { + q.m.Lock() + q.prefetch = p + q.m.Unlock() + return q +} + +// Skip skips over the n initial documents from the query results. Note that +// this only makes sense with capped collections where documents are naturally +// ordered by insertion time, or with sorted results. +func (q *Query) Skip(n int) *Query { + q.m.Lock() + q.op.skip = int32(n) + q.m.Unlock() + return q +} + +// Limit restricts the maximum number of documents retrieved to n, and also +// changes the batch size to the same value. Once n documents have been +// returned by Next, the following call will return ErrNotFound. +func (q *Query) Limit(n int) *Query { + q.m.Lock() + switch { + case n == 1: + q.limit = 1 + q.op.limit = -1 + case n == math.MinInt32: // -MinInt32 == -MinInt32 + q.limit = math.MaxInt32 + q.op.limit = math.MinInt32 + 1 + case n < 0: + q.limit = int32(-n) + q.op.limit = int32(n) + default: + q.limit = int32(n) + q.op.limit = int32(n) + } + q.m.Unlock() + return q +} + +// Select enables selecting which fields should be retrieved for the results +// found. For example, the following query would only retrieve the name field: +// +// err := collection.Find(nil).Select(bson.M{"name": 1}).One(&result) +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Retrieving+a+Subset+of+Fields +// +func (q *Query) Select(selector interface{}) *Query { + q.m.Lock() + q.op.selector = selector + q.m.Unlock() + return q +} + +// Sort asks the database to order returned documents according to the +// provided field names. A field name may be prefixed by - (minus) for +// it to be sorted in reverse order. +// +// For example: +// +// query1 := collection.Find(nil).Sort("firstname", "lastname") +// query2 := collection.Find(nil).Sort("-age") +// query3 := collection.Find(nil).Sort("$natural") +// query4 := collection.Find(nil).Select(bson.M{"score": bson.M{"$meta": "textScore"}}).Sort("$textScore:score") +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Sorting+and+Natural+Order +// +func (q *Query) Sort(fields ...string) *Query { + q.m.Lock() + var order bson.D + for _, field := range fields { + n := 1 + var kind string + if field != "" { + if field[0] == '$' { + if c := strings.Index(field, ":"); c > 1 && c < len(field)-1 { + kind = field[1:c] + field = field[c+1:] + } + } + switch field[0] { + case '+': + field = field[1:] + case '-': + n = -1 + field = field[1:] + } + } + if field == "" { + panic("Sort: empty field name") + } + if kind == "textScore" { + order = append(order, bson.DocElem{field, bson.M{"$meta": kind}}) + } else { + order = append(order, bson.DocElem{field, n}) + } + } + q.op.options.OrderBy = order + q.op.hasOptions = true + q.m.Unlock() + return q +} + +// Explain returns a number of details about how the MongoDB server would +// execute the requested query, such as the number of objects examined, +// the number of times the read lock was yielded to allow writes to go in, +// and so on. +// +// For example: +// +// m := bson.M{} +// err := collection.Find(bson.M{"filename": name}).Explain(m) +// if err == nil { +// fmt.Printf("Explain: %#v\n", m) +// } +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Optimization +// http://www.mongodb.org/display/DOCS/Query+Optimizer +// +func (q *Query) Explain(result interface{}) error { + q.m.Lock() + clone := &Query{session: q.session, query: q.query} + q.m.Unlock() + clone.op.options.Explain = true + clone.op.hasOptions = true + if clone.op.limit > 0 { + clone.op.limit = -q.op.limit + } + iter := clone.Iter() + if iter.Next(result) { + return nil + } + return iter.Close() +} + +// Hint will include an explicit "hint" in the query to force the server +// to use a specified index, potentially improving performance in some +// situations. The provided parameters are the fields that compose the +// key of the index to be used. For details on how the indexKey may be +// built, see the EnsureIndex method. +// +// For example: +// +// query := collection.Find(bson.M{"firstname": "Joe", "lastname": "Winter"}) +// query.Hint("lastname", "firstname") +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Optimization +// http://www.mongodb.org/display/DOCS/Query+Optimizer +// +func (q *Query) Hint(indexKey ...string) *Query { + q.m.Lock() + keyInfo, err := parseIndexKey(indexKey) + q.op.options.Hint = keyInfo.key + q.op.hasOptions = true + q.m.Unlock() + if err != nil { + panic(err) + } + return q +} + +// SetMaxScan constrains the query to stop after scanning the specified +// number of documents. +// +// This modifier is generally used to prevent potentially long running +// queries from disrupting performance by scanning through too much data. +func (q *Query) SetMaxScan(n int) *Query { + q.m.Lock() + q.op.options.MaxScan = n + q.op.hasOptions = true + q.m.Unlock() + return q +} + +// Snapshot will force the performed query to make use of an available +// index on the _id field to prevent the same document from being returned +// more than once in a single iteration. This might happen without this +// setting in situations when the document changes in size and thus has to +// be moved while the iteration is running. +// +// Because snapshot mode traverses the _id index, it may not be used with +// sorting or explicit hints. It also cannot use any other index for the +// query. +// +// Even with snapshot mode, items inserted or deleted during the query may +// or may not be returned; that is, this mode is not a true point-in-time +// snapshot. +// +// The same effect of Snapshot may be obtained by using any unique index on +// field(s) that will not be modified (best to use Hint explicitly too). +// A non-unique index (such as creation time) may be made unique by +// appending _id to the index when creating it. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/How+to+do+Snapshotted+Queries+in+the+Mongo+Database +// +func (q *Query) Snapshot() *Query { + q.m.Lock() + q.op.options.Snapshot = true + q.op.hasOptions = true + q.m.Unlock() + return q +} + +// LogReplay enables an option that optimizes queries that are typically +// made on the MongoDB oplog for replaying it. This is an internal +// implementation aspect and most likely uninteresting for other uses. +// It has seen at least one use case, though, so it's exposed via the API. +func (q *Query) LogReplay() *Query { + q.m.Lock() + q.op.flags |= flagLogReplay + q.m.Unlock() + return q +} + +func checkQueryError(fullname string, d []byte) error { + l := len(d) + if l < 16 { + return nil + } + if d[5] == '$' && d[6] == 'e' && d[7] == 'r' && d[8] == 'r' && d[9] == '\x00' && d[4] == '\x02' { + goto Error + } + if len(fullname) < 5 || fullname[len(fullname)-5:] != ".$cmd" { + return nil + } + for i := 0; i+8 < l; i++ { + if d[i] == '\x02' && d[i+1] == 'e' && d[i+2] == 'r' && d[i+3] == 'r' && d[i+4] == 'm' && d[i+5] == 's' && d[i+6] == 'g' && d[i+7] == '\x00' { + goto Error + } + } + return nil + +Error: + result := &queryError{} + bson.Unmarshal(d, result) + if result.LastError != nil { + return result.LastError + } + if result.Err == "" && result.ErrMsg == "" { + return nil + } + if result.AssertionCode != 0 && result.Assertion != "" { + return &QueryError{Code: result.AssertionCode, Message: result.Assertion, Assertion: true} + } + if result.Err != "" { + return &QueryError{Code: result.Code, Message: result.Err} + } + return &QueryError{Code: result.Code, Message: result.ErrMsg} +} + +// One executes the query and unmarshals the first obtained document into the +// result argument. The result must be a struct or map value capable of being +// unmarshalled into by gobson. This function blocks until either a result +// is available or an error happens. For example: +// +// err := collection.Find(bson.M{"a", 1}).One(&result) +// +// In case the resulting document includes a field named $err or errmsg, which +// are standard ways for MongoDB to return query errors, the returned err will +// be set to a *QueryError value including the Err message and the Code. In +// those cases, the result argument is still unmarshalled into with the +// received document so that any other custom values may be obtained if +// desired. +// +func (q *Query) One(result interface{}) (err error) { + q.m.Lock() + session := q.session + op := q.op // Copy. + q.m.Unlock() + + socket, err := session.acquireSocket(true) + if err != nil { + return err + } + defer socket.Release() + + op.flags |= session.slaveOkFlag() + op.limit = -1 + + data, err := socket.SimpleQuery(&op) + if err != nil { + return err + } + if data == nil { + return ErrNotFound + } + if result != nil { + err = bson.Unmarshal(data, result) + if err == nil { + debugf("Query %p document unmarshaled: %#v", q, result) + } else { + debugf("Query %p document unmarshaling failed: %#v", q, err) + return err + } + } + return checkQueryError(op.collection, data) +} + +// The DBRef type implements support for the database reference MongoDB +// convention as supported by multiple drivers. This convention enables +// cross-referencing documents between collections and databases using +// a structure which includes a collection name, a document id, and +// optionally a database name. +// +// See the FindRef methods on Session and on Database. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Database+References +// +type DBRef struct { + Collection string `bson:"$ref"` + Id interface{} `bson:"$id"` + Database string `bson:"$db,omitempty"` +} + +// NOTE: Order of fields for DBRef above does matter, per documentation. + +// FindRef returns a query that looks for the document in the provided +// reference. If the reference includes the DB field, the document will +// be retrieved from the respective database. +// +// See also the DBRef type and the FindRef method on Session. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Database+References +// +func (db *Database) FindRef(ref *DBRef) *Query { + var c *Collection + if ref.Database == "" { + c = db.C(ref.Collection) + } else { + c = db.Session.DB(ref.Database).C(ref.Collection) + } + return c.FindId(ref.Id) +} + +// FindRef returns a query that looks for the document in the provided +// reference. For a DBRef to be resolved correctly at the session level +// it must necessarily have the optional DB field defined. +// +// See also the DBRef type and the FindRef method on Database. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Database+References +// +func (s *Session) FindRef(ref *DBRef) *Query { + if ref.Database == "" { + panic(errors.New(fmt.Sprintf("Can't resolve database for %#v", ref))) + } + c := s.DB(ref.Database).C(ref.Collection) + return c.FindId(ref.Id) +} + +// CollectionNames returns the collection names present in the db database. +func (db *Database) CollectionNames() (names []string, err error) { + // Try with a command. + var cmdResult struct { + Collections []struct { + Name string + } + } + err = db.Run(bson.D{{"listCollections", 1}}, &cmdResult) + if err == nil { + for _, coll := range cmdResult.Collections { + names = append(names, coll.Name) + } + sort.Strings(names) + return names, err + } + if err != nil && !isNoCmd(err) { + return nil, err + } + + // Command not yet supported. Query the database instead. + nameIndex := len(db.Name) + 1 + iter := db.C("system.namespaces").Find(nil).Iter() + var result *struct{ Name string } + for iter.Next(&result) { + if strings.Index(result.Name, "$") < 0 || strings.Index(result.Name, ".oplog.$") >= 0 { + names = append(names, result.Name[nameIndex:]) + } + } + if err := iter.Close(); err != nil { + return nil, err + } + sort.Strings(names) + return names, nil +} + +type dbNames struct { + Databases []struct { + Name string + Empty bool + } +} + +// DatabaseNames returns the names of non-empty databases present in the cluster. +func (s *Session) DatabaseNames() (names []string, err error) { + var result dbNames + err = s.Run("listDatabases", &result) + if err != nil { + return nil, err + } + for _, db := range result.Databases { + if !db.Empty { + names = append(names, db.Name) + } + } + sort.Strings(names) + return names, nil +} + +// Iter executes the query and returns an iterator capable of going over all +// the results. Results will be returned in batches of configurable +// size (see the Batch method) and more documents will be requested when a +// configurable number of documents is iterated over (see the Prefetch method). +func (q *Query) Iter() *Iter { + q.m.Lock() + session := q.session + op := q.op + prefetch := q.prefetch + limit := q.limit + q.m.Unlock() + + iter := &Iter{ + session: session, + prefetch: prefetch, + limit: limit, + timeout: -1, + } + iter.gotReply.L = &iter.m + iter.op.collection = op.collection + iter.op.limit = op.limit + iter.op.replyFunc = iter.replyFunc() + iter.docsToReceive++ + op.replyFunc = iter.op.replyFunc + op.flags |= session.slaveOkFlag() + + socket, err := session.acquireSocket(true) + if err != nil { + iter.err = err + } else { + iter.server = socket.Server() + err = socket.Query(&op) + if err != nil { + // Must lock as the query above may call replyFunc. + iter.m.Lock() + iter.err = err + iter.m.Unlock() + } + socket.Release() + } + return iter +} + +// Tail returns a tailable iterator. Unlike a normal iterator, a +// tailable iterator may wait for new values to be inserted in the +// collection once the end of the current result set is reached, +// A tailable iterator may only be used with capped collections. +// +// The timeout parameter indicates how long Next will block waiting +// for a result before timing out. If set to -1, Next will not +// timeout, and will continue waiting for a result for as long as +// the cursor is valid and the session is not closed. If set to 0, +// Next times out as soon as it reaches the end of the result set. +// Otherwise, Next will wait for at least the given number of +// seconds for a new document to be available before timing out. +// +// On timeouts, Next will unblock and return false, and the Timeout +// method will return true if called. In these cases, Next may still +// be called again on the same iterator to check if a new value is +// available at the current cursor position, and again it will block +// according to the specified timeoutSecs. If the cursor becomes +// invalid, though, both Next and Timeout will return false and +// the query must be restarted. +// +// The following example demonstrates timeout handling and query +// restarting: +// +// iter := collection.Find(nil).Sort("$natural").Tail(5 * time.Second) +// for { +// for iter.Next(&result) { +// fmt.Println(result.Id) +// lastId = result.Id +// } +// if iter.Err() != nil { +// return iter.Close() +// } +// if iter.Timeout() { +// continue +// } +// query := collection.Find(bson.M{"_id": bson.M{"$gt": lastId}}) +// iter = query.Sort("$natural").Tail(5 * time.Second) +// } +// iter.Close() +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Tailable+Cursors +// http://www.mongodb.org/display/DOCS/Capped+Collections +// http://www.mongodb.org/display/DOCS/Sorting+and+Natural+Order +// +func (q *Query) Tail(timeout time.Duration) *Iter { + q.m.Lock() + session := q.session + op := q.op + prefetch := q.prefetch + q.m.Unlock() + + iter := &Iter{session: session, prefetch: prefetch} + iter.gotReply.L = &iter.m + iter.timeout = timeout + iter.op.collection = op.collection + iter.op.limit = op.limit + iter.op.replyFunc = iter.replyFunc() + iter.docsToReceive++ + op.replyFunc = iter.op.replyFunc + op.flags |= flagTailable | flagAwaitData | session.slaveOkFlag() + + socket, err := session.acquireSocket(true) + if err != nil { + iter.err = err + } else { + iter.server = socket.Server() + err = socket.Query(&op) + if err != nil { + // Must lock as the query above may call replyFunc. + iter.m.Lock() + iter.err = err + iter.m.Unlock() + } + socket.Release() + } + return iter +} + +func (s *Session) slaveOkFlag() (flag queryOpFlags) { + s.m.RLock() + if s.slaveOk { + flag = flagSlaveOk + } + s.m.RUnlock() + return +} + +// Err returns nil if no errors happened during iteration, or the actual +// error otherwise. +// +// In case a resulting document included a field named $err or errmsg, which are +// standard ways for MongoDB to report an improper query, the returned value has +// a *QueryError type, and includes the Err message and the Code. +func (iter *Iter) Err() error { + iter.m.Lock() + err := iter.err + iter.m.Unlock() + if err == ErrNotFound { + return nil + } + return err +} + +// Close kills the server cursor used by the iterator, if any, and returns +// nil if no errors happened during iteration, or the actual error otherwise. +// +// Server cursors are automatically closed at the end of an iteration, which +// means close will do nothing unless the iteration was interrupted before +// the server finished sending results to the driver. If Close is not called +// in such a situation, the cursor will remain available at the server until +// the default cursor timeout period is reached. No further problems arise. +// +// Close is idempotent. That means it can be called repeatedly and will +// return the same result every time. +// +// In case a resulting document included a field named $err or errmsg, which are +// standard ways for MongoDB to report an improper query, the returned value has +// a *QueryError type. +func (iter *Iter) Close() error { + iter.m.Lock() + iter.killCursor() + err := iter.err + iter.m.Unlock() + if err == ErrNotFound { + return nil + } + return err +} + +func (iter *Iter) killCursor() error { + if iter.op.cursorId != 0 { + socket, err := iter.acquireSocket() + if err == nil { + // TODO Batch kills. + err = socket.Query(&killCursorsOp{[]int64{iter.op.cursorId}}) + socket.Release() + } + if err != nil && (iter.err == nil || iter.err == ErrNotFound) { + iter.err = err + } + iter.op.cursorId = 0 + return err + } + return nil +} + +// Timeout returns true if Next returned false due to a timeout of +// a tailable cursor. In those cases, Next may be called again to continue +// the iteration at the previous cursor position. +func (iter *Iter) Timeout() bool { + iter.m.Lock() + result := iter.timedout + iter.m.Unlock() + return result +} + +// Next retrieves the next document from the result set, blocking if necessary. +// This method will also automatically retrieve another batch of documents from +// the server when the current one is exhausted, or before that in background +// if pre-fetching is enabled (see the Query.Prefetch and Session.SetPrefetch +// methods). +// +// Next returns true if a document was successfully unmarshalled onto result, +// and false at the end of the result set or if an error happened. +// When Next returns false, the Err method should be called to verify if +// there was an error during iteration. +// +// For example: +// +// iter := collection.Find(nil).Iter() +// for iter.Next(&result) { +// fmt.Printf("Result: %v\n", result.Id) +// } +// if err := iter.Close(); err != nil { +// return err +// } +// +func (iter *Iter) Next(result interface{}) bool { + iter.m.Lock() + iter.timedout = false + timeout := time.Time{} + for iter.err == nil && iter.docData.Len() == 0 && (iter.docsToReceive > 0 || iter.op.cursorId != 0) { + if iter.docsToReceive == 0 { + if iter.timeout >= 0 { + if timeout.IsZero() { + timeout = time.Now().Add(iter.timeout) + } + if time.Now().After(timeout) { + iter.timedout = true + iter.m.Unlock() + return false + } + } + iter.getMore() + if iter.err != nil { + break + } + } + iter.gotReply.Wait() + } + + // Exhaust available data before reporting any errors. + if docData, ok := iter.docData.Pop().([]byte); ok { + if iter.limit > 0 { + iter.limit-- + if iter.limit == 0 { + if iter.docData.Len() > 0 { + iter.m.Unlock() + panic(fmt.Errorf("data remains after limit exhausted: %d", iter.docData.Len())) + } + iter.err = ErrNotFound + if iter.killCursor() != nil { + iter.m.Unlock() + return false + } + } + } + if iter.op.cursorId != 0 && iter.err == nil { + if iter.docsBeforeMore == 0 { + iter.getMore() + } + iter.docsBeforeMore-- // Goes negative. + } + iter.m.Unlock() + err := bson.Unmarshal(docData, result) + if err != nil { + debugf("Iter %p document unmarshaling failed: %#v", iter, err) + iter.m.Lock() + if iter.err == nil { + iter.err = err + } + iter.m.Unlock() + return false + } + debugf("Iter %p document unmarshaled: %#v", iter, result) + // XXX Only have to check first document for a query error? + err = checkQueryError(iter.op.collection, docData) + if err != nil { + iter.m.Lock() + if iter.err == nil { + iter.err = err + } + iter.m.Unlock() + return false + } + return true + } else if iter.err != nil { + debugf("Iter %p returning false: %s", iter, iter.err) + iter.m.Unlock() + return false + } else if iter.op.cursorId == 0 { + iter.err = ErrNotFound + debugf("Iter %p exhausted with cursor=0", iter) + iter.m.Unlock() + return false + } + + panic("unreachable") +} + +// All retrieves all documents from the result set into the provided slice +// and closes the iterator. +// +// The result argument must necessarily be the address for a slice. The slice +// may be nil or previously allocated. +// +// WARNING: Obviously, All must not be used with result sets that may be +// potentially large, since it may consume all memory until the system +// crashes. Consider building the query with a Limit clause to ensure the +// result size is bounded. +// +// For instance: +// +// var result []struct{ Value int } +// iter := collection.Find(nil).Limit(100).Iter() +// err := iter.All(&result) +// if err != nil { +// return err +// } +// +func (iter *Iter) All(result interface{}) error { + resultv := reflect.ValueOf(result) + if resultv.Kind() != reflect.Ptr || resultv.Elem().Kind() != reflect.Slice { + panic("result argument must be a slice address") + } + slicev := resultv.Elem() + slicev = slicev.Slice(0, slicev.Cap()) + elemt := slicev.Type().Elem() + i := 0 + for { + if slicev.Len() == i { + elemp := reflect.New(elemt) + if !iter.Next(elemp.Interface()) { + break + } + slicev = reflect.Append(slicev, elemp.Elem()) + slicev = slicev.Slice(0, slicev.Cap()) + } else { + if !iter.Next(slicev.Index(i).Addr().Interface()) { + break + } + } + i++ + } + resultv.Elem().Set(slicev.Slice(0, i)) + return iter.Close() +} + +// All works like Iter.All. +func (q *Query) All(result interface{}) error { + return q.Iter().All(result) +} + +// The For method is obsolete and will be removed in a future release. +// See Iter as an elegant replacement. +func (q *Query) For(result interface{}, f func() error) error { + return q.Iter().For(result, f) +} + +// The For method is obsolete and will be removed in a future release. +// See Iter as an elegant replacement. +func (iter *Iter) For(result interface{}, f func() error) (err error) { + valid := false + v := reflect.ValueOf(result) + if v.Kind() == reflect.Ptr { + v = v.Elem() + switch v.Kind() { + case reflect.Map, reflect.Ptr, reflect.Interface, reflect.Slice: + valid = v.IsNil() + } + } + if !valid { + panic("For needs a pointer to nil reference value. See the documentation.") + } + zero := reflect.Zero(v.Type()) + for { + v.Set(zero) + if !iter.Next(result) { + break + } + err = f() + if err != nil { + return err + } + } + return iter.Err() +} + +func (iter *Iter) acquireSocket() (*mongoSocket, error) { + socket, err := iter.session.acquireSocket(true) + if err != nil { + return nil, err + } + if socket.Server() != iter.server { + // Socket server changed during iteration. This may happen + // with Eventual sessions, if a Refresh is done, or if a + // monotonic session gets a write and shifts from secondary + // to primary. Our cursor is in a specific server, though. + iter.session.m.Lock() + sockTimeout := iter.session.sockTimeout + iter.session.m.Unlock() + socket.Release() + socket, _, err = iter.server.AcquireSocket(0, sockTimeout) + if err != nil { + return nil, err + } + err := iter.session.socketLogin(socket) + if err != nil { + socket.Release() + return nil, err + } + } + return socket, nil +} + +func (iter *Iter) getMore() { + socket, err := iter.acquireSocket() + if err != nil { + iter.err = err + return + } + defer socket.Release() + + debugf("Iter %p requesting more documents", iter) + if iter.limit > 0 { + limit := iter.limit - int32(iter.docsToReceive) - int32(iter.docData.Len()) + if limit < iter.op.limit { + iter.op.limit = limit + } + } + if err := socket.Query(&iter.op); err != nil { + iter.err = err + } + iter.docsToReceive++ +} + +type countCmd struct { + Count string + Query interface{} + Limit int32 ",omitempty" + Skip int32 ",omitempty" +} + +// Count returns the total number of documents in the result set. +func (q *Query) Count() (n int, err error) { + q.m.Lock() + session := q.session + op := q.op + limit := q.limit + q.m.Unlock() + + c := strings.Index(op.collection, ".") + if c < 0 { + return 0, errors.New("Bad collection name: " + op.collection) + } + + dbname := op.collection[:c] + cname := op.collection[c+1:] + query := op.query + if query == nil { + query = bson.D{} + } + result := struct{ N int }{} + err = session.DB(dbname).Run(countCmd{cname, query, limit, op.skip}, &result) + return result.N, err +} + +// Count returns the total number of documents in the collection. +func (c *Collection) Count() (n int, err error) { + return c.Find(nil).Count() +} + +type distinctCmd struct { + Collection string "distinct" + Key string + Query interface{} ",omitempty" +} + +// Distinct returns a list of distinct values for the given key within +// the result set. The list of distinct values will be unmarshalled +// in the "values" key of the provided result parameter. +// +// For example: +// +// var result []int +// err := collection.Find(bson.M{"gender": "F"}).Distinct("age", &result) +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Aggregation +// +func (q *Query) Distinct(key string, result interface{}) error { + q.m.Lock() + session := q.session + op := q.op // Copy. + q.m.Unlock() + + c := strings.Index(op.collection, ".") + if c < 0 { + return errors.New("Bad collection name: " + op.collection) + } + + dbname := op.collection[:c] + cname := op.collection[c+1:] + + var doc struct{ Values bson.Raw } + err := session.DB(dbname).Run(distinctCmd{cname, key, op.query}, &doc) + if err != nil { + return err + } + return doc.Values.Unmarshal(result) +} + +type mapReduceCmd struct { + Collection string "mapreduce" + Map string ",omitempty" + Reduce string ",omitempty" + Finalize string ",omitempty" + Limit int32 ",omitempty" + Out interface{} + Query interface{} ",omitempty" + Sort interface{} ",omitempty" + Scope interface{} ",omitempty" + Verbose bool ",omitempty" +} + +type mapReduceResult struct { + Results bson.Raw + Result bson.Raw + TimeMillis int64 "timeMillis" + Counts struct{ Input, Emit, Output int } + Ok bool + Err string + Timing *MapReduceTime +} + +type MapReduce struct { + Map string // Map Javascript function code (required) + Reduce string // Reduce Javascript function code (required) + Finalize string // Finalize Javascript function code (optional) + Out interface{} // Output collection name or document. If nil, results are inlined into the result parameter. + Scope interface{} // Optional global scope for Javascript functions + Verbose bool +} + +type MapReduceInfo struct { + InputCount int // Number of documents mapped + EmitCount int // Number of times reduce called emit + OutputCount int // Number of documents in resulting collection + Database string // Output database, if results are not inlined + Collection string // Output collection, if results are not inlined + Time int64 // Time to run the job, in nanoseconds + VerboseTime *MapReduceTime // Only defined if Verbose was true +} + +type MapReduceTime struct { + Total int64 // Total time, in nanoseconds + Map int64 "mapTime" // Time within map function, in nanoseconds + EmitLoop int64 "emitLoop" // Time within the emit/map loop, in nanoseconds +} + +// MapReduce executes a map/reduce job for documents covered by the query. +// That kind of job is suitable for very flexible bulk aggregation of data +// performed at the server side via Javascript functions. +// +// Results from the job may be returned as a result of the query itself +// through the result parameter in case they'll certainly fit in memory +// and in a single document. If there's the possibility that the amount +// of data might be too large, results must be stored back in an alternative +// collection or even a separate database, by setting the Out field of the +// provided MapReduce job. In that case, provide nil as the result parameter. +// +// These are some of the ways to set Out: +// +// nil +// Inline results into the result parameter. +// +// bson.M{"replace": "mycollection"} +// The output will be inserted into a collection which replaces any +// existing collection with the same name. +// +// bson.M{"merge": "mycollection"} +// This option will merge new data into the old output collection. In +// other words, if the same key exists in both the result set and the +// old collection, the new key will overwrite the old one. +// +// bson.M{"reduce": "mycollection"} +// If documents exist for a given key in the result set and in the old +// collection, then a reduce operation (using the specified reduce +// function) will be performed on the two values and the result will be +// written to the output collection. If a finalize function was +// provided, this will be run after the reduce as well. +// +// bson.M{...., "db": "mydb"} +// Any of the above options can have the "db" key included for doing +// the respective action in a separate database. +// +// The following is a trivial example which will count the number of +// occurrences of a field named n on each document in a collection, and +// will return results inline: +// +// job := &mgo.MapReduce{ +// Map: "function() { emit(this.n, 1) }", +// Reduce: "function(key, values) { return Array.sum(values) }", +// } +// var result []struct { Id int "_id"; Value int } +// _, err := collection.Find(nil).MapReduce(job, &result) +// if err != nil { +// return err +// } +// for _, item := range result { +// fmt.Println(item.Value) +// } +// +// This function is compatible with MongoDB 1.7.4+. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/MapReduce +// +func (q *Query) MapReduce(job *MapReduce, result interface{}) (info *MapReduceInfo, err error) { + q.m.Lock() + session := q.session + op := q.op // Copy. + limit := q.limit + q.m.Unlock() + + c := strings.Index(op.collection, ".") + if c < 0 { + return nil, errors.New("Bad collection name: " + op.collection) + } + + dbname := op.collection[:c] + cname := op.collection[c+1:] + + cmd := mapReduceCmd{ + Collection: cname, + Map: job.Map, + Reduce: job.Reduce, + Finalize: job.Finalize, + Out: fixMROut(job.Out), + Scope: job.Scope, + Verbose: job.Verbose, + Query: op.query, + Sort: op.options.OrderBy, + Limit: limit, + } + + if cmd.Out == nil { + cmd.Out = bson.D{{"inline", 1}} + } + + var doc mapReduceResult + err = session.DB(dbname).Run(&cmd, &doc) + if err != nil { + return nil, err + } + if doc.Err != "" { + return nil, errors.New(doc.Err) + } + + info = &MapReduceInfo{ + InputCount: doc.Counts.Input, + EmitCount: doc.Counts.Emit, + OutputCount: doc.Counts.Output, + Time: doc.TimeMillis * 1e6, + } + + if doc.Result.Kind == 0x02 { + err = doc.Result.Unmarshal(&info.Collection) + info.Database = dbname + } else if doc.Result.Kind == 0x03 { + var v struct{ Collection, Db string } + err = doc.Result.Unmarshal(&v) + info.Collection = v.Collection + info.Database = v.Db + } + + if doc.Timing != nil { + info.VerboseTime = doc.Timing + info.VerboseTime.Total *= 1e6 + info.VerboseTime.Map *= 1e6 + info.VerboseTime.EmitLoop *= 1e6 + } + + if err != nil { + return nil, err + } + if result != nil { + return info, doc.Results.Unmarshal(result) + } + return info, nil +} + +// The "out" option in the MapReduce command must be ordered. This was +// found after the implementation was accepting maps for a long time, +// so rather than breaking the API, we'll fix the order if necessary. +// Details about the order requirement may be seen in MongoDB's code: +// +// http://goo.gl/L8jwJX +// +func fixMROut(out interface{}) interface{} { + outv := reflect.ValueOf(out) + if outv.Kind() != reflect.Map || outv.Type().Key() != reflect.TypeOf("") { + return out + } + outs := make(bson.D, outv.Len()) + + outTypeIndex := -1 + for i, k := range outv.MapKeys() { + ks := k.String() + outs[i].Name = ks + outs[i].Value = outv.MapIndex(k).Interface() + switch ks { + case "normal", "replace", "merge", "reduce", "inline": + outTypeIndex = i + } + } + if outTypeIndex > 0 { + outs[0], outs[outTypeIndex] = outs[outTypeIndex], outs[0] + } + return outs +} + +// Change holds fields for running a findAndModify MongoDB command via +// the Query.Apply method. +type Change struct { + Update interface{} // The update document + Upsert bool // Whether to insert in case the document isn't found + Remove bool // Whether to remove the document found rather than updating + ReturnNew bool // Should the modified document be returned rather than the old one +} + +type findModifyCmd struct { + Collection string "findAndModify" + Query, Update, Sort, Fields interface{} ",omitempty" + Upsert, Remove, New bool ",omitempty" +} + +type valueResult struct { + Value bson.Raw + LastError LastError "lastErrorObject" +} + +// Apply runs the findAndModify MongoDB command, which allows updating, upserting +// or removing a document matching a query and atomically returning either the old +// version (the default) or the new version of the document (when ReturnNew is true). +// If no objects are found Apply returns ErrNotFound. +// +// The Sort and Select query methods affect the result of Apply. In case +// multiple documents match the query, Sort enables selecting which document to +// act upon by ordering it first. Select enables retrieving only a selection +// of fields of the new or old document. +// +// This simple example increments a counter and prints its new value: +// +// change := mgo.Change{ +// Update: bson.M{"$inc": bson.M{"n": 1}}, +// ReturnNew: true, +// } +// info, err = col.Find(M{"_id": id}).Apply(change, &doc) +// fmt.Println(doc.N) +// +// This method depends on MongoDB >= 2.0 to work properly. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/findAndModify+Command +// http://www.mongodb.org/display/DOCS/Updating +// http://www.mongodb.org/display/DOCS/Atomic+Operations +// +func (q *Query) Apply(change Change, result interface{}) (info *ChangeInfo, err error) { + q.m.Lock() + session := q.session + op := q.op // Copy. + q.m.Unlock() + + c := strings.Index(op.collection, ".") + if c < 0 { + return nil, errors.New("bad collection name: " + op.collection) + } + + dbname := op.collection[:c] + cname := op.collection[c+1:] + + cmd := findModifyCmd{ + Collection: cname, + Update: change.Update, + Upsert: change.Upsert, + Remove: change.Remove, + New: change.ReturnNew, + Query: op.query, + Sort: op.options.OrderBy, + Fields: op.selector, + } + + session = session.Clone() + defer session.Close() + session.SetMode(Strong, false) + + var doc valueResult + err = session.DB(dbname).Run(&cmd, &doc) + if err != nil { + if qerr, ok := err.(*QueryError); ok && qerr.Message == "No matching object found" { + return nil, ErrNotFound + } + return nil, err + } + if doc.LastError.N == 0 { + return nil, ErrNotFound + } + if doc.Value.Kind != 0x0A { + err = doc.Value.Unmarshal(result) + if err != nil { + return nil, err + } + } + info = &ChangeInfo{} + lerr := &doc.LastError + if lerr.UpdatedExisting { + info.Updated = lerr.N + } else if change.Remove { + info.Removed = lerr.N + } else if change.Upsert { + info.UpsertedId = lerr.UpsertedId + } + return info, nil +} + +// The BuildInfo type encapsulates details about the running MongoDB server. +// +// Note that the VersionArray field was introduced in MongoDB 2.0+, but it is +// internally assembled from the Version information for previous versions. +// In both cases, VersionArray is guaranteed to have at least 4 entries. +type BuildInfo struct { + Version string + VersionArray []int `bson:"versionArray"` // On MongoDB 2.0+; assembled from Version otherwise + GitVersion string `bson:"gitVersion"` + OpenSSLVersion string `bson:"OpenSSLVersion"` + SysInfo string `bson:"sysInfo"` + Bits int + Debug bool + MaxObjectSize int `bson:"maxBsonObjectSize"` +} + +// VersionAtLeast returns whether the BuildInfo version is greater than or +// equal to the provided version number. If more than one number is +// provided, numbers will be considered as major, minor, and so on. +func (bi *BuildInfo) VersionAtLeast(version ...int) bool { + for i := range version { + if i == len(bi.VersionArray) { + return false + } + if bi.VersionArray[i] < version[i] { + return false + } + } + return true +} + +// BuildInfo retrieves the version and other details about the +// running MongoDB server. +func (s *Session) BuildInfo() (info BuildInfo, err error) { + err = s.Run(bson.D{{"buildInfo", "1"}}, &info) + if len(info.VersionArray) == 0 { + for _, a := range strings.Split(info.Version, ".") { + i, err := strconv.Atoi(a) + if err != nil { + break + } + info.VersionArray = append(info.VersionArray, i) + } + } + for len(info.VersionArray) < 4 { + info.VersionArray = append(info.VersionArray, 0) + } + if i := strings.IndexByte(info.GitVersion, ' '); i >= 0 { + // Strip off the " modules: enterprise" suffix. This is a _git version_. + // That information may be moved to another field if people need it. + info.GitVersion = info.GitVersion[:i] + } + return +} + +// --------------------------------------------------------------------------- +// Internal session handling helpers. + +func (s *Session) acquireSocket(slaveOk bool) (*mongoSocket, error) { + + // Read-only lock to check for previously reserved socket. + s.m.RLock() + if s.masterSocket != nil { + socket := s.masterSocket + socket.Acquire() + s.m.RUnlock() + return socket, nil + } + if s.slaveSocket != nil && s.slaveOk && slaveOk { + socket := s.slaveSocket + socket.Acquire() + s.m.RUnlock() + return socket, nil + } + s.m.RUnlock() + + // No go. We may have to request a new socket and change the session, + // so try again but with an exclusive lock now. + s.m.Lock() + defer s.m.Unlock() + + if s.masterSocket != nil { + s.masterSocket.Acquire() + return s.masterSocket, nil + } + if s.slaveSocket != nil && s.slaveOk && slaveOk { + s.slaveSocket.Acquire() + return s.slaveSocket, nil + } + + // Still not good. We need a new socket. + sock, err := s.cluster().AcquireSocket(slaveOk && s.slaveOk, s.syncTimeout, s.sockTimeout, s.queryConfig.op.serverTags, s.poolLimit) + if err != nil { + return nil, err + } + + // Authenticate the new socket. + if err = s.socketLogin(sock); err != nil { + sock.Release() + return nil, err + } + + // Keep track of the new socket, if necessary. + // Note that, as a special case, if the Eventual session was + // not refreshed (s.slaveSocket != nil), it means the developer + // asked to preserve an existing reserved socket, so we'll + // keep a master one around too before a Refresh happens. + if s.consistency != Eventual || s.slaveSocket != nil { + s.setSocket(sock) + } + + // Switch over a Monotonic session to the master. + if !slaveOk && s.consistency == Monotonic { + s.slaveOk = false + } + + return sock, nil +} + +// setSocket binds socket to this section. +func (s *Session) setSocket(socket *mongoSocket) { + info := socket.Acquire() + if info.Master { + if s.masterSocket != nil { + panic("setSocket(master) with existing master socket reserved") + } + s.masterSocket = socket + } else { + if s.slaveSocket != nil { + panic("setSocket(slave) with existing slave socket reserved") + } + s.slaveSocket = socket + } +} + +// unsetSocket releases any slave and/or master sockets reserved. +func (s *Session) unsetSocket() { + if s.masterSocket != nil { + s.masterSocket.Release() + } + if s.slaveSocket != nil { + s.slaveSocket.Release() + } + s.masterSocket = nil + s.slaveSocket = nil +} + +func (iter *Iter) replyFunc() replyFunc { + return func(err error, op *replyOp, docNum int, docData []byte) { + iter.m.Lock() + iter.docsToReceive-- + if err != nil { + iter.err = err + debugf("Iter %p received an error: %s", iter, err.Error()) + } else if docNum == -1 { + debugf("Iter %p received no documents (cursor=%d).", iter, op.cursorId) + if op != nil && op.cursorId != 0 { + // It's a tailable cursor. + iter.op.cursorId = op.cursorId + } else { + iter.err = ErrNotFound + } + } else { + rdocs := int(op.replyDocs) + if docNum == 0 { + iter.docsToReceive += rdocs - 1 + docsToProcess := iter.docData.Len() + rdocs + if iter.limit == 0 || int32(docsToProcess) < iter.limit { + iter.docsBeforeMore = docsToProcess - int(iter.prefetch*float64(rdocs)) + } else { + iter.docsBeforeMore = -1 + } + iter.op.cursorId = op.cursorId + } + // XXX Handle errors and flags. + debugf("Iter %p received reply document %d/%d (cursor=%d)", iter, docNum+1, rdocs, op.cursorId) + iter.docData.Push(docData) + } + iter.gotReply.Broadcast() + iter.m.Unlock() + } +} + +// writeQuery runs the given modifying operation, potentially followed up +// by a getLastError command in case the session is in safe mode. The +// LastError result is made available in lerr, and if lerr.Err is set it +// will also be returned as err. +func (c *Collection) writeQuery(op interface{}) (lerr *LastError, err error) { + s := c.Database.Session + dbname := c.Database.Name + socket, err := s.acquireSocket(dbname == "local") + if err != nil { + return nil, err + } + defer socket.Release() + + s.m.RLock() + safeOp := s.safeOp + s.m.RUnlock() + + if safeOp == nil { + return nil, socket.Query(op) + } else { + var mutex sync.Mutex + var replyData []byte + var replyErr error + mutex.Lock() + query := *safeOp // Copy the data. + query.collection = dbname + ".$cmd" + query.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { + replyData = docData + replyErr = err + mutex.Unlock() + } + err = socket.Query(op, &query) + if err != nil { + return nil, err + } + mutex.Lock() // Wait. + if replyErr != nil { + return nil, replyErr // XXX TESTME + } + if hasErrMsg(replyData) { + // Looks like getLastError itself failed. + err = checkQueryError(query.collection, replyData) + if err != nil { + return nil, err + } + } + result := &LastError{} + bson.Unmarshal(replyData, &result) + debugf("Result from writing query: %#v", result) + if result.Err != "" { + return result, result + } + return result, nil + } + panic("unreachable") +} + +func hasErrMsg(d []byte) bool { + l := len(d) + for i := 0; i+8 < l; i++ { + if d[i] == '\x02' && d[i+1] == 'e' && d[i+2] == 'r' && d[i+3] == 'r' && d[i+4] == 'm' && d[i+5] == 's' && d[i+6] == 'g' && d[i+7] == '\x00' { + return true + } + } + return false +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/session_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/session_test.go new file mode 100644 index 0000000..87ae849 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/session_test.go @@ -0,0 +1,3484 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo_test + +import ( + "flag" + "fmt" + "math" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + "time" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +func (s *S) TestRunString(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + result := struct{ Ok int }{} + err = session.Run("ping", &result) + c.Assert(err, IsNil) + c.Assert(result.Ok, Equals, 1) +} + +func (s *S) TestRunValue(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + result := struct{ Ok int }{} + err = session.Run(M{"ping": 1}, &result) + c.Assert(err, IsNil) + c.Assert(result.Ok, Equals, 1) +} + +func (s *S) TestPing(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + // Just ensure the nonce has been received. + result := struct{}{} + err = session.Run("ping", &result) + + mgo.ResetStats() + + err = session.Ping() + c.Assert(err, IsNil) + + // Pretty boring. + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 1) + c.Assert(stats.ReceivedOps, Equals, 1) +} + +func (s *S) TestURLSingle(c *C) { + session, err := mgo.Dial("mongodb://localhost:40001/") + c.Assert(err, IsNil) + defer session.Close() + + result := struct{ Ok int }{} + err = session.Run("ping", &result) + c.Assert(err, IsNil) + c.Assert(result.Ok, Equals, 1) +} + +func (s *S) TestURLMany(c *C) { + session, err := mgo.Dial("mongodb://localhost:40011,localhost:40012/") + c.Assert(err, IsNil) + defer session.Close() + + result := struct{ Ok int }{} + err = session.Run("ping", &result) + c.Assert(err, IsNil) + c.Assert(result.Ok, Equals, 1) +} + +func (s *S) TestURLParsing(c *C) { + urls := []string{ + "localhost:40001?foo=1&bar=2", + "localhost:40001?foo=1;bar=2", + } + for _, url := range urls { + session, err := mgo.Dial(url) + if session != nil { + session.Close() + } + c.Assert(err, ErrorMatches, "unsupported connection URL option: (foo=1|bar=2)") + } +} + +func (s *S) TestInsertFindOne(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1, "b": 2}) + c.Assert(err, IsNil) + err = coll.Insert(M{"a": 1, "b": 3}) + c.Assert(err, IsNil) + + result := struct{ A, B int }{} + + err = coll.Find(M{"a": 1}).Sort("b").One(&result) + c.Assert(err, IsNil) + c.Assert(result.A, Equals, 1) + c.Assert(result.B, Equals, 2) +} + +func (s *S) TestInsertFindOneNil(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Find(nil).One(nil) + c.Assert(err, ErrorMatches, "unauthorized.*|not authorized.*") +} + +func (s *S) TestInsertFindOneMap(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1, "b": 2}) + c.Assert(err, IsNil) + result := make(M) + err = coll.Find(M{"a": 1}).One(result) + c.Assert(err, IsNil) + c.Assert(result["a"], Equals, 1) + c.Assert(result["b"], Equals, 2) +} + +func (s *S) TestInsertFindAll(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1, "b": 2}) + c.Assert(err, IsNil) + err = coll.Insert(M{"a": 3, "b": 4}) + c.Assert(err, IsNil) + + type R struct{ A, B int } + var result []R + + assertResult := func() { + c.Assert(len(result), Equals, 2) + c.Assert(result[0].A, Equals, 1) + c.Assert(result[0].B, Equals, 2) + c.Assert(result[1].A, Equals, 3) + c.Assert(result[1].B, Equals, 4) + } + + // nil slice + err = coll.Find(nil).Sort("a").All(&result) + c.Assert(err, IsNil) + assertResult() + + // Previously allocated slice + allocd := make([]R, 5) + result = allocd + err = coll.Find(nil).Sort("a").All(&result) + c.Assert(err, IsNil) + assertResult() + + // Ensure result is backed by the originally allocated array + c.Assert(&result[0], Equals, &allocd[0]) + + // Non-pointer slice error + f := func() { coll.Find(nil).All(result) } + c.Assert(f, Panics, "result argument must be a slice address") + + // Non-slice error + f = func() { coll.Find(nil).All(new(int)) } + c.Assert(f, Panics, "result argument must be a slice address") +} + +func (s *S) TestFindRef(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + db1 := session.DB("db1") + db1col1 := db1.C("col1") + + db2 := session.DB("db2") + db2col1 := db2.C("col1") + + err = db1col1.Insert(M{"_id": 1, "n": 1}) + c.Assert(err, IsNil) + err = db1col1.Insert(M{"_id": 2, "n": 2}) + c.Assert(err, IsNil) + err = db2col1.Insert(M{"_id": 2, "n": 3}) + c.Assert(err, IsNil) + + result := struct{ N int }{} + + ref1 := &mgo.DBRef{Collection: "col1", Id: 1} + ref2 := &mgo.DBRef{Collection: "col1", Id: 2, Database: "db2"} + + err = db1.FindRef(ref1).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 1) + + err = db1.FindRef(ref2).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 3) + + err = db2.FindRef(ref1).One(&result) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = db2.FindRef(ref2).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 3) + + err = session.FindRef(ref2).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 3) + + f := func() { session.FindRef(ref1).One(&result) } + c.Assert(f, PanicMatches, "Can't resolve database for &mgo.DBRef{Collection:\"col1\", Id:1, Database:\"\"}") +} + +func (s *S) TestDatabaseAndCollectionNames(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + db1 := session.DB("db1") + db1col1 := db1.C("col1") + db1col2 := db1.C("col2") + + db2 := session.DB("db2") + db2col1 := db2.C("col3") + + err = db1col1.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + err = db1col2.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + err = db2col1.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + names, err := session.DatabaseNames() + c.Assert(err, IsNil) + if !reflect.DeepEqual(names, []string{"db1", "db2"}) { + // 2.4+ has "local" as well. + c.Assert(names, DeepEquals, []string{"db1", "db2", "local"}) + } + + names, err = db1.CollectionNames() + c.Assert(err, IsNil) + c.Assert(names, DeepEquals, []string{"col1", "col2", "system.indexes"}) + + names, err = db2.CollectionNames() + c.Assert(err, IsNil) + c.Assert(names, DeepEquals, []string{"col3", "system.indexes"}) +} + +func (s *S) TestSelect(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"a": 1, "b": 2}) + + result := struct{ A, B int }{} + + err = coll.Find(M{"a": 1}).Select(M{"b": 1}).One(&result) + c.Assert(err, IsNil) + c.Assert(result.A, Equals, 0) + c.Assert(result.B, Equals, 2) +} + +func (s *S) TestInlineMap(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + var v, result1 struct { + A int + M map[string]int ",inline" + } + + v.A = 1 + v.M = map[string]int{"b": 2} + err = coll.Insert(v) + c.Assert(err, IsNil) + + noId := M{"_id": 0} + + err = coll.Find(nil).Select(noId).One(&result1) + c.Assert(err, IsNil) + c.Assert(result1.A, Equals, 1) + c.Assert(result1.M, DeepEquals, map[string]int{"b": 2}) + + var result2 M + err = coll.Find(nil).Select(noId).One(&result2) + c.Assert(err, IsNil) + c.Assert(result2, DeepEquals, M{"a": 1, "b": 2}) + +} + +func (s *S) TestUpdate(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"k": n, "n": n}) + c.Assert(err, IsNil) + } + + err = coll.Update(M{"k": 42}, M{"$inc": M{"n": 1}}) + c.Assert(err, IsNil) + + result := make(M) + err = coll.Find(M{"k": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 43) + + err = coll.Update(M{"k": 47}, M{"k": 47, "n": 47}) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.Find(M{"k": 47}).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestUpdateId(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"_id": n, "n": n}) + c.Assert(err, IsNil) + } + + err = coll.UpdateId(42, M{"$inc": M{"n": 1}}) + c.Assert(err, IsNil) + + result := make(M) + err = coll.FindId(42).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 43) + + err = coll.UpdateId(47, M{"k": 47, "n": 47}) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.FindId(47).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestUpdateNil(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"k": 42, "n": 42}) + c.Assert(err, IsNil) + err = coll.Update(nil, M{"$inc": M{"n": 1}}) + c.Assert(err, IsNil) + + result := make(M) + err = coll.Find(M{"k": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 43) + + err = coll.Insert(M{"k": 45, "n": 45}) + c.Assert(err, IsNil) + _, err = coll.UpdateAll(nil, M{"$inc": M{"n": 1}}) + c.Assert(err, IsNil) + + err = coll.Find(M{"k": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 44) + err = coll.Find(M{"k": 45}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 46) + +} + +func (s *S) TestUpsert(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"k": n, "n": n}) + c.Assert(err, IsNil) + } + + info, err := coll.Upsert(M{"k": 42}, M{"k": 42, "n": 24}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.UpsertedId, IsNil) + + result := M{} + err = coll.Find(M{"k": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 24) + + // Insert with internally created id. + info, err = coll.Upsert(M{"k": 47}, M{"k": 47, "n": 47}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 0) + c.Assert(info.UpsertedId, NotNil) + + err = coll.Find(M{"k": 47}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 47) + + result = M{} + err = coll.Find(M{"_id": info.UpsertedId}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 47) + + // Insert with provided id. + info, err = coll.Upsert(M{"k": 48}, M{"k": 48, "n": 48, "_id": 48}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 0) + if s.versionAtLeast(2, 6) { + c.Assert(info.UpsertedId, Equals, 48) + } else { + c.Assert(info.UpsertedId, IsNil) // Unfortunate, but that's what Mongo gave us. + } + + err = coll.Find(M{"k": 48}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 48) +} + +func (s *S) TestUpsertId(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"_id": n, "n": n}) + c.Assert(err, IsNil) + } + + info, err := coll.UpsertId(42, M{"n": 24}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.UpsertedId, IsNil) + + result := M{} + err = coll.FindId(42).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 24) + + info, err = coll.UpsertId(47, M{"_id": 47, "n": 47}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 0) + if s.versionAtLeast(2, 6) { + c.Assert(info.UpsertedId, Equals, 47) + } else { + c.Assert(info.UpsertedId, IsNil) + } + + err = coll.FindId(47).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 47) +} + +func (s *S) TestUpdateAll(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"k": n, "n": n}) + c.Assert(err, IsNil) + } + + info, err := coll.UpdateAll(M{"k": M{"$gt": 42}}, M{"$inc": M{"n": 1}}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 4) + + result := make(M) + err = coll.Find(M{"k": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 42) + + err = coll.Find(M{"k": 43}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 44) + + err = coll.Find(M{"k": 44}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 45) + + if !s.versionAtLeast(2, 6) { + // 2.6 made this invalid. + info, err = coll.UpdateAll(M{"k": 47}, M{"k": 47, "n": 47}) + c.Assert(err, Equals, nil) + c.Assert(info.Updated, Equals, 0) + } +} + +func (s *S) TestRemove(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + err = coll.Remove(M{"n": M{"$gt": 42}}) + c.Assert(err, IsNil) + + result := &struct{ N int }{} + err = coll.Find(M{"n": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 42) + + err = coll.Find(M{"n": 43}).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.Find(M{"n": 44}).One(result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 44) +} + +func (s *S) TestRemoveId(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"_id": 40}, M{"_id": 41}, M{"_id": 42}) + c.Assert(err, IsNil) + + err = coll.RemoveId(41) + c.Assert(err, IsNil) + + c.Assert(coll.FindId(40).One(nil), IsNil) + c.Assert(coll.FindId(41).One(nil), Equals, mgo.ErrNotFound) + c.Assert(coll.FindId(42).One(nil), IsNil) +} + +func (s *S) TestRemoveAll(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + info, err := coll.RemoveAll(M{"n": M{"$gt": 42}}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 0) + c.Assert(info.Removed, Equals, 4) + c.Assert(info.UpsertedId, IsNil) + + result := &struct{ N int }{} + err = coll.Find(M{"n": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 42) + + err = coll.Find(M{"n": 43}).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.Find(M{"n": 44}).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestDropDatabase(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + db1 := session.DB("db1") + db1.C("col").Insert(M{"_id": 1}) + + db2 := session.DB("db2") + db2.C("col").Insert(M{"_id": 1}) + + err = db1.DropDatabase() + c.Assert(err, IsNil) + + names, err := session.DatabaseNames() + c.Assert(err, IsNil) + if !reflect.DeepEqual(names, []string{"db2"}) { + // 2.4+ has "local" as well. + c.Assert(names, DeepEquals, []string{"db2", "local"}) + } + + err = db2.DropDatabase() + c.Assert(err, IsNil) + + names, err = session.DatabaseNames() + c.Assert(err, IsNil) + if !reflect.DeepEqual(names, []string(nil)) { + // 2.4+ has "local" as well. + c.Assert(names, DeepEquals, []string{"local"}) + } +} + +func (s *S) TestDropCollection(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("db1") + db.C("col1").Insert(M{"_id": 1}) + db.C("col2").Insert(M{"_id": 1}) + + err = db.C("col1").DropCollection() + c.Assert(err, IsNil) + + names, err := db.CollectionNames() + c.Assert(err, IsNil) + c.Assert(names, DeepEquals, []string{"col2", "system.indexes"}) + + err = db.C("col2").DropCollection() + c.Assert(err, IsNil) + + names, err = db.CollectionNames() + c.Assert(err, IsNil) + c.Assert(names, DeepEquals, []string{"system.indexes"}) +} + +func (s *S) TestCreateCollectionCapped(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + info := &mgo.CollectionInfo{ + Capped: true, + MaxBytes: 1024, + MaxDocs: 3, + } + err = coll.Create(info) + c.Assert(err, IsNil) + + ns := []int{1, 2, 3, 4, 5} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + n, err := coll.Find(nil).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) +} + +func (s *S) TestCreateCollectionNoIndex(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + info := &mgo.CollectionInfo{ + DisableIdIndex: true, + } + err = coll.Create(info) + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + + indexes, err := coll.Indexes() + c.Assert(indexes, HasLen, 0) +} + +func (s *S) TestCreateCollectionForceIndex(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + info := &mgo.CollectionInfo{ + ForceIdIndex: true, + Capped: true, + MaxBytes: 1024, + } + err = coll.Create(info) + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + + indexes, err := coll.Indexes() + c.Assert(indexes, HasLen, 1) +} + +func (s *S) TestIsDupValues(c *C) { + c.Assert(mgo.IsDup(nil), Equals, false) + c.Assert(mgo.IsDup(&mgo.LastError{Code: 1}), Equals, false) + c.Assert(mgo.IsDup(&mgo.QueryError{Code: 1}), Equals, false) + c.Assert(mgo.IsDup(&mgo.LastError{Code: 11000}), Equals, true) + c.Assert(mgo.IsDup(&mgo.QueryError{Code: 11000}), Equals, true) + c.Assert(mgo.IsDup(&mgo.LastError{Code: 11001}), Equals, true) + c.Assert(mgo.IsDup(&mgo.QueryError{Code: 11001}), Equals, true) + c.Assert(mgo.IsDup(&mgo.LastError{Code: 12582}), Equals, true) + c.Assert(mgo.IsDup(&mgo.QueryError{Code: 12582}), Equals, true) + lerr := &mgo.LastError{Code: 16460, Err: "error inserting 1 documents to shard ... caused by :: E11000 duplicate key error index: ..."} + c.Assert(mgo.IsDup(lerr), Equals, true) +} + +func (s *S) TestIsDupPrimary(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + err = coll.Insert(M{"_id": 1}) + c.Assert(err, ErrorMatches, ".*duplicate key error.*") + c.Assert(mgo.IsDup(err), Equals, true) +} + +func (s *S) TestIsDupUnique(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + index := mgo.Index{ + Key: []string{"a", "b"}, + Unique: true, + } + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndex(index) + c.Assert(err, IsNil) + + err = coll.Insert(M{"a": 1, "b": 1}) + c.Assert(err, IsNil) + err = coll.Insert(M{"a": 1, "b": 1}) + c.Assert(err, ErrorMatches, ".*duplicate key error.*") + c.Assert(mgo.IsDup(err), Equals, true) +} + +func (s *S) TestIsDupCapped(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + info := &mgo.CollectionInfo{ + ForceIdIndex: true, + Capped: true, + MaxBytes: 1024, + } + err = coll.Create(info) + c.Assert(err, IsNil) + + err = coll.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + err = coll.Insert(M{"_id": 1}) + // The error was different for capped collections before 2.6. + c.Assert(err, ErrorMatches, ".*duplicate key.*") + // The issue is reduced by using IsDup. + c.Assert(mgo.IsDup(err), Equals, true) +} + +func (s *S) TestIsDupFindAndModify(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndex(mgo.Index{Key: []string{"n"}, Unique: true}) + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + err = coll.Insert(M{"n": 2}) + c.Assert(err, IsNil) + _, err = coll.Find(M{"n": 1}).Apply(mgo.Change{Update: M{"$inc": M{"n": 1}}}, bson.M{}) + c.Assert(err, ErrorMatches, ".*duplicate key error.*") + c.Assert(mgo.IsDup(err), Equals, true) +} + +func (s *S) TestFindAndModify(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"n": 42}) + + session.SetMode(mgo.Monotonic, true) + + result := M{} + info, err := coll.Find(M{"n": 42}).Apply(mgo.Change{Update: M{"$inc": M{"n": 1}}}, result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 42) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.Removed, Equals, 0) + c.Assert(info.UpsertedId, IsNil) + + result = M{} + info, err = coll.Find(M{"n": 43}).Apply(mgo.Change{Update: M{"$inc": M{"n": 1}}, ReturnNew: true}, result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 44) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.Removed, Equals, 0) + c.Assert(info.UpsertedId, IsNil) + + result = M{} + info, err = coll.Find(M{"n": 50}).Apply(mgo.Change{Upsert: true, Update: M{"n": 51, "o": 52}}, result) + c.Assert(err, IsNil) + c.Assert(result["n"], IsNil) + c.Assert(info.Updated, Equals, 0) + c.Assert(info.Removed, Equals, 0) + c.Assert(info.UpsertedId, NotNil) + + result = M{} + info, err = coll.Find(nil).Sort("-n").Apply(mgo.Change{Update: M{"$inc": M{"n": 1}}, ReturnNew: true}, result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 52) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.Removed, Equals, 0) + c.Assert(info.UpsertedId, IsNil) + + result = M{} + info, err = coll.Find(M{"n": 52}).Select(M{"o": 1}).Apply(mgo.Change{Remove: true}, result) + c.Assert(err, IsNil) + c.Assert(result["n"], IsNil) + c.Assert(result["o"], Equals, 52) + c.Assert(info.Updated, Equals, 0) + c.Assert(info.Removed, Equals, 1) + c.Assert(info.UpsertedId, IsNil) + + result = M{} + info, err = coll.Find(M{"n": 60}).Apply(mgo.Change{Remove: true}, result) + c.Assert(err, Equals, mgo.ErrNotFound) + c.Assert(len(result), Equals, 0) + c.Assert(info, IsNil) +} + +func (s *S) TestFindAndModifyBug997828(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"n": "not-a-number"}) + + result := make(M) + _, err = coll.Find(M{"n": "not-a-number"}).Apply(mgo.Change{Update: M{"$inc": M{"n": 1}}}, result) + c.Assert(err, ErrorMatches, `(exception: )?Cannot apply \$inc .*`) + if s.versionAtLeast(2, 1) { + qerr, _ := err.(*mgo.QueryError) + c.Assert(qerr, NotNil, Commentf("err: %#v", err)) + if s.versionAtLeast(2, 6) { + // Oh, the dance of error codes. :-( + c.Assert(qerr.Code, Equals, 16837) + } else { + c.Assert(qerr.Code, Equals, 10140) + } + } else { + lerr, _ := err.(*mgo.LastError) + c.Assert(lerr, NotNil, Commentf("err: %#v", err)) + c.Assert(lerr.Code, Equals, 10140) + } +} + +func (s *S) TestCountCollection(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + n, err := coll.Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) +} + +func (s *S) TestCountQuery(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + n, err := coll.Find(M{"n": M{"$gt": 40}}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 2) +} + +func (s *S) TestCountQuerySorted(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + n, err := coll.Find(M{"n": M{"$gt": 40}}).Sort("n").Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 2) +} + +func (s *S) TestCountSkipLimit(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + n, err := coll.Find(nil).Skip(1).Limit(3).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) + + n, err = coll.Find(nil).Skip(1).Limit(5).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 4) +} + +func (s *S) TestQueryExplain(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + m := M{} + query := coll.Find(nil).Limit(2) + err = query.Explain(m) + c.Assert(err, IsNil) + c.Assert(m["cursor"], Equals, "BasicCursor") + c.Assert(m["nscanned"], Equals, 2) + c.Assert(m["n"], Equals, 2) + + n := 0 + var result M + iter := query.Iter() + for iter.Next(&result) { + n++ + } + c.Assert(iter.Close(), IsNil) + c.Assert(n, Equals, 2) +} + +func (s *S) TestQueryMaxScan(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + query := coll.Find(nil).SetMaxScan(2) + var result []M + err = query.All(&result) + c.Assert(err, IsNil) + c.Assert(result, HasLen, 2) +} + +func (s *S) TestQueryHint(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + coll.EnsureIndexKey("a") + + m := M{} + err = coll.Find(nil).Hint("a").Explain(m) + c.Assert(err, IsNil) + c.Assert(m["indexBounds"], NotNil) + c.Assert(m["indexBounds"].(M)["a"], NotNil) +} + +func (s *S) TestFindOneNotFound(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + result := struct{ A, B int }{} + err = coll.Find(M{"a": 1}).One(&result) + c.Assert(err, Equals, mgo.ErrNotFound) + c.Assert(err, ErrorMatches, "not found") + c.Assert(err == mgo.ErrNotFound, Equals, true) +} + +func (s *S) TestFindNil(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + + result := struct{ N int }{} + + err = coll.Find(nil).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 1) +} + +func (s *S) TestFindId(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"_id": 41, "n": 41}) + c.Assert(err, IsNil) + err = coll.Insert(M{"_id": 42, "n": 42}) + c.Assert(err, IsNil) + + result := struct{ N int }{} + + err = coll.FindId(42).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 42) +} + +func (s *S) TestFindIterAll(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + session.Refresh() // Release socket. + + mgo.ResetStats() + + iter := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Prefetch(0).Batch(2).Iter() + result := struct{ N int }{} + for i := 2; i < 7; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, ns[i]) + if i == 1 { + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + } + + ok := iter.Next(&result) + c.Assert(ok, Equals, false) + c.Assert(iter.Close(), IsNil) + + session.Refresh() // Release socket. + + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 3) // 1*QUERY_OP + 2*GET_MORE_OP + c.Assert(stats.ReceivedOps, Equals, 3) // and their REPLY_OPs. + c.Assert(stats.ReceivedDocs, Equals, 5) + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestFindIterTwiceWithSameQuery(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for i := 40; i != 47; i++ { + coll.Insert(M{"n": i}) + } + + query := coll.Find(M{}).Sort("n") + + result1 := query.Skip(1).Iter() + result2 := query.Skip(2).Iter() + + result := struct{ N int }{} + ok := result2.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, 42) + ok = result1.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, 41) +} + +func (s *S) TestFindIterWithoutResults(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"n": 42}) + + iter := coll.Find(M{"n": 0}).Iter() + + result := struct{ N int }{} + ok := iter.Next(&result) + c.Assert(ok, Equals, false) + c.Assert(iter.Close(), IsNil) + c.Assert(result.N, Equals, 0) +} + +func (s *S) TestFindIterLimit(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + session.Refresh() // Release socket. + + mgo.ResetStats() + + query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Limit(3) + iter := query.Iter() + + result := struct{ N int }{} + for i := 2; i < 5; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, ns[i]) + } + + ok := iter.Next(&result) + c.Assert(ok, Equals, false) + c.Assert(iter.Close(), IsNil) + + session.Refresh() // Release socket. + + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 2) // 1*QUERY_OP + 1*KILL_CURSORS_OP + c.Assert(stats.ReceivedOps, Equals, 1) // and its REPLY_OP + c.Assert(stats.ReceivedDocs, Equals, 3) + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestTooManyItemsLimitBug(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(runtime.NumCPU())) + + mgo.SetDebug(false) + coll := session.DB("mydb").C("mycoll") + words := strings.Split("foo bar baz", " ") + for i := 0; i < 5; i++ { + words = append(words, words...) + } + doc := bson.D{{"words", words}} + inserts := 10000 + limit := 5000 + iters := 0 + c.Assert(inserts > limit, Equals, true) + for i := 0; i < inserts; i++ { + err := coll.Insert(&doc) + c.Assert(err, IsNil) + } + iter := coll.Find(nil).Limit(limit).Iter() + for iter.Next(&doc) { + if iters%100 == 0 { + c.Logf("Seen %d docments", iters) + } + iters++ + } + c.Assert(iter.Close(), IsNil) + c.Assert(iters, Equals, limit) +} + +func serverCursorsOpen(session *mgo.Session) int { + var result struct { + Cursors struct { + TotalOpen int `bson:"totalOpen"` + TimedOut int `bson:"timedOut"` + } + } + err := session.Run("serverStatus", &result) + if err != nil { + panic(err) + } + return result.Cursors.TotalOpen +} + +func (s *S) TestFindIterLimitWithMore(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + // Insane amounts of logging otherwise due to the + // amount of data being shuffled. + mgo.SetDebug(false) + defer mgo.SetDebug(true) + + // Should amount to more than 4MB bson payload, + // the default limit per result chunk. + const total = 4096 + var d struct{ A [1024]byte } + docs := make([]interface{}, total) + for i := 0; i < total; i++ { + docs[i] = &d + } + err = coll.Insert(docs...) + c.Assert(err, IsNil) + + n, err := coll.Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, total) + + // First, try restricting to a single chunk with a negative limit. + nresults := 0 + iter := coll.Find(nil).Limit(-total).Iter() + var discard struct{} + for iter.Next(&discard) { + nresults++ + } + if nresults < total/2 || nresults >= total { + c.Fatalf("Bad result size with negative limit: %d", nresults) + } + + cursorsOpen := serverCursorsOpen(session) + + // Try again, with a positive limit. Should reach the end now, + // using multiple chunks. + nresults = 0 + iter = coll.Find(nil).Limit(total).Iter() + for iter.Next(&discard) { + nresults++ + } + c.Assert(nresults, Equals, total) + + // Ensure the cursor used is properly killed. + c.Assert(serverCursorsOpen(session), Equals, cursorsOpen) + + // Edge case, -MinInt == -MinInt. + nresults = 0 + iter = coll.Find(nil).Limit(math.MinInt32).Iter() + for iter.Next(&discard) { + nresults++ + } + if nresults < total/2 || nresults >= total { + c.Fatalf("Bad result size with MinInt32 limit: %d", nresults) + } +} + +func (s *S) TestFindIterLimitWithBatch(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + // Ping the database to ensure the nonce has been received already. + c.Assert(session.Ping(), IsNil) + + session.Refresh() // Release socket. + + mgo.ResetStats() + + query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Limit(3).Batch(2) + iter := query.Iter() + result := struct{ N int }{} + for i := 2; i < 5; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, ns[i]) + if i == 3 { + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + } + + ok := iter.Next(&result) + c.Assert(ok, Equals, false) + c.Assert(iter.Close(), IsNil) + + session.Refresh() // Release socket. + + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 3) // 1*QUERY_OP + 1*GET_MORE_OP + 1*KILL_CURSORS_OP + c.Assert(stats.ReceivedOps, Equals, 2) // and its REPLY_OPs + c.Assert(stats.ReceivedDocs, Equals, 3) + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestFindIterSortWithBatch(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + // Without this, the logic above breaks because Mongo refuses to + // return a cursor with an in-memory sort. + coll.EnsureIndexKey("n") + + // Ping the database to ensure the nonce has been received already. + c.Assert(session.Ping(), IsNil) + + session.Refresh() // Release socket. + + mgo.ResetStats() + + query := coll.Find(M{"n": M{"$lte": 44}}).Sort("-n").Batch(2) + iter := query.Iter() + ns = []int{46, 45, 44, 43, 42, 41, 40} + result := struct{ N int }{} + for i := 2; i < len(ns); i++ { + c.Logf("i=%d", i) + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, ns[i]) + if i == 3 { + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + } + + ok := iter.Next(&result) + c.Assert(ok, Equals, false) + c.Assert(iter.Close(), IsNil) + + session.Refresh() // Release socket. + + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 3) // 1*QUERY_OP + 2*GET_MORE_OP + c.Assert(stats.ReceivedOps, Equals, 3) // and its REPLY_OPs + c.Assert(stats.ReceivedDocs, Equals, 5) + c.Assert(stats.SocketsInUse, Equals, 0) +} + +// Test tailable cursors in a situation where Next has to sleep to +// respect the timeout requested on Tail. +func (s *S) TestFindTailTimeoutWithSleep(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + cresult := struct{ ErrMsg string }{} + + db := session.DB("mydb") + err = db.Run(bson.D{{"create", "mycoll"}, {"capped", true}, {"size", 1024}}, &cresult) + c.Assert(err, IsNil) + c.Assert(cresult.ErrMsg, Equals, "") + coll := db.C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + session.Refresh() // Release socket. + + mgo.ResetStats() + + timeout := 3 * time.Second + + query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Prefetch(0).Batch(2) + iter := query.Tail(timeout) + + n := len(ns) + result := struct{ N int }{} + for i := 2; i != n; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(iter.Err(), IsNil) + c.Assert(iter.Timeout(), Equals, false) + c.Assert(result.N, Equals, ns[i]) + if i == 3 { // The batch boundary. + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + } + + mgo.ResetStats() + + // The following call to Next will block. + go func() { + // The internal AwaitData timing of MongoDB is around 2 seconds, + // so this should force mgo to sleep at least once by itself to + // respect the requested timeout. + time.Sleep(timeout + 5e8*time.Nanosecond) + session := session.New() + defer session.Close() + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"n": 47}) + }() + + c.Log("Will wait for Next with N=47...") + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(iter.Err(), IsNil) + c.Assert(iter.Timeout(), Equals, false) + c.Assert(result.N, Equals, 47) + c.Log("Got Next with N=47!") + + // The following may break because it depends a bit on the internal + // timing used by MongoDB's AwaitData logic. If it does, the problem + // will be observed as more GET_MORE_OPs than predicted: + // 1*QUERY for nonce + 1*GET_MORE_OP on Next + 1*GET_MORE_OP on Next after sleep + + // 1*INSERT_OP + 1*QUERY_OP for getLastError on insert of 47 + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 5) + c.Assert(stats.ReceivedOps, Equals, 4) // REPLY_OPs for 1*QUERY_OP for nonce + 2*GET_MORE_OPs + 1*QUERY_OP + c.Assert(stats.ReceivedDocs, Equals, 3) // nonce + N=47 result + getLastError response + + c.Log("Will wait for a result which will never come...") + + started := time.Now() + ok = iter.Next(&result) + c.Assert(ok, Equals, false) + c.Assert(iter.Err(), IsNil) + c.Assert(iter.Timeout(), Equals, true) + c.Assert(started.Before(time.Now().Add(-timeout)), Equals, true) + + c.Log("Will now reuse the timed out tail cursor...") + + coll.Insert(M{"n": 48}) + ok = iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(iter.Close(), IsNil) + c.Assert(iter.Timeout(), Equals, false) + c.Assert(result.N, Equals, 48) +} + +// Test tailable cursors in a situation where Next never gets to sleep once +// to respect the timeout requested on Tail. +func (s *S) TestFindTailTimeoutNoSleep(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + cresult := struct{ ErrMsg string }{} + + db := session.DB("mydb") + err = db.Run(bson.D{{"create", "mycoll"}, {"capped", true}, {"size", 1024}}, &cresult) + c.Assert(err, IsNil) + c.Assert(cresult.ErrMsg, Equals, "") + coll := db.C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + session.Refresh() // Release socket. + + mgo.ResetStats() + + timeout := 1 * time.Second + + query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Prefetch(0).Batch(2) + iter := query.Tail(timeout) + + n := len(ns) + result := struct{ N int }{} + for i := 2; i != n; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(iter.Err(), IsNil) + c.Assert(iter.Timeout(), Equals, false) + c.Assert(result.N, Equals, ns[i]) + if i == 3 { // The batch boundary. + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + } + + mgo.ResetStats() + + // The following call to Next will block. + go func() { + // The internal AwaitData timing of MongoDB is around 2 seconds, + // so this item should arrive within the AwaitData threshold. + time.Sleep(5e8) + session := session.New() + defer session.Close() + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"n": 47}) + }() + + c.Log("Will wait for Next with N=47...") + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(iter.Err(), IsNil) + c.Assert(iter.Timeout(), Equals, false) + c.Assert(result.N, Equals, 47) + c.Log("Got Next with N=47!") + + // The following may break because it depends a bit on the internal + // timing used by MongoDB's AwaitData logic. If it does, the problem + // will be observed as more GET_MORE_OPs than predicted: + // 1*QUERY_OP for nonce + 1*GET_MORE_OP on Next + + // 1*INSERT_OP + 1*QUERY_OP for getLastError on insert of 47 + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 4) + c.Assert(stats.ReceivedOps, Equals, 3) // REPLY_OPs for 1*QUERY_OP for nonce + 1*GET_MORE_OPs and 1*QUERY_OP + c.Assert(stats.ReceivedDocs, Equals, 3) // nonce + N=47 result + getLastError response + + c.Log("Will wait for a result which will never come...") + + started := time.Now() + ok = iter.Next(&result) + c.Assert(ok, Equals, false) + c.Assert(iter.Err(), IsNil) + c.Assert(iter.Timeout(), Equals, true) + c.Assert(started.Before(time.Now().Add(-timeout)), Equals, true) + + c.Log("Will now reuse the timed out tail cursor...") + + coll.Insert(M{"n": 48}) + ok = iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(iter.Close(), IsNil) + c.Assert(iter.Timeout(), Equals, false) + c.Assert(result.N, Equals, 48) +} + +// Test tailable cursors in a situation where Next never gets to sleep once +// to respect the timeout requested on Tail. +func (s *S) TestFindTailNoTimeout(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + cresult := struct{ ErrMsg string }{} + + db := session.DB("mydb") + err = db.Run(bson.D{{"create", "mycoll"}, {"capped", true}, {"size", 1024}}, &cresult) + c.Assert(err, IsNil) + c.Assert(cresult.ErrMsg, Equals, "") + coll := db.C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + session.Refresh() // Release socket. + + mgo.ResetStats() + + query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Prefetch(0).Batch(2) + iter := query.Tail(-1) + c.Assert(err, IsNil) + + n := len(ns) + result := struct{ N int }{} + for i := 2; i != n; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, ns[i]) + if i == 3 { // The batch boundary. + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + } + + mgo.ResetStats() + + // The following call to Next will block. + go func() { + time.Sleep(5e8) + session := session.New() + defer session.Close() + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"n": 47}) + }() + + c.Log("Will wait for Next with N=47...") + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(iter.Err(), IsNil) + c.Assert(iter.Timeout(), Equals, false) + c.Assert(result.N, Equals, 47) + c.Log("Got Next with N=47!") + + // The following may break because it depends a bit on the internal + // timing used by MongoDB's AwaitData logic. If it does, the problem + // will be observed as more GET_MORE_OPs than predicted: + // 1*QUERY_OP for nonce + 1*GET_MORE_OP on Next + + // 1*INSERT_OP + 1*QUERY_OP for getLastError on insert of 47 + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 4) + c.Assert(stats.ReceivedOps, Equals, 3) // REPLY_OPs for 1*QUERY_OP for nonce + 1*GET_MORE_OPs and 1*QUERY_OP + c.Assert(stats.ReceivedDocs, Equals, 3) // nonce + N=47 result + getLastError response + + c.Log("Will wait for a result which will never come...") + + gotNext := make(chan bool) + go func() { + ok := iter.Next(&result) + gotNext <- ok + }() + + select { + case ok := <-gotNext: + c.Fatalf("Next returned: %v", ok) + case <-time.After(3e9): + // Good. Should still be sleeping at that point. + } + + // Closing the session should cause Next to return. + session.Close() + + select { + case ok := <-gotNext: + c.Assert(ok, Equals, false) + c.Assert(iter.Err(), ErrorMatches, "Closed explicitly") + c.Assert(iter.Timeout(), Equals, false) + case <-time.After(1e9): + c.Fatal("Closing the session did not unblock Next") + } +} + +func (s *S) TestIterNextResetsResult(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{1, 2, 3} + for _, n := range ns { + coll.Insert(M{"n" + strconv.Itoa(n): n}) + } + + query := coll.Find(nil).Sort("$natural") + + i := 0 + var sresult *struct{ N1, N2, N3 int } + iter := query.Iter() + for iter.Next(&sresult) { + switch i { + case 0: + c.Assert(sresult.N1, Equals, 1) + c.Assert(sresult.N2+sresult.N3, Equals, 0) + case 1: + c.Assert(sresult.N2, Equals, 2) + c.Assert(sresult.N1+sresult.N3, Equals, 0) + case 2: + c.Assert(sresult.N3, Equals, 3) + c.Assert(sresult.N1+sresult.N2, Equals, 0) + } + i++ + } + c.Assert(iter.Close(), IsNil) + + i = 0 + var mresult M + iter = query.Iter() + for iter.Next(&mresult) { + delete(mresult, "_id") + switch i { + case 0: + c.Assert(mresult, DeepEquals, M{"n1": 1}) + case 1: + c.Assert(mresult, DeepEquals, M{"n2": 2}) + case 2: + c.Assert(mresult, DeepEquals, M{"n3": 3}) + } + i++ + } + c.Assert(iter.Close(), IsNil) + + i = 0 + var iresult interface{} + iter = query.Iter() + for iter.Next(&iresult) { + mresult, ok := iresult.(bson.M) + c.Assert(ok, Equals, true, Commentf("%#v", iresult)) + delete(mresult, "_id") + switch i { + case 0: + c.Assert(mresult, DeepEquals, bson.M{"n1": 1}) + case 1: + c.Assert(mresult, DeepEquals, bson.M{"n2": 2}) + case 2: + c.Assert(mresult, DeepEquals, bson.M{"n3": 3}) + } + i++ + } + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestFindForOnIter(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + session.Refresh() // Release socket. + + mgo.ResetStats() + + query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Prefetch(0).Batch(2) + iter := query.Iter() + + i := 2 + var result *struct{ N int } + err = iter.For(&result, func() error { + c.Assert(i < 7, Equals, true) + c.Assert(result.N, Equals, ns[i]) + if i == 1 { + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + i++ + return nil + }) + c.Assert(err, IsNil) + + session.Refresh() // Release socket. + + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 3) // 1*QUERY_OP + 2*GET_MORE_OP + c.Assert(stats.ReceivedOps, Equals, 3) // and their REPLY_OPs. + c.Assert(stats.ReceivedDocs, Equals, 5) + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestFindFor(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + session.Refresh() // Release socket. + + mgo.ResetStats() + + query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Prefetch(0).Batch(2) + + i := 2 + var result *struct{ N int } + err = query.For(&result, func() error { + c.Assert(i < 7, Equals, true) + c.Assert(result.N, Equals, ns[i]) + if i == 1 { + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + i++ + return nil + }) + c.Assert(err, IsNil) + + session.Refresh() // Release socket. + + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 3) // 1*QUERY_OP + 2*GET_MORE_OP + c.Assert(stats.ReceivedOps, Equals, 3) // and their REPLY_OPs. + c.Assert(stats.ReceivedDocs, Equals, 5) + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestFindForStopOnError(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + query := coll.Find(M{"n": M{"$gte": 42}}) + i := 2 + var result *struct{ N int } + err = query.For(&result, func() error { + c.Assert(i < 4, Equals, true) + c.Assert(result.N, Equals, ns[i]) + if i == 3 { + return fmt.Errorf("stop!") + } + i++ + return nil + }) + c.Assert(err, ErrorMatches, "stop!") +} + +func (s *S) TestFindForResetsResult(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{1, 2, 3} + for _, n := range ns { + coll.Insert(M{"n" + strconv.Itoa(n): n}) + } + + query := coll.Find(nil).Sort("$natural") + + i := 0 + var sresult *struct{ N1, N2, N3 int } + err = query.For(&sresult, func() error { + switch i { + case 0: + c.Assert(sresult.N1, Equals, 1) + c.Assert(sresult.N2+sresult.N3, Equals, 0) + case 1: + c.Assert(sresult.N2, Equals, 2) + c.Assert(sresult.N1+sresult.N3, Equals, 0) + case 2: + c.Assert(sresult.N3, Equals, 3) + c.Assert(sresult.N1+sresult.N2, Equals, 0) + } + i++ + return nil + }) + c.Assert(err, IsNil) + + i = 0 + var mresult M + err = query.For(&mresult, func() error { + delete(mresult, "_id") + switch i { + case 0: + c.Assert(mresult, DeepEquals, M{"n1": 1}) + case 1: + c.Assert(mresult, DeepEquals, M{"n2": 2}) + case 2: + c.Assert(mresult, DeepEquals, M{"n3": 3}) + } + i++ + return nil + }) + c.Assert(err, IsNil) + + i = 0 + var iresult interface{} + err = query.For(&iresult, func() error { + mresult, ok := iresult.(bson.M) + c.Assert(ok, Equals, true, Commentf("%#v", iresult)) + delete(mresult, "_id") + switch i { + case 0: + c.Assert(mresult, DeepEquals, bson.M{"n1": 1}) + case 1: + c.Assert(mresult, DeepEquals, bson.M{"n2": 2}) + case 2: + c.Assert(mresult, DeepEquals, bson.M{"n3": 3}) + } + i++ + return nil + }) + c.Assert(err, IsNil) +} + +func (s *S) TestFindIterSnapshot(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + // Insane amounts of logging otherwise due to the + // amount of data being shuffled. + mgo.SetDebug(false) + defer mgo.SetDebug(true) + + coll := session.DB("mydb").C("mycoll") + + var a [1024000]byte + + for n := 0; n < 10; n++ { + err := coll.Insert(M{"_id": n, "n": n, "a1": &a}) + c.Assert(err, IsNil) + } + + query := coll.Find(M{"n": M{"$gt": -1}}).Batch(2).Prefetch(0) + query.Snapshot() + iter := query.Iter() + + seen := map[int]bool{} + result := struct { + Id int "_id" + }{} + for iter.Next(&result) { + if len(seen) == 2 { + // Grow all entries so that they have to move. + // Backwards so that the order is inverted. + for n := 10; n >= 0; n-- { + _, err := coll.Upsert(M{"_id": n}, M{"$set": M{"a2": &a}}) + c.Assert(err, IsNil) + } + } + if seen[result.Id] { + c.Fatalf("seen duplicated key: %d", result.Id) + } + seen[result.Id] = true + } + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestSort(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + coll.Insert(M{"a": 1, "b": 1}) + coll.Insert(M{"a": 2, "b": 2}) + coll.Insert(M{"a": 2, "b": 1}) + coll.Insert(M{"a": 0, "b": 1}) + coll.Insert(M{"a": 2, "b": 0}) + coll.Insert(M{"a": 0, "b": 2}) + coll.Insert(M{"a": 1, "b": 2}) + coll.Insert(M{"a": 0, "b": 0}) + coll.Insert(M{"a": 1, "b": 0}) + + query := coll.Find(M{}) + query.Sort("-a") // Should be ignored. + query.Sort("-b", "a") + iter := query.Iter() + + l := make([]int, 18) + r := struct{ A, B int }{} + for i := 0; i != len(l); i += 2 { + ok := iter.Next(&r) + c.Assert(ok, Equals, true) + c.Assert(err, IsNil) + l[i] = r.A + l[i+1] = r.B + } + + c.Assert(l, DeepEquals, []int{0, 2, 1, 2, 2, 2, 0, 1, 1, 1, 2, 1, 0, 0, 1, 0, 2, 0}) +} + +func (s *S) TestSortWithBadArgs(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + f1 := func() { coll.Find(nil).Sort("") } + f2 := func() { coll.Find(nil).Sort("+") } + f3 := func() { coll.Find(nil).Sort("foo", "-") } + + for _, f := range []func(){f1, f2, f3} { + c.Assert(f, PanicMatches, "Sort: empty field name") + } +} + +func (s *S) TestSortScoreText(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndex(mgo.Index{ + Key: []string{"$text:a", "$text:b"}, + }) + c.Assert(err, IsNil) + + err = coll.Insert(M{ + "a": "none", + "b": "twice: foo foo", + }) + c.Assert(err, IsNil) + err = coll.Insert(M{ + "a": "just once: foo", + "b": "none", + }) + c.Assert(err, IsNil) + err = coll.Insert(M{ + "a": "many: foo foo foo", + "b": "none", + }) + c.Assert(err, IsNil) + err = coll.Insert(M{ + "a": "none", + "b": "none", + "c": "ignore: foo", + }) + c.Assert(err, IsNil) + + query := coll.Find(M{"$text": M{"$search": "foo"}}) + query.Select(M{"score": M{"$meta": "textScore"}}) + query.Sort("$textScore:score") + iter := query.Iter() + + var r struct{ A, B string } + var results []string + for iter.Next(&r) { + results = append(results, r.A, r.B) + } + + c.Assert(results, DeepEquals, []string{ + "many: foo foo foo", "none", + "none", "twice: foo foo", + "just once: foo", "none", + }) +} + +func (s *S) TestPrefetching(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + mgo.SetDebug(false) + docs := make([]interface{}, 800) + for i := 0; i != 600; i++ { + docs[i] = bson.D{{"n", i}} + } + coll.Insert(docs...) + + for testi := 0; testi < 5; testi++ { + mgo.ResetStats() + + var iter *mgo.Iter + var beforeMore int + + switch testi { + case 0: // The default session value. + session.SetBatch(100) + iter = coll.Find(M{}).Iter() + beforeMore = 75 + + case 2: // Changing the session value. + session.SetBatch(100) + session.SetPrefetch(0.27) + iter = coll.Find(M{}).Iter() + beforeMore = 73 + + case 1: // Changing via query methods. + iter = coll.Find(M{}).Prefetch(0.27).Batch(100).Iter() + beforeMore = 73 + + case 3: // With prefetch on first document. + iter = coll.Find(M{}).Prefetch(1.0).Batch(100).Iter() + beforeMore = 0 + + case 4: // Without prefetch. + iter = coll.Find(M{}).Prefetch(0).Batch(100).Iter() + beforeMore = 100 + } + + pings := 0 + for batchi := 0; batchi < len(docs)/100-1; batchi++ { + c.Logf("Iterating over %d documents on batch %d", beforeMore, batchi) + var result struct{ N int } + for i := 0; i < beforeMore; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true, Commentf("iter.Err: %v", iter.Err())) + } + beforeMore = 99 + c.Logf("Done iterating.") + + session.Run("ping", nil) // Roundtrip to settle down. + pings++ + + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, (batchi+1)*100+pings) + + c.Logf("Iterating over one more document on batch %d", batchi) + ok := iter.Next(&result) + c.Assert(ok, Equals, true, Commentf("iter.Err: %v", iter.Err())) + c.Logf("Done iterating.") + + session.Run("ping", nil) // Roundtrip to settle down. + pings++ + + stats = mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, (batchi+2)*100+pings) + } + } +} + +func (s *S) TestSafeSetting(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + // Check the default + safe := session.Safe() + c.Assert(safe.W, Equals, 0) + c.Assert(safe.WMode, Equals, "") + c.Assert(safe.WTimeout, Equals, 0) + c.Assert(safe.FSync, Equals, false) + c.Assert(safe.J, Equals, false) + + // Tweak it + session.SetSafe(&mgo.Safe{W: 1, WTimeout: 2, FSync: true}) + safe = session.Safe() + c.Assert(safe.W, Equals, 1) + c.Assert(safe.WMode, Equals, "") + c.Assert(safe.WTimeout, Equals, 2) + c.Assert(safe.FSync, Equals, true) + c.Assert(safe.J, Equals, false) + + // Reset it again. + session.SetSafe(&mgo.Safe{}) + safe = session.Safe() + c.Assert(safe.W, Equals, 0) + c.Assert(safe.WMode, Equals, "") + c.Assert(safe.WTimeout, Equals, 0) + c.Assert(safe.FSync, Equals, false) + c.Assert(safe.J, Equals, false) + + // Ensure safety to something more conservative. + session.SetSafe(&mgo.Safe{W: 5, WTimeout: 6, J: true}) + safe = session.Safe() + c.Assert(safe.W, Equals, 5) + c.Assert(safe.WMode, Equals, "") + c.Assert(safe.WTimeout, Equals, 6) + c.Assert(safe.FSync, Equals, false) + c.Assert(safe.J, Equals, true) + + // Ensure safety to something less conservative won't change it. + session.EnsureSafe(&mgo.Safe{W: 4, WTimeout: 7}) + safe = session.Safe() + c.Assert(safe.W, Equals, 5) + c.Assert(safe.WMode, Equals, "") + c.Assert(safe.WTimeout, Equals, 6) + c.Assert(safe.FSync, Equals, false) + c.Assert(safe.J, Equals, true) + + // But to something more conservative will. + session.EnsureSafe(&mgo.Safe{W: 6, WTimeout: 4, FSync: true}) + safe = session.Safe() + c.Assert(safe.W, Equals, 6) + c.Assert(safe.WMode, Equals, "") + c.Assert(safe.WTimeout, Equals, 4) + c.Assert(safe.FSync, Equals, true) + c.Assert(safe.J, Equals, false) + + // Even more conservative. + session.EnsureSafe(&mgo.Safe{WMode: "majority", WTimeout: 2}) + safe = session.Safe() + c.Assert(safe.W, Equals, 0) + c.Assert(safe.WMode, Equals, "majority") + c.Assert(safe.WTimeout, Equals, 2) + c.Assert(safe.FSync, Equals, true) + c.Assert(safe.J, Equals, false) + + // WMode always overrides, whatever it is, but J doesn't. + session.EnsureSafe(&mgo.Safe{WMode: "something", J: true}) + safe = session.Safe() + c.Assert(safe.W, Equals, 0) + c.Assert(safe.WMode, Equals, "something") + c.Assert(safe.WTimeout, Equals, 2) + c.Assert(safe.FSync, Equals, true) + c.Assert(safe.J, Equals, false) + + // EnsureSafe with nil does nothing. + session.EnsureSafe(nil) + safe = session.Safe() + c.Assert(safe.W, Equals, 0) + c.Assert(safe.WMode, Equals, "something") + c.Assert(safe.WTimeout, Equals, 2) + c.Assert(safe.FSync, Equals, true) + c.Assert(safe.J, Equals, false) + + // Changing the safety of a cloned session doesn't touch the original. + clone := session.Clone() + defer clone.Close() + clone.EnsureSafe(&mgo.Safe{WMode: "foo"}) + safe = session.Safe() + c.Assert(safe.WMode, Equals, "something") +} + +func (s *S) TestSafeInsert(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + // Insert an element with a predefined key. + err = coll.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + mgo.ResetStats() + + // Session should be safe by default, so inserting it again must fail. + err = coll.Insert(M{"_id": 1}) + c.Assert(err, ErrorMatches, ".*E11000 duplicate.*") + c.Assert(err.(*mgo.LastError).Code, Equals, 11000) + + // It must have sent two operations (INSERT_OP + getLastError QUERY_OP) + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 2) + + mgo.ResetStats() + + // If we disable safety, though, it won't complain. + session.SetSafe(nil) + err = coll.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + // Must have sent a single operation this time (just the INSERT_OP) + stats = mgo.GetStats() + c.Assert(stats.SentOps, Equals, 1) +} + +func (s *S) TestSafeParameters(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + // Tweak the safety parameters to something unachievable. + session.SetSafe(&mgo.Safe{W: 4, WTimeout: 100}) + err = coll.Insert(M{"_id": 1}) + c.Assert(err, ErrorMatches, "timeout|timed out waiting for slaves") + if !s.versionAtLeast(2, 6) { + // 2.6 turned it into a query error. + c.Assert(err.(*mgo.LastError).WTimeout, Equals, true) + } +} + +func (s *S) TestQueryErrorOne(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + result := struct { + Err string "$err" + }{} + + err = coll.Find(M{"a": 1}).Select(M{"a": M{"b": 1}}).One(&result) + c.Assert(err, ErrorMatches, ".*Unsupported projection option:.*") + c.Assert(err.(*mgo.QueryError).Message, Matches, ".*Unsupported projection option:.*") + if s.versionAtLeast(2, 6) { + // Oh, the dance of error codes. :-( + c.Assert(err.(*mgo.QueryError).Code, Equals, 17287) + } else { + c.Assert(err.(*mgo.QueryError).Code, Equals, 13097) + } + + // The result should be properly unmarshalled with QueryError + c.Assert(result.Err, Matches, ".*Unsupported projection option:.*") +} + +func (s *S) TestQueryErrorNext(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + result := struct { + Err string "$err" + }{} + + iter := coll.Find(M{"a": 1}).Select(M{"a": M{"b": 1}}).Iter() + + ok := iter.Next(&result) + c.Assert(ok, Equals, false) + + err = iter.Close() + c.Assert(err, ErrorMatches, ".*Unsupported projection option:.*") + c.Assert(err.(*mgo.QueryError).Message, Matches, ".*Unsupported projection option:.*") + if s.versionAtLeast(2, 6) { + // Oh, the dance of error codes. :-( + c.Assert(err.(*mgo.QueryError).Code, Equals, 17287) + } else { + c.Assert(err.(*mgo.QueryError).Code, Equals, 13097) + } + c.Assert(iter.Err(), Equals, err) + + // The result should be properly unmarshalled with QueryError + c.Assert(result.Err, Matches, ".*Unsupported projection option:.*") +} + +func (s *S) TestEnsureIndex(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + index1 := mgo.Index{ + Key: []string{"a"}, + Background: true, + } + + index2 := mgo.Index{ + Key: []string{"a", "-b"}, + Unique: true, + DropDups: true, + } + + // Obsolete: + index3 := mgo.Index{ + Key: []string{"@loc_old"}, + Min: -500, + Max: 500, + Bits: 32, + } + + index4 := mgo.Index{ + Key: []string{"$2d:loc"}, + Min: -500, + Max: 500, + Bits: 32, + } + + index5 := mgo.Index{ + Key: []string{"$text:a", "$text:b"}, + } + + index6 := mgo.Index{ + Key: []string{"$text:a"}, + DefaultLanguage: "portuguese", + LanguageOverride: "idioma", + } + + coll1 := session.DB("mydb").C("mycoll1") + coll2 := session.DB("mydb").C("mycoll2") + + for _, index := range []mgo.Index{index1, index2, index3, index4, index5} { + err = coll1.EnsureIndex(index) + c.Assert(err, IsNil) + } + + // Cannot have multiple text indexes on the same collection. + err = coll2.EnsureIndex(index6) + c.Assert(err, IsNil) + + sysidx := session.DB("mydb").C("system.indexes") + + result1 := M{} + err = sysidx.Find(M{"name": "a_1"}).One(result1) + c.Assert(err, IsNil) + + result2 := M{} + err = sysidx.Find(M{"name": "a_1_b_-1"}).One(result2) + c.Assert(err, IsNil) + + result3 := M{} + err = sysidx.Find(M{"name": "loc_old_2d"}).One(result3) + c.Assert(err, IsNil) + + result4 := M{} + err = sysidx.Find(M{"name": "loc_2d"}).One(result4) + c.Assert(err, IsNil) + + result5 := M{} + err = sysidx.Find(M{"name": "a_text_b_text"}).One(result5) + c.Assert(err, IsNil) + + result6 := M{} + err = sysidx.Find(M{"name": "a_text"}).One(result6) + c.Assert(err, IsNil) + + delete(result1, "v") + expected1 := M{ + "name": "a_1", + "key": M{"a": 1}, + "ns": "mydb.mycoll1", + "background": true, + } + c.Assert(result1, DeepEquals, expected1) + + delete(result2, "v") + expected2 := M{ + "name": "a_1_b_-1", + "key": M{"a": 1, "b": -1}, + "ns": "mydb.mycoll1", + "unique": true, + "dropDups": true, + } + if s.versionAtLeast(2, 7) { + // Was deprecated in 2.6, and not being reported by 2.7+. + delete(expected2, "dropDups") + } + c.Assert(result2, DeepEquals, expected2) + + delete(result3, "v") + expected3 := M{ + "name": "loc_old_2d", + "key": M{"loc_old": "2d"}, + "ns": "mydb.mycoll1", + "min": -500, + "max": 500, + "bits": 32, + } + c.Assert(result3, DeepEquals, expected3) + + delete(result4, "v") + expected4 := M{ + "name": "loc_2d", + "key": M{"loc": "2d"}, + "ns": "mydb.mycoll1", + "min": -500, + "max": 500, + "bits": 32, + } + c.Assert(result4, DeepEquals, expected4) + + delete(result5, "v") + expected5 := M{ + "name": "a_text_b_text", + "key": M{"_fts": "text", "_ftsx": 1}, + "ns": "mydb.mycoll1", + "weights": M{"a": 1, "b": 1}, + "default_language": "english", + "language_override": "language", + "textIndexVersion": 2, + } + c.Assert(result5, DeepEquals, expected5) + + delete(result6, "v") + expected6 := M{ + "name": "a_text", + "key": M{"_fts": "text", "_ftsx": 1}, + "ns": "mydb.mycoll2", + "weights": M{"a": 1}, + "default_language": "portuguese", + "language_override": "idioma", + "textIndexVersion": 2, + } + c.Assert(result6, DeepEquals, expected6) + + // Ensure the index actually works for real. + err = coll1.Insert(M{"a": 1, "b": 1}) + c.Assert(err, IsNil) + err = coll1.Insert(M{"a": 1, "b": 1}) + c.Assert(err, ErrorMatches, ".*duplicate key error.*") + c.Assert(mgo.IsDup(err), Equals, true) +} + +func (s *S) TestEnsureIndexWithBadInfo(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndex(mgo.Index{}) + c.Assert(err, ErrorMatches, "invalid index key:.*") + + err = coll.EnsureIndex(mgo.Index{Key: []string{""}}) + c.Assert(err, ErrorMatches, "invalid index key:.*") +} + +func (s *S) TestEnsureIndexWithUnsafeSession(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + session.SetSafe(nil) + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + // Should fail since there are duplicated entries. + index := mgo.Index{ + Key: []string{"a"}, + Unique: true, + } + + err = coll.EnsureIndex(index) + c.Assert(err, ErrorMatches, ".*duplicate key error.*") +} + +func (s *S) TestEnsureIndexKey(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndexKey("a") + c.Assert(err, IsNil) + + err = coll.EnsureIndexKey("a", "-b") + c.Assert(err, IsNil) + + sysidx := session.DB("mydb").C("system.indexes") + + result1 := M{} + err = sysidx.Find(M{"name": "a_1"}).One(result1) + c.Assert(err, IsNil) + + result2 := M{} + err = sysidx.Find(M{"name": "a_1_b_-1"}).One(result2) + c.Assert(err, IsNil) + + delete(result1, "v") + expected1 := M{ + "name": "a_1", + "key": M{"a": 1}, + "ns": "mydb.mycoll", + } + c.Assert(result1, DeepEquals, expected1) + + delete(result2, "v") + expected2 := M{ + "name": "a_1_b_-1", + "key": M{"a": 1, "b": -1}, + "ns": "mydb.mycoll", + } + c.Assert(result2, DeepEquals, expected2) +} + +func (s *S) TestEnsureIndexDropIndex(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndexKey("a") + c.Assert(err, IsNil) + + err = coll.EnsureIndexKey("-b") + c.Assert(err, IsNil) + + err = coll.DropIndex("-b") + c.Assert(err, IsNil) + + sysidx := session.DB("mydb").C("system.indexes") + dummy := &struct{}{} + + err = sysidx.Find(M{"name": "a_1"}).One(dummy) + c.Assert(err, IsNil) + + err = sysidx.Find(M{"name": "b_1"}).One(dummy) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.DropIndex("a") + c.Assert(err, IsNil) + + err = sysidx.Find(M{"name": "a_1"}).One(dummy) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.DropIndex("a") + c.Assert(err, ErrorMatches, "index not found.*") +} + +func (s *S) TestEnsureIndexCaching(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndexKey("a") + c.Assert(err, IsNil) + + mgo.ResetStats() + + // Second EnsureIndex should be cached and do nothing. + err = coll.EnsureIndexKey("a") + c.Assert(err, IsNil) + + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 0) + + // Resetting the cache should make it contact the server again. + session.ResetIndexCache() + + err = coll.EnsureIndexKey("a") + c.Assert(err, IsNil) + + stats = mgo.GetStats() + c.Assert(stats.SentOps, Equals, 2) + + // Dropping the index should also drop the cached index key. + err = coll.DropIndex("a") + c.Assert(err, IsNil) + + mgo.ResetStats() + + err = coll.EnsureIndexKey("a") + c.Assert(err, IsNil) + + stats = mgo.GetStats() + c.Assert(stats.SentOps, Equals, 2) +} + +func (s *S) TestEnsureIndexGetIndexes(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndexKey("-b") + c.Assert(err, IsNil) + + err = coll.EnsureIndexKey("a") + c.Assert(err, IsNil) + + // Obsolete. + err = coll.EnsureIndexKey("@c") + c.Assert(err, IsNil) + + err = coll.EnsureIndexKey("$2d:d") + c.Assert(err, IsNil) + + indexes, err := coll.Indexes() + c.Assert(err, IsNil) + + c.Assert(indexes[0].Name, Equals, "_id_") + c.Assert(indexes[1].Name, Equals, "a_1") + c.Assert(indexes[1].Key, DeepEquals, []string{"a"}) + c.Assert(indexes[2].Name, Equals, "b_-1") + c.Assert(indexes[2].Key, DeepEquals, []string{"-b"}) + c.Assert(indexes[3].Name, Equals, "c_2d") + c.Assert(indexes[3].Key, DeepEquals, []string{"$2d:c"}) + c.Assert(indexes[4].Name, Equals, "d_2d") + c.Assert(indexes[4].Key, DeepEquals, []string{"$2d:d"}) +} + +func (s *S) TestEnsureIndexEvalGetIndexes(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = session.Run(bson.D{{"eval", "db.getSiblingDB('mydb').mycoll.ensureIndex({b: -1})"}}, nil) + c.Assert(err, IsNil) + err = session.Run(bson.D{{"eval", "db.getSiblingDB('mydb').mycoll.ensureIndex({a: 1})"}}, nil) + c.Assert(err, IsNil) + err = session.Run(bson.D{{"eval", "db.getSiblingDB('mydb').mycoll.ensureIndex({c: '2d'})"}}, nil) + c.Assert(err, IsNil) + err = session.Run(bson.D{{"eval", "db.getSiblingDB('mydb').mycoll.ensureIndex({d: -1, e: 1})"}}, nil) + c.Assert(err, IsNil) + + indexes, err := coll.Indexes() + c.Assert(err, IsNil) + + c.Assert(indexes[0].Name, Equals, "_id_") + c.Assert(indexes[1].Name, Equals, "a_1") + c.Assert(indexes[1].Key, DeepEquals, []string{"a"}) + c.Assert(indexes[2].Name, Equals, "b_-1") + c.Assert(indexes[2].Key, DeepEquals, []string{"-b"}) + c.Assert(indexes[3].Name, Equals, "c_2d") + c.Assert(indexes[3].Key, DeepEquals, []string{"$2d:c"}) + c.Assert(indexes[4].Name, Equals, "d_-1_e_1") + c.Assert(indexes[4].Key, DeepEquals, []string{"-d", "e"}) +} + +var testTTL = flag.Bool("test-ttl", false, "test TTL collections (may take 1 minute)") + +func (s *S) TestEnsureIndexExpireAfter(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + session.SetSafe(nil) + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"n": 1, "t": time.Now().Add(-120 * time.Second)}) + c.Assert(err, IsNil) + err = coll.Insert(M{"n": 2, "t": time.Now()}) + c.Assert(err, IsNil) + + // Should fail since there are duplicated entries. + index := mgo.Index{ + Key: []string{"t"}, + ExpireAfter: 1 * time.Minute, + } + + err = coll.EnsureIndex(index) + c.Assert(err, IsNil) + + indexes, err := coll.Indexes() + c.Assert(err, IsNil) + c.Assert(indexes[1].Name, Equals, "t_1") + c.Assert(indexes[1].ExpireAfter, Equals, 1*time.Minute) + + if *testTTL { + worked := false + stop := time.Now().Add(70 * time.Second) + for time.Now().Before(stop) { + n, err := coll.Count() + c.Assert(err, IsNil) + if n == 1 { + worked = true + break + } + c.Assert(n, Equals, 2) + c.Logf("Still has 2 entries...") + time.Sleep(1 * time.Second) + } + if !worked { + c.Fatalf("TTL index didn't work") + } + } +} + +func (s *S) TestDistinct(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + var result []int + err = coll.Find(M{"n": M{"$gt": 2}}).Sort("n").Distinct("n", &result) + + sort.IntSlice(result).Sort() + c.Assert(result, DeepEquals, []int{3, 4, 6}) +} + +func (s *S) TestMapReduce(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1); }", + Reduce: "function(key, values) { return Array.sum(values); }", + } + var result []struct { + Id int "_id" + Value int + } + + info, err := coll.Find(M{"n": M{"$gt": 2}}).MapReduce(job, &result) + c.Assert(err, IsNil) + c.Assert(info.InputCount, Equals, 4) + c.Assert(info.EmitCount, Equals, 4) + c.Assert(info.OutputCount, Equals, 3) + c.Assert(info.VerboseTime, IsNil) + + expected := map[int]int{3: 1, 4: 2, 6: 1} + for _, item := range result { + c.Logf("Item: %#v", &item) + c.Assert(item.Value, Equals, expected[item.Id]) + expected[item.Id] = -1 + } +} + +func (s *S) TestMapReduceFinalize(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1) }", + Reduce: "function(key, values) { return Array.sum(values) }", + Finalize: "function(key, count) { return {count: count} }", + } + var result []struct { + Id int "_id" + Value struct{ Count int } + } + _, err = coll.Find(nil).MapReduce(job, &result) + c.Assert(err, IsNil) + + expected := map[int]int{1: 1, 2: 2, 3: 1, 4: 2, 6: 1} + for _, item := range result { + c.Logf("Item: %#v", &item) + c.Assert(item.Value.Count, Equals, expected[item.Id]) + expected[item.Id] = -1 + } +} + +func (s *S) TestMapReduceToCollection(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1); }", + Reduce: "function(key, values) { return Array.sum(values); }", + Out: "mr", + } + + info, err := coll.Find(nil).MapReduce(job, nil) + c.Assert(err, IsNil) + c.Assert(info.InputCount, Equals, 7) + c.Assert(info.EmitCount, Equals, 7) + c.Assert(info.OutputCount, Equals, 5) + c.Assert(info.Collection, Equals, "mr") + c.Assert(info.Database, Equals, "mydb") + + expected := map[int]int{1: 1, 2: 2, 3: 1, 4: 2, 6: 1} + var item *struct { + Id int "_id" + Value int + } + mr := session.DB("mydb").C("mr") + iter := mr.Find(nil).Iter() + for iter.Next(&item) { + c.Logf("Item: %#v", &item) + c.Assert(item.Value, Equals, expected[item.Id]) + expected[item.Id] = -1 + } + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestMapReduceToOtherDb(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1); }", + Reduce: "function(key, values) { return Array.sum(values); }", + Out: bson.D{{"replace", "mr"}, {"db", "otherdb"}}, + } + + info, err := coll.Find(nil).MapReduce(job, nil) + c.Assert(err, IsNil) + c.Assert(info.InputCount, Equals, 7) + c.Assert(info.EmitCount, Equals, 7) + c.Assert(info.OutputCount, Equals, 5) + c.Assert(info.Collection, Equals, "mr") + c.Assert(info.Database, Equals, "otherdb") + + expected := map[int]int{1: 1, 2: 2, 3: 1, 4: 2, 6: 1} + var item *struct { + Id int "_id" + Value int + } + mr := session.DB("otherdb").C("mr") + iter := mr.Find(nil).Iter() + for iter.Next(&item) { + c.Logf("Item: %#v", &item) + c.Assert(item.Value, Equals, expected[item.Id]) + expected[item.Id] = -1 + } + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestMapReduceOutOfOrder(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1); }", + Reduce: "function(key, values) { return Array.sum(values); }", + Out: bson.M{"a": "a", "z": "z", "replace": "mr", "db": "otherdb", "b": "b", "y": "y"}, + } + + info, err := coll.Find(nil).MapReduce(job, nil) + c.Assert(err, IsNil) + c.Assert(info.Collection, Equals, "mr") + c.Assert(info.Database, Equals, "otherdb") +} + +func (s *S) TestMapReduceScope(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + coll.Insert(M{"n": 1}) + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, x); }", + Reduce: "function(key, values) { return Array.sum(values); }", + Scope: M{"x": 42}, + } + + var result []bson.M + _, err = coll.Find(nil).MapReduce(job, &result) + c.Assert(len(result), Equals, 1) + c.Assert(result[0]["value"], Equals, 42.0) +} + +func (s *S) TestMapReduceVerbose(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for i := 0; i < 100; i++ { + err = coll.Insert(M{"n": i}) + c.Assert(err, IsNil) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1); }", + Reduce: "function(key, values) { return Array.sum(values); }", + Verbose: true, + } + + info, err := coll.Find(nil).MapReduce(job, nil) + c.Assert(err, IsNil) + c.Assert(info.VerboseTime, NotNil) +} + +func (s *S) TestMapReduceLimit(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1); }", + Reduce: "function(key, values) { return Array.sum(values); }", + } + + var result []bson.M + _, err = coll.Find(nil).Limit(3).MapReduce(job, &result) + c.Assert(err, IsNil) + c.Assert(len(result), Equals, 3) +} + +func (s *S) TestBuildInfo(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + info, err := session.BuildInfo() + c.Assert(err, IsNil) + + var v []int + for i, a := range strings.Split(info.Version, ".") { + for _, token := range []string{"-rc", "-pre"} { + if i == 2 && strings.Contains(a, token) { + a = a[:strings.Index(a, token)] + info.VersionArray[len(info.VersionArray)-1] = 0 + } + } + n, err := strconv.Atoi(a) + c.Assert(err, IsNil) + v = append(v, n) + } + for len(v) < 4 { + v = append(v, 0) + } + + c.Assert(info.VersionArray, DeepEquals, v) + c.Assert(info.GitVersion, Matches, "[a-z0-9]+") + c.Assert(info.SysInfo, Matches, ".*[0-9:]+.*") + if info.Bits != 32 && info.Bits != 64 { + c.Fatalf("info.Bits is %d", info.Bits) + } + if info.MaxObjectSize < 8192 { + c.Fatalf("info.MaxObjectSize seems too small: %d", info.MaxObjectSize) + } +} + +func (s *S) TestZeroTimeRoundtrip(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + var d struct{ T time.Time } + conn := session.DB("mydb").C("mycoll") + err = conn.Insert(d) + c.Assert(err, IsNil) + + var result bson.M + err = conn.Find(nil).One(&result) + c.Assert(err, IsNil) + t, isTime := result["t"].(time.Time) + c.Assert(isTime, Equals, true) + c.Assert(t, Equals, time.Time{}) +} + +func (s *S) TestFsyncLock(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + clone := session.Clone() + defer clone.Close() + + err = session.FsyncLock() + c.Assert(err, IsNil) + + done := make(chan time.Time) + go func() { + time.Sleep(3e9) + now := time.Now() + err := session.FsyncUnlock() + c.Check(err, IsNil) + done <- now + }() + + err = clone.DB("mydb").C("mycoll").Insert(bson.M{"n": 1}) + unlocked := time.Now() + unlocking := <-done + c.Assert(err, IsNil) + + c.Assert(unlocked.After(unlocking), Equals, true) + c.Assert(unlocked.Sub(unlocking) < 1e9, Equals, true) +} + +func (s *S) TestFsync(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + // Not much to do here. Just a smoke check. + err = session.Fsync(false) + c.Assert(err, IsNil) + err = session.Fsync(true) + c.Assert(err, IsNil) +} + +func (s *S) TestRepairCursor(c *C) { + if !s.versionAtLeast(2, 7) { + c.Skip("RepairCursor only works on 2.7+") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + session.SetBatch(2) + + coll := session.DB("mydb").C("mycoll3") + err = coll.DropCollection() + + ns := []int{0, 10, 20, 30, 40, 50} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + repairIter := coll.Repair() + + c.Assert(repairIter.Err(), IsNil) + + result := struct{ N int }{} + resultCounts := map[int]int{} + for repairIter.Next(&result) { + resultCounts[result.N]++ + } + + c.Assert(repairIter.Next(&result), Equals, false) + c.Assert(repairIter.Err(), IsNil) + c.Assert(repairIter.Close(), IsNil) + + // Verify that the results of the repair cursor are valid. + // The repair cursor can return multiple copies + // of the same document, so to check correctness we only + // need to verify that at least 1 of each document was returned. + + for _, key := range ns { + c.Assert(resultCounts[key] > 0, Equals, true) + } +} + +func (s *S) TestPipeIter(c *C) { + if !s.versionAtLeast(2, 1) { + c.Skip("Pipe only works on 2.1+") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + pipe := coll.Pipe([]M{{"$match": M{"n": M{"$gte": 42}}}}) + + // Ensure cursor logic is working by forcing a small batch. + pipe.Batch(2) + + // Smoke test for AllowDiskUse. + pipe.AllowDiskUse() + + iter := pipe.Iter() + result := struct{ N int }{} + for i := 2; i < 7; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, ns[i]) + } + + c.Assert(iter.Next(&result), Equals, false) + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestPipeAll(c *C) { + if !s.versionAtLeast(2, 1) { + c.Skip("Pipe only works on 2.1+") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + var result []struct{ N int } + err = coll.Pipe([]M{{"$match": M{"n": M{"$gte": 42}}}}).All(&result) + c.Assert(err, IsNil) + for i := 2; i < 7; i++ { + c.Assert(result[i-2].N, Equals, ns[i]) + } +} + +func (s *S) TestPipeOne(c *C) { + if !s.versionAtLeast(2, 1) { + c.Skip("Pipe only works on 2.1+") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"a": 1, "b": 2}) + + result := struct{ A, B int }{} + + pipe := coll.Pipe([]M{{"$project": M{"a": 1, "b": M{"$add": []interface{}{"$b", 1}}}}}) + err = pipe.One(&result) + c.Assert(err, IsNil) + c.Assert(result.A, Equals, 1) + c.Assert(result.B, Equals, 3) + + pipe = coll.Pipe([]M{{"$match": M{"a": 2}}}) + err = pipe.One(&result) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestPipeExplain(c *C) { + if !s.versionAtLeast(2, 1) { + c.Skip("Pipe only works on 2.1+") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"a": 1, "b": 2}) + + pipe := coll.Pipe([]M{{"$project": M{"a": 1, "b": M{"$add": []interface{}{"$b", 1}}}}}) + + // The explain command result changes across versions. + var result struct{ Ok int } + err = pipe.Explain(&result) + c.Assert(err, IsNil) + c.Assert(result.Ok, Equals, 1) +} + +func (s *S) TestBatch1Bug(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for i := 0; i < 3; i++ { + err := coll.Insert(M{"n": i}) + c.Assert(err, IsNil) + } + + var ns []struct{ N int } + err = coll.Find(nil).Batch(1).All(&ns) + c.Assert(err, IsNil) + c.Assert(len(ns), Equals, 3) + + session.SetBatch(1) + err = coll.Find(nil).All(&ns) + c.Assert(err, IsNil) + c.Assert(len(ns), Equals, 3) +} + +func (s *S) TestInterfaceIterBug(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for i := 0; i < 3; i++ { + err := coll.Insert(M{"n": i}) + c.Assert(err, IsNil) + } + + var result interface{} + + i := 0 + iter := coll.Find(nil).Sort("n").Iter() + for iter.Next(&result) { + c.Assert(result.(bson.M)["n"], Equals, i) + i++ + } + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestFindIterCloseKillsCursor(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + cursors := serverCursorsOpen(session) + + coll := session.DB("mydb").C("mycoll") + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err = coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + iter := coll.Find(nil).Batch(2).Iter() + c.Assert(iter.Next(bson.M{}), Equals, true) + + c.Assert(iter.Close(), IsNil) + c.Assert(serverCursorsOpen(session), Equals, cursors) +} + +func (s *S) TestLogReplay(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + for i := 0; i < 5; i++ { + err = coll.Insert(M{"ts": time.Now()}) + c.Assert(err, IsNil) + } + + iter := coll.Find(nil).LogReplay().Iter() + if s.versionAtLeast(2, 6) { + // This used to fail in 2.4. Now it's just a smoke test. + c.Assert(iter.Err(), IsNil) + } else { + c.Assert(iter.Next(bson.M{}), Equals, false) + c.Assert(iter.Err(), ErrorMatches, "no ts field in query") + } +} + +func (s *S) TestSetCursorTimeout(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 42}) + + // This is just a smoke test. Won't wait 10 minutes for an actual timeout. + + session.SetCursorTimeout(0) + + var result struct{ N int } + iter := coll.Find(nil).Iter() + c.Assert(iter.Next(&result), Equals, true) + c.Assert(result.N, Equals, 42) + c.Assert(iter.Next(&result), Equals, false) +} + +// -------------------------------------------------------------------------- +// Some benchmarks that require a running database. + +func (s *S) BenchmarkFindIterRaw(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + doc := bson.D{ + {"f2", "a short string"}, + {"f3", bson.D{{"1", "one"}, {"2", 2.0}}}, + {"f4", []string{"a", "b", "c", "d", "e", "f", "g"}}, + } + + for i := 0; i < c.N+1; i++ { + err := coll.Insert(doc) + c.Assert(err, IsNil) + } + + session.SetBatch(c.N) + + var raw bson.Raw + iter := coll.Find(nil).Iter() + iter.Next(&raw) + c.ResetTimer() + i := 0 + for iter.Next(&raw) { + i++ + } + c.StopTimer() + c.Assert(iter.Err(), IsNil) + c.Assert(i, Equals, c.N) +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/socket.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/socket.go new file mode 100644 index 0000000..1fb0dff --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/socket.go @@ -0,0 +1,675 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "errors" + "net" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +type replyFunc func(err error, reply *replyOp, docNum int, docData []byte) + +type mongoSocket struct { + sync.Mutex + server *mongoServer // nil when cached + conn net.Conn + timeout time.Duration + addr string // For debugging only. + nextRequestId uint32 + replyFuncs map[uint32]replyFunc + references int + creds []Credential + logout []Credential + cachedNonce string + gotNonce sync.Cond + dead error + serverInfo *mongoServerInfo +} + +type queryOpFlags uint32 + +const ( + _ queryOpFlags = 1 << iota + flagTailable + flagSlaveOk + flagLogReplay + flagNoCursorTimeout + flagAwaitData +) + +type queryOp struct { + collection string + query interface{} + skip int32 + limit int32 + selector interface{} + flags queryOpFlags + replyFunc replyFunc + + options queryWrapper + hasOptions bool + serverTags []bson.D +} + +type queryWrapper struct { + Query interface{} "$query" + OrderBy interface{} "$orderby,omitempty" + Hint interface{} "$hint,omitempty" + Explain bool "$explain,omitempty" + Snapshot bool "$snapshot,omitempty" + ReadPreference bson.D "$readPreference,omitempty" + MaxScan int "$maxScan,omitempty" +} + +func (op *queryOp) finalQuery(socket *mongoSocket) interface{} { + if op.flags&flagSlaveOk != 0 && len(op.serverTags) > 0 && socket.ServerInfo().Mongos { + op.hasOptions = true + op.options.ReadPreference = bson.D{{"mode", "secondaryPreferred"}, {"tags", op.serverTags}} + } + if op.hasOptions { + if op.query == nil { + var empty bson.D + op.options.Query = empty + } else { + op.options.Query = op.query + } + debugf("final query is %#v\n", &op.options) + return &op.options + } + return op.query +} + +type getMoreOp struct { + collection string + limit int32 + cursorId int64 + replyFunc replyFunc +} + +type replyOp struct { + flags uint32 + cursorId int64 + firstDoc int32 + replyDocs int32 +} + +type insertOp struct { + collection string // "database.collection" + documents []interface{} // One or more documents to insert + flags uint32 +} + +type updateOp struct { + collection string // "database.collection" + selector interface{} + update interface{} + flags uint32 +} + +type deleteOp struct { + collection string // "database.collection" + selector interface{} + flags uint32 +} + +type killCursorsOp struct { + cursorIds []int64 +} + +type requestInfo struct { + bufferPos int + replyFunc replyFunc +} + +func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket { + socket := &mongoSocket{ + conn: conn, + addr: server.Addr, + server: server, + replyFuncs: make(map[uint32]replyFunc), + } + socket.gotNonce.L = &socket.Mutex + if err := socket.InitialAcquire(server.Info(), timeout); err != nil { + panic("newSocket: InitialAcquire returned error: " + err.Error()) + } + stats.socketsAlive(+1) + debugf("Socket %p to %s: initialized", socket, socket.addr) + socket.resetNonce() + go socket.readLoop() + return socket +} + +// Server returns the server that the socket is associated with. +// It returns nil while the socket is cached in its respective server. +func (socket *mongoSocket) Server() *mongoServer { + socket.Lock() + server := socket.server + socket.Unlock() + return server +} + +// ServerInfo returns details for the server at the time the socket +// was initially acquired. +func (socket *mongoSocket) ServerInfo() *mongoServerInfo { + socket.Lock() + serverInfo := socket.serverInfo + socket.Unlock() + return serverInfo +} + +// InitialAcquire obtains the first reference to the socket, either +// right after the connection is made or once a recycled socket is +// being put back in use. +func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error { + socket.Lock() + if socket.references > 0 { + panic("Socket acquired out of cache with references") + } + if socket.dead != nil { + dead := socket.dead + socket.Unlock() + return dead + } + socket.references++ + socket.serverInfo = serverInfo + socket.timeout = timeout + stats.socketsInUse(+1) + stats.socketRefs(+1) + socket.Unlock() + return nil +} + +// Acquire obtains an additional reference to the socket. +// The socket will only be recycled when it's released as many +// times as it's been acquired. +func (socket *mongoSocket) Acquire() (info *mongoServerInfo) { + socket.Lock() + if socket.references == 0 { + panic("Socket got non-initial acquire with references == 0") + } + // We'll track references to dead sockets as well. + // Caller is still supposed to release the socket. + socket.references++ + stats.socketRefs(+1) + serverInfo := socket.serverInfo + socket.Unlock() + return serverInfo +} + +// Release decrements a socket reference. The socket will be +// recycled once its released as many times as it's been acquired. +func (socket *mongoSocket) Release() { + socket.Lock() + if socket.references == 0 { + panic("socket.Release() with references == 0") + } + socket.references-- + stats.socketRefs(-1) + if socket.references == 0 { + stats.socketsInUse(-1) + server := socket.server + socket.Unlock() + socket.LogoutAll() + // If the socket is dead server is nil. + if server != nil { + server.RecycleSocket(socket) + } + } else { + socket.Unlock() + } +} + +// SetTimeout changes the timeout used on socket operations. +func (socket *mongoSocket) SetTimeout(d time.Duration) { + socket.Lock() + socket.timeout = d + socket.Unlock() +} + +type deadlineType int + +const ( + readDeadline deadlineType = 1 + writeDeadline deadlineType = 2 +) + +func (socket *mongoSocket) updateDeadline(which deadlineType) { + var when time.Time + if socket.timeout > 0 { + when = time.Now().Add(socket.timeout) + } + whichstr := "" + switch which { + case readDeadline | writeDeadline: + whichstr = "read/write" + socket.conn.SetDeadline(when) + case readDeadline: + whichstr = "read" + socket.conn.SetReadDeadline(when) + case writeDeadline: + whichstr = "write" + socket.conn.SetWriteDeadline(when) + default: + panic("invalid parameter to updateDeadline") + } + debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when) +} + +// Close terminates the socket use. +func (socket *mongoSocket) Close() { + socket.kill(errors.New("Closed explicitly"), false) +} + +func (socket *mongoSocket) kill(err error, abend bool) { + socket.Lock() + if socket.dead != nil { + debugf("Socket %p to %s: killed again: %s (previously: %s)", socket, socket.addr, err.Error(), socket.dead.Error()) + socket.Unlock() + return + } + logf("Socket %p to %s: closing: %s (abend=%v)", socket, socket.addr, err.Error(), abend) + socket.dead = err + socket.conn.Close() + stats.socketsAlive(-1) + replyFuncs := socket.replyFuncs + socket.replyFuncs = make(map[uint32]replyFunc) + server := socket.server + socket.server = nil + socket.gotNonce.Broadcast() + socket.Unlock() + for _, replyFunc := range replyFuncs { + logf("Socket %p to %s: notifying replyFunc of closed socket: %s", socket, socket.addr, err.Error()) + replyFunc(err, nil, -1, nil) + } + if abend { + server.AbendSocket(socket) + } +} + +func (socket *mongoSocket) SimpleQuery(op *queryOp) (data []byte, err error) { + var wait, change sync.Mutex + var replyDone bool + var replyData []byte + var replyErr error + wait.Lock() + op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { + change.Lock() + if !replyDone { + replyDone = true + replyErr = err + if err == nil { + replyData = docData + } + } + change.Unlock() + wait.Unlock() + } + err = socket.Query(op) + if err != nil { + return nil, err + } + wait.Lock() + change.Lock() + data = replyData + err = replyErr + change.Unlock() + return data, err +} + +func (socket *mongoSocket) Query(ops ...interface{}) (err error) { + + if lops := socket.flushLogout(); len(lops) > 0 { + ops = append(lops, ops...) + } + + buf := make([]byte, 0, 256) + + // Serialize operations synchronously to avoid interrupting + // other goroutines while we can't really be sending data. + // Also, record id positions so that we can compute request + // ids at once later with the lock already held. + requests := make([]requestInfo, len(ops)) + requestCount := 0 + + for _, op := range ops { + debugf("Socket %p to %s: serializing op: %#v", socket, socket.addr, op) + start := len(buf) + var replyFunc replyFunc + switch op := op.(type) { + + case *updateOp: + buf = addHeader(buf, 2001) + buf = addInt32(buf, 0) // Reserved + buf = addCString(buf, op.collection) + buf = addInt32(buf, int32(op.flags)) + debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.selector) + buf, err = addBSON(buf, op.selector) + if err != nil { + return err + } + debugf("Socket %p to %s: serializing update document: %#v", socket, socket.addr, op.update) + buf, err = addBSON(buf, op.update) + if err != nil { + return err + } + + case *insertOp: + buf = addHeader(buf, 2002) + buf = addInt32(buf, int32(op.flags)) + buf = addCString(buf, op.collection) + for _, doc := range op.documents { + debugf("Socket %p to %s: serializing document for insertion: %#v", socket, socket.addr, doc) + buf, err = addBSON(buf, doc) + if err != nil { + return err + } + } + + case *queryOp: + buf = addHeader(buf, 2004) + buf = addInt32(buf, int32(op.flags)) + buf = addCString(buf, op.collection) + buf = addInt32(buf, op.skip) + buf = addInt32(buf, op.limit) + buf, err = addBSON(buf, op.finalQuery(socket)) + if err != nil { + return err + } + if op.selector != nil { + buf, err = addBSON(buf, op.selector) + if err != nil { + return err + } + } + replyFunc = op.replyFunc + + case *getMoreOp: + buf = addHeader(buf, 2005) + buf = addInt32(buf, 0) // Reserved + buf = addCString(buf, op.collection) + buf = addInt32(buf, op.limit) + buf = addInt64(buf, op.cursorId) + replyFunc = op.replyFunc + + case *deleteOp: + buf = addHeader(buf, 2006) + buf = addInt32(buf, 0) // Reserved + buf = addCString(buf, op.collection) + buf = addInt32(buf, int32(op.flags)) + debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.selector) + buf, err = addBSON(buf, op.selector) + if err != nil { + return err + } + + case *killCursorsOp: + buf = addHeader(buf, 2007) + buf = addInt32(buf, 0) // Reserved + buf = addInt32(buf, int32(len(op.cursorIds))) + for _, cursorId := range op.cursorIds { + buf = addInt64(buf, cursorId) + } + + default: + panic("internal error: unknown operation type") + } + + setInt32(buf, start, int32(len(buf)-start)) + + if replyFunc != nil { + request := &requests[requestCount] + request.replyFunc = replyFunc + request.bufferPos = start + requestCount++ + } + } + + // Buffer is ready for the pipe. Lock, allocate ids, and enqueue. + + socket.Lock() + if socket.dead != nil { + dead := socket.dead + socket.Unlock() + debugf("Socket %p to %s: failing query, already closed: %s", socket, socket.addr, socket.dead.Error()) + // XXX This seems necessary in case the session is closed concurrently + // with a query being performed, but it's not yet tested: + for i := 0; i != requestCount; i++ { + request := &requests[i] + if request.replyFunc != nil { + request.replyFunc(dead, nil, -1, nil) + } + } + return dead + } + + wasWaiting := len(socket.replyFuncs) > 0 + + // Reserve id 0 for requests which should have no responses. + requestId := socket.nextRequestId + 1 + if requestId == 0 { + requestId++ + } + socket.nextRequestId = requestId + uint32(requestCount) + for i := 0; i != requestCount; i++ { + request := &requests[i] + setInt32(buf, request.bufferPos+4, int32(requestId)) + socket.replyFuncs[requestId] = request.replyFunc + requestId++ + } + + debugf("Socket %p to %s: sending %d op(s) (%d bytes)", socket, socket.addr, len(ops), len(buf)) + stats.sentOps(len(ops)) + + socket.updateDeadline(writeDeadline) + _, err = socket.conn.Write(buf) + if !wasWaiting && requestCount > 0 { + socket.updateDeadline(readDeadline) + } + socket.Unlock() + return err +} + +func fill(r net.Conn, b []byte) error { + l := len(b) + n, err := r.Read(b) + for n != l && err == nil { + var ni int + ni, err = r.Read(b[n:]) + n += ni + } + return err +} + +// Estimated minimum cost per socket: 1 goroutine + memory for the largest +// document ever seen. +func (socket *mongoSocket) readLoop() { + p := make([]byte, 36) // 16 from header + 20 from OP_REPLY fixed fields + s := make([]byte, 4) + conn := socket.conn // No locking, conn never changes. + for { + // XXX Handle timeouts, , etc + err := fill(conn, p) + if err != nil { + socket.kill(err, true) + return + } + + totalLen := getInt32(p, 0) + responseTo := getInt32(p, 8) + opCode := getInt32(p, 12) + + // Don't use socket.server.Addr here. socket is not + // locked and socket.server may go away. + debugf("Socket %p to %s: got reply (%d bytes)", socket, socket.addr, totalLen) + + _ = totalLen + + if opCode != 1 { + socket.kill(errors.New("opcode != 1, corrupted data?"), true) + return + } + + reply := replyOp{ + flags: uint32(getInt32(p, 16)), + cursorId: getInt64(p, 20), + firstDoc: getInt32(p, 28), + replyDocs: getInt32(p, 32), + } + + stats.receivedOps(+1) + stats.receivedDocs(int(reply.replyDocs)) + + socket.Lock() + replyFunc, ok := socket.replyFuncs[uint32(responseTo)] + if ok { + delete(socket.replyFuncs, uint32(responseTo)) + } + socket.Unlock() + + if replyFunc != nil && reply.replyDocs == 0 { + replyFunc(nil, &reply, -1, nil) + } else { + for i := 0; i != int(reply.replyDocs); i++ { + err := fill(conn, s) + if err != nil { + if replyFunc != nil { + replyFunc(err, nil, -1, nil) + } + socket.kill(err, true) + return + } + + b := make([]byte, int(getInt32(s, 0))) + + // copy(b, s) in an efficient way. + b[0] = s[0] + b[1] = s[1] + b[2] = s[2] + b[3] = s[3] + + err = fill(conn, b[4:]) + if err != nil { + if replyFunc != nil { + replyFunc(err, nil, -1, nil) + } + socket.kill(err, true) + return + } + + if globalDebug && globalLogger != nil { + m := bson.M{} + if err := bson.Unmarshal(b, m); err == nil { + debugf("Socket %p to %s: received document: %#v", socket, socket.addr, m) + } + } + + if replyFunc != nil { + replyFunc(nil, &reply, i, b) + } + + // XXX Do bound checking against totalLen. + } + } + + socket.Lock() + if len(socket.replyFuncs) == 0 { + // Nothing else to read for now. Disable deadline. + socket.conn.SetReadDeadline(time.Time{}) + } else { + socket.updateDeadline(readDeadline) + } + socket.Unlock() + + // XXX Do bound checking against totalLen. + } +} + +var emptyHeader = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + +func addHeader(b []byte, opcode int) []byte { + i := len(b) + b = append(b, emptyHeader...) + // Enough for current opcodes. + b[i+12] = byte(opcode) + b[i+13] = byte(opcode >> 8) + return b +} + +func addInt32(b []byte, i int32) []byte { + return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24)) +} + +func addInt64(b []byte, i int64) []byte { + return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24), + byte(i>>32), byte(i>>40), byte(i>>48), byte(i>>56)) +} + +func addCString(b []byte, s string) []byte { + b = append(b, []byte(s)...) + b = append(b, 0) + return b +} + +func addBSON(b []byte, doc interface{}) ([]byte, error) { + if doc == nil { + return append(b, 5, 0, 0, 0, 0), nil + } + data, err := bson.Marshal(doc) + if err != nil { + return b, err + } + return append(b, data...), nil +} + +func setInt32(b []byte, pos int, i int32) { + b[pos] = byte(i) + b[pos+1] = byte(i >> 8) + b[pos+2] = byte(i >> 16) + b[pos+3] = byte(i >> 24) +} + +func getInt32(b []byte, pos int) int32 { + return (int32(b[pos+0])) | + (int32(b[pos+1]) << 8) | + (int32(b[pos+2]) << 16) | + (int32(b[pos+3]) << 24) +} + +func getInt64(b []byte, pos int) int64 { + return (int64(b[pos+0])) | + (int64(b[pos+1]) << 8) | + (int64(b[pos+2]) << 16) | + (int64(b[pos+3]) << 24) | + (int64(b[pos+4]) << 32) | + (int64(b[pos+5]) << 40) | + (int64(b[pos+6]) << 48) | + (int64(b[pos+7]) << 56) +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/stats.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/stats.go new file mode 100644 index 0000000..59723e6 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/stats.go @@ -0,0 +1,147 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "sync" +) + +var stats *Stats +var statsMutex sync.Mutex + +func SetStats(enabled bool) { + statsMutex.Lock() + if enabled { + if stats == nil { + stats = &Stats{} + } + } else { + stats = nil + } + statsMutex.Unlock() +} + +func GetStats() (snapshot Stats) { + statsMutex.Lock() + snapshot = *stats + statsMutex.Unlock() + return +} + +func ResetStats() { + statsMutex.Lock() + debug("Resetting stats") + old := stats + stats = &Stats{} + // These are absolute values: + stats.Clusters = old.Clusters + stats.SocketsInUse = old.SocketsInUse + stats.SocketsAlive = old.SocketsAlive + stats.SocketRefs = old.SocketRefs + statsMutex.Unlock() + return +} + +type Stats struct { + Clusters int + MasterConns int + SlaveConns int + SentOps int + ReceivedOps int + ReceivedDocs int + SocketsAlive int + SocketsInUse int + SocketRefs int +} + +func (stats *Stats) cluster(delta int) { + if stats != nil { + statsMutex.Lock() + stats.Clusters += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) conn(delta int, master bool) { + if stats != nil { + statsMutex.Lock() + if master { + stats.MasterConns += delta + } else { + stats.SlaveConns += delta + } + statsMutex.Unlock() + } +} + +func (stats *Stats) sentOps(delta int) { + if stats != nil { + statsMutex.Lock() + stats.SentOps += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) receivedOps(delta int) { + if stats != nil { + statsMutex.Lock() + stats.ReceivedOps += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) receivedDocs(delta int) { + if stats != nil { + statsMutex.Lock() + stats.ReceivedDocs += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) socketsInUse(delta int) { + if stats != nil { + statsMutex.Lock() + stats.SocketsInUse += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) socketsAlive(delta int) { + if stats != nil { + statsMutex.Lock() + stats.SocketsAlive += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) socketRefs(delta int) { + if stats != nil { + statsMutex.Lock() + stats.SocketRefs += delta + statsMutex.Unlock() + } +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/suite_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/suite_test.go new file mode 100644 index 0000000..334407e --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/suite_test.go @@ -0,0 +1,254 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo_test + +import ( + "errors" + "flag" + "fmt" + "net" + "os/exec" + "runtime" + "strconv" + "testing" + "time" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +var fast = flag.Bool("fast", false, "Skip slow tests") + +type M bson.M + +type cLogger C + +func (c *cLogger) Output(calldepth int, s string) error { + ns := time.Now().UnixNano() + t := float64(ns%100e9) / 1e9 + ((*C)(c)).Logf("[LOG] %.05f %s", t, s) + return nil +} + +func TestAll(t *testing.T) { + TestingT(t) +} + +type S struct { + session *mgo.Session + stopped bool + build mgo.BuildInfo + frozen []string +} + +func (s *S) versionAtLeast(v ...int) bool { + for i := range v { + if i == len(s.build.VersionArray) { + return false + } + if s.build.VersionArray[i] < v[i] { + return false + } + } + return true +} + +var _ = Suite(&S{}) + +func (s *S) SetUpSuite(c *C) { + mgo.SetDebug(true) + mgo.SetStats(true) + s.StartAll() + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + s.build, err = session.BuildInfo() + c.Check(err, IsNil) + session.Close() +} + +func (s *S) SetUpTest(c *C) { + err := run("mongo --nodb testdb/dropall.js") + if err != nil { + panic(err.Error()) + } + mgo.SetLogger((*cLogger)(c)) + mgo.ResetStats() +} + +func (s *S) TearDownTest(c *C) { + if s.stopped { + s.StartAll() + } + for _, host := range s.frozen { + if host != "" { + s.Thaw(host) + } + } + var stats mgo.Stats + for i := 0; ; i++ { + stats = mgo.GetStats() + if stats.SocketsInUse == 0 && stats.SocketsAlive == 0 { + break + } + if i == 20 { + c.Fatal("Test left sockets in a dirty state") + } + c.Logf("Waiting for sockets to die: %d in use, %d alive", stats.SocketsInUse, stats.SocketsAlive) + time.Sleep(500 * time.Millisecond) + } + for i := 0; ; i++ { + stats = mgo.GetStats() + if stats.Clusters == 0 { + break + } + if i == 60 { + c.Fatal("Test left clusters alive") + } + c.Logf("Waiting for clusters to die: %d alive", stats.Clusters) + time.Sleep(1 * time.Second) + } +} + +func (s *S) Stop(host string) { + // Give a moment for slaves to sync and avoid getting rollback issues. + panicOnWindows() + time.Sleep(2 * time.Second) + err := run("cd _testdb && supervisorctl stop " + supvName(host)) + if err != nil { + panic(err) + } + s.stopped = true +} + +func (s *S) pid(host string) int { + output, err := exec.Command("lsof", "-iTCP:"+hostPort(host), "-sTCP:LISTEN", "-Fp").CombinedOutput() + if err != nil { + panic(err) + } + pidstr := string(output[1 : len(output)-1]) + pid, err := strconv.Atoi(pidstr) + if err != nil { + panic("cannot convert pid to int: " + pidstr) + } + return pid +} + +func (s *S) Freeze(host string) { + err := stop(s.pid(host)) + if err != nil { + panic(err) + } + s.frozen = append(s.frozen, host) +} + +func (s *S) Thaw(host string) { + err := cont(s.pid(host)) + if err != nil { + panic(err) + } + for i, frozen := range s.frozen { + if frozen == host { + s.frozen[i] = "" + } + } +} + +func (s *S) StartAll() { + // Restart any stopped nodes. + run("cd _testdb && supervisorctl start all") + err := run("cd testdb && mongo --nodb wait.js") + if err != nil { + panic(err) + } + s.stopped = false +} + +func run(command string) error { + var output []byte + var err error + if runtime.GOOS == "windows" { + output, err = exec.Command("cmd", "/C", command).CombinedOutput() + } else { + output, err = exec.Command("/bin/sh", "-c", command).CombinedOutput() + } + + if err != nil { + msg := fmt.Sprintf("Failed to execute: %s: %s\n%s", command, err.Error(), string(output)) + return errors.New(msg) + } + return nil +} + +var supvNames = map[string]string{ + "40001": "db1", + "40002": "db2", + "40011": "rs1a", + "40012": "rs1b", + "40013": "rs1c", + "40021": "rs2a", + "40022": "rs2b", + "40023": "rs2c", + "40031": "rs3a", + "40032": "rs3b", + "40033": "rs3c", + "40041": "rs4a", + "40101": "cfg1", + "40102": "cfg2", + "40103": "cfg3", + "40201": "s1", + "40202": "s2", + "40203": "s3", +} + +// supvName returns the supervisord name for the given host address. +func supvName(host string) string { + host, port, err := net.SplitHostPort(host) + if err != nil { + panic(err) + } + name, ok := supvNames[port] + if !ok { + panic("Unknown host: " + host) + } + return name +} + +func hostPort(host string) string { + _, port, err := net.SplitHostPort(host) + if err != nil { + panic(err) + } + return port +} + +func panicOnWindows() { + if runtime.GOOS == "windows" { + panic("the test suite is not yet fully supported on Windows") + } +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/syscall_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/syscall_test.go new file mode 100644 index 0000000..b8bbd7b --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/syscall_test.go @@ -0,0 +1,15 @@ +// +build !windows + +package mgo_test + +import ( + "syscall" +) + +func stop(pid int) (err error) { + return syscall.Kill(pid, syscall.SIGSTOP) +} + +func cont(pid int) (err error) { + return syscall.Kill(pid, syscall.SIGCONT) +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/syscall_windows_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/syscall_windows_test.go new file mode 100644 index 0000000..f2deaca --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/syscall_windows_test.go @@ -0,0 +1,11 @@ +package mgo_test + +func stop(pid int) (err error) { + panicOnWindows() // Always does. + return nil +} + +func cont(pid int) (err error) { + panicOnWindows() // Always does. + return nil +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/client.pem b/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/client.pem new file mode 100644 index 0000000..cc57eec --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/client.pem @@ -0,0 +1,44 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAwE2sl8YeTTSetwo9kykJ5mCZ/FtfPtn/0X4nOlTM2Qc/uWzA +sjSYoSV4UkuOiWjKQQH2EDeXaltshOo7F0oCY5ozVeQe+phe987iKTvLtf7NoXJD +KqNqR4Kb4ylbCrEky7+Xvw6yrrqw8qgWy+9VsrilR3q8LsETE9SBMtfp3BUaaNQp +peNm+iAhx3uZSv3mdzSLFSA/o61kAyG0scLExYDjo/7xyMNQoloLvNmx4Io160+y +lOz077/qqU620tmuDLRz1QdxK/bptmXTnsBCRxl+U8nzbwVZgWFENhXplbcN+SjN +LhdnvTiU2qFhgZmc7ZtCKdPIpx3W6pH9bx7kTwIDAQABAoIBAQCOQygyo8NY9FuS +J8ZDrvF+9+oS8fm1QorpDT2x/ngI+j7fSyAG9bgQRusLXpAVAWvWyb+iYa3nZbkT +X0DVys+XpcTifr+YPc7L3sYbIPxkKBsxm5kq2vfN7Uart7V9ZG1HOfblxdbUQpKT +AVzUA7vPWqATEC5VHEqjuerWlTqRr9YLZE/nkE7ICLISqdl4WDYfUYJwoXWfYkXQ +Lfl5Qh2leyri9S3urvDrhnURTQ1lM182IbTRA+9rUiFzsRW+9U4HPY7Ao2Itp8dr +GRP4rcq4TP+NcF0Ky64cNfKXCWmwqTBRFYAlTD6gwjN/s2BzvWD/2nlnc0DYAXrB +TgFCPk7xAoGBAOwuHICwwTxtzrdWjuRGU3RxL4eLEXedtL8yon/yPci3e+8eploX +1Fp0rEK2gIGDp/X8DiOtrKXih8XPusCwE/I3EvjHdI0RylLZXTPOp1Ei21dXRsiV +YxcF+d5s11q5tJtF+5ISUeIz2iSc9Z2LBnb8JDK1jcCRa5Q212q3ZWW5AoGBANBw +9CoMbxINLG1q0NvOOSwMKDk2OB+9JbQ5lwF4ijZl2I6qRoOCzQ3lBs0Qv/AeBjNR +SerDs2+eWnIBUbgSdiqcOKnXAI/Qbl1IkVFYV/2g9m6dgu1fNWNBv8NIYDHCLfDx +W3fpO5JMf+iE5XC4XqCfSBIME2yxPSGQjal6tB5HAoGAddYDzolhv/6hVoPPQ0F7 +PeuC5UOTcXSzy3k97kw0W0KAiStnoCengYIYuChKMVQ4ptgdTdvG+fTt/NnJuX2g +Vgb4ZjtNgVzQ70kX4VNH04lqmkcnP8iY6dHHexwezls9KwNdouGVDSEFw6K0QOgu +T4s5nDtNADkNzaMXE11xL7ECgYBoML3rstFmTY1ymB0Uck3jtaP5jR+axdpt7weL +Zax4qooILhcXL6++DUhMAt5ecTOaPTzci7xKw/Xj3MLzZs8IV5R/WQhf2sj/+gEh +jy5UijwEaNmEO74dAkWPoMLsvGpocMzO8JeldnXNTXi+0noCgfvtgXnIMAQlnfMh +z0LviwKBgQCg5KR9JC4iuKses7Kfv2YelcO8vOZkRzBu3NdRWMsiJQC+qfetgd57 +RjRjlRWd1WCHJ5Kmx3hkUaZZOrX5knqfsRW3Nl0I74xgWl7Bli2eSJ9VWl59bcd6 +DqphhY7/gcW+QZlhXpnqbf0W8jB2gPhTYERyCBoS9LfhZWZu/11wuQ== +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICyTCCAjKgAwIBAgIBATANBgkqhkiG9w0BAQUFADBcMQswCQYDVQQGEwJHTzEM +MAoGA1UECBMDTUdPMQwwCgYDVQQHEwNNR08xDDAKBgNVBAoTA01HTzEPMA0GA1UE +CxMGU2VydmVyMRIwEAYDVQQDEwlsb2NhbGhvc3QwHhcNMTQwOTI0MTQwMzUzWhcN +MTUwOTI0MTQwMzUzWjBcMQswCQYDVQQGEwJHTzEMMAoGA1UECBMDTUdPMQwwCgYD +VQQHEwNNR08xDDAKBgNVBAoTA01HTzEPMA0GA1UECxMGQ2xpZW50MRIwEAYDVQQD +Ewlsb2NhbGhvc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDATayX +xh5NNJ63Cj2TKQnmYJn8W18+2f/Rfic6VMzZBz+5bMCyNJihJXhSS46JaMpBAfYQ +N5dqW2yE6jsXSgJjmjNV5B76mF73zuIpO8u1/s2hckMqo2pHgpvjKVsKsSTLv5e/ +DrKuurDyqBbL71WyuKVHerwuwRMT1IEy1+ncFRpo1Cml42b6ICHHe5lK/eZ3NIsV +ID+jrWQDIbSxwsTFgOOj/vHIw1CiWgu82bHgijXrT7KU7PTvv+qpTrbS2a4MtHPV +B3Er9um2ZdOewEJHGX5TyfNvBVmBYUQ2FemVtw35KM0uF2e9OJTaoWGBmZztm0Ip +08inHdbqkf1vHuRPAgMBAAGjFzAVMBMGA1UdJQQMMAoGCCsGAQUFBwMCMA0GCSqG +SIb3DQEBBQUAA4GBAJZD7idSIRzhGlJYARPKWnX2CxD4VVB0F5cH5Mlc2YnoUSU/ +rKuPZFuOYND3awKqez6K3rNb3+tQmNitmoOT8ImmX1uJKBo5w9tuo4B2MmLQcPMk +3fhPePuQCjtlArSmKVrNTrYPkyB9NwKS6q0+FzseFTw9ZJUIKiO9sSjMe+HP +-----END CERTIFICATE----- diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/dropall.js b/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/dropall.js new file mode 100644 index 0000000..232eca3 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/dropall.js @@ -0,0 +1,52 @@ + +var ports = [40001, 40002, 40011, 40012, 40013, 40021, 40022, 40023, 40041, 40101, 40102, 40103, 40201, 40202, 40203] +var auth = [40002, 40103, 40203, 40031] +var db1 = new Mongo("localhost:40001") + +if (db1.getDB("admin").serverBuildInfo().OpenSSLVersion != "") { + ports.push(40003) + auth.push(40003) +} + +for (var i in ports) { + var port = ports[i] + var server = "localhost:" + port + var mongo = new Mongo("localhost:" + port) + var admin = mongo.getDB("admin") + + for (var j in auth) { + if (auth[j] == port) { + admin.auth("root", "rapadura") + admin.system.users.find().forEach(function(u) { + if (u.user == "root" || u.user == "reader") { + return; + } + if (typeof admin.dropUser == "function") { + mongo.getDB(u.db).dropUser(u.user); + } else { + admin.removeUser(u.user); + } + }) + break + } + } + var result = admin.runCommand({"listDatabases": 1}) + // Why is the command returning undefined!? + while (typeof result.databases == "undefined") { + result = admin.runCommand({"listDatabases": 1}) + } + var dbs = result.databases + for (var j = 0; j != dbs.length; j++) { + var db = dbs[j] + switch (db.name) { + case "admin": + case "local": + case "config": + break + default: + mongo.getDB(db.name).dropDatabase() + } + } +} + +// vim:ts=4:sw=4:et diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/init.js b/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/init.js new file mode 100644 index 0000000..7deb67e --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/init.js @@ -0,0 +1,110 @@ +//var settings = {heartbeatSleep: 0.05, heartbeatTimeout: 0.5} +var settings = {}; + +// We know the master of the first set (pri=1), but not of the second. +var rs1cfg = {_id: "rs1", + members: [{_id: 1, host: "127.0.0.1:40011", priority: 1, tags: {rs1: "a"}}, + {_id: 2, host: "127.0.0.1:40012", priority: 0, tags: {rs1: "b"}}, + {_id: 3, host: "127.0.0.1:40013", priority: 0, tags: {rs1: "c"}}], + settings: settings} +var rs2cfg = {_id: "rs2", + members: [{_id: 1, host: "127.0.0.1:40021", priority: 1, tags: {rs2: "a"}}, + {_id: 2, host: "127.0.0.1:40022", priority: 1, tags: {rs2: "b"}}, + {_id: 3, host: "127.0.0.1:40023", priority: 1, tags: {rs2: "c"}}], + settings: settings} +var rs3cfg = {_id: "rs3", + members: [{_id: 1, host: "127.0.0.1:40031", priority: 1, tags: {rs3: "a"}}, + {_id: 2, host: "127.0.0.1:40032", priority: 1, tags: {rs3: "b"}}, + {_id: 3, host: "127.0.0.1:40033", priority: 1, tags: {rs3: "c"}}], + settings: settings} + +for (var i = 0; i != 60; i++) { + try { + db1 = new Mongo("127.0.0.1:40001").getDB("admin") + db2 = new Mongo("127.0.0.1:40002").getDB("admin") + rs1a = new Mongo("127.0.0.1:40011").getDB("admin") + rs2a = new Mongo("127.0.0.1:40021").getDB("admin") + rs3a = new Mongo("127.0.0.1:40031").getDB("admin") + break + } catch(err) { + print("Can't connect yet...") + } + sleep(1000) +} + +function hasSSL() { + return db1.serverBuildInfo().OpenSSLVersion != "" +} + +rs1a.runCommand({replSetInitiate: rs1cfg}) +rs2a.runCommand({replSetInitiate: rs2cfg}) +rs3a.runCommand({replSetInitiate: rs3cfg}) + +function configShards() { + cfg1 = new Mongo("127.0.0.1:40201").getDB("admin") + cfg1.runCommand({addshard: "127.0.0.1:40001"}) + cfg1.runCommand({addshard: "rs1/127.0.0.1:40011"}) + + cfg2 = new Mongo("127.0.0.1:40202").getDB("admin") + cfg2.runCommand({addshard: "rs2/127.0.0.1:40021"}) + + cfg3 = new Mongo("127.0.0.1:40203").getDB("admin") + cfg3.runCommand({addshard: "rs3/127.0.0.1:40031"}) +} + +function configAuth() { + var addrs = ["127.0.0.1:40002", "127.0.0.1:40203", "127.0.0.1:40031"] + if (hasSSL()) { + addrs.push("127.0.0.1:40003") + } + for (var i in addrs) { + var db = new Mongo(addrs[i]).getDB("admin") + var v = db.serverBuildInfo().versionArray + if (v < [2, 5]) { + db.addUser("root", "rapadura") + } else { + db.createUser({user: "root", pwd: "rapadura", roles: ["root"]}) + } + db.auth("root", "rapadura") + if (v >= [2, 6]) { + db.createUser({user: "reader", pwd: "rapadura", roles: ["readAnyDatabase"]}) + } else if (v >= [2, 4]) { + db.addUser({user: "reader", pwd: "rapadura", roles: ["readAnyDatabase"]}) + } else { + db.addUser("reader", "rapadura", true) + } + } +} + +function countHealthy(rs) { + var status = rs.runCommand({replSetGetStatus: 1}) + var count = 0 + if (typeof status.members != "undefined") { + for (var i = 0; i != status.members.length; i++) { + var m = status.members[i] + if (m.health == 1 && (m.state == 1 || m.state == 2)) { + count += 1 + } + } + } + return count +} + +var totalRSMembers = rs1cfg.members.length + rs2cfg.members.length + rs3cfg.members.length + +for (var i = 0; i != 60; i++) { + var count = countHealthy(rs1a) + countHealthy(rs2a) + countHealthy(rs3a) + print("Replica sets have", count, "healthy nodes.") + if (count == totalRSMembers) { + sleep(2000) + configShards() + configAuth() + quit(0) + } + sleep(1000) +} + +print("Replica sets didn't sync up properly.") +quit(12) + +// vim:ts=4:sw=4:et diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/server.pem b/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/server.pem new file mode 100644 index 0000000..16fbef1 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/server.pem @@ -0,0 +1,33 @@ +-----BEGIN CERTIFICATE----- +MIIC+DCCAmGgAwIBAgIJAJ5pBAq2HXAsMA0GCSqGSIb3DQEBBQUAMFwxCzAJBgNV +BAYTAkdPMQwwCgYDVQQIEwNNR08xDDAKBgNVBAcTA01HTzEMMAoGA1UEChMDTUdP +MQ8wDQYDVQQLEwZTZXJ2ZXIxEjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0xNDA5MjQx +MzUxMTBaFw0xNTA5MjQxMzUxMTBaMFwxCzAJBgNVBAYTAkdPMQwwCgYDVQQIEwNN +R08xDDAKBgNVBAcTA01HTzEMMAoGA1UEChMDTUdPMQ8wDQYDVQQLEwZTZXJ2ZXIx +EjAQBgNVBAMTCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA +pQ5wO2L23xMI4PzpVt/Ftvez82IvA9amwr3fUd7RjlYwiFsFeMnG24a4CUoOeKF0 +fpQWc9rmCs0EeP5ofZ2otOsfxoVWXZAZWdgauuwlYB6EeFaAMH3fxVH3IiH+21RR +q2w9sH/s4fqh5stavUfyPdVmCcb8NW0jD8jlqniJL0kCAwEAAaOBwTCBvjAdBgNV +HQ4EFgQUjyVWGMHBrmPDGwCY5VusHsKIpzIwgY4GA1UdIwSBhjCBg4AUjyVWGMHB +rmPDGwCY5VusHsKIpzKhYKReMFwxCzAJBgNVBAYTAkdPMQwwCgYDVQQIEwNNR08x +DDAKBgNVBAcTA01HTzEMMAoGA1UEChMDTUdPMQ8wDQYDVQQLEwZTZXJ2ZXIxEjAQ +BgNVBAMTCWxvY2FsaG9zdIIJAJ5pBAq2HXAsMAwGA1UdEwQFMAMBAf8wDQYJKoZI +hvcNAQEFBQADgYEAa65TgDKp3SRUDNAILSuQOCEbenWh/DMPL4vTVgo/Dxd4emoO +7i8/4HMTa0XeYIVbAsxO+dqtxqt32IcV7DurmQozdUZ7q0ueJRXon6APnCN0IqPC +sF71w63xXfpmnvTAfQXi7x6TUAyAQ2nScHExAjzc000DF1dO/6+nIINqNQE= +-----END CERTIFICATE----- +-----BEGIN RSA PRIVATE KEY----- +MIICWwIBAAKBgQClDnA7YvbfEwjg/OlW38W297PzYi8D1qbCvd9R3tGOVjCIWwV4 +ycbbhrgJSg54oXR+lBZz2uYKzQR4/mh9nai06x/GhVZdkBlZ2Bq67CVgHoR4VoAw +fd/FUfciIf7bVFGrbD2wf+zh+qHmy1q9R/I91WYJxvw1bSMPyOWqeIkvSQIDAQAB +AoGABA9S22MXx2zkbwRJiQWAC3wURQxJM8L33xpkf9MHPIUKNJBolgwAhC3QIQpd +SMJP5z0lQDxGJEXesksvrsdN+vsgbleRfQsAIcY/rEhr9h8m6auM08f+69oIX32o +aTOWJJRofjbgzE5c/RijqhIaYGdq54a0EE9mAaODwZoa2/ECQQDRGrIRI5L3pdRA +yifDKNjvAFOk6TbdGe+J9zHFw4F7bA2In/b+rno9vrj+EanOevD8LRLzeFshzXrG +WQFzZ69/AkEAyhLSY7WNiQTeJWCwXawVnoSl5AMSRYFA/A2sEUokfORR5BS7gqvL +mmEKmvslnZp5qlMtM4AyrW2OaoGvE6sFNwJACB3xK5kl61cUli9Cu+CqCx0IIi6r +YonPMpvV4sdkD1ZycAtFmz1KoXr102b8IHfFQwS855aUcwt26Jwr4j70IQJAXv9+ +PTXq9hF9xiCwiTkPaNh/jLQM8PQU8uoSjIZIpRZJkWpVxNay/z7D15xeULuAmxxD +UcThDjtFCrkw75Qk/QJAFfcM+5r31R1RrBGM1QPKwDqkFTGsFKnMWuS/pXyLTTOv +I+In9ZJyA/R5zKeJZjM7xtZs0ANU9HpOpgespq6CvA== +-----END RSA PRIVATE KEY----- diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/setup.sh b/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/setup.sh new file mode 100644 index 0000000..317e8e5 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/setup.sh @@ -0,0 +1,58 @@ +#!/bin/sh -e + +start() { + mkdir _testdb + cd _testdb + mkdir db1 db2 db3 rs1a rs1b rs1c rs2a rs2b rs2c rs3a rs3b rs3c rs4a cfg1 cfg2 cfg3 + cp ../testdb/supervisord.conf supervisord.conf + cp ../testdb/server.pem server.pem + echo keyfile > keyfile + chmod 600 keyfile + COUNT=$(grep '^\[program' supervisord.conf | wc -l | tr -d ' ') + if ! mongod --help | grep -q -- --ssl; then + COUNT=$(($COUNT - 1)) + fi + echo "Running supervisord..." + supervisord || ( echo "Supervisord failed executing ($?)" && exit 1 ) + echo "Supervisord is up, starting $COUNT processes..." + for i in $(seq 10); do + RUNNING=$(supervisorctl status | grep RUNNING | wc -l | tr -d ' ') + echo "$RUNNING processes running..." + if [ x$COUNT = x$RUNNING ]; then + echo "Running setup.js with mongo..." + mongo --nodb ../testdb/init.js + exit 0 + fi + sleep 1 + done + echo "Failed to start all processes. Check out what's up at $PWD now!" + exit 1 +} + +stop() { + if [ -d _testdb ]; then + echo "Shutting down test cluster..." + (cd _testdb && supervisorctl shutdown) + rm -rf _testdb + fi +} + + +if [ ! -f suite_test.go ]; then + echo "This script must be run from within the source directory." + exit 1 +fi + +case "$1" in + + start) + start $2 + ;; + + stop) + stop $2 + ;; + +esac + +# vim:ts=4:sw=4:et diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/supervisord.conf b/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/supervisord.conf new file mode 100644 index 0000000..1c2b859 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/supervisord.conf @@ -0,0 +1,65 @@ +[supervisord] +logfile = %(here)s/supervisord.log +pidfile = %(here)s/supervisord.pid +directory = %(here)s +#nodaemon = true + +[inet_http_server] +port = 127.0.0.1:9001 + +[supervisorctl] +serverurl = http://127.0.0.1:9001 + +[rpcinterface:supervisor] +supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface + +[program:db1] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --shardsvr --dbpath %(here)s/db1 --bind_ip=127.0.0.1 --port 40001 + +[program:db2] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --shardsvr --dbpath %(here)s/db2 --bind_ip=127.0.0.1 --port 40002 --auth + +[program:db3] +command = mongod -nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --dbpath %(here)s/db3 --bind_ip=127.0.0.1 --port 40003 --auth --sslMode preferSSL --sslCAFile %(here)s/server.pem --sslPEMKeyFile %(here)s/server.pem + +[program:rs1a] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --shardsvr --replSet rs1 --dbpath %(here)s/rs1a --bind_ip=127.0.0.1 --port 40011 +[program:rs1b] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --shardsvr --replSet rs1 --dbpath %(here)s/rs1b --bind_ip=127.0.0.1 --port 40012 +[program:rs1c] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --shardsvr --replSet rs1 --dbpath %(here)s/rs1c --bind_ip=127.0.0.1 --port 40013 + +[program:rs2a] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --shardsvr --replSet rs2 --dbpath %(here)s/rs2a --bind_ip=127.0.0.1 --port 40021 +[program:rs2b] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --shardsvr --replSet rs2 --dbpath %(here)s/rs2b --bind_ip=127.0.0.1 --port 40022 +[program:rs2c] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --shardsvr --replSet rs2 --dbpath %(here)s/rs2c --bind_ip=127.0.0.1 --port 40023 + +[program:rs3a] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --shardsvr --replSet rs3 --dbpath %(here)s/rs3a --bind_ip=127.0.0.1 --port 40031 --auth --keyFile=%(here)s/keyfile +[program:rs3b] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --shardsvr --replSet rs3 --dbpath %(here)s/rs3b --bind_ip=127.0.0.1 --port 40032 --auth --keyFile=%(here)s/keyfile +[program:rs3c] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --shardsvr --replSet rs3 --dbpath %(here)s/rs3c --bind_ip=127.0.0.1 --port 40033 --auth --keyFile=%(here)s/keyfile + +[program:rs4a] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --shardsvr --replSet rs4 --dbpath %(here)s/rs4a --bind_ip=127.0.0.1 --port 40041 + +[program:cfg1] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --configsvr --dbpath %(here)s/cfg1 --bind_ip=127.0.0.1 --port 40101 + +[program:cfg2] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --configsvr --dbpath %(here)s/cfg2 --bind_ip=127.0.0.1 --port 40102 + +[program:cfg3] +command = mongod --nohttpinterface --noprealloc --nojournal --smallfiles --nssize=1 --oplogSize=1 --configsvr --dbpath %(here)s/cfg3 --bind_ip=127.0.0.1 --port 40103 --auth --keyFile=%(here)s/keyfile + +[program:s1] +command = mongos --configdb 127.0.0.1:40101 --bind_ip=127.0.0.1 --port 40201 --chunkSize 1 + +[program:s2] +command = mongos --configdb 127.0.0.1:40102 --bind_ip=127.0.0.1 --port 40202 --chunkSize 1 + +[program:s3] +command = mongos --configdb 127.0.0.1:40103 --bind_ip=127.0.0.1 --port 40203 --chunkSize 1 --keyFile=%(here)s/keyfile diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/wait.js b/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/wait.js new file mode 100644 index 0000000..de0d660 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/testdb/wait.js @@ -0,0 +1,58 @@ +// We know the master of the first set (pri=1), but not of the second. +var settings = {} +var rs1cfg = {_id: "rs1", + members: [{_id: 1, host: "127.0.0.1:40011", priority: 1}, + {_id: 2, host: "127.0.0.1:40012", priority: 0}, + {_id: 3, host: "127.0.0.1:40013", priority: 0}]} +var rs2cfg = {_id: "rs2", + members: [{_id: 1, host: "127.0.0.1:40021", priority: 1}, + {_id: 2, host: "127.0.0.1:40022", priority: 1}, + {_id: 3, host: "127.0.0.1:40023", priority: 0}]} +var rs3cfg = {_id: "rs3", + members: [{_id: 1, host: "127.0.0.1:40031", priority: 1}, + {_id: 2, host: "127.0.0.1:40032", priority: 1}, + {_id: 3, host: "127.0.0.1:40033", priority: 1}], + settings: settings} + +for (var i = 0; i != 60; i++) { + try { + rs1a = new Mongo("127.0.0.1:40011").getDB("admin") + rs2a = new Mongo("127.0.0.1:40021").getDB("admin") + rs3a = new Mongo("127.0.0.1:40031").getDB("admin") + rs3a.auth("root", "rapadura") + db1 = new Mongo("127.0.0.1:40001").getDB("admin") + db2 = new Mongo("127.0.0.1:40002").getDB("admin") + break + } catch(err) { + print("Can't connect yet...") + } + sleep(1000) +} + +function countHealthy(rs) { + var status = rs.runCommand({replSetGetStatus: 1}) + var count = 0 + if (typeof status.members != "undefined") { + for (var i = 0; i != status.members.length; i++) { + var m = status.members[i] + if (m.health == 1 && (m.state == 1 || m.state == 2)) { + count += 1 + } + } + } + return count +} + +var totalRSMembers = rs1cfg.members.length + rs2cfg.members.length + rs3cfg.members.length + +for (var i = 0; i != 60; i++) { + var count = countHealthy(rs1a) + countHealthy(rs2a) + countHealthy(rs3a) + print("Replica sets have", count, "healthy nodes.") + if (count == totalRSMembers) { + quit(0) + } + sleep(1000) +} + +print("Replica sets didn't sync up properly.") +quit(12) diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/chaos.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/chaos.go new file mode 100644 index 0000000..c98adb9 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/chaos.go @@ -0,0 +1,68 @@ +package txn + +import ( + mrand "math/rand" + "time" +) + +var chaosEnabled = false +var chaosSetting Chaos + +// Chaos holds parameters for the failure injection mechanism. +type Chaos struct { + // KillChance is the 0.0 to 1.0 chance that a given checkpoint + // within the algorithm will raise an interruption that will + // stop the procedure. + KillChance float64 + + // SlowdownChance is the 0.0 to 1.0 chance that a given checkpoint + // within the algorithm will be delayed by Slowdown before + // continuing. + SlowdownChance float64 + Slowdown time.Duration + + // If Breakpoint is set, the above settings will only affect the + // named breakpoint. + Breakpoint string +} + +// SetChaos sets the failure injection parameters to c. +func SetChaos(c Chaos) { + chaosSetting = c + chaosEnabled = c.KillChance > 0 || c.SlowdownChance > 0 +} + +func chaos(bpname string) { + if !chaosEnabled { + return + } + switch chaosSetting.Breakpoint { + case "", bpname: + kc := chaosSetting.KillChance + if kc > 0 && mrand.Intn(1000) < int(kc*1000) { + panic(chaosError{}) + } + if bpname == "insert" { + return + } + sc := chaosSetting.SlowdownChance + if sc > 0 && mrand.Intn(1000) < int(sc*1000) { + time.Sleep(chaosSetting.Slowdown) + } + } +} + +type chaosError struct{} + +func (f *flusher) handleChaos(err *error) { + v := recover() + if v == nil { + return + } + if _, ok := v.(chaosError); ok { + f.debugf("Killed by chaos!") + *err = ErrChaos + return + } + panic(v) +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/debug.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/debug.go new file mode 100644 index 0000000..8224bb3 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/debug.go @@ -0,0 +1,109 @@ +package txn + +import ( + "bytes" + "fmt" + "sort" + "sync/atomic" + + "gopkg.in/mgo.v2/bson" +) + +var ( + debugEnabled bool + logger log_Logger +) + +type log_Logger interface { + Output(calldepth int, s string) error +} + +// Specify the *log.Logger where logged messages should be sent to. +func SetLogger(l log_Logger) { + logger = l +} + +// SetDebug enables or disables debugging. +func SetDebug(debug bool) { + debugEnabled = debug +} + +var ErrChaos = fmt.Errorf("interrupted by chaos") + +var debugId uint32 + +func debugPrefix() string { + d := atomic.AddUint32(&debugId, 1) - 1 + s := make([]byte, 0, 10) + for i := uint(0); i < 8; i++ { + s = append(s, "abcdefghijklmnop"[(d>>(4*i))&0xf]) + if d>>(4*(i+1)) == 0 { + break + } + } + s = append(s, ')', ' ') + return string(s) +} + +func logf(format string, args ...interface{}) { + if logger != nil { + logger.Output(2, fmt.Sprintf(format, argsForLog(args)...)) + } +} + +func debugf(format string, args ...interface{}) { + if debugEnabled && logger != nil { + logger.Output(2, fmt.Sprintf(format, argsForLog(args)...)) + } +} + +func argsForLog(args []interface{}) []interface{} { + for i, arg := range args { + switch v := arg.(type) { + case bson.ObjectId: + args[i] = v.Hex() + case []bson.ObjectId: + lst := make([]string, len(v)) + for j, id := range v { + lst[j] = id.Hex() + } + args[i] = lst + case map[docKey][]bson.ObjectId: + buf := &bytes.Buffer{} + var dkeys docKeys + for dkey := range v { + dkeys = append(dkeys, dkey) + } + sort.Sort(dkeys) + for i, dkey := range dkeys { + if i > 0 { + buf.WriteByte(' ') + } + buf.WriteString(fmt.Sprintf("%v: {", dkey)) + for j, id := range v[dkey] { + if j > 0 { + buf.WriteByte(' ') + } + buf.WriteString(id.Hex()) + } + buf.WriteByte('}') + } + args[i] = buf.String() + case map[docKey][]int64: + buf := &bytes.Buffer{} + var dkeys docKeys + for dkey := range v { + dkeys = append(dkeys, dkey) + } + sort.Sort(dkeys) + for i, dkey := range dkeys { + if i > 0 { + buf.WriteByte(' ') + } + buf.WriteString(fmt.Sprintf("%v: %v", dkey, v[dkey])) + } + args[i] = buf.String() + } + } + return args +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/dockey_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/dockey_test.go new file mode 100644 index 0000000..e8dee95 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/dockey_test.go @@ -0,0 +1,205 @@ +package txn + +import ( + "sort" + + . "gopkg.in/check.v1" +) + +type DocKeySuite struct{} + +var _ = Suite(&DocKeySuite{}) + +type T struct { + A int + B string +} + +type T2 struct { + A int + B string +} + +type T3 struct { + A int + B string +} + +type T4 struct { + A int + B string +} + +type T5 struct { + F int + Q string +} + +type T6 struct { + A int + B string +} + +type T7 struct { + A bool + B float64 +} + +type T8 struct { + A int + B string +} + +type T9 struct { + A int + B string + C bool +} + +type T10 struct { + C int `bson:"a"` + D string `bson:"b,omitempty"` +} + +type T11 struct { + C int + D string +} + +type T12 struct { + S string +} + +type T13 struct { + p, q, r bool + S string +} + +var docKeysTests = [][]docKeys{ + {{ + {"c", 1}, + {"c", 5}, + {"c", 2}, + }, { + {"c", 1}, + {"c", 2}, + {"c", 5}, + }}, {{ + {"c", "foo"}, + {"c", "bar"}, + {"c", "bob"}, + }, { + {"c", "bar"}, + {"c", "bob"}, + {"c", "foo"}, + }}, {{ + {"c", 0.2}, + {"c", 0.07}, + {"c", 0.9}, + }, { + {"c", 0.07}, + {"c", 0.2}, + {"c", 0.9}, + }}, {{ + {"c", true}, + {"c", false}, + {"c", true}, + }, { + {"c", false}, + {"c", true}, + {"c", true}, + }}, {{ + {"c", T{1, "b"}}, + {"c", T{1, "a"}}, + {"c", T{0, "b"}}, + {"c", T{0, "a"}}, + }, { + {"c", T{0, "a"}}, + {"c", T{0, "b"}}, + {"c", T{1, "a"}}, + {"c", T{1, "b"}}, + }}, {{ + {"c", T{1, "a"}}, + {"c", T{0, "a"}}, + }, { + {"c", T{0, "a"}}, + {"c", T{1, "a"}}, + }}, {{ + {"c", T3{0, "b"}}, + {"c", T2{1, "b"}}, + {"c", T3{1, "a"}}, + {"c", T2{0, "a"}}, + }, { + {"c", T2{0, "a"}}, + {"c", T3{0, "b"}}, + {"c", T3{1, "a"}}, + {"c", T2{1, "b"}}, + }}, {{ + {"c", T5{1, "b"}}, + {"c", T4{1, "b"}}, + {"c", T5{0, "a"}}, + {"c", T4{0, "a"}}, + }, { + {"c", T4{0, "a"}}, + {"c", T5{0, "a"}}, + {"c", T4{1, "b"}}, + {"c", T5{1, "b"}}, + }}, {{ + {"c", T6{1, "b"}}, + {"c", T7{true, 0.2}}, + {"c", T6{0, "a"}}, + {"c", T7{false, 0.04}}, + }, { + {"c", T6{0, "a"}}, + {"c", T6{1, "b"}}, + {"c", T7{false, 0.04}}, + {"c", T7{true, 0.2}}, + }}, {{ + {"c", T9{1, "b", true}}, + {"c", T8{1, "b"}}, + {"c", T9{0, "a", false}}, + {"c", T8{0, "a"}}, + }, { + {"c", T9{0, "a", false}}, + {"c", T8{0, "a"}}, + {"c", T9{1, "b", true}}, + {"c", T8{1, "b"}}, + }}, {{ + {"b", 2}, + {"a", 5}, + {"c", 2}, + {"b", 1}, + }, { + {"a", 5}, + {"b", 1}, + {"b", 2}, + {"c", 2}, + }}, {{ + {"c", T11{1, "a"}}, + {"c", T11{1, "a"}}, + {"c", T10{1, "a"}}, + }, { + {"c", T10{1, "a"}}, + {"c", T11{1, "a"}}, + {"c", T11{1, "a"}}, + }}, {{ + {"c", T12{"a"}}, + {"c", T13{false, true, false, "a"}}, + {"c", T12{"b"}}, + {"c", T13{false, true, false, "b"}}, + }, { + {"c", T12{"a"}}, + {"c", T13{false, true, false, "a"}}, + {"c", T12{"b"}}, + {"c", T13{false, true, false, "b"}}, + }}, +} + +func (s *DocKeySuite) TestSort(c *C) { + for _, test := range docKeysTests { + keys := test[0] + expected := test[1] + sort.Sort(keys) + c.Check(keys, DeepEquals, expected) + } +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/flusher.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/flusher.go new file mode 100644 index 0000000..25b2f03 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/flusher.go @@ -0,0 +1,968 @@ +package txn + +import ( + "fmt" + + "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +func flush(r *Runner, t *transaction) error { + f := &flusher{ + Runner: r, + goal: t, + goalKeys: make(map[docKey]bool), + queue: make(map[docKey][]token), + debugId: debugPrefix(), + } + for _, dkey := range f.goal.docKeys() { + f.goalKeys[dkey] = true + } + return f.run() +} + +type flusher struct { + *Runner + goal *transaction + goalKeys map[docKey]bool + queue map[docKey][]token + debugId string +} + +func (f *flusher) run() (err error) { + if chaosEnabled { + defer f.handleChaos(&err) + } + + f.debugf("Processing %s", f.goal) + seen := make(map[bson.ObjectId]*transaction) + if err := f.recurse(f.goal, seen); err != nil { + return err + } + if f.goal.done() { + return nil + } + + // Sparse workloads will generally be managed entirely by recurse. + // Getting here means one or more transactions have dependencies + // and perhaps cycles. + + // Build successors data for Tarjan's sort. Must consider + // that entries in txn-queue are not necessarily valid. + successors := make(map[bson.ObjectId][]bson.ObjectId) + ready := true + for _, dqueue := range f.queue { + NextPair: + for i := 0; i < len(dqueue); i++ { + pred := dqueue[i] + predid := pred.id() + predt := seen[predid] + if predt == nil || predt.Nonce != pred.nonce() { + continue + } + predsuccids, ok := successors[predid] + if !ok { + successors[predid] = nil + } + + for j := i + 1; j < len(dqueue); j++ { + succ := dqueue[j] + succid := succ.id() + succt := seen[succid] + if succt == nil || succt.Nonce != succ.nonce() { + continue + } + if _, ok := successors[succid]; !ok { + successors[succid] = nil + } + + // Found a valid pred/succ pair. + i = j - 1 + for _, predsuccid := range predsuccids { + if predsuccid == succid { + continue NextPair + } + } + successors[predid] = append(predsuccids, succid) + if succid == f.goal.Id { + // There are still pre-requisites to handle. + ready = false + } + continue NextPair + } + } + } + f.debugf("Queues: %v", f.queue) + f.debugf("Successors: %v", successors) + if ready { + f.debugf("Goal %s has no real pre-requisites", f.goal) + return f.advance(f.goal, nil, true) + } + + // Robert Tarjan's algorithm for detecting strongly-connected + // components is used for topological sorting and detecting + // cycles at once. The order in which transactions are applied + // in commonly affected documents must be a global agreement. + sorted := tarjanSort(successors) + if debugEnabled { + f.debugf("Tarjan output: %v", sorted) + } + pull := make(map[bson.ObjectId]*transaction) + for i := len(sorted) - 1; i >= 0; i-- { + scc := sorted[i] + f.debugf("Flushing %v", scc) + if len(scc) == 1 { + pull[scc[0]] = seen[scc[0]] + } + for _, id := range scc { + if err := f.advance(seen[id], pull, true); err != nil { + return err + } + } + if len(scc) > 1 { + for _, id := range scc { + pull[id] = seen[id] + } + } + } + return nil +} + +func (f *flusher) recurse(t *transaction, seen map[bson.ObjectId]*transaction) error { + seen[t.Id] = t + err := f.advance(t, nil, false) + if err != errPreReqs { + return err + } + for _, dkey := range t.docKeys() { + for _, dtt := range f.queue[dkey] { + id := dtt.id() + if seen[id] != nil { + continue + } + qt, err := f.load(id) + if err != nil { + return err + } + err = f.recurse(qt, seen) + if err != nil { + return err + } + } + } + return nil +} + +func (f *flusher) advance(t *transaction, pull map[bson.ObjectId]*transaction, force bool) error { + for { + switch t.State { + case tpreparing, tprepared: + revnos, err := f.prepare(t, force) + if err != nil { + return err + } + if t.State != tprepared { + continue + } + if err = f.assert(t, revnos, pull); err != nil { + return err + } + if t.State != tprepared { + continue + } + if err = f.checkpoint(t, revnos); err != nil { + return err + } + case tapplying: + return f.apply(t, pull) + case taborting: + return f.abortOrReload(t, nil, pull) + case tapplied, taborted: + return nil + default: + panic(fmt.Errorf("transaction in unknown state: %q", t.State)) + } + } + panic("unreachable") +} + +type stash string + +const ( + stashStable stash = "" + stashInsert stash = "insert" + stashRemove stash = "remove" +) + +type txnInfo struct { + Queue []token `bson:"txn-queue"` + Revno int64 `bson:"txn-revno,omitempty"` + Insert bson.ObjectId `bson:"txn-insert,omitempty"` + Remove bson.ObjectId `bson:"txn-remove,omitempty"` +} + +type stashState string + +const ( + stashNew stashState = "" + stashInserting stashState = "inserting" +) + +var txnFields = bson.D{{"txn-queue", 1}, {"txn-revno", 1}, {"txn-remove", 1}, {"txn-insert", 1}} + +var errPreReqs = fmt.Errorf("transaction has pre-requisites and force is false") + +// prepare injects t's id onto txn-queue for all affected documents +// and collects the current txn-queue and txn-revno values during +// the process. If the prepared txn-queue indicates that there are +// pre-requisite transactions to be applied and the force parameter +// is false, errPreReqs will be returned. Otherwise, the current +// tip revision numbers for all the documents are returned. +func (f *flusher) prepare(t *transaction, force bool) (revnos []int64, err error) { + if t.State != tpreparing { + return f.rescan(t, force) + } + f.debugf("Preparing %s", t) + + // dkeys being sorted means stable iteration across all runners. This + // isn't strictly required, but reduces the chances of cycles. + dkeys := t.docKeys() + + revno := make(map[docKey]int64) + info := txnInfo{} + tt := tokenFor(t) +NextDoc: + for _, dkey := range dkeys { + change := mgo.Change{ + Update: bson.D{{"$addToSet", bson.D{{"txn-queue", tt}}}}, + ReturnNew: true, + } + c := f.tc.Database.C(dkey.C) + cquery := c.FindId(dkey.Id).Select(txnFields) + + RetryDoc: + change.Upsert = false + chaos("") + if _, err := cquery.Apply(change, &info); err == nil { + if info.Remove == "" { + // Fast path, unless workload is insert/remove heavy. + revno[dkey] = info.Revno + f.queue[dkey] = info.Queue + f.debugf("[A] Prepared document %v with revno %d and queue: %v", dkey, info.Revno, info.Queue) + continue NextDoc + } else { + // Handle remove in progress before preparing it. + if err := f.loadAndApply(info.Remove); err != nil { + return nil, err + } + goto RetryDoc + } + } else if err != mgo.ErrNotFound { + return nil, err + } + + // Document missing. Use stash collection. + change.Upsert = true + chaos("") + _, err := f.sc.FindId(dkey).Apply(change, &info) + if err != nil { + return nil, err + } + if info.Insert != "" { + // Handle insert in progress before preparing it. + if err := f.loadAndApply(info.Insert); err != nil { + return nil, err + } + goto RetryDoc + } + + // Must confirm stash is still in use and is the same one + // prepared, since applying a remove overwrites the stash. + docFound := false + stashFound := false + if err = c.FindId(dkey.Id).Select(txnFields).One(&info); err == nil { + docFound = true + } else if err != mgo.ErrNotFound { + return nil, err + } else if err = f.sc.FindId(dkey).One(&info); err == nil { + stashFound = true + if info.Revno == 0 { + // Missing revno in the stash only happens when it + // has been upserted, in which case it defaults to -1. + // Txn-inserted documents get revno -1 while in the stash + // for the first time, and -revno-1 == 2 when they go live. + info.Revno = -1 + } + } else if err != mgo.ErrNotFound { + return nil, err + } + + if docFound && info.Remove == "" || stashFound && info.Insert == "" { + for _, dtt := range info.Queue { + if dtt != tt { + continue + } + // Found tt properly prepared. + if stashFound { + f.debugf("[B] Prepared document %v on stash with revno %d and queue: %v", dkey, info.Revno, info.Queue) + } else { + f.debugf("[B] Prepared document %v with revno %d and queue: %v", dkey, info.Revno, info.Queue) + } + revno[dkey] = info.Revno + f.queue[dkey] = info.Queue + continue NextDoc + } + } + + // The stash wasn't valid and tt got overwriten. Try again. + f.unstashToken(tt, dkey) + goto RetryDoc + } + + // Save the prepared nonce onto t. + nonce := tt.nonce() + qdoc := bson.D{{"_id", t.Id}, {"s", tpreparing}} + udoc := bson.D{{"$set", bson.D{{"s", tprepared}, {"n", nonce}}}} + chaos("set-prepared") + err = f.tc.Update(qdoc, udoc) + if err == nil { + t.State = tprepared + t.Nonce = nonce + } else if err == mgo.ErrNotFound { + f.debugf("Can't save nonce of %s: LOST RACE", tt) + if err := f.reload(t); err != nil { + return nil, err + } else if t.State == tpreparing { + panic("can't save nonce yet transaction is still preparing") + } else if t.State != tprepared { + return t.Revnos, nil + } + tt = t.token() + } else if err != nil { + return nil, err + } + + prereqs, found := f.hasPreReqs(tt, dkeys) + if !found { + // Must only happen when reloading above. + return f.rescan(t, force) + } else if prereqs && !force { + f.debugf("Prepared queue with %s [has prereqs & not forced].", tt) + return nil, errPreReqs + } + revnos = assembledRevnos(t.Ops, revno) + if !prereqs { + f.debugf("Prepared queue with %s [no prereqs]. Revnos: %v", tt, revnos) + } else { + f.debugf("Prepared queue with %s [forced] Revnos: %v", tt, revnos) + } + return revnos, nil +} + +func (f *flusher) unstashToken(tt token, dkey docKey) error { + qdoc := bson.D{{"_id", dkey}, {"txn-queue", tt}} + udoc := bson.D{{"$pull", bson.D{{"txn-queue", tt}}}} + chaos("") + if err := f.sc.Update(qdoc, udoc); err == nil { + chaos("") + err = f.sc.Remove(bson.D{{"_id", dkey}, {"txn-queue", bson.D{}}}) + } else if err != mgo.ErrNotFound { + return err + } + return nil +} + +func (f *flusher) rescan(t *transaction, force bool) (revnos []int64, err error) { + f.debugf("Rescanning %s", t) + if t.State != tprepared { + panic(fmt.Errorf("rescanning transaction in invalid state: %q", t.State)) + } + + // dkeys being sorted means stable iteration across all + // runners. This isn't strictly required, but reduces the chances + // of cycles. + dkeys := t.docKeys() + + tt := t.token() + if !force { + prereqs, found := f.hasPreReqs(tt, dkeys) + if found && prereqs { + // Its state is already known. + return nil, errPreReqs + } + } + + revno := make(map[docKey]int64) + info := txnInfo{} + for _, dkey := range dkeys { + retry := 0 + + RetryDoc: + c := f.tc.Database.C(dkey.C) + if err := c.FindId(dkey.Id).Select(txnFields).One(&info); err == mgo.ErrNotFound { + // Document is missing. Look in stash. + if err := f.sc.FindId(dkey).One(&info); err == mgo.ErrNotFound { + // Stash also doesn't exist. Maybe someone applied it. + if err := f.reload(t); err != nil { + return nil, err + } else if t.State != tprepared { + return t.Revnos, err + } + // Not applying either. + retry++ + if retry < 3 { + // Retry since there might be an insert/remove race. + goto RetryDoc + } + // Neither the doc nor the stash seem to exist. + return nil, fmt.Errorf("cannot find document %v for applying transaction %s", dkey, t) + } else if err != nil { + return nil, err + } + // Stash found. + if info.Insert != "" { + // Handle insert in progress before assuming ordering is good. + if err := f.loadAndApply(info.Insert); err != nil { + return nil, err + } + goto RetryDoc + } + if info.Revno == 0 { + // Missing revno in the stash means -1. + info.Revno = -1 + } + } else if err != nil { + return nil, err + } else if info.Remove != "" { + // Handle remove in progress before assuming ordering is good. + if err := f.loadAndApply(info.Remove); err != nil { + return nil, err + } + goto RetryDoc + } + revno[dkey] = info.Revno + + found := false + for _, id := range info.Queue { + if id == tt { + found = true + break + } + } + f.queue[dkey] = info.Queue + if !found { + // Previously set txn-queue was popped by someone. + // Transaction is being/has been applied elsewhere. + f.debugf("Rescanned document %v misses %s in queue: %v", dkey, tt, info.Queue) + err := f.reload(t) + if t.State == tpreparing || t.State == tprepared { + panic("rescanned document misses transaction in queue") + } + return t.Revnos, err + } + } + + prereqs, found := f.hasPreReqs(tt, dkeys) + if !found { + panic("rescanning loop guarantees that this can't happen") + } else if prereqs && !force { + f.debugf("Rescanned queue with %s: has prereqs, not forced", tt) + return nil, errPreReqs + } + revnos = assembledRevnos(t.Ops, revno) + if !prereqs { + f.debugf("Rescanned queue with %s: no prereqs, revnos: %v", tt, revnos) + } else { + f.debugf("Rescanned queue with %s: has prereqs, forced, revnos: %v", tt, revnos) + } + return revnos, nil +} + +func assembledRevnos(ops []Op, revno map[docKey]int64) []int64 { + revnos := make([]int64, len(ops)) + for i, op := range ops { + dkey := op.docKey() + revnos[i] = revno[dkey] + drevno := revno[dkey] + switch { + case op.Insert != nil && drevno < 0: + revno[dkey] = -drevno + 1 + case op.Update != nil && drevno >= 0: + revno[dkey] = drevno + 1 + case op.Remove && drevno >= 0: + revno[dkey] = -drevno - 1 + } + } + return revnos +} + +func (f *flusher) hasPreReqs(tt token, dkeys docKeys) (prereqs, found bool) { + found = true +NextDoc: + for _, dkey := range dkeys { + for _, dtt := range f.queue[dkey] { + if dtt == tt { + continue NextDoc + } else if dtt.id() != tt.id() { + prereqs = true + } + } + found = false + } + return +} + +func (f *flusher) reload(t *transaction) error { + var newt transaction + query := f.tc.FindId(t.Id) + query.Select(bson.D{{"s", 1}, {"n", 1}, {"r", 1}}) + if err := query.One(&newt); err != nil { + return fmt.Errorf("failed to reload transaction: %v", err) + } + t.State = newt.State + t.Nonce = newt.Nonce + t.Revnos = newt.Revnos + f.debugf("Reloaded %s: %q", t, t.State) + return nil +} + +func (f *flusher) loadAndApply(id bson.ObjectId) error { + t, err := f.load(id) + if err != nil { + return err + } + return f.advance(t, nil, true) +} + +// assert verifies that all assertions in t match the content that t +// will be applied upon. If an assertion fails, the transaction state +// is changed to aborted. +func (f *flusher) assert(t *transaction, revnos []int64, pull map[bson.ObjectId]*transaction) error { + f.debugf("Asserting %s with revnos %v", t, revnos) + if t.State != tprepared { + panic(fmt.Errorf("asserting transaction in invalid state: %q", t.State)) + } + qdoc := make(bson.D, 3) + revno := make(map[docKey]int64) + for i, op := range t.Ops { + dkey := op.docKey() + if _, ok := revno[dkey]; !ok { + revno[dkey] = revnos[i] + } + if op.Assert == nil { + continue + } + if op.Assert == DocMissing { + if revnos[i] >= 0 { + return f.abortOrReload(t, revnos, pull) + } + continue + } + if op.Insert != nil { + return fmt.Errorf("Insert can only Assert txn.DocMissing", op.Assert) + } + // if revnos[i] < 0 { abort }? + + qdoc = append(qdoc[:0], bson.DocElem{"_id", op.Id}) + if op.Assert != DocMissing { + var revnoq interface{} + if n := revno[dkey]; n == 0 { + revnoq = bson.D{{"$exists", false}} + } else { + revnoq = n + } + // XXX Add tt to the query here, once we're sure it's all working. + // Not having it increases the chances of breaking on bad logic. + qdoc = append(qdoc, bson.DocElem{"txn-revno", revnoq}) + if op.Assert != DocExists { + qdoc = append(qdoc, bson.DocElem{"$or", []interface{}{op.Assert}}) + } + } + + c := f.tc.Database.C(op.C) + if err := c.Find(qdoc).Select(bson.D{{"_id", 1}}).One(nil); err == mgo.ErrNotFound { + // Assertion failed or someone else started applying. + return f.abortOrReload(t, revnos, pull) + } else if err != nil { + return err + } + } + f.debugf("Asserting %s succeeded", t) + return nil +} + +func (f *flusher) abortOrReload(t *transaction, revnos []int64, pull map[bson.ObjectId]*transaction) (err error) { + f.debugf("Aborting or reloading %s (was %q)", t, t.State) + if t.State == tprepared { + qdoc := bson.D{{"_id", t.Id}, {"s", tprepared}} + udoc := bson.D{{"$set", bson.D{{"s", taborting}}}} + chaos("set-aborting") + if err = f.tc.Update(qdoc, udoc); err == nil { + t.State = taborting + } else if err == mgo.ErrNotFound { + if err = f.reload(t); err != nil || t.State != taborting { + f.debugf("Won't abort %s. Reloaded state: %q", t, t.State) + return err + } + } else { + return err + } + } else if t.State != taborting { + panic(fmt.Errorf("aborting transaction in invalid state: %q", t.State)) + } + + if len(revnos) > 0 { + if pull == nil { + pull = map[bson.ObjectId]*transaction{t.Id: t} + } + seen := make(map[docKey]bool) + for i, op := range t.Ops { + dkey := op.docKey() + if seen[op.docKey()] { + continue + } + seen[dkey] = true + + pullAll := tokensToPull(f.queue[dkey], pull, "") + if len(pullAll) == 0 { + continue + } + udoc := bson.D{{"$pullAll", bson.D{{"txn-queue", pullAll}}}} + chaos("") + if revnos[i] < 0 { + err = f.sc.UpdateId(dkey, udoc) + } else { + c := f.tc.Database.C(dkey.C) + err = c.UpdateId(dkey.Id, udoc) + } + if err != nil && err != mgo.ErrNotFound { + return err + } + } + } + udoc := bson.D{{"$set", bson.D{{"s", taborted}}}} + chaos("set-aborted") + if err := f.tc.UpdateId(t.Id, udoc); err != nil && err != mgo.ErrNotFound { + return err + } + t.State = taborted + f.debugf("Aborted %s", t) + return nil +} + +func (f *flusher) checkpoint(t *transaction, revnos []int64) error { + var debugRevnos map[docKey][]int64 + if debugEnabled { + debugRevnos = make(map[docKey][]int64) + for i, op := range t.Ops { + dkey := op.docKey() + debugRevnos[dkey] = append(debugRevnos[dkey], revnos[i]) + } + f.debugf("Ready to apply %s. Saving revnos %v", t, debugRevnos) + } + + // Save in t the txn-revno values the transaction must run on. + qdoc := bson.D{{"_id", t.Id}, {"s", tprepared}} + udoc := bson.D{{"$set", bson.D{{"s", tapplying}, {"r", revnos}}}} + chaos("set-applying") + err := f.tc.Update(qdoc, udoc) + if err == nil { + t.State = tapplying + t.Revnos = revnos + f.debugf("Ready to apply %s. Saving revnos %v: DONE", t, debugRevnos) + } else if err == mgo.ErrNotFound { + f.debugf("Ready to apply %s. Saving revnos %v: LOST RACE", t, debugRevnos) + return f.reload(t) + } + return nil +} + +func (f *flusher) apply(t *transaction, pull map[bson.ObjectId]*transaction) error { + f.debugf("Applying transaction %s", t) + if t.State != tapplying { + panic(fmt.Errorf("applying transaction in invalid state: %q", t.State)) + } + if pull == nil { + pull = map[bson.ObjectId]*transaction{t.Id: t} + } + + logRevnos := append([]int64(nil), t.Revnos...) + logDoc := bson.D{{"_id", t.Id}} + + tt := tokenFor(t) + for i := range t.Ops { + op := &t.Ops[i] + dkey := op.docKey() + dqueue := f.queue[dkey] + revno := t.Revnos[i] + + var opName string + if debugEnabled { + opName = op.name() + f.debugf("Applying %s op %d (%s) on %v with txn-revno %d", t, i, opName, dkey, revno) + } + + c := f.tc.Database.C(op.C) + + qdoc := bson.D{{"_id", dkey.Id}, {"txn-revno", revno}, {"txn-queue", tt}} + if op.Insert != nil { + qdoc[0].Value = dkey + if revno == -1 { + qdoc[1].Value = bson.D{{"$exists", false}} + } + } else if revno == 0 { + // There's no document with revno 0. The only way to see it is + // when an existent document participates in a transaction the + // first time. Txn-inserted documents get revno -1 while in the + // stash for the first time, and -revno-1 == 2 when they go live. + qdoc[1].Value = bson.D{{"$exists", false}} + } + + pullAll := tokensToPull(dqueue, pull, tt) + + var d bson.D + var outcome string + var err error + switch { + case op.Update != nil: + if revno < 0 { + err = mgo.ErrNotFound + f.debugf("Won't try to apply update op; negative revision means the document is missing or stashed") + } else { + newRevno := revno + 1 + logRevnos[i] = newRevno + if d, err = objToDoc(op.Update); err != nil { + return err + } + if d, err = addToDoc(d, "$pullAll", bson.D{{"txn-queue", pullAll}}); err != nil { + return err + } + if d, err = addToDoc(d, "$set", bson.D{{"txn-revno", newRevno}}); err != nil { + return err + } + chaos("") + err = c.Update(qdoc, d) + } + case op.Remove: + if revno < 0 { + err = mgo.ErrNotFound + } else { + newRevno := -revno - 1 + logRevnos[i] = newRevno + nonce := newNonce() + stash := txnInfo{} + change := mgo.Change{ + Update: bson.D{{"$push", bson.D{{"n", nonce}}}}, + Upsert: true, + ReturnNew: true, + } + if _, err = f.sc.FindId(dkey).Apply(change, &stash); err != nil { + return err + } + change = mgo.Change{ + Update: bson.D{{"$set", bson.D{{"txn-remove", t.Id}}}}, + ReturnNew: true, + } + var info txnInfo + if _, err = c.Find(qdoc).Apply(change, &info); err == nil { + // The document still exists so the stash previously + // observed was either out of date or necessarily + // contained the token being applied. + f.debugf("Marked document %v to be removed on revno %d with queue: %v", dkey, info.Revno, info.Queue) + updated := false + if !hasToken(stash.Queue, tt) { + var set, unset bson.D + if revno == 0 { + // Missing revno in stash means -1. + set = bson.D{{"txn-queue", info.Queue}} + unset = bson.D{{"n", 1}, {"txn-revno", 1}} + } else { + set = bson.D{{"txn-queue", info.Queue}, {"txn-revno", newRevno}} + unset = bson.D{{"n", 1}} + } + qdoc := bson.D{{"_id", dkey}, {"n", nonce}} + udoc := bson.D{{"$set", set}, {"$unset", unset}} + if err = f.sc.Update(qdoc, udoc); err == nil { + updated = true + } else if err != mgo.ErrNotFound { + return err + } + } + if updated { + f.debugf("Updated stash for document %v with revno %d and queue: %v", dkey, newRevno, info.Queue) + } else { + f.debugf("Stash for document %v was up-to-date", dkey) + } + err = c.Remove(qdoc) + } + } + case op.Insert != nil: + if revno >= 0 { + err = mgo.ErrNotFound + } else { + newRevno := -revno + 1 + logRevnos[i] = newRevno + if d, err = objToDoc(op.Insert); err != nil { + return err + } + change := mgo.Change{ + Update: bson.D{{"$set", bson.D{{"txn-insert", t.Id}}}}, + ReturnNew: true, + } + chaos("") + var info txnInfo + if _, err = f.sc.Find(qdoc).Apply(change, &info); err == nil { + f.debugf("Stash for document %v has revno %d and queue: %v", dkey, info.Revno, info.Queue) + d = setInDoc(d, bson.D{{"_id", op.Id}, {"txn-revno", newRevno}, {"txn-queue", info.Queue}}) + // Unlikely yet unfortunate race in here if this gets seriously + // delayed. If someone inserts+removes meanwhile, this will + // reinsert, and there's no way to avoid that while keeping the + // collection clean or compromising sharding. applyOps can solve + // the former, but it can't shard (SERVER-1439). + chaos("insert") + err = c.Insert(d) + if err == nil || mgo.IsDup(err) { + if err == nil { + f.debugf("New document %v inserted with revno %d and queue: %v", dkey, info.Revno, info.Queue) + } else { + f.debugf("Document %v already existed", dkey) + } + chaos("") + if err = f.sc.Remove(qdoc); err == nil { + f.debugf("Stash for document %v removed", dkey) + } + } + } + } + case op.Assert != nil: + // Pure assertion. No changes to apply. + } + if err == nil { + outcome = "DONE" + } else if err == mgo.ErrNotFound || mgo.IsDup(err) { + outcome = "MISS" + err = nil + } else { + outcome = err.Error() + } + if debugEnabled { + f.debugf("Applying %s op %d (%s) on %v with txn-revno %d: %s", t, i, opName, dkey, revno, outcome) + } + if err != nil { + return err + } + + if f.lc != nil && op.isChange() { + // Add change to the log document. + var dr bson.D + for li := range logDoc { + elem := &logDoc[li] + if elem.Name == op.C { + dr = elem.Value.(bson.D) + break + } + } + if dr == nil { + logDoc = append(logDoc, bson.DocElem{op.C, bson.D{{"d", []interface{}{}}, {"r", []int64{}}}}) + dr = logDoc[len(logDoc)-1].Value.(bson.D) + } + dr[0].Value = append(dr[0].Value.([]interface{}), op.Id) + dr[1].Value = append(dr[1].Value.([]int64), logRevnos[i]) + } + } + t.State = tapplied + + if f.lc != nil { + // Insert log document into the changelog collection. + f.debugf("Inserting %s into change log", t) + err := f.lc.Insert(logDoc) + if err != nil && !mgo.IsDup(err) { + return err + } + } + + // It's been applied, so errors are ignored here. It's fine for someone + // else to win the race and mark it as applied, and it's also fine for + // it to remain pending until a later point when someone will perceive + // it has been applied and mark it at such. + f.debugf("Marking %s as applied", t) + chaos("set-applied") + f.tc.Update(bson.D{{"_id", t.Id}, {"s", tapplying}}, bson.D{{"$set", bson.D{{"s", tapplied}}}}) + return nil +} + +func tokensToPull(dqueue []token, pull map[bson.ObjectId]*transaction, dontPull token) []token { + var result []token + for j := len(dqueue) - 1; j >= 0; j-- { + dtt := dqueue[j] + if dtt == dontPull { + continue + } + if _, ok := pull[dtt.id()]; ok { + // It was handled before and this is a leftover invalid + // nonce in the queue. Cherry-pick it out. + result = append(result, dtt) + } + } + return result +} + +func objToDoc(obj interface{}) (d bson.D, err error) { + data, err := bson.Marshal(obj) + if err != nil { + return nil, err + } + err = bson.Unmarshal(data, &d) + if err != nil { + return nil, err + } + return d, err +} + +func addToDoc(doc bson.D, key string, add bson.D) (bson.D, error) { + for i := range doc { + elem := &doc[i] + if elem.Name != key { + continue + } + if old, ok := elem.Value.(bson.D); ok { + elem.Value = append(old, add...) + return doc, nil + } else { + return nil, fmt.Errorf("invalid %q value in change document: %#v", key, elem.Value) + } + } + return append(doc, bson.DocElem{key, add}), nil +} + +func setInDoc(doc bson.D, set bson.D) bson.D { + dlen := len(doc) +NextS: + for s := range set { + sname := set[s].Name + for d := 0; d < dlen; d++ { + if doc[d].Name == sname { + doc[d].Value = set[s].Value + continue NextS + } + } + doc = append(doc, set[s]) + } + return doc +} + +func hasToken(tokens []token, tt token) bool { + for _, ttt := range tokens { + if ttt == tt { + return true + } + } + return false +} + +func (f *flusher) debugf(format string, args ...interface{}) { + if !debugEnabled { + return + } + debugf(f.debugId+format, args...) +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/mgo_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/mgo_test.go new file mode 100644 index 0000000..5abc473 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/mgo_test.go @@ -0,0 +1,101 @@ +package txn_test + +import ( + "bytes" + "gopkg.in/mgo.v2" + . "gopkg.in/check.v1" + "os/exec" + "time" +) + +// ---------------------------------------------------------------------------- +// The mgo test suite + +type MgoSuite struct { + output bytes.Buffer + server *exec.Cmd + session *mgo.Session +} + +var mgoaddr = "127.0.0.1:50017" + +func (s *MgoSuite) SetUpSuite(c *C) { + //mgo.SetDebug(true) + mgo.SetStats(true) + dbdir := c.MkDir() + args := []string{ + "--dbpath", dbdir, + "--bind_ip", "127.0.0.1", + "--port", "50017", + "--nssize", "1", + "--noprealloc", + "--smallfiles", + "--nojournal", + "-vvvvv", + } + s.server = exec.Command("mongod", args...) + s.server.Stdout = &s.output + s.server.Stderr = &s.output + err := s.server.Start() + if err != nil { + panic(err) + } +} + +func (s *MgoSuite) TearDownSuite(c *C) { + s.server.Process.Kill() + s.server.Process.Wait() +} + +func (s *MgoSuite) SetUpTest(c *C) { + err := DropAll(mgoaddr) + if err != nil { + panic(err) + } + mgo.SetLogger(c) + mgo.ResetStats() + + s.session, err = mgo.Dial(mgoaddr) + c.Assert(err, IsNil) +} + +func (s *MgoSuite) TearDownTest(c *C) { + if s.session != nil { + s.session.Close() + } + for i := 0; ; i++ { + stats := mgo.GetStats() + if stats.SocketsInUse == 0 && stats.SocketsAlive == 0 { + break + } + if i == 20 { + c.Fatal("Test left sockets in a dirty state") + } + c.Logf("Waiting for sockets to die: %d in use, %d alive", stats.SocketsInUse, stats.SocketsAlive) + time.Sleep(500 * time.Millisecond) + } +} + +func DropAll(mongourl string) (err error) { + session, err := mgo.Dial(mongourl) + if err != nil { + return err + } + defer session.Close() + + names, err := session.DatabaseNames() + if err != nil { + return err + } + for _, name := range names { + switch name { + case "admin", "local", "config": + default: + err = session.DB(name).DropDatabase() + if err != nil { + return err + } + } + } + return nil +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/sim_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/sim_test.go new file mode 100644 index 0000000..35f7048 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/sim_test.go @@ -0,0 +1,389 @@ +package txn_test + +import ( + "flag" + "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" + "gopkg.in/mgo.v2/txn" + . "gopkg.in/check.v1" + "math/rand" + "time" +) + +var ( + duration = flag.Duration("duration", 200*time.Millisecond, "duration for each simulation") + seed = flag.Int64("seed", 0, "seed for rand") +) + +type params struct { + killChance float64 + slowdownChance float64 + slowdown time.Duration + + unsafe bool + workers int + accounts int + changeHalf bool + reinsertCopy bool + reinsertZeroed bool + changelog bool + + changes int +} + +func (s *S) TestSim1Worker(c *C) { + simulate(c, params{ + workers: 1, + accounts: 4, + killChance: 0.01, + slowdownChance: 0.3, + slowdown: 100 * time.Millisecond, + }) +} + +func (s *S) TestSim4WorkersDense(c *C) { + simulate(c, params{ + workers: 4, + accounts: 2, + killChance: 0.01, + slowdownChance: 0.3, + slowdown: 100 * time.Millisecond, + }) +} + +func (s *S) TestSim4WorkersSparse(c *C) { + simulate(c, params{ + workers: 4, + accounts: 10, + killChance: 0.01, + slowdownChance: 0.3, + slowdown: 100 * time.Millisecond, + }) +} + +func (s *S) TestSimHalf1Worker(c *C) { + simulate(c, params{ + workers: 1, + accounts: 4, + changeHalf: true, + killChance: 0.01, + slowdownChance: 0.3, + slowdown: 100 * time.Millisecond, + }) +} + +func (s *S) TestSimHalf4WorkersDense(c *C) { + simulate(c, params{ + workers: 4, + accounts: 2, + changeHalf: true, + killChance: 0.01, + slowdownChance: 0.3, + slowdown: 100 * time.Millisecond, + }) +} + +func (s *S) TestSimHalf4WorkersSparse(c *C) { + simulate(c, params{ + workers: 4, + accounts: 10, + changeHalf: true, + killChance: 0.01, + slowdownChance: 0.3, + slowdown: 100 * time.Millisecond, + }) +} + +func (s *S) TestSimReinsertCopy1Worker(c *C) { + simulate(c, params{ + workers: 1, + accounts: 10, + reinsertCopy: true, + killChance: 0.01, + slowdownChance: 0.3, + slowdown: 100 * time.Millisecond, + }) +} + +func (s *S) TestSimReinsertCopy4Workers(c *C) { + simulate(c, params{ + workers: 4, + accounts: 10, + reinsertCopy: true, + killChance: 0.01, + slowdownChance: 0.3, + slowdown: 100 * time.Millisecond, + }) +} + +func (s *S) TestSimReinsertZeroed1Worker(c *C) { + simulate(c, params{ + workers: 1, + accounts: 10, + reinsertZeroed: true, + killChance: 0.01, + slowdownChance: 0.3, + slowdown: 100 * time.Millisecond, + }) +} + +func (s *S) TestSimReinsertZeroed4Workers(c *C) { + simulate(c, params{ + workers: 4, + accounts: 10, + reinsertZeroed: true, + killChance: 0.01, + slowdownChance: 0.3, + slowdown: 100 * time.Millisecond, + }) +} + +func (s *S) TestSimChangeLog(c *C) { + simulate(c, params{ + workers: 4, + accounts: 10, + killChance: 0.01, + slowdownChance: 0.3, + slowdown: 100 * time.Millisecond, + changelog: true, + }) +} + + +type balanceChange struct { + id bson.ObjectId + origin int + target int + amount int +} + +func simulate(c *C, params params) { + seed := *seed + if seed == 0 { + seed = time.Now().UnixNano() + } + rand.Seed(seed) + c.Logf("Seed: %v", seed) + + txn.SetChaos(txn.Chaos{ + KillChance: params.killChance, + SlowdownChance: params.slowdownChance, + Slowdown: params.slowdown, + }) + defer txn.SetChaos(txn.Chaos{}) + + session, err := mgo.Dial(mgoaddr) + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("test") + tc := db.C("tc") + + runner := txn.NewRunner(tc) + + tclog := db.C("tc.log") + if params.changelog { + info := mgo.CollectionInfo{ + Capped: true, + MaxBytes: 1000000, + } + err := tclog.Create(&info) + c.Assert(err, IsNil) + runner.ChangeLog(tclog) + } + + accounts := db.C("accounts") + for i := 0; i < params.accounts; i++ { + err := accounts.Insert(M{"_id": i, "balance": 300}) + c.Assert(err, IsNil) + } + var stop time.Time + if params.changes <= 0 { + stop = time.Now().Add(*duration) + } + + max := params.accounts + if params.reinsertCopy || params.reinsertZeroed { + max = int(float64(params.accounts) * 1.5) + } + + changes := make(chan balanceChange, 1024) + + //session.SetMode(mgo.Eventual, true) + for i := 0; i < params.workers; i++ { + go func() { + n := 0 + for { + if n > 0 && n == params.changes { + break + } + if !stop.IsZero() && time.Now().After(stop) { + break + } + + change := balanceChange{ + id: bson.NewObjectId(), + origin: rand.Intn(max), + target: rand.Intn(max), + amount: 100, + } + + var old Account + var oldExists bool + if params.reinsertCopy || params.reinsertZeroed { + if err := accounts.FindId(change.origin).One(&old); err != mgo.ErrNotFound { + c.Check(err, IsNil) + change.amount = old.Balance + oldExists = true + } + } + + var ops []txn.Op + switch { + case params.reinsertCopy && oldExists: + ops = []txn.Op{{ + C: "accounts", + Id: change.origin, + Assert: M{"balance": change.amount}, + Remove: true, + }, { + C: "accounts", + Id: change.target, + Assert: txn.DocMissing, + Insert: M{"balance": change.amount}, + }} + case params.reinsertZeroed && oldExists: + ops = []txn.Op{{ + C: "accounts", + Id: change.target, + Assert: txn.DocMissing, + Insert: M{"balance": 0}, + }, { + C: "accounts", + Id: change.origin, + Assert: M{"balance": change.amount}, + Remove: true, + }, { + C: "accounts", + Id: change.target, + Assert: txn.DocExists, + Update: M{"$inc": M{"balance": change.amount}}, + }} + case params.changeHalf: + ops = []txn.Op{{ + C: "accounts", + Id: change.origin, + Assert: M{"balance": M{"$gte": change.amount}}, + Update: M{"$inc": M{"balance": -change.amount / 2}}, + }, { + C: "accounts", + Id: change.target, + Assert: txn.DocExists, + Update: M{"$inc": M{"balance": change.amount / 2}}, + }, { + C: "accounts", + Id: change.origin, + Update: M{"$inc": M{"balance": -change.amount / 2}}, + }, { + C: "accounts", + Id: change.target, + Update: M{"$inc": M{"balance": change.amount / 2}}, + }} + default: + ops = []txn.Op{{ + C: "accounts", + Id: change.origin, + Assert: M{"balance": M{"$gte": change.amount}}, + Update: M{"$inc": M{"balance": -change.amount}}, + }, { + C: "accounts", + Id: change.target, + Assert: txn.DocExists, + Update: M{"$inc": M{"balance": change.amount}}, + }} + } + + err = runner.Run(ops, change.id, nil) + if err != nil && err != txn.ErrAborted && err != txn.ErrChaos { + c.Check(err, IsNil) + } + n++ + changes <- change + } + changes <- balanceChange{} + }() + } + + alive := params.workers + changeLog := make([]balanceChange, 0, 1024) + for alive > 0 { + change := <-changes + if change.id == "" { + alive-- + } else { + changeLog = append(changeLog, change) + } + } + c.Check(len(changeLog), Not(Equals), 0, Commentf("No operations were even attempted.")) + + txn.SetChaos(txn.Chaos{}) + err = runner.ResumeAll() + c.Assert(err, IsNil) + + n, err := accounts.Count() + c.Check(err, IsNil) + c.Check(n, Equals, params.accounts, Commentf("Number of accounts has changed.")) + + n, err = accounts.Find(M{"balance": M{"$lt": 0}}).Count() + c.Check(err, IsNil) + c.Check(n, Equals, 0, Commentf("There are %d accounts with negative balance.", n)) + + globalBalance := 0 + iter := accounts.Find(nil).Iter() + account := Account{} + for iter.Next(&account) { + globalBalance += account.Balance + } + c.Check(iter.Close(), IsNil) + c.Check(globalBalance, Equals, params.accounts*300, Commentf("Total amount of money should be constant.")) + + // Compute and verify the exact final state of all accounts. + balance := make(map[int]int) + for i := 0; i < params.accounts; i++ { + balance[i] += 300 + } + var applied, aborted int + for _, change := range changeLog { + err := runner.Resume(change.id) + if err == txn.ErrAborted { + aborted++ + continue + } else if err != nil { + c.Fatalf("resuming %s failed: %v", change.id, err) + } + balance[change.origin] -= change.amount + balance[change.target] += change.amount + applied++ + } + iter = accounts.Find(nil).Iter() + for iter.Next(&account) { + c.Assert(account.Balance, Equals, balance[account.Id]) + } + c.Check(iter.Close(), IsNil) + c.Logf("Total transactions: %d (%d applied, %d aborted)", len(changeLog), applied, aborted) + + if params.changelog { + n, err := tclog.Count() + c.Assert(err, IsNil) + // Check if the capped collection is full. + dummy := make([]byte, 1024) + tclog.Insert(M{"_id": bson.NewObjectId(), "dummy": dummy}) + m, err := tclog.Count() + c.Assert(err, IsNil) + if m == n+1 { + // Wasn't full, so it must have seen it all. + c.Assert(err, IsNil) + c.Assert(n, Equals, applied) + } + } +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/tarjan.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/tarjan.go new file mode 100644 index 0000000..e56541c --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/tarjan.go @@ -0,0 +1,94 @@ +package txn + +import ( + "gopkg.in/mgo.v2/bson" + "sort" +) + +func tarjanSort(successors map[bson.ObjectId][]bson.ObjectId) [][]bson.ObjectId { + // http://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm + data := &tarjanData{ + successors: successors, + nodes: make([]tarjanNode, 0, len(successors)), + index: make(map[bson.ObjectId]int, len(successors)), + } + + for id := range successors { + id := bson.ObjectId(string(id)) + if _, seen := data.index[id]; !seen { + data.strongConnect(id) + } + } + + // Sort connected components to stabilize the algorithm. + for _, ids := range data.output { + if len(ids) > 1 { + sort.Sort(idList(ids)) + } + } + return data.output +} + +type tarjanData struct { + successors map[bson.ObjectId][]bson.ObjectId + output [][]bson.ObjectId + + nodes []tarjanNode + stack []bson.ObjectId + index map[bson.ObjectId]int +} + +type tarjanNode struct { + lowlink int + stacked bool +} + +type idList []bson.ObjectId + +func (l idList) Len() int { return len(l) } +func (l idList) Swap(i, j int) { l[i], l[j] = l[j], l[i] } +func (l idList) Less(i, j int) bool { return l[i] < l[j] } + +func (data *tarjanData) strongConnect(id bson.ObjectId) *tarjanNode { + index := len(data.nodes) + data.index[id] = index + data.stack = append(data.stack, id) + data.nodes = append(data.nodes, tarjanNode{index, true}) + node := &data.nodes[index] + + for _, succid := range data.successors[id] { + succindex, seen := data.index[succid] + if !seen { + succnode := data.strongConnect(succid) + if succnode.lowlink < node.lowlink { + node.lowlink = succnode.lowlink + } + } else if data.nodes[succindex].stacked { + // Part of the current strongly-connected component. + if succindex < node.lowlink { + node.lowlink = succindex + } + } + } + + if node.lowlink == index { + // Root node; pop stack and output new + // strongly-connected component. + var scc []bson.ObjectId + i := len(data.stack) - 1 + for { + stackid := data.stack[i] + stackindex := data.index[stackid] + data.nodes[stackindex].stacked = false + scc = append(scc, stackid) + if stackindex == index { + break + } + i-- + } + data.stack = data.stack[:i] + data.output = append(data.output, scc) + } + + return node +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/tarjan_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/tarjan_test.go new file mode 100644 index 0000000..79745c3 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/tarjan_test.go @@ -0,0 +1,44 @@ +package txn + +import ( + "fmt" + "gopkg.in/mgo.v2/bson" + . "gopkg.in/check.v1" +) + +type TarjanSuite struct{} + +var _ = Suite(TarjanSuite{}) + +func bid(n int) bson.ObjectId { + return bson.ObjectId(fmt.Sprintf("%024d", n)) +} + +func bids(ns ...int) (ids []bson.ObjectId) { + for _, n := range ns { + ids = append(ids, bid(n)) + } + return +} + +func (TarjanSuite) TestExample(c *C) { + successors := map[bson.ObjectId][]bson.ObjectId{ + bid(1): bids(2, 3), + bid(2): bids(1, 5), + bid(3): bids(4), + bid(4): bids(3, 5), + bid(5): bids(6), + bid(6): bids(7), + bid(7): bids(8), + bid(8): bids(6, 9), + bid(9): bids(), + } + + c.Assert(tarjanSort(successors), DeepEquals, [][]bson.ObjectId{ + bids(9), + bids(6, 7, 8), + bids(5), + bids(3, 4), + bids(1, 2), + }) +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/txn.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/txn.go new file mode 100644 index 0000000..5809e2d --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/txn.go @@ -0,0 +1,609 @@ +// The txn package implements support for multi-document transactions. +// +// For details check the following blog post: +// +// http://blog.labix.org/2012/08/22/multi-doc-transactions-for-mongodb +// +package txn + +import ( + "encoding/binary" + "fmt" + "reflect" + "sort" + "strings" + "sync" + + "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" + + crand "crypto/rand" + mrand "math/rand" +) + +type state int + +const ( + tpreparing state = 1 // One or more documents not prepared + tprepared state = 2 // Prepared but not yet ready to run + taborting state = 3 // Assertions failed, cleaning up + tapplying state = 4 // Changes are in progress + taborted state = 5 // Pre-conditions failed, nothing done + tapplied state = 6 // All changes applied +) + +func (s state) String() string { + switch s { + case tpreparing: + return "preparing" + case tprepared: + return "prepared" + case taborting: + return "aborting" + case tapplying: + return "applying" + case taborted: + return "aborted" + case tapplied: + return "applied" + } + panic(fmt.Errorf("unknown state: %d", s)) +} + +var rand *mrand.Rand +var randmu sync.Mutex + +func init() { + var seed int64 + err := binary.Read(crand.Reader, binary.BigEndian, &seed) + if err != nil { + panic(err) + } + rand = mrand.New(mrand.NewSource(seed)) +} + +type transaction struct { + Id bson.ObjectId `bson:"_id"` + State state `bson:"s"` + Info interface{} `bson:"i,omitempty"` + Ops []Op `bson:"o"` + Nonce string `bson:"n,omitempty"` + Revnos []int64 `bson:"r,omitempty"` + + docKeysCached docKeys +} + +func (t *transaction) String() string { + if t.Nonce == "" { + return t.Id.Hex() + } + return string(t.token()) +} + +func (t *transaction) done() bool { + return t.State == tapplied || t.State == taborted +} + +func (t *transaction) token() token { + if t.Nonce == "" { + panic("transaction has no nonce") + } + return tokenFor(t) +} + +func (t *transaction) docKeys() docKeys { + if t.docKeysCached != nil { + return t.docKeysCached + } + dkeys := make(docKeys, 0, len(t.Ops)) +NextOp: + for _, op := range t.Ops { + dkey := op.docKey() + for i := range dkeys { + if dkey == dkeys[i] { + continue NextOp + } + } + dkeys = append(dkeys, dkey) + } + sort.Sort(dkeys) + t.docKeysCached = dkeys + return dkeys +} + +// tokenFor returns a unique transaction token that +// is composed by t's id and a nonce. If t already has +// a nonce assigned to it, it will be used, otherwise +// a new nonce will be generated. +func tokenFor(t *transaction) token { + nonce := t.Nonce + if nonce == "" { + nonce = newNonce() + } + return token(t.Id.Hex() + "_" + nonce) +} + +func newNonce() string { + randmu.Lock() + r := rand.Uint32() + randmu.Unlock() + n := make([]byte, 8) + for i := uint(0); i < 8; i++ { + n[i] = "0123456789abcdef"[(r>>(4*i))&0xf] + } + return string(n) +} + +type token string + +func (tt token) id() bson.ObjectId { return bson.ObjectIdHex(string(tt[:24])) } +func (tt token) nonce() string { return string(tt[25:]) } + +// Op represents an operation to a single document that may be +// applied as part of a transaction with other operations. +type Op struct { + // C and Id identify the collection and document this operation + // refers to. Id is matched against the "_id" document field. + C string `bson:"c"` + Id interface{} `bson:"d"` + + // Assert optionally holds a query document that is used to + // test the operation document at the time the transaction is + // going to be applied. The assertions for all operations in + // a transaction are tested before any changes take place, + // and the transaction is entirely aborted if any of them + // fails. This is also the only way to prevent a transaction + // from being being applied (the transaction continues despite + // the outcome of Insert, Update, and Remove). + Assert interface{} `bson:"a,omitempty"` + + // The Insert, Update and Remove fields describe the mutation + // intended by the operation. At most one of them may be set + // per operation. If none are set, Assert must be set and the + // operation becomes a read-only test. + // + // Insert holds the document to be inserted at the time the + // transaction is applied. The Id field will be inserted + // into the document automatically as its _id field. The + // transaction will continue even if the document already + // exists. Use Assert with txn.DocMissing if the insertion is + // required. + // + // Update holds the update document to be applied at the time + // the transaction is applied. The transaction will continue + // even if a document with Id is missing. Use Assert to + // test for the document presence or its contents. + // + // Remove indicates whether to remove the document with Id. + // The transaction continues even if the document doesn't yet + // exist at the time the transaction is applied. Use Assert + // with txn.DocExists to make sure it will be removed. + Insert interface{} `bson:"i,omitempty"` + Update interface{} `bson:"u,omitempty"` + Remove bool `bson:"r,omitempty"` +} + +func (op *Op) isChange() bool { + return op.Update != nil || op.Insert != nil || op.Remove +} + +func (op *Op) docKey() docKey { + return docKey{op.C, op.Id} +} + +func (op *Op) name() string { + switch { + case op.Update != nil: + return "update" + case op.Insert != nil: + return "insert" + case op.Remove: + return "remove" + case op.Assert != nil: + return "assert" + } + return "none" +} + +const ( + // DocExists and DocMissing may be used on an operation's + // Assert value to assert that the document with the given + // Id exists or does not exist, respectively. + DocExists = "d+" + DocMissing = "d-" +) + +// A Runner applies operations as part of a transaction onto any number +// of collections within a database. See the Run method for details. +type Runner struct { + tc *mgo.Collection // txns + sc *mgo.Collection // stash + lc *mgo.Collection // log +} + +// NewRunner returns a new transaction runner that uses tc to hold its +// transactions. +// +// Multiple transaction collections may exist in a single database, but +// all collections that are touched by operations in a given transaction +// collection must be handled exclusively by it. +// +// A second collection with the same name of tc but suffixed by ".stash" +// will be used for implementing the transactional behavior of insert +// and remove operations. +func NewRunner(tc *mgo.Collection) *Runner { + return &Runner{tc, tc.Database.C(tc.Name + ".stash"), nil} +} + +var ErrAborted = fmt.Errorf("transaction aborted") + +// Run creates a new transaction with ops and runs it immediately. +// The id parameter specifies the transaction id, and may be written +// down ahead of time to later verify the success of the change and +// resume it, when the procedure is interrupted for any reason. If +// empty, a random id will be generated. +// The info parameter, if not nil, is included under the "i" +// field of the transaction document. +// +// Operations across documents are not atomically applied, but are +// guaranteed to be eventually all applied in the order provided or +// all aborted, as long as the affected documents are only modified +// through transactions. If documents are simultaneously modified +// by transactions and out of transactions the behavior is undefined. +// +// If Run returns no errors, all operations were applied successfully. +// If it returns ErrAborted, one or more operations can't be applied +// and the transaction was entirely aborted with no changes performed. +// Otherwise, if the transaction is interrupted while running for any +// reason, it may be resumed explicitly or by attempting to apply +// another transaction on any of the documents targeted by ops, as +// long as the interruption was made after the transaction document +// itself was inserted. Run Resume with the obtained transaction id +// to confirm whether the transaction was applied or not. +// +// Any number of transactions may be run concurrently, with one +// runner or many. +func (r *Runner) Run(ops []Op, id bson.ObjectId, info interface{}) (err error) { + const efmt = "error in transaction op %d: %s" + for i := range ops { + op := &ops[i] + if op.C == "" || op.Id == nil { + return fmt.Errorf(efmt, i, "C or Id missing") + } + changes := 0 + if op.Insert != nil { + changes++ + } + if op.Update != nil { + changes++ + } + if op.Remove { + changes++ + } + if changes > 1 { + return fmt.Errorf(efmt, i, "more than one of Insert/Update/Remove set") + } + if changes == 0 && op.Assert == nil { + return fmt.Errorf(efmt, i, "none of Assert/Insert/Update/Remove set") + } + } + if id == "" { + id = bson.NewObjectId() + } + + // Insert transaction sooner rather than later, to stay on the safer side. + t := transaction{ + Id: id, + Ops: ops, + State: tpreparing, + Info: info, + } + if err = r.tc.Insert(&t); err != nil { + return err + } + if err = flush(r, &t); err != nil { + return err + } + if t.State == taborted { + return ErrAborted + } else if t.State != tapplied { + panic(fmt.Errorf("invalid state for %s after flush: %q", &t, t.State)) + } + return nil +} + +// ResumeAll resumes all pending transactions. All ErrAborted errors +// from individual transactions are ignored. +func (r *Runner) ResumeAll() (err error) { + debugf("Resuming all unfinished transactions") + iter := r.tc.Find(bson.D{{"s", bson.D{{"$in", []state{tpreparing, tprepared, tapplying}}}}}).Iter() + var t transaction + for iter.Next(&t) { + if t.State == tapplied || t.State == taborted { + continue + } + debugf("Resuming %s from %q", t.Id, t.State) + if err := flush(r, &t); err != nil { + return err + } + if !t.done() { + panic(fmt.Errorf("invalid state for %s after flush: %q", &t, t.State)) + } + } + return nil +} + +// Resume resumes the transaction with id. It returns mgo.ErrNotFound +// if the transaction is not found. Otherwise, it has the same semantics +// of the Run method after the transaction is inserted. +func (r *Runner) Resume(id bson.ObjectId) (err error) { + t, err := r.load(id) + if err != nil { + return err + } + if !t.done() { + debugf("Resuming %s from %q", t, t.State) + if err := flush(r, t); err != nil { + return err + } + } + if t.State == taborted { + return ErrAborted + } else if t.State != tapplied { + panic(fmt.Errorf("invalid state for %s after flush: %q", t, t.State)) + } + return nil +} + +// ChangeLog enables logging of changes to the given collection +// every time a transaction that modifies content is done being +// applied. +// +// Saved documents are in the format: +// +// {"_id": , : {"d": [, ...], "r": [, ...]}} +// +// The document revision is the value of the txn-revno field after +// the change has been applied. Negative values indicate the document +// was not present in the collection. Revisions will not change when +// updates or removes are applied to missing documents or inserts are +// attempted when the document isn't present. +func (r *Runner) ChangeLog(logc *mgo.Collection) { + r.lc = logc +} + +// PurgeMissing removes from collections any state that refers to transaction +// documents that for whatever reason have been lost from the system (removed +// by accident or lost in a hard crash, for example). +// +// This method should very rarely be needed, if at all, and should never be +// used during the normal operation of an application. Its purpose is to put +// a system that has seen unavoidable corruption back in a working state. +func (r *Runner) PurgeMissing(collections ...string) error { + type M map[string]interface{} + type S []interface{} + pipeline := []M{ + {"$project": M{"_id": 1, "txn-queue": 1}}, + {"$unwind": "$txn-queue"}, + {"$sort": M{"_id": 1, "txn-queue": 1}}, + //{"$group": M{"_id": M{"$substr": S{"$txn-queue", 0, 24}}, "docids": M{"$push": "$_id"}}}, + } + + type TRef struct { + DocId interface{} "_id" + TxnId string "txn-queue" + } + + found := make(map[bson.ObjectId]bool) + colls := make(map[string]bool) + + sort.Strings(collections) + for _, collection := range collections { + c := r.tc.Database.C(collection) + iter := c.Pipe(pipeline).Iter() + var tref TRef + for iter.Next(&tref) { + txnId := bson.ObjectIdHex(tref.TxnId[:24]) + if found[txnId] { + continue + } + if r.tc.FindId(txnId).One(nil) == nil { + found[txnId] = true + continue + } + logf("WARNING: purging from document %s/%v the missing transaction id %s", collection, tref.DocId, txnId) + err := c.UpdateId(tref.DocId, M{"$pull": M{"txn-queue": M{"$regex": "^" + txnId.Hex() + "_*"}}}) + if err != nil { + return fmt.Errorf("error purging missing transaction %s: %v", txnId.Hex(), err) + } + } + colls[collection] = true + } + + type StashTRef struct { + Id docKey "_id" + TxnId string "txn-queue" + } + + iter := r.sc.Pipe(pipeline).Iter() + var stref StashTRef + for iter.Next(&stref) { + txnId := bson.ObjectIdHex(stref.TxnId[:24]) + if found[txnId] { + continue + } + if r.tc.FindId(txnId).One(nil) == nil { + found[txnId] = true + continue + } + logf("WARNING: purging from stash document %s/%v the missing transaction id %s", stref.Id.C, stref.Id.Id, txnId) + err := r.sc.UpdateId(stref.Id, M{"$pull": M{"txn-queue": M{"$regex": "^" + txnId.Hex() + "_*"}}}) + if err != nil { + return fmt.Errorf("error purging missing transaction %s: %v", txnId.Hex(), err) + } + } + + return nil +} + +func (r *Runner) load(id bson.ObjectId) (*transaction, error) { + var t transaction + err := r.tc.FindId(id).One(&t) + if err == mgo.ErrNotFound { + return nil, fmt.Errorf("cannot find transaction %s", id) + } else if err != nil { + return nil, err + } + return &t, nil +} + +type typeNature int + +const ( + // The order of these values matters. Transactions + // from applications using different ordering will + // be incompatible with each other. + _ typeNature = iota + natureString + natureInt + natureFloat + natureBool + natureStruct +) + +func valueNature(v interface{}) (value interface{}, nature typeNature) { + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.String: + return rv.String(), natureString + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), natureInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return int64(rv.Uint()), natureInt + case reflect.Float32, reflect.Float64: + return rv.Float(), natureFloat + case reflect.Bool: + return rv.Bool(), natureBool + case reflect.Struct: + return v, natureStruct + } + panic("document id type unsupported by txn: " + rv.Kind().String()) +} + +type docKey struct { + C string + Id interface{} +} + +type docKeys []docKey + +func (ks docKeys) Len() int { return len(ks) } +func (ks docKeys) Swap(i, j int) { ks[i], ks[j] = ks[j], ks[i] } +func (ks docKeys) Less(i, j int) bool { + a, b := ks[i], ks[j] + if a.C != b.C { + return a.C < b.C + } + return valuecmp(a.Id, b.Id) == -1 +} + +func valuecmp(a, b interface{}) int { + av, an := valueNature(a) + bv, bn := valueNature(b) + if an < bn { + return -1 + } + if an > bn { + return 1 + } + + if av == bv { + return 0 + } + var less bool + switch an { + case natureString: + less = av.(string) < bv.(string) + case natureInt: + less = av.(int64) < bv.(int64) + case natureFloat: + less = av.(float64) < bv.(float64) + case natureBool: + less = !av.(bool) && bv.(bool) + case natureStruct: + less = structcmp(av, bv) == -1 + default: + panic("unreachable") + } + if less { + return -1 + } + return 1 +} + +func structcmp(a, b interface{}) int { + av := reflect.ValueOf(a) + bv := reflect.ValueOf(b) + + var ai, bi = 0, 0 + var an, bn = av.NumField(), bv.NumField() + var avi, bvi interface{} + var af, bf reflect.StructField + for { + for ai < an { + af = av.Type().Field(ai) + if isExported(af.Name) { + avi = av.Field(ai).Interface() + ai++ + break + } + ai++ + } + for bi < bn { + bf = bv.Type().Field(bi) + if isExported(bf.Name) { + bvi = bv.Field(bi).Interface() + bi++ + break + } + bi++ + } + if n := valuecmp(avi, bvi); n != 0 { + return n + } + nameA := getFieldName(af) + nameB := getFieldName(bf) + if nameA < nameB { + return -1 + } + if nameA > nameB { + return 1 + } + if ai == an && bi == bn { + return 0 + } + if ai == an || bi == bn { + if ai == bn { + return -1 + } + return 1 + } + } + panic("unreachable") +} + +func isExported(name string) bool { + a := name[0] + return a >= 'A' && a <= 'Z' +} + +func getFieldName(f reflect.StructField) string { + name := f.Tag.Get("bson") + if i := strings.Index(name, ","); i >= 0 { + name = name[:i] + } + if name == "" { + name = strings.ToLower(f.Name) + } + return name +} diff --git a/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/txn_test.go b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/txn_test.go new file mode 100644 index 0000000..1e396ea --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/mgo.v2/txn/txn_test.go @@ -0,0 +1,627 @@ +package txn_test + +import ( + "sync" + "testing" + "time" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" + "gopkg.in/mgo.v2/txn" +) + +func TestAll(t *testing.T) { + TestingT(t) +} + +type S struct { + MgoSuite + + db *mgo.Database + tc, sc *mgo.Collection + accounts *mgo.Collection + runner *txn.Runner +} + +var _ = Suite(&S{}) + +type M map[string]interface{} + +func (s *S) SetUpTest(c *C) { + txn.SetChaos(txn.Chaos{}) + txn.SetLogger(c) + txn.SetDebug(true) + s.MgoSuite.SetUpTest(c) + + s.db = s.session.DB("test") + s.tc = s.db.C("tc") + s.sc = s.db.C("tc.stash") + s.accounts = s.db.C("accounts") + s.runner = txn.NewRunner(s.tc) +} + +func (s *S) TearDownTest(c *C) { + txn.SetLogger(nil) + txn.SetDebug(false) +} + +type Account struct { + Id int `bson:"_id"` + Balance int +} + +func (s *S) TestDocExists(c *C) { + err := s.accounts.Insert(M{"_id": 0, "balance": 300}) + c.Assert(err, IsNil) + + exists := []txn.Op{{ + C: "accounts", + Id: 0, + Assert: txn.DocExists, + }} + missing := []txn.Op{{ + C: "accounts", + Id: 0, + Assert: txn.DocMissing, + }} + + err = s.runner.Run(exists, "", nil) + c.Assert(err, IsNil) + err = s.runner.Run(missing, "", nil) + c.Assert(err, Equals, txn.ErrAborted) + + err = s.accounts.RemoveId(0) + c.Assert(err, IsNil) + + err = s.runner.Run(exists, "", nil) + c.Assert(err, Equals, txn.ErrAborted) + err = s.runner.Run(missing, "", nil) + c.Assert(err, IsNil) +} + +func (s *S) TestInsert(c *C) { + err := s.accounts.Insert(M{"_id": 0, "balance": 300}) + c.Assert(err, IsNil) + + ops := []txn.Op{{ + C: "accounts", + Id: 0, + Insert: M{"balance": 200}, + }} + + err = s.runner.Run(ops, "", nil) + c.Assert(err, IsNil) + + var account Account + err = s.accounts.FindId(0).One(&account) + c.Assert(err, IsNil) + c.Assert(account.Balance, Equals, 300) + + ops[0].Id = 1 + err = s.runner.Run(ops, "", nil) + c.Assert(err, IsNil) + + err = s.accounts.FindId(1).One(&account) + c.Assert(err, IsNil) + c.Assert(account.Balance, Equals, 200) +} + +func (s *S) TestInsertStructID(c *C) { + type id struct { + FirstName string + LastName string + } + ops := []txn.Op{{ + C: "accounts", + Id: id{FirstName: "John", LastName: "Jones"}, + Assert: txn.DocMissing, + Insert: M{"balance": 200}, + }, { + C: "accounts", + Id: id{FirstName: "Sally", LastName: "Smith"}, + Assert: txn.DocMissing, + Insert: M{"balance": 800}, + }} + + err := s.runner.Run(ops, "", nil) + c.Assert(err, IsNil) + + n, err := s.accounts.Find(nil).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 2) +} + +func (s *S) TestRemove(c *C) { + err := s.accounts.Insert(M{"_id": 0, "balance": 300}) + c.Assert(err, IsNil) + + ops := []txn.Op{{ + C: "accounts", + Id: 0, + Remove: true, + }} + + err = s.runner.Run(ops, "", nil) + c.Assert(err, IsNil) + + err = s.accounts.FindId(0).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = s.runner.Run(ops, "", nil) + c.Assert(err, IsNil) +} + +func (s *S) TestUpdate(c *C) { + var err error + err = s.accounts.Insert(M{"_id": 0, "balance": 200}) + c.Assert(err, IsNil) + err = s.accounts.Insert(M{"_id": 1, "balance": 200}) + c.Assert(err, IsNil) + + ops := []txn.Op{{ + C: "accounts", + Id: 0, + Update: M{"$inc": M{"balance": 100}}, + }} + + err = s.runner.Run(ops, "", nil) + c.Assert(err, IsNil) + + var account Account + err = s.accounts.FindId(0).One(&account) + c.Assert(err, IsNil) + c.Assert(account.Balance, Equals, 300) + + ops[0].Id = 1 + + err = s.accounts.FindId(1).One(&account) + c.Assert(err, IsNil) + c.Assert(account.Balance, Equals, 200) +} + +func (s *S) TestInsertUpdate(c *C) { + ops := []txn.Op{{ + C: "accounts", + Id: 0, + Insert: M{"_id": 0, "balance": 200}, + }, { + C: "accounts", + Id: 0, + Update: M{"$inc": M{"balance": 100}}, + }} + + err := s.runner.Run(ops, "", nil) + c.Assert(err, IsNil) + + var account Account + err = s.accounts.FindId(0).One(&account) + c.Assert(err, IsNil) + c.Assert(account.Balance, Equals, 300) + + err = s.runner.Run(ops, "", nil) + c.Assert(err, IsNil) + + err = s.accounts.FindId(0).One(&account) + c.Assert(err, IsNil) + c.Assert(account.Balance, Equals, 400) +} + +func (s *S) TestUpdateInsert(c *C) { + ops := []txn.Op{{ + C: "accounts", + Id: 0, + Update: M{"$inc": M{"balance": 100}}, + }, { + C: "accounts", + Id: 0, + Insert: M{"_id": 0, "balance": 200}, + }} + + err := s.runner.Run(ops, "", nil) + c.Assert(err, IsNil) + + var account Account + err = s.accounts.FindId(0).One(&account) + c.Assert(err, IsNil) + c.Assert(account.Balance, Equals, 200) + + err = s.runner.Run(ops, "", nil) + c.Assert(err, IsNil) + + err = s.accounts.FindId(0).One(&account) + c.Assert(err, IsNil) + c.Assert(account.Balance, Equals, 300) +} + +func (s *S) TestInsertRemoveInsert(c *C) { + ops := []txn.Op{{ + C: "accounts", + Id: 0, + Insert: M{"_id": 0, "balance": 200}, + }, { + C: "accounts", + Id: 0, + Remove: true, + }, { + C: "accounts", + Id: 0, + Insert: M{"_id": 0, "balance": 300}, + }} + + err := s.runner.Run(ops, "", nil) + c.Assert(err, IsNil) + + var account Account + err = s.accounts.FindId(0).One(&account) + c.Assert(err, IsNil) + c.Assert(account.Balance, Equals, 300) +} + +func (s *S) TestQueueStashing(c *C) { + txn.SetChaos(txn.Chaos{ + KillChance: 1, + Breakpoint: "set-applying", + }) + + opses := [][]txn.Op{{{ + C: "accounts", + Id: 0, + Insert: M{"balance": 100}, + }}, {{ + C: "accounts", + Id: 0, + Remove: true, + }}, {{ + C: "accounts", + Id: 0, + Insert: M{"balance": 200}, + }}, {{ + C: "accounts", + Id: 0, + Update: M{"$inc": M{"balance": 100}}, + }}} + + var last bson.ObjectId + for _, ops := range opses { + last = bson.NewObjectId() + err := s.runner.Run(ops, last, nil) + c.Assert(err, Equals, txn.ErrChaos) + } + + txn.SetChaos(txn.Chaos{}) + err := s.runner.Resume(last) + c.Assert(err, IsNil) + + var account Account + err = s.accounts.FindId(0).One(&account) + c.Assert(err, IsNil) + c.Assert(account.Balance, Equals, 300) +} + +func (s *S) TestInfo(c *C) { + ops := []txn.Op{{ + C: "accounts", + Id: 0, + Assert: txn.DocMissing, + }} + + id := bson.NewObjectId() + err := s.runner.Run(ops, id, M{"n": 42}) + c.Assert(err, IsNil) + + var t struct{ I struct{ N int } } + err = s.tc.FindId(id).One(&t) + c.Assert(err, IsNil) + c.Assert(t.I.N, Equals, 42) +} + +func (s *S) TestErrors(c *C) { + doc := bson.M{"foo": 1} + tests := []txn.Op{{ + C: "c", + Id: 0, + }, { + C: "c", + Id: 0, + Insert: doc, + Remove: true, + }, { + C: "c", + Id: 0, + Insert: doc, + Update: doc, + }, { + C: "c", + Id: 0, + Update: doc, + Remove: true, + }, { + C: "c", + Assert: doc, + }, { + Id: 0, + Assert: doc, + }} + + txn.SetChaos(txn.Chaos{KillChance: 1.0}) + for _, op := range tests { + c.Logf("op: %v", op) + err := s.runner.Run([]txn.Op{op}, "", nil) + c.Assert(err, ErrorMatches, "error in transaction op 0: .*") + } +} + +func (s *S) TestAssertNestedOr(c *C) { + // Assert uses $or internally. Ensure nesting works. + err := s.accounts.Insert(M{"_id": 0, "balance": 300}) + c.Assert(err, IsNil) + + ops := []txn.Op{{ + C: "accounts", + Id: 0, + Assert: bson.D{{"$or", []bson.D{{{"balance", 100}}, {{"balance", 300}}}}}, + Update: bson.D{{"$inc", bson.D{{"balance", 100}}}}, + }} + + err = s.runner.Run(ops, "", nil) + c.Assert(err, IsNil) + + var account Account + err = s.accounts.FindId(0).One(&account) + c.Assert(err, IsNil) + c.Assert(account.Balance, Equals, 400) +} + +func (s *S) TestVerifyFieldOrdering(c *C) { + // Used to have a map in certain operations, which means + // the ordering of fields would be messed up. + fields := bson.D{{"a", 1}, {"b", 2}, {"c", 3}} + ops := []txn.Op{{ + C: "accounts", + Id: 0, + Insert: fields, + }} + + err := s.runner.Run(ops, "", nil) + c.Assert(err, IsNil) + + var d bson.D + err = s.accounts.FindId(0).One(&d) + c.Assert(err, IsNil) + + var filtered bson.D + for _, e := range d { + switch e.Name { + case "a", "b", "c": + filtered = append(filtered, e) + } + } + c.Assert(filtered, DeepEquals, fields) +} + +func (s *S) TestChangeLog(c *C) { + chglog := s.db.C("chglog") + s.runner.ChangeLog(chglog) + + ops := []txn.Op{{ + C: "debts", + Id: 0, + Assert: txn.DocMissing, + }, { + C: "accounts", + Id: 0, + Insert: M{"balance": 300}, + }, { + C: "accounts", + Id: 1, + Insert: M{"balance": 300}, + }, { + C: "people", + Id: "joe", + Insert: M{"accounts": []int64{0, 1}}, + }} + id := bson.NewObjectId() + err := s.runner.Run(ops, id, nil) + c.Assert(err, IsNil) + + type IdList []interface{} + type Log struct { + Docs IdList "d" + Revnos []int64 "r" + } + var m map[string]*Log + err = chglog.FindId(id).One(&m) + c.Assert(err, IsNil) + + c.Assert(m["accounts"], DeepEquals, &Log{IdList{0, 1}, []int64{2, 2}}) + c.Assert(m["people"], DeepEquals, &Log{IdList{"joe"}, []int64{2}}) + c.Assert(m["debts"], IsNil) + + ops = []txn.Op{{ + C: "accounts", + Id: 0, + Update: M{"$inc": M{"balance": 100}}, + }, { + C: "accounts", + Id: 1, + Update: M{"$inc": M{"balance": 100}}, + }} + id = bson.NewObjectId() + err = s.runner.Run(ops, id, nil) + c.Assert(err, IsNil) + + m = nil + err = chglog.FindId(id).One(&m) + c.Assert(err, IsNil) + + c.Assert(m["accounts"], DeepEquals, &Log{IdList{0, 1}, []int64{3, 3}}) + c.Assert(m["people"], IsNil) + + ops = []txn.Op{{ + C: "accounts", + Id: 0, + Remove: true, + }, { + C: "people", + Id: "joe", + Remove: true, + }} + id = bson.NewObjectId() + err = s.runner.Run(ops, id, nil) + c.Assert(err, IsNil) + + m = nil + err = chglog.FindId(id).One(&m) + c.Assert(err, IsNil) + + c.Assert(m["accounts"], DeepEquals, &Log{IdList{0}, []int64{-4}}) + c.Assert(m["people"], DeepEquals, &Log{IdList{"joe"}, []int64{-3}}) +} + +func (s *S) TestPurgeMissing(c *C) { + txn.SetChaos(txn.Chaos{ + KillChance: 1, + Breakpoint: "set-applying", + }) + + err := s.accounts.Insert(M{"_id": 0, "balance": 100}) + c.Assert(err, IsNil) + err = s.accounts.Insert(M{"_id": 1, "balance": 100}) + c.Assert(err, IsNil) + + ops1 := []txn.Op{{ + C: "accounts", + Id: 3, + Insert: M{"balance": 100}, + }} + + ops2 := []txn.Op{{ + C: "accounts", + Id: 0, + Remove: true, + }, { + C: "accounts", + Id: 1, + Update: M{"$inc": M{"balance": 100}}, + }, { + C: "accounts", + Id: 2, + Insert: M{"balance": 100}, + }} + + first := bson.NewObjectId() + c.Logf("---- Running ops1 under transaction %q, to be canceled by chaos", first.Hex()) + err = s.runner.Run(ops1, first, nil) + c.Assert(err, Equals, txn.ErrChaos) + + last := bson.NewObjectId() + c.Logf("---- Running ops2 under transaction %q, to be canceled by chaos", last.Hex()) + err = s.runner.Run(ops2, last, nil) + c.Assert(err, Equals, txn.ErrChaos) + + c.Logf("---- Removing transaction %q", last.Hex()) + err = s.tc.RemoveId(last) + c.Assert(err, IsNil) + + c.Logf("---- Disabling chaos and attempting to resume all") + txn.SetChaos(txn.Chaos{}) + err = s.runner.ResumeAll() + c.Assert(err, IsNil) + + again := bson.NewObjectId() + c.Logf("---- Running ops2 again under transaction %q, to fail for missing transaction", again.Hex()) + err = s.runner.Run(ops2, "", nil) + c.Assert(err, ErrorMatches, "cannot find transaction .*") + + c.Logf("---- Puring missing transactions") + err = s.runner.PurgeMissing("accounts") + c.Assert(err, IsNil) + + c.Logf("---- Resuming pending transactions") + err = s.runner.ResumeAll() + c.Assert(err, IsNil) + + expect := []struct{ Id, Balance int }{ + {0, -1}, + {1, 200}, + {2, 100}, + {3, 100}, + } + var got Account + for _, want := range expect { + err = s.accounts.FindId(want.Id).One(&got) + if want.Balance == -1 { + if err != mgo.ErrNotFound { + c.Errorf("Account %d should not exist, find got err=%#v", err) + } + } else if err != nil { + c.Errorf("Account %d should have balance of %d, but wasn't found", want.Id, want.Balance) + } else if got.Balance != want.Balance { + c.Errorf("Account %d should have balance of %d, got %d", want.Id, want.Balance, got.Balance) + } + } +} + +func (s *S) TestTxnQueueStressTest(c *C) { + txn.SetChaos(txn.Chaos{ + SlowdownChance: 0.3, + Slowdown: 50 * time.Millisecond, + }) + defer txn.SetChaos(txn.Chaos{}) + + // So we can run more iterations of the test in less time. + txn.SetDebug(false) + + err := s.accounts.Insert(M{"_id": 0, "balance": 0}, M{"_id": 1, "balance": 0}) + c.Assert(err, IsNil) + + // Run half of the operations changing account 0 and then 1, + // and the other half in the opposite order. + ops01 := []txn.Op{{ + C: "accounts", + Id: 0, + Update: M{"$inc": M{"balance": 1}}, + }, { + C: "accounts", + Id: 1, + Update: M{"$inc": M{"balance": 1}}, + }} + + ops10 := []txn.Op{{ + C: "accounts", + Id: 1, + Update: M{"$inc": M{"balance": 1}}, + }, { + C: "accounts", + Id: 0, + Update: M{"$inc": M{"balance": 1}}, + }} + + ops := [][]txn.Op{ops01, ops10} + + const runners = 4 + const changes = 1000 + + var wg sync.WaitGroup + wg.Add(runners) + for n := 0; n < runners; n++ { + n := n + go func() { + defer wg.Done() + for i := 0; i < changes; i++ { + err = s.runner.Run(ops[n%2], "", nil) + c.Assert(err, IsNil) + } + }() + } + wg.Wait() + + for id := 0; id < 2; id++ { + var account Account + err = s.accounts.FindId(id).One(&account) + if account.Balance != runners*changes { + c.Errorf("Account should have balance of %d, got %d", runners*changes, account.Balance) + } + } +} diff --git a/Godeps/_workspace/src/labix.org/v2/mgo/bson/LICENSE b/Godeps/_workspace/src/labix.org/v2/mgo/bson/LICENSE new file mode 100644 index 0000000..8903260 --- /dev/null +++ b/Godeps/_workspace/src/labix.org/v2/mgo/bson/LICENSE @@ -0,0 +1,25 @@ +BSON library for Go + +Copyright (c) 2010-2012 - Gustavo Niemeyer + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Godeps/_workspace/src/labix.org/v2/mgo/bson/bson.go b/Godeps/_workspace/src/labix.org/v2/mgo/bson/bson.go new file mode 100644 index 0000000..3ebfd84 --- /dev/null +++ b/Godeps/_workspace/src/labix.org/v2/mgo/bson/bson.go @@ -0,0 +1,682 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Package bson is an implementation of the BSON specification for Go: +// +// http://bsonspec.org +// +// It was created as part of the mgo MongoDB driver for Go, but is standalone +// and may be used on its own without the driver. +package bson + +import ( + "crypto/md5" + "crypto/rand" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "os" + "reflect" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" +) + +// -------------------------------------------------------------------------- +// The public API. + +// A value implementing the bson.Getter interface will have its GetBSON +// method called when the given value has to be marshalled, and the result +// of this method will be marshaled in place of the actual object. +// +// If GetBSON returns return a non-nil error, the marshalling procedure +// will stop and error out with the provided value. +type Getter interface { + GetBSON() (interface{}, error) +} + +// A value implementing the bson.Setter interface will receive the BSON +// value via the SetBSON method during unmarshaling, and the object +// itself will not be changed as usual. +// +// If setting the value works, the method should return nil or alternatively +// bson.SetZero to set the respective field to its zero value (nil for +// pointer types). If SetBSON returns a value of type bson.TypeError, the +// BSON value will be omitted from a map or slice being decoded and the +// unmarshalling will continue. If it returns any other non-nil error, the +// unmarshalling procedure will stop and error out with the provided value. +// +// This interface is generally useful in pointer receivers, since the method +// will want to change the receiver. A type field that implements the Setter +// interface doesn't have to be a pointer, though. +// +// Unlike the usual behavior, unmarshalling onto a value that implements a +// Setter interface will NOT reset the value to its zero state. This allows +// the value to decide by itself how to be unmarshalled. +// +// For example: +// +// type MyString string +// +// func (s *MyString) SetBSON(raw bson.Raw) error { +// return raw.Unmarshal(s) +// } +// +type Setter interface { + SetBSON(raw Raw) error +} + +// SetZero may be returned from a SetBSON method to have the value set to +// its respective zero value. When used in pointer values, this will set the +// field to nil rather than to the pre-allocated value. +var SetZero = errors.New("set to zero") + +// M is a convenient alias for a map[string]interface{} map, useful for +// dealing with BSON in a native way. For instance: +// +// bson.M{"a": 1, "b": true} +// +// There's no special handling for this type in addition to what's done anyway +// for an equivalent map type. Elements in the map will be dumped in an +// undefined ordered. See also the bson.D type for an ordered alternative. +type M map[string]interface{} + +// D represents a BSON document containing ordered elements. For example: +// +// bson.D{{"a", 1}, {"b", true}} +// +// In some situations, such as when creating indexes for MongoDB, the order in +// which the elements are defined is important. If the order is not important, +// using a map is generally more comfortable. See bson.M and bson.RawD. +type D []DocElem + +// See the D type. +type DocElem struct { + Name string + Value interface{} +} + +// Map returns a map out of the ordered element name/value pairs in d. +func (d D) Map() (m M) { + m = make(M, len(d)) + for _, item := range d { + m[item.Name] = item.Value + } + return m +} + +// The Raw type represents raw unprocessed BSON documents and elements. +// Kind is the kind of element as defined per the BSON specification, and +// Data is the raw unprocessed data for the respective element. +// Using this type it is possible to unmarshal or marshal values partially. +// +// Relevant documentation: +// +// http://bsonspec.org/#/specification +// +type Raw struct { + Kind byte + Data []byte +} + +// RawD represents a BSON document containing raw unprocessed elements. +// This low-level representation may be useful when lazily processing +// documents of uncertain content, or when manipulating the raw content +// documents in general. +type RawD []RawDocElem + +// See the RawD type. +type RawDocElem struct { + Name string + Value Raw +} + +// ObjectId is a unique ID identifying a BSON value. It must be exactly 12 bytes +// long. MongoDB objects by default have such a property set in their "_id" +// property. +// +// http://www.mongodb.org/display/DOCS/Object+IDs +type ObjectId string + +// ObjectIdHex returns an ObjectId from the provided hex representation. +// Calling this function with an invalid hex representation will +// cause a runtime panic. See the IsObjectIdHex function. +func ObjectIdHex(s string) ObjectId { + d, err := hex.DecodeString(s) + if err != nil || len(d) != 12 { + panic(fmt.Sprintf("Invalid input to ObjectIdHex: %q", s)) + } + return ObjectId(d) +} + +// IsObjectIdHex returns whether s is a valid hex representation of +// an ObjectId. See the ObjectIdHex function. +func IsObjectIdHex(s string) bool { + if len(s) != 24 { + return false + } + _, err := hex.DecodeString(s) + return err == nil +} + +// objectIdCounter is atomically incremented when generating a new ObjectId +// using NewObjectId() function. It's used as a counter part of an id. +var objectIdCounter uint32 = 0 + +// machineId stores machine id generated once and used in subsequent calls +// to NewObjectId function. +var machineId = readMachineId() + +// readMachineId generates machine id and puts it into the machineId global +// variable. If this function fails to get the hostname, it will cause +// a runtime error. +func readMachineId() []byte { + var sum [3]byte + id := sum[:] + hostname, err1 := os.Hostname() + if err1 != nil { + _, err2 := io.ReadFull(rand.Reader, id) + if err2 != nil { + panic(fmt.Errorf("cannot get hostname: %v; %v", err1, err2)) + } + return id + } + hw := md5.New() + hw.Write([]byte(hostname)) + copy(id, hw.Sum(nil)) + return id +} + +// NewObjectId returns a new unique ObjectId. +func NewObjectId() ObjectId { + var b [12]byte + // Timestamp, 4 bytes, big endian + binary.BigEndian.PutUint32(b[:], uint32(time.Now().Unix())) + // Machine, first 3 bytes of md5(hostname) + b[4] = machineId[0] + b[5] = machineId[1] + b[6] = machineId[2] + // Pid, 2 bytes, specs don't specify endianness, but we use big endian. + pid := os.Getpid() + b[7] = byte(pid >> 8) + b[8] = byte(pid) + // Increment, 3 bytes, big endian + i := atomic.AddUint32(&objectIdCounter, 1) + b[9] = byte(i >> 16) + b[10] = byte(i >> 8) + b[11] = byte(i) + return ObjectId(b[:]) +} + +// NewObjectIdWithTime returns a dummy ObjectId with the timestamp part filled +// with the provided number of seconds from epoch UTC, and all other parts +// filled with zeroes. It's not safe to insert a document with an id generated +// by this method, it is useful only for queries to find documents with ids +// generated before or after the specified timestamp. +func NewObjectIdWithTime(t time.Time) ObjectId { + var b [12]byte + binary.BigEndian.PutUint32(b[:4], uint32(t.Unix())) + return ObjectId(string(b[:])) +} + +// String returns a hex string representation of the id. +// Example: ObjectIdHex("4d88e15b60f486e428412dc9"). +func (id ObjectId) String() string { + return fmt.Sprintf(`ObjectIdHex("%x")`, string(id)) +} + +// Hex returns a hex representation of the ObjectId. +func (id ObjectId) Hex() string { + return hex.EncodeToString([]byte(id)) +} + +// MarshalJSON turns a bson.ObjectId into a json.Marshaller. +func (id ObjectId) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`"%x"`, string(id))), nil +} + +// UnmarshalJSON turns *bson.ObjectId into a json.Unmarshaller. +func (id *ObjectId) UnmarshalJSON(data []byte) error { + if len(data) != 26 || data[0] != '"' || data[25] != '"' { + return errors.New(fmt.Sprintf("Invalid ObjectId in JSON: %s", string(data))) + } + var buf [12]byte + _, err := hex.Decode(buf[:], data[1:25]) + if err != nil { + return errors.New(fmt.Sprintf("Invalid ObjectId in JSON: %s (%s)", string(data), err)) + } + *id = ObjectId(string(buf[:])) + return nil +} + +// Valid returns true if id is valid. A valid id must contain exactly 12 bytes. +func (id ObjectId) Valid() bool { + return len(id) == 12 +} + +// byteSlice returns byte slice of id from start to end. +// Calling this function with an invalid id will cause a runtime panic. +func (id ObjectId) byteSlice(start, end int) []byte { + if len(id) != 12 { + panic(fmt.Sprintf("Invalid ObjectId: %q", string(id))) + } + return []byte(string(id)[start:end]) +} + +// Time returns the timestamp part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Time() time.Time { + // First 4 bytes of ObjectId is 32-bit big-endian seconds from epoch. + secs := int64(binary.BigEndian.Uint32(id.byteSlice(0, 4))) + return time.Unix(secs, 0) +} + +// Machine returns the 3-byte machine id part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Machine() []byte { + return id.byteSlice(4, 7) +} + +// Pid returns the process id part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Pid() uint16 { + return binary.BigEndian.Uint16(id.byteSlice(7, 9)) +} + +// Counter returns the incrementing value part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Counter() int32 { + b := id.byteSlice(9, 12) + // Counter is stored as big-endian 3-byte value + return int32(uint32(b[0])<<16 | uint32(b[1])<<8 | uint32(b[2])) +} + +// The Symbol type is similar to a string and is used in languages with a +// distinct symbol type. +type Symbol string + +// Now returns the current time with millisecond precision. MongoDB stores +// timestamps with the same precision, so a Time returned from this method +// will not change after a roundtrip to the database. That's the only reason +// why this function exists. Using the time.Now function also works fine +// otherwise. +func Now() time.Time { + return time.Unix(0, time.Now().UnixNano()/1e6*1e6) +} + +// MongoTimestamp is a special internal type used by MongoDB that for some +// strange reason has its own datatype defined in BSON. +type MongoTimestamp int64 + +type orderKey int64 + +// MaxKey is a special value that compares higher than all other possible BSON +// values in a MongoDB database. +var MaxKey = orderKey(1<<63 - 1) + +// MinKey is a special value that compares lower than all other possible BSON +// values in a MongoDB database. +var MinKey = orderKey(-1 << 63) + +type undefined struct{} + +// Undefined represents the undefined BSON value. +var Undefined undefined + +// Binary is a representation for non-standard binary values. Any kind should +// work, but the following are known as of this writing: +// +// 0x00 - Generic. This is decoded as []byte(data), not Binary{0x00, data}. +// 0x01 - Function (!?) +// 0x02 - Obsolete generic. +// 0x03 - UUID +// 0x05 - MD5 +// 0x80 - User defined. +// +type Binary struct { + Kind byte + Data []byte +} + +// RegEx represents a regular expression. The Options field may contain +// individual characters defining the way in which the pattern should be +// applied, and must be sorted. Valid options as of this writing are 'i' for +// case insensitive matching, 'm' for multi-line matching, 'x' for verbose +// mode, 'l' to make \w, \W, and similar be locale-dependent, 's' for dot-all +// mode (a '.' matches everything), and 'u' to make \w, \W, and similar match +// unicode. The value of the Options parameter is not verified before being +// marshaled into the BSON format. +type RegEx struct { + Pattern string + Options string +} + +// JavaScript is a type that holds JavaScript code. If Scope is non-nil, it +// will be marshaled as a mapping from identifiers to values that may be +// used when evaluating the provided Code. +type JavaScript struct { + Code string + Scope interface{} +} + +const initialBufferSize = 64 + +func handleErr(err *error) { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } else if _, ok := r.(externalPanic); ok { + panic(r) + } else if s, ok := r.(string); ok { + *err = errors.New(s) + } else if e, ok := r.(error); ok { + *err = e + } else { + panic(r) + } + } +} + +// Marshal serializes the in value, which may be a map or a struct value. +// In the case of struct values, only exported fields will be serialized. +// The lowercased field name is used as the key for each exported field, +// but this behavior may be changed using the respective field tag. +// The tag may also contain flags to tweak the marshalling behavior for +// the field. The tag formats accepted are: +// +// "[][,[,]]" +// +// `(...) bson:"[][,[,]]" (...)` +// +// The following flags are currently supported: +// +// omitempty Only include the field if it's not set to the zero +// value for the type or to empty slices or maps. +// +// minsize Marshal an int64 value as an int32, if that's feasible +// while preserving the numeric value. +// +// inline Inline the field, which must be a struct or a map, +// causing all of its fields or keys to be processed as if +// they were part of the outer struct. For maps, keys must +// not conflict with the bson keys of other struct fields. +// +// Some examples: +// +// type T struct { +// A bool +// B int "myb" +// C string "myc,omitempty" +// D string `bson:",omitempty" json:"jsonkey"` +// E int64 ",minsize" +// F int64 "myf,omitempty,minsize" +// } +// +func Marshal(in interface{}) (out []byte, err error) { + defer handleErr(&err) + e := &encoder{make([]byte, 0, initialBufferSize)} + e.addDoc(reflect.ValueOf(in)) + return e.out, nil +} + +// Unmarshal deserializes data from in into the out value. The out value +// must be a map, a pointer to a struct, or a pointer to a bson.D value. +// The lowercased field name is used as the key for each exported field, +// but this behavior may be changed using the respective field tag. +// The tag may also contain flags to tweak the marshalling behavior for +// the field. The tag formats accepted are: +// +// "[][,[,]]" +// +// `(...) bson:"[][,[,]]" (...)` +// +// The following flags are currently supported during unmarshal (see the +// Marshal method for other flags): +// +// inline Inline the field, which must be a struct or a map. +// Inlined structs are handled as if its fields were part +// of the outer struct. An inlined map causes keys that do +// not match any other struct field to be inserted in the +// map rather than being discarded as usual. +// +// The target field or element types of out may not necessarily match +// the BSON values of the provided data. The following conversions are +// made automatically: +// +// - Numeric types are converted if at least the integer part of the +// value would be preserved correctly +// - Bools are converted to numeric types as 1 or 0 +// - Numeric types are converted to bools as true if not 0 or false otherwise +// - Binary and string BSON data is converted to a string, array or byte slice +// +// If the value would not fit the type and cannot be converted, it's +// silently skipped. +// +// Pointer values are initialized when necessary. +func Unmarshal(in []byte, out interface{}) (err error) { + defer handleErr(&err) + v := reflect.ValueOf(out) + switch v.Kind() { + case reflect.Map, reflect.Ptr: + d := newDecoder(in) + d.readDocTo(v) + case reflect.Struct: + return errors.New("Unmarshal can't deal with struct values. Use a pointer.") + default: + return errors.New("Unmarshal needs a map or a pointer to a struct.") + } + return nil +} + +// Unmarshal deserializes raw into the out value. If the out value type +// is not compatible with raw, a *bson.TypeError is returned. +// +// See the Unmarshal function documentation for more details on the +// unmarshalling process. +func (raw Raw) Unmarshal(out interface{}) (err error) { + defer handleErr(&err) + v := reflect.ValueOf(out) + switch v.Kind() { + case reflect.Ptr: + v = v.Elem() + fallthrough + case reflect.Map: + d := newDecoder(raw.Data) + good := d.readElemTo(v, raw.Kind) + if !good { + return &TypeError{v.Type(), raw.Kind} + } + case reflect.Struct: + return errors.New("Raw Unmarshal can't deal with struct values. Use a pointer.") + default: + return errors.New("Raw Unmarshal needs a map or a valid pointer.") + } + return nil +} + +type TypeError struct { + Type reflect.Type + Kind byte +} + +func (e *TypeError) Error() string { + return fmt.Sprintf("BSON kind 0x%02x isn't compatible with type %s", e.Kind, e.Type.String()) +} + +// -------------------------------------------------------------------------- +// Maintain a mapping of keys to structure field indexes + +type structInfo struct { + FieldsMap map[string]fieldInfo + FieldsList []fieldInfo + InlineMap int + Zero reflect.Value +} + +type fieldInfo struct { + Key string + Num int + OmitEmpty bool + MinSize bool + Inline []int +} + +var structMap = make(map[reflect.Type]*structInfo) +var structMapMutex sync.RWMutex + +type externalPanic string + +func (e externalPanic) String() string { + return string(e) +} + +func getStructInfo(st reflect.Type) (*structInfo, error) { + structMapMutex.RLock() + sinfo, found := structMap[st] + structMapMutex.RUnlock() + if found { + return sinfo, nil + } + n := st.NumField() + fieldsMap := make(map[string]fieldInfo) + fieldsList := make([]fieldInfo, 0, n) + inlineMap := -1 + for i := 0; i != n; i++ { + field := st.Field(i) + if field.PkgPath != "" { + continue // Private field + } + + info := fieldInfo{Num: i} + + tag := field.Tag.Get("bson") + if tag == "" && strings.Index(string(field.Tag), ":") < 0 { + tag = string(field.Tag) + } + if tag == "-" { + continue + } + + // XXX Drop this after a few releases. + if s := strings.Index(tag, "/"); s >= 0 { + recommend := tag[:s] + for _, c := range tag[s+1:] { + switch c { + case 'c': + recommend += ",omitempty" + case 's': + recommend += ",minsize" + default: + msg := fmt.Sprintf("Unsupported flag %q in tag %q of type %s", string([]byte{uint8(c)}), tag, st) + panic(externalPanic(msg)) + } + } + msg := fmt.Sprintf("Replace tag %q in field %s of type %s by %q", tag, field.Name, st, recommend) + panic(externalPanic(msg)) + } + + inline := false + fields := strings.Split(tag, ",") + if len(fields) > 1 { + for _, flag := range fields[1:] { + switch flag { + case "omitempty": + info.OmitEmpty = true + case "minsize": + info.MinSize = true + case "inline": + inline = true + default: + msg := fmt.Sprintf("Unsupported flag %q in tag %q of type %s", flag, tag, st) + panic(externalPanic(msg)) + } + } + tag = fields[0] + } + + if inline { + switch field.Type.Kind() { + case reflect.Map: + if inlineMap >= 0 { + return nil, errors.New("Multiple ,inline maps in struct " + st.String()) + } + if field.Type.Key() != reflect.TypeOf("") { + return nil, errors.New("Option ,inline needs a map with string keys in struct " + st.String()) + } + inlineMap = info.Num + case reflect.Struct: + sinfo, err := getStructInfo(field.Type) + if err != nil { + return nil, err + } + for _, finfo := range sinfo.FieldsList { + if _, found := fieldsMap[finfo.Key]; found { + msg := "Duplicated key '" + finfo.Key + "' in struct " + st.String() + return nil, errors.New(msg) + } + if finfo.Inline == nil { + finfo.Inline = []int{i, finfo.Num} + } else { + finfo.Inline = append([]int{i}, finfo.Inline...) + } + fieldsMap[finfo.Key] = finfo + fieldsList = append(fieldsList, finfo) + } + default: + panic("Option ,inline needs a struct value or map field") + } + continue + } + + if tag != "" { + info.Key = tag + } else { + info.Key = strings.ToLower(field.Name) + } + + if _, found = fieldsMap[info.Key]; found { + msg := "Duplicated key '" + info.Key + "' in struct " + st.String() + return nil, errors.New(msg) + } + + fieldsList = append(fieldsList, info) + fieldsMap[info.Key] = info + } + sinfo = &structInfo{ + fieldsMap, + fieldsList, + inlineMap, + reflect.New(st).Elem(), + } + structMapMutex.Lock() + structMap[st] = sinfo + structMapMutex.Unlock() + return sinfo, nil +} diff --git a/Godeps/_workspace/src/labix.org/v2/mgo/bson/bson_test.go b/Godeps/_workspace/src/labix.org/v2/mgo/bson/bson_test.go new file mode 100644 index 0000000..66359e7 --- /dev/null +++ b/Godeps/_workspace/src/labix.org/v2/mgo/bson/bson_test.go @@ -0,0 +1,1452 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. + +package bson_test + +import ( + "encoding/binary" + "encoding/json" + "errors" + "labix.org/v2/mgo/bson" + . "launchpad.net/gocheck" + "net/url" + "reflect" + "testing" + "time" +) + +func TestAll(t *testing.T) { + TestingT(t) +} + +type S struct{} + +var _ = Suite(&S{}) + +// Wrap up the document elements contained in data, prepending the int32 +// length of the data, and appending the '\x00' value closing the document. +func wrapInDoc(data string) string { + result := make([]byte, len(data)+5) + binary.LittleEndian.PutUint32(result, uint32(len(result))) + copy(result[4:], []byte(data)) + return string(result) +} + +func makeZeroDoc(value interface{}) (zero interface{}) { + v := reflect.ValueOf(value) + t := v.Type() + switch t.Kind() { + case reflect.Map: + mv := reflect.MakeMap(t) + zero = mv.Interface() + case reflect.Ptr: + pv := reflect.New(v.Type().Elem()) + zero = pv.Interface() + case reflect.Slice: + zero = reflect.New(t).Interface() + default: + panic("unsupported doc type") + } + return zero +} + +func testUnmarshal(c *C, data string, obj interface{}) { + zero := makeZeroDoc(obj) + err := bson.Unmarshal([]byte(data), zero) + c.Assert(err, IsNil) + c.Assert(zero, DeepEquals, obj) +} + +type testItemType struct { + obj interface{} + data string +} + +// -------------------------------------------------------------------------- +// Samples from bsonspec.org: + +var sampleItems = []testItemType{ + {bson.M{"hello": "world"}, + "\x16\x00\x00\x00\x02hello\x00\x06\x00\x00\x00world\x00\x00"}, + {bson.M{"BSON": []interface{}{"awesome", float64(5.05), 1986}}, + "1\x00\x00\x00\x04BSON\x00&\x00\x00\x00\x020\x00\x08\x00\x00\x00" + + "awesome\x00\x011\x00333333\x14@\x102\x00\xc2\x07\x00\x00\x00\x00"}, +} + +func (s *S) TestMarshalSampleItems(c *C) { + for i, item := range sampleItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, item.data, Commentf("Failed on item %d", i)) + } +} + +func (s *S) TestUnmarshalSampleItems(c *C) { + for i, item := range sampleItems { + value := bson.M{} + err := bson.Unmarshal([]byte(item.data), value) + c.Assert(err, IsNil) + c.Assert(value, DeepEquals, item.obj, Commentf("Failed on item %d", i)) + } +} + +// -------------------------------------------------------------------------- +// Every type, ordered by the type flag. These are not wrapped with the +// length and last \x00 from the document. wrapInDoc() computes them. +// Note that all of them should be supported as two-way conversions. + +var allItems = []testItemType{ + {bson.M{}, + ""}, + {bson.M{"_": float64(5.05)}, + "\x01_\x00333333\x14@"}, + {bson.M{"_": "yo"}, + "\x02_\x00\x03\x00\x00\x00yo\x00"}, + {bson.M{"_": bson.M{"a": true}}, + "\x03_\x00\x09\x00\x00\x00\x08a\x00\x01\x00"}, + {bson.M{"_": []interface{}{true, false}}, + "\x04_\x00\r\x00\x00\x00\x080\x00\x01\x081\x00\x00\x00"}, + {bson.M{"_": []byte("yo")}, + "\x05_\x00\x02\x00\x00\x00\x00yo"}, + {bson.M{"_": bson.Binary{0x80, []byte("udef")}}, + "\x05_\x00\x04\x00\x00\x00\x80udef"}, + {bson.M{"_": bson.Undefined}, // Obsolete, but still seen in the wild. + "\x06_\x00"}, + {bson.M{"_": bson.ObjectId("0123456789ab")}, + "\x07_\x000123456789ab"}, + {bson.M{"_": false}, + "\x08_\x00\x00"}, + {bson.M{"_": true}, + "\x08_\x00\x01"}, + {bson.M{"_": time.Unix(0, 258e6)}, // Note the NS <=> MS conversion. + "\x09_\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"_": nil}, + "\x0A_\x00"}, + {bson.M{"_": bson.RegEx{"ab", "cd"}}, + "\x0B_\x00ab\x00cd\x00"}, + {bson.M{"_": bson.JavaScript{"code", nil}}, + "\x0D_\x00\x05\x00\x00\x00code\x00"}, + {bson.M{"_": bson.Symbol("sym")}, + "\x0E_\x00\x04\x00\x00\x00sym\x00"}, + {bson.M{"_": bson.JavaScript{"code", bson.M{"": nil}}}, + "\x0F_\x00\x14\x00\x00\x00\x05\x00\x00\x00code\x00" + + "\x07\x00\x00\x00\x0A\x00\x00"}, + {bson.M{"_": 258}, + "\x10_\x00\x02\x01\x00\x00"}, + {bson.M{"_": bson.MongoTimestamp(258)}, + "\x11_\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"_": int64(258)}, + "\x12_\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"_": int64(258 << 32)}, + "\x12_\x00\x00\x00\x00\x00\x02\x01\x00\x00"}, + {bson.M{"_": bson.MaxKey}, + "\x7F_\x00"}, + {bson.M{"_": bson.MinKey}, + "\xFF_\x00"}, +} + +func (s *S) TestMarshalAllItems(c *C) { + for i, item := range allItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalAllItems(c *C) { + for i, item := range allItems { + value := bson.M{} + err := bson.Unmarshal([]byte(wrapInDoc(item.data)), value) + c.Assert(err, IsNil) + c.Assert(value, DeepEquals, item.obj, Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalRawAllItems(c *C) { + for i, item := range allItems { + if len(item.data) == 0 { + continue + } + value := item.obj.(bson.M)["_"] + if value == nil { + continue + } + pv := reflect.New(reflect.ValueOf(value).Type()) + raw := bson.Raw{item.data[0], []byte(item.data[3:])} + c.Logf("Unmarshal raw: %#v, %#v", raw, pv.Interface()) + err := raw.Unmarshal(pv.Interface()) + c.Assert(err, IsNil) + c.Assert(pv.Elem().Interface(), DeepEquals, value, Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalRawIncompatible(c *C) { + raw := bson.Raw{0x08, []byte{0x01}} // true + err := raw.Unmarshal(&struct{}{}) + c.Assert(err, ErrorMatches, "BSON kind 0x08 isn't compatible with type struct \\{\\}") +} + +func (s *S) TestUnmarshalZeroesStruct(c *C) { + data, err := bson.Marshal(bson.M{"b": 2}) + c.Assert(err, IsNil) + type T struct{ A, B int } + v := T{A: 1} + err = bson.Unmarshal(data, &v) + c.Assert(err, IsNil) + c.Assert(v.A, Equals, 0) + c.Assert(v.B, Equals, 2) +} + +func (s *S) TestUnmarshalZeroesMap(c *C) { + data, err := bson.Marshal(bson.M{"b": 2}) + c.Assert(err, IsNil) + m := bson.M{"a": 1} + err = bson.Unmarshal(data, &m) + c.Assert(err, IsNil) + c.Assert(m, DeepEquals, bson.M{"b": 2}) +} + +func (s *S) TestUnmarshalNonNilInterface(c *C) { + data, err := bson.Marshal(bson.M{"b": 2}) + c.Assert(err, IsNil) + m := bson.M{"a": 1} + var i interface{} + i = m + err = bson.Unmarshal(data, &i) + c.Assert(err, IsNil) + c.Assert(i, DeepEquals, bson.M{"b": 2}) + c.Assert(m, DeepEquals, bson.M{"a": 1}) +} + +// -------------------------------------------------------------------------- +// Some one way marshaling operations which would unmarshal differently. + +var oneWayMarshalItems = []testItemType{ + // These are being passed as pointers, and will unmarshal as values. + {bson.M{"": &bson.Binary{0x02, []byte("old")}}, + "\x05\x00\x07\x00\x00\x00\x02\x03\x00\x00\x00old"}, + {bson.M{"": &bson.Binary{0x80, []byte("udef")}}, + "\x05\x00\x04\x00\x00\x00\x80udef"}, + {bson.M{"": &bson.RegEx{"ab", "cd"}}, + "\x0B\x00ab\x00cd\x00"}, + {bson.M{"": &bson.JavaScript{"code", nil}}, + "\x0D\x00\x05\x00\x00\x00code\x00"}, + {bson.M{"": &bson.JavaScript{"code", bson.M{"": nil}}}, + "\x0F\x00\x14\x00\x00\x00\x05\x00\x00\x00code\x00" + + "\x07\x00\x00\x00\x0A\x00\x00"}, + + // There's no float32 type in BSON. Will encode as a float64. + {bson.M{"": float32(5.05)}, + "\x01\x00\x00\x00\x00@33\x14@"}, + + // The array will be unmarshaled as a slice instead. + {bson.M{"": [2]bool{true, false}}, + "\x04\x00\r\x00\x00\x00\x080\x00\x01\x081\x00\x00\x00"}, + + // The typed slice will be unmarshaled as []interface{}. + {bson.M{"": []bool{true, false}}, + "\x04\x00\r\x00\x00\x00\x080\x00\x01\x081\x00\x00\x00"}, + + // Will unmarshal as a []byte. + {bson.M{"": bson.Binary{0x00, []byte("yo")}}, + "\x05\x00\x02\x00\x00\x00\x00yo"}, + {bson.M{"": bson.Binary{0x02, []byte("old")}}, + "\x05\x00\x07\x00\x00\x00\x02\x03\x00\x00\x00old"}, + + // No way to preserve the type information here. We might encode as a zero + // value, but this would mean that pointer values in structs wouldn't be + // able to correctly distinguish between unset and set to the zero value. + {bson.M{"": (*byte)(nil)}, + "\x0A\x00"}, + + // No int types smaller than int32 in BSON. Could encode this as a char, + // but it would still be ambiguous, take more, and be awkward in Go when + // loaded without typing information. + {bson.M{"": byte(8)}, + "\x10\x00\x08\x00\x00\x00"}, + + // There are no unsigned types in BSON. Will unmarshal as int32 or int64. + {bson.M{"": uint32(258)}, + "\x10\x00\x02\x01\x00\x00"}, + {bson.M{"": uint64(258)}, + "\x12\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"": uint64(258 << 32)}, + "\x12\x00\x00\x00\x00\x00\x02\x01\x00\x00"}, + + // This will unmarshal as int. + {bson.M{"": int32(258)}, + "\x10\x00\x02\x01\x00\x00"}, + + // That's a special case. The unsigned value is too large for an int32, + // so an int64 is used instead. + {bson.M{"": uint32(1<<32 - 1)}, + "\x12\x00\xFF\xFF\xFF\xFF\x00\x00\x00\x00"}, + {bson.M{"": uint(1<<32 - 1)}, + "\x12\x00\xFF\xFF\xFF\xFF\x00\x00\x00\x00"}, +} + +func (s *S) TestOneWayMarshalItems(c *C) { + for i, item := range oneWayMarshalItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), + Commentf("Failed on item %d", i)) + } +} + +// -------------------------------------------------------------------------- +// Two-way tests for user-defined structures using the samples +// from bsonspec.org. + +type specSample1 struct { + Hello string +} + +type specSample2 struct { + BSON []interface{} "BSON" +} + +var structSampleItems = []testItemType{ + {&specSample1{"world"}, + "\x16\x00\x00\x00\x02hello\x00\x06\x00\x00\x00world\x00\x00"}, + {&specSample2{[]interface{}{"awesome", float64(5.05), 1986}}, + "1\x00\x00\x00\x04BSON\x00&\x00\x00\x00\x020\x00\x08\x00\x00\x00" + + "awesome\x00\x011\x00333333\x14@\x102\x00\xc2\x07\x00\x00\x00\x00"}, +} + +func (s *S) TestMarshalStructSampleItems(c *C) { + for i, item := range structSampleItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, item.data, + Commentf("Failed on item %d", i)) + } +} + +func (s *S) TestUnmarshalStructSampleItems(c *C) { + for _, item := range structSampleItems { + testUnmarshal(c, item.data, item.obj) + } +} + +func (s *S) Test64bitInt(c *C) { + var i int64 = (1 << 31) + if int(i) > 0 { + data, err := bson.Marshal(bson.M{"i": int(i)}) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc("\x12i\x00\x00\x00\x00\x80\x00\x00\x00\x00")) + + var result struct{ I int } + err = bson.Unmarshal(data, &result) + c.Assert(err, IsNil) + c.Assert(int64(result.I), Equals, i) + } +} + +// -------------------------------------------------------------------------- +// Generic two-way struct marshaling tests. + +var bytevar = byte(8) +var byteptr = &bytevar + +var structItems = []testItemType{ + {&struct{ Ptr *byte }{nil}, + "\x0Aptr\x00"}, + {&struct{ Ptr *byte }{&bytevar}, + "\x10ptr\x00\x08\x00\x00\x00"}, + {&struct{ Ptr **byte }{&byteptr}, + "\x10ptr\x00\x08\x00\x00\x00"}, + {&struct{ Byte byte }{8}, + "\x10byte\x00\x08\x00\x00\x00"}, + {&struct{ Byte byte }{0}, + "\x10byte\x00\x00\x00\x00\x00"}, + {&struct { + V byte "Tag" + }{8}, + "\x10Tag\x00\x08\x00\x00\x00"}, + {&struct { + V *struct { + Byte byte + } + }{&struct{ Byte byte }{8}}, + "\x03v\x00" + "\x0f\x00\x00\x00\x10byte\x00\b\x00\x00\x00\x00"}, + {&struct{ priv byte }{}, ""}, + + // The order of the dumped fields should be the same in the struct. + {&struct{ A, C, B, D, F, E *byte }{}, + "\x0Aa\x00\x0Ac\x00\x0Ab\x00\x0Ad\x00\x0Af\x00\x0Ae\x00"}, + + {&struct{ V bson.Raw }{bson.Raw{0x03, []byte("\x0f\x00\x00\x00\x10byte\x00\b\x00\x00\x00\x00")}}, + "\x03v\x00" + "\x0f\x00\x00\x00\x10byte\x00\b\x00\x00\x00\x00"}, + {&struct{ V bson.Raw }{bson.Raw{0x10, []byte("\x00\x00\x00\x00")}}, + "\x10v\x00" + "\x00\x00\x00\x00"}, + + // Byte arrays. + {&struct{ V [2]byte }{[2]byte{'y', 'o'}}, + "\x05v\x00\x02\x00\x00\x00\x00yo"}, +} + +func (s *S) TestMarshalStructItems(c *C) { + for i, item := range structItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), + Commentf("Failed on item %d", i)) + } +} + +func (s *S) TestUnmarshalStructItems(c *C) { + for _, item := range structItems { + testUnmarshal(c, wrapInDoc(item.data), item.obj) + } +} + +func (s *S) TestUnmarshalRawStructItems(c *C) { + for i, item := range structItems { + raw := bson.Raw{0x03, []byte(wrapInDoc(item.data))} + zero := makeZeroDoc(item.obj) + err := raw.Unmarshal(zero) + c.Assert(err, IsNil) + c.Assert(zero, DeepEquals, item.obj, Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalRawNil(c *C) { + // Regression test: shouldn't try to nil out the pointer itself, + // as it's not settable. + raw := bson.Raw{0x0A, []byte{}} + err := raw.Unmarshal(&struct{}{}) + c.Assert(err, IsNil) +} + +// -------------------------------------------------------------------------- +// One-way marshaling tests. + +type dOnIface struct { + D interface{} +} + +type ignoreField struct { + Before string + Ignore string `bson:"-"` + After string +} + +var marshalItems = []testItemType{ + // Ordered document dump. Will unmarshal as a dictionary by default. + {bson.D{{"a", nil}, {"c", nil}, {"b", nil}, {"d", nil}, {"f", nil}, {"e", true}}, + "\x0Aa\x00\x0Ac\x00\x0Ab\x00\x0Ad\x00\x0Af\x00\x08e\x00\x01"}, + {MyD{{"a", nil}, {"c", nil}, {"b", nil}, {"d", nil}, {"f", nil}, {"e", true}}, + "\x0Aa\x00\x0Ac\x00\x0Ab\x00\x0Ad\x00\x0Af\x00\x08e\x00\x01"}, + {&dOnIface{bson.D{{"a", nil}, {"c", nil}, {"b", nil}, {"d", true}}}, + "\x03d\x00" + wrapInDoc("\x0Aa\x00\x0Ac\x00\x0Ab\x00\x08d\x00\x01")}, + + {bson.RawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}, + "\x0Aa\x00" + "\x0Ac\x00" + "\x08b\x00\x01"}, + {MyRawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}, + "\x0Aa\x00" + "\x0Ac\x00" + "\x08b\x00\x01"}, + {&dOnIface{bson.RawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}}, + "\x03d\x00" + wrapInDoc("\x0Aa\x00"+"\x0Ac\x00"+"\x08b\x00\x01")}, + + {&ignoreField{"before", "ignore", "after"}, + "\x02before\x00\a\x00\x00\x00before\x00\x02after\x00\x06\x00\x00\x00after\x00"}, + + // Marshalling a Raw document does nothing. + {bson.Raw{0x03, []byte(wrapInDoc("anything"))}, + "anything"}, + {bson.Raw{Data: []byte(wrapInDoc("anything"))}, + "anything"}, +} + +func (s *S) TestMarshalOneWayItems(c *C) { + for _, item := range marshalItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data)) + } +} + +// -------------------------------------------------------------------------- +// One-way unmarshaling tests. + +var unmarshalItems = []testItemType{ + // Field is private. Should not attempt to unmarshal it. + {&struct{ priv byte }{}, + "\x10priv\x00\x08\x00\x00\x00"}, + + // Wrong casing. Field names are lowercased. + {&struct{ Byte byte }{}, + "\x10Byte\x00\x08\x00\x00\x00"}, + + // Ignore non-existing field. + {&struct{ Byte byte }{9}, + "\x10boot\x00\x08\x00\x00\x00" + "\x10byte\x00\x09\x00\x00\x00"}, + + // Do not unmarshal on ignored field. + {&ignoreField{"before", "", "after"}, + "\x02before\x00\a\x00\x00\x00before\x00" + + "\x02-\x00\a\x00\x00\x00ignore\x00" + + "\x02after\x00\x06\x00\x00\x00after\x00"}, + + // Ignore unsuitable types silently. + {map[string]string{"str": "s"}, + "\x02str\x00\x02\x00\x00\x00s\x00" + "\x10int\x00\x01\x00\x00\x00"}, + {map[string][]int{"array": []int{5, 9}}, + "\x04array\x00" + wrapInDoc("\x100\x00\x05\x00\x00\x00"+"\x021\x00\x02\x00\x00\x00s\x00"+"\x102\x00\x09\x00\x00\x00")}, + + // Wrong type. Shouldn't init pointer. + {&struct{ Str *byte }{}, + "\x02str\x00\x02\x00\x00\x00s\x00"}, + {&struct{ Str *struct{ Str string } }{}, + "\x02str\x00\x02\x00\x00\x00s\x00"}, + + // Ordered document. + {&struct{ bson.D }{bson.D{{"a", nil}, {"c", nil}, {"b", nil}, {"d", true}}}, + "\x03d\x00" + wrapInDoc("\x0Aa\x00\x0Ac\x00\x0Ab\x00\x08d\x00\x01")}, + + // Raw document. + {&bson.Raw{0x03, []byte(wrapInDoc("\x10byte\x00\x08\x00\x00\x00"))}, + "\x10byte\x00\x08\x00\x00\x00"}, + + // RawD document. + {&struct{ bson.RawD }{bson.RawD{{"a", bson.Raw{0x0A, []byte{}}}, {"c", bson.Raw{0x0A, []byte{}}}, {"b", bson.Raw{0x08, []byte{0x01}}}}}, + "\x03rawd\x00" + wrapInDoc("\x0Aa\x00\x0Ac\x00\x08b\x00\x01")}, + + // Decode old binary. + {bson.M{"_": []byte("old")}, + "\x05_\x00\x07\x00\x00\x00\x02\x03\x00\x00\x00old"}, + + // Decode old binary without length. According to the spec, this shouldn't happen. + {bson.M{"_": []byte("old")}, + "\x05_\x00\x03\x00\x00\x00\x02old"}, +} + +func (s *S) TestUnmarshalOneWayItems(c *C) { + for _, item := range unmarshalItems { + testUnmarshal(c, wrapInDoc(item.data), item.obj) + } +} + +func (s *S) TestUnmarshalNilInStruct(c *C) { + // Nil is the default value, so we need to ensure it's indeed being set. + b := byte(1) + v := &struct{ Ptr *byte }{&b} + err := bson.Unmarshal([]byte(wrapInDoc("\x0Aptr\x00")), v) + c.Assert(err, IsNil) + c.Assert(v, DeepEquals, &struct{ Ptr *byte }{nil}) +} + +// -------------------------------------------------------------------------- +// Marshalling error cases. + +type structWithDupKeys struct { + Name byte + Other byte "name" // Tag should precede. +} + +var marshalErrorItems = []testItemType{ + {bson.M{"": uint64(1 << 63)}, + "BSON has no uint64 type, and value is too large to fit correctly in an int64"}, + {bson.M{"": bson.ObjectId("tooshort")}, + "ObjectIDs must be exactly 12 bytes long \\(got 8\\)"}, + {int64(123), + "Can't marshal int64 as a BSON document"}, + {bson.M{"": 1i}, + "Can't marshal complex128 in a BSON document"}, + {&structWithDupKeys{}, + "Duplicated key 'name' in struct bson_test.structWithDupKeys"}, + {bson.Raw{0x0A, []byte{}}, + "Attempted to unmarshal Raw kind 10 as a document"}, + {&inlineCantPtr{&struct{ A, B int }{1, 2}}, + "Option ,inline needs a struct value or map field"}, + {&inlineDupName{1, struct{ A, B int }{2, 3}}, + "Duplicated key 'a' in struct bson_test.inlineDupName"}, + {&inlineDupMap{}, + "Multiple ,inline maps in struct bson_test.inlineDupMap"}, + {&inlineBadKeyMap{}, + "Option ,inline needs a map with string keys in struct bson_test.inlineBadKeyMap"}, + {&inlineMap{A: 1, M: map[string]interface{}{"a": 1}}, + `Can't have key "a" in inlined map; conflicts with struct field`}, +} + +func (s *S) TestMarshalErrorItems(c *C) { + for _, item := range marshalErrorItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, ErrorMatches, item.data) + c.Assert(data, IsNil) + } +} + +// -------------------------------------------------------------------------- +// Unmarshalling error cases. + +type unmarshalErrorType struct { + obj interface{} + data string + error string +} + +var unmarshalErrorItems = []unmarshalErrorType{ + // Tag name conflicts with existing parameter. + {&structWithDupKeys{}, + "\x10name\x00\x08\x00\x00\x00", + "Duplicated key 'name' in struct bson_test.structWithDupKeys"}, + + // Non-string map key. + {map[int]interface{}{}, + "\x10name\x00\x08\x00\x00\x00", + "BSON map must have string keys. Got: map\\[int\\]interface \\{\\}"}, + + {nil, + "\xEEname\x00", + "Unknown element kind \\(0xEE\\)"}, + + {struct{ Name bool }{}, + "\x10name\x00\x08\x00\x00\x00", + "Unmarshal can't deal with struct values. Use a pointer."}, + + {123, + "\x10name\x00\x08\x00\x00\x00", + "Unmarshal needs a map or a pointer to a struct."}, +} + +func (s *S) TestUnmarshalErrorItems(c *C) { + for _, item := range unmarshalErrorItems { + data := []byte(wrapInDoc(item.data)) + var value interface{} + switch reflect.ValueOf(item.obj).Kind() { + case reflect.Map, reflect.Ptr: + value = makeZeroDoc(item.obj) + case reflect.Invalid: + value = bson.M{} + default: + value = item.obj + } + err := bson.Unmarshal(data, value) + c.Assert(err, ErrorMatches, item.error) + } +} + +type unmarshalRawErrorType struct { + obj interface{} + raw bson.Raw + error string +} + +var unmarshalRawErrorItems = []unmarshalRawErrorType{ + // Tag name conflicts with existing parameter. + {&structWithDupKeys{}, + bson.Raw{0x03, []byte("\x10byte\x00\x08\x00\x00\x00")}, + "Duplicated key 'name' in struct bson_test.structWithDupKeys"}, + + {&struct{}{}, + bson.Raw{0xEE, []byte{}}, + "Unknown element kind \\(0xEE\\)"}, + + {struct{ Name bool }{}, + bson.Raw{0x10, []byte("\x08\x00\x00\x00")}, + "Raw Unmarshal can't deal with struct values. Use a pointer."}, + + {123, + bson.Raw{0x10, []byte("\x08\x00\x00\x00")}, + "Raw Unmarshal needs a map or a valid pointer."}, +} + +func (s *S) TestUnmarshalRawErrorItems(c *C) { + for i, item := range unmarshalRawErrorItems { + err := item.raw.Unmarshal(item.obj) + c.Assert(err, ErrorMatches, item.error, Commentf("Failed on item %d: %#v\n", i, item)) + } +} + +var corruptedData = []string{ + "\x04\x00\x00\x00\x00", // Shorter than minimum + "\x06\x00\x00\x00\x00", // Not enough data + "\x05\x00\x00", // Broken length + "\x05\x00\x00\x00\xff", // Corrupted termination + "\x0A\x00\x00\x00\x0Aooop\x00", // Unfinished C string + + // Array end past end of string (s[2]=0x07 is correct) + wrapInDoc("\x04\x00\x09\x00\x00\x00\x0A\x00\x00"), + + // Array end within string, but past acceptable. + wrapInDoc("\x04\x00\x08\x00\x00\x00\x0A\x00\x00"), + + // Document end within string, but past acceptable. + wrapInDoc("\x03\x00\x08\x00\x00\x00\x0A\x00\x00"), + + // String with corrupted end. + wrapInDoc("\x02\x00\x03\x00\x00\x00yo\xFF"), +} + +func (s *S) TestUnmarshalMapDocumentTooShort(c *C) { + for _, data := range corruptedData { + err := bson.Unmarshal([]byte(data), bson.M{}) + c.Assert(err, ErrorMatches, "Document is corrupted") + + err = bson.Unmarshal([]byte(data), &struct{}{}) + c.Assert(err, ErrorMatches, "Document is corrupted") + } +} + +// -------------------------------------------------------------------------- +// Setter test cases. + +var setterResult = map[string]error{} + +type setterType struct { + received interface{} +} + +func (o *setterType) SetBSON(raw bson.Raw) error { + err := raw.Unmarshal(&o.received) + if err != nil { + panic("The panic:" + err.Error()) + } + if s, ok := o.received.(string); ok { + if result, ok := setterResult[s]; ok { + return result + } + } + return nil +} + +type ptrSetterDoc struct { + Field *setterType "_" +} + +type valSetterDoc struct { + Field setterType "_" +} + +func (s *S) TestUnmarshalAllItemsWithPtrSetter(c *C) { + for _, item := range allItems { + for i := 0; i != 2; i++ { + var field *setterType + if i == 0 { + obj := &ptrSetterDoc{} + err := bson.Unmarshal([]byte(wrapInDoc(item.data)), obj) + c.Assert(err, IsNil) + field = obj.Field + } else { + obj := &valSetterDoc{} + err := bson.Unmarshal([]byte(wrapInDoc(item.data)), obj) + c.Assert(err, IsNil) + field = &obj.Field + } + if item.data == "" { + // Nothing to unmarshal. Should be untouched. + if i == 0 { + c.Assert(field, IsNil) + } else { + c.Assert(field.received, IsNil) + } + } else { + expected := item.obj.(bson.M)["_"] + c.Assert(field, NotNil, Commentf("Pointer not initialized (%#v)", expected)) + c.Assert(field.received, DeepEquals, expected) + } + } + } +} + +func (s *S) TestUnmarshalWholeDocumentWithSetter(c *C) { + obj := &setterType{} + err := bson.Unmarshal([]byte(sampleItems[0].data), obj) + c.Assert(err, IsNil) + c.Assert(obj.received, DeepEquals, bson.M{"hello": "world"}) +} + +func (s *S) TestUnmarshalSetterOmits(c *C) { + setterResult["2"] = &bson.TypeError{} + setterResult["4"] = &bson.TypeError{} + defer func() { + delete(setterResult, "2") + delete(setterResult, "4") + }() + + m := map[string]*setterType{} + data := wrapInDoc("\x02abc\x00\x02\x00\x00\x001\x00" + + "\x02def\x00\x02\x00\x00\x002\x00" + + "\x02ghi\x00\x02\x00\x00\x003\x00" + + "\x02jkl\x00\x02\x00\x00\x004\x00") + err := bson.Unmarshal([]byte(data), m) + c.Assert(err, IsNil) + c.Assert(m["abc"], NotNil) + c.Assert(m["def"], IsNil) + c.Assert(m["ghi"], NotNil) + c.Assert(m["jkl"], IsNil) + + c.Assert(m["abc"].received, Equals, "1") + c.Assert(m["ghi"].received, Equals, "3") +} + +func (s *S) TestUnmarshalSetterErrors(c *C) { + boom := errors.New("BOOM") + setterResult["2"] = boom + defer delete(setterResult, "2") + + m := map[string]*setterType{} + data := wrapInDoc("\x02abc\x00\x02\x00\x00\x001\x00" + + "\x02def\x00\x02\x00\x00\x002\x00" + + "\x02ghi\x00\x02\x00\x00\x003\x00") + err := bson.Unmarshal([]byte(data), m) + c.Assert(err, Equals, boom) + c.Assert(m["abc"], NotNil) + c.Assert(m["def"], IsNil) + c.Assert(m["ghi"], IsNil) + + c.Assert(m["abc"].received, Equals, "1") +} + +func (s *S) TestDMap(c *C) { + d := bson.D{{"a", 1}, {"b", 2}} + c.Assert(d.Map(), DeepEquals, bson.M{"a": 1, "b": 2}) +} + +func (s *S) TestUnmarshalSetterSetZero(c *C) { + setterResult["foo"] = bson.SetZero + defer delete(setterResult, "field") + + data, err := bson.Marshal(bson.M{"field": "foo"}) + c.Assert(err, IsNil) + + m := map[string]*setterType{} + err = bson.Unmarshal([]byte(data), m) + c.Assert(err, IsNil) + + value, ok := m["field"] + c.Assert(ok, Equals, true) + c.Assert(value, IsNil) +} + +// -------------------------------------------------------------------------- +// Getter test cases. + +type typeWithGetter struct { + result interface{} + err error +} + +func (t *typeWithGetter) GetBSON() (interface{}, error) { + return t.result, t.err +} + +type docWithGetterField struct { + Field *typeWithGetter "_" +} + +func (s *S) TestMarshalAllItemsWithGetter(c *C) { + for i, item := range allItems { + if item.data == "" { + continue + } + obj := &docWithGetterField{} + obj.Field = &typeWithGetter{result: item.obj.(bson.M)["_"]} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), + Commentf("Failed on item #%d", i)) + } +} + +func (s *S) TestMarshalWholeDocumentWithGetter(c *C) { + obj := &typeWithGetter{result: sampleItems[0].obj} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, sampleItems[0].data) +} + +func (s *S) TestGetterErrors(c *C) { + e := errors.New("oops") + + obj1 := &docWithGetterField{} + obj1.Field = &typeWithGetter{sampleItems[0].obj, e} + data, err := bson.Marshal(obj1) + c.Assert(err, ErrorMatches, "oops") + c.Assert(data, IsNil) + + obj2 := &typeWithGetter{sampleItems[0].obj, e} + data, err = bson.Marshal(obj2) + c.Assert(err, ErrorMatches, "oops") + c.Assert(data, IsNil) +} + +type intGetter int64 + +func (t intGetter) GetBSON() (interface{}, error) { + return int64(t), nil +} + +type typeWithIntGetter struct { + V intGetter ",minsize" +} + +func (s *S) TestMarshalShortWithGetter(c *C) { + obj := typeWithIntGetter{42} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + m := bson.M{} + err = bson.Unmarshal(data, m) + c.Assert(m["v"], Equals, 42) +} + +// -------------------------------------------------------------------------- +// Cross-type conversion tests. + +type crossTypeItem struct { + obj1 interface{} + obj2 interface{} +} + +type condStr struct { + V string ",omitempty" +} +type condStrNS struct { + V string `a:"A" bson:",omitempty" b:"B"` +} +type condBool struct { + V bool ",omitempty" +} +type condInt struct { + V int ",omitempty" +} +type condUInt struct { + V uint ",omitempty" +} +type condFloat struct { + V float64 ",omitempty" +} +type condIface struct { + V interface{} ",omitempty" +} +type condPtr struct { + V *bool ",omitempty" +} +type condSlice struct { + V []string ",omitempty" +} +type condMap struct { + V map[string]int ",omitempty" +} +type namedCondStr struct { + V string "myv,omitempty" +} +type condTime struct { + V time.Time ",omitempty" +} +type condStruct struct { + V struct{ A []int } ",omitempty" +} + +type shortInt struct { + V int64 ",minsize" +} +type shortUint struct { + V uint64 ",minsize" +} +type shortIface struct { + V interface{} ",minsize" +} +type shortPtr struct { + V *int64 ",minsize" +} +type shortNonEmptyInt struct { + V int64 ",minsize,omitempty" +} + +type inlineInt struct { + V struct{ A, B int } ",inline" +} +type inlineCantPtr struct { + V *struct{ A, B int } ",inline" +} +type inlineDupName struct { + A int + V struct{ A, B int } ",inline" +} +type inlineMap struct { + A int + M map[string]interface{} ",inline" +} +type inlineMapInt struct { + A int + M map[string]int ",inline" +} +type inlineMapMyM struct { + A int + M MyM ",inline" +} +type inlineDupMap struct { + M1 map[string]interface{} ",inline" + M2 map[string]interface{} ",inline" +} +type inlineBadKeyMap struct { + M map[int]int ",inline" +} + +type ( + MyString string + MyBytes []byte + MyBool bool + MyD []bson.DocElem + MyRawD []bson.RawDocElem + MyM map[string]interface{} +) + +var ( + truevar = true + falsevar = false + + int64var = int64(42) + int64ptr = &int64var + intvar = int(42) + intptr = &intvar +) + +func parseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return u +} + +// That's a pretty fun test. It will dump the first item, generate a zero +// value equivalent to the second one, load the dumped data onto it, and then +// verify that the resulting value is deep-equal to the untouched second value. +// Then, it will do the same in the *opposite* direction! +var twoWayCrossItems = []crossTypeItem{ + // int<=>int + {&struct{ I int }{42}, &struct{ I int8 }{42}}, + {&struct{ I int }{42}, &struct{ I int32 }{42}}, + {&struct{ I int }{42}, &struct{ I int64 }{42}}, + {&struct{ I int8 }{42}, &struct{ I int32 }{42}}, + {&struct{ I int8 }{42}, &struct{ I int64 }{42}}, + {&struct{ I int32 }{42}, &struct{ I int64 }{42}}, + + // uint<=>uint + {&struct{ I uint }{42}, &struct{ I uint8 }{42}}, + {&struct{ I uint }{42}, &struct{ I uint32 }{42}}, + {&struct{ I uint }{42}, &struct{ I uint64 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I uint32 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I uint64 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I uint64 }{42}}, + + // float32<=>float64 + {&struct{ I float32 }{42}, &struct{ I float64 }{42}}, + + // int<=>uint + {&struct{ I uint }{42}, &struct{ I int }{42}}, + {&struct{ I uint }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint }{42}, &struct{ I int64 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int64 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int64 }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int64 }{42}}, + + // int <=> float + {&struct{ I int }{42}, &struct{ I float64 }{42}}, + + // int <=> bool + {&struct{ I int }{1}, &struct{ I bool }{true}}, + {&struct{ I int }{0}, &struct{ I bool }{false}}, + + // uint <=> float64 + {&struct{ I uint }{42}, &struct{ I float64 }{42}}, + + // uint <=> bool + {&struct{ I uint }{1}, &struct{ I bool }{true}}, + {&struct{ I uint }{0}, &struct{ I bool }{false}}, + + // float64 <=> bool + {&struct{ I float64 }{1}, &struct{ I bool }{true}}, + {&struct{ I float64 }{0}, &struct{ I bool }{false}}, + + // string <=> string and string <=> []byte + {&struct{ S []byte }{[]byte("abc")}, &struct{ S string }{"abc"}}, + {&struct{ S []byte }{[]byte("def")}, &struct{ S bson.Symbol }{"def"}}, + {&struct{ S string }{"ghi"}, &struct{ S bson.Symbol }{"ghi"}}, + + // map <=> struct + {&struct { + A struct { + B, C int + } + }{struct{ B, C int }{1, 2}}, + map[string]map[string]int{"a": map[string]int{"b": 1, "c": 2}}}, + + {&struct{ A bson.Symbol }{"abc"}, map[string]string{"a": "abc"}}, + {&struct{ A bson.Symbol }{"abc"}, map[string][]byte{"a": []byte("abc")}}, + {&struct{ A []byte }{[]byte("abc")}, map[string]string{"a": "abc"}}, + {&struct{ A uint }{42}, map[string]int{"a": 42}}, + {&struct{ A uint }{42}, map[string]float64{"a": 42}}, + {&struct{ A uint }{1}, map[string]bool{"a": true}}, + {&struct{ A int }{42}, map[string]uint{"a": 42}}, + {&struct{ A int }{42}, map[string]float64{"a": 42}}, + {&struct{ A int }{1}, map[string]bool{"a": true}}, + {&struct{ A float64 }{42}, map[string]float32{"a": 42}}, + {&struct{ A float64 }{42}, map[string]int{"a": 42}}, + {&struct{ A float64 }{42}, map[string]uint{"a": 42}}, + {&struct{ A float64 }{1}, map[string]bool{"a": true}}, + {&struct{ A bool }{true}, map[string]int{"a": 1}}, + {&struct{ A bool }{true}, map[string]uint{"a": 1}}, + {&struct{ A bool }{true}, map[string]float64{"a": 1}}, + {&struct{ A **byte }{&byteptr}, map[string]byte{"a": 8}}, + + // url.URL <=> string + {&struct{ URL *url.URL }{parseURL("h://e.c/p")}, map[string]string{"url": "h://e.c/p"}}, + {&struct{ URL url.URL }{*parseURL("h://e.c/p")}, map[string]string{"url": "h://e.c/p"}}, + + // Slices + {&struct{ S []int }{[]int{1, 2, 3}}, map[string][]int{"s": []int{1, 2, 3}}}, + {&struct{ S *[]int }{&[]int{1, 2, 3}}, map[string][]int{"s": []int{1, 2, 3}}}, + + // Conditionals + {&condBool{true}, map[string]bool{"v": true}}, + {&condBool{}, map[string]bool{}}, + {&condInt{1}, map[string]int{"v": 1}}, + {&condInt{}, map[string]int{}}, + {&condUInt{1}, map[string]uint{"v": 1}}, + {&condUInt{}, map[string]uint{}}, + {&condFloat{}, map[string]int{}}, + {&condStr{"yo"}, map[string]string{"v": "yo"}}, + {&condStr{}, map[string]string{}}, + {&condStrNS{"yo"}, map[string]string{"v": "yo"}}, + {&condStrNS{}, map[string]string{}}, + {&condSlice{[]string{"yo"}}, map[string][]string{"v": []string{"yo"}}}, + {&condSlice{}, map[string][]string{}}, + {&condMap{map[string]int{"k": 1}}, bson.M{"v": bson.M{"k": 1}}}, + {&condMap{}, map[string][]string{}}, + {&condIface{"yo"}, map[string]string{"v": "yo"}}, + {&condIface{""}, map[string]string{"v": ""}}, + {&condIface{}, map[string]string{}}, + {&condPtr{&truevar}, map[string]bool{"v": true}}, + {&condPtr{&falsevar}, map[string]bool{"v": false}}, + {&condPtr{}, map[string]string{}}, + + {&condTime{time.Unix(123456789, 123e6)}, map[string]time.Time{"v": time.Unix(123456789, 123e6)}}, + {&condTime{}, map[string]string{}}, + + {&condStruct{struct{ A []int }{[]int{1}}}, bson.M{"v": bson.M{"a": []interface{}{1}}}}, + {&condStruct{struct{ A []int }{}}, bson.M{}}, + + {&namedCondStr{"yo"}, map[string]string{"myv": "yo"}}, + {&namedCondStr{}, map[string]string{}}, + + {&shortInt{1}, map[string]interface{}{"v": 1}}, + {&shortInt{1 << 30}, map[string]interface{}{"v": 1 << 30}}, + {&shortInt{1 << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortUint{1 << 30}, map[string]interface{}{"v": 1 << 30}}, + {&shortUint{1 << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortIface{int64(1) << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortPtr{int64ptr}, map[string]interface{}{"v": intvar}}, + + {&shortNonEmptyInt{1}, map[string]interface{}{"v": 1}}, + {&shortNonEmptyInt{1 << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortNonEmptyInt{}, map[string]interface{}{}}, + + {&inlineInt{struct{ A, B int }{1, 2}}, map[string]interface{}{"a": 1, "b": 2}}, + {&inlineMap{A: 1, M: map[string]interface{}{"b": 2}}, map[string]interface{}{"a": 1, "b": 2}}, + {&inlineMap{A: 1, M: nil}, map[string]interface{}{"a": 1}}, + {&inlineMapInt{A: 1, M: map[string]int{"b": 2}}, map[string]int{"a": 1, "b": 2}}, + {&inlineMapInt{A: 1, M: nil}, map[string]int{"a": 1}}, + {&inlineMapMyM{A: 1, M: MyM{"b": MyM{"c": 3}}}, map[string]interface{}{"a": 1, "b": map[string]interface{}{"c": 3}}}, + + // []byte <=> MyBytes + {&struct{ B MyBytes }{[]byte("abc")}, map[string]string{"b": "abc"}}, + {&struct{ B MyBytes }{[]byte{}}, map[string]string{"b": ""}}, + {&struct{ B MyBytes }{}, map[string]bool{}}, + {&struct{ B []byte }{[]byte("abc")}, map[string]MyBytes{"b": []byte("abc")}}, + + // bool <=> MyBool + {&struct{ B MyBool }{true}, map[string]bool{"b": true}}, + {&struct{ B MyBool }{}, map[string]bool{"b": false}}, + {&struct{ B MyBool }{}, map[string]string{}}, + {&struct{ B bool }{}, map[string]MyBool{"b": false}}, + + // arrays + {&struct{ V [2]int }{[...]int{1, 2}}, map[string][2]int{"v": [2]int{1, 2}}}, + + // zero time + {&struct{ V time.Time }{}, map[string]interface{}{"v": time.Time{}}}, + + // zero time + 1 second + 1 millisecond; overflows int64 as nanoseconds + {&struct{ V time.Time }{time.Unix(-62135596799, 1e6).Local()}, + map[string]interface{}{"v": time.Unix(-62135596799, 1e6).Local()}}, + + // bson.D <=> []DocElem + {&bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}, &bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}}, + {&bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}, &MyD{{"a", MyD{{"b", 1}, {"c", 2}}}}}, + + // bson.RawD <=> []RawDocElem + {&bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}, &bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}}, + {&bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}, &MyRawD{{"a", bson.Raw{0x08, []byte{0x01}}}}}, + + // bson.M <=> map + {bson.M{"a": bson.M{"b": 1, "c": 2}}, MyM{"a": MyM{"b": 1, "c": 2}}}, + {bson.M{"a": bson.M{"b": 1, "c": 2}}, map[string]interface{}{"a": map[string]interface{}{"b": 1, "c": 2}}}, + + // bson.M <=> map[MyString] + {bson.M{"a": bson.M{"b": 1, "c": 2}}, map[MyString]interface{}{"a": map[MyString]interface{}{"b": 1, "c": 2}}}, +} + +// Same thing, but only one way (obj1 => obj2). +var oneWayCrossItems = []crossTypeItem{ + // map <=> struct + {map[string]interface{}{"a": 1, "b": "2", "c": 3}, map[string]int{"a": 1, "c": 3}}, + + // inline map elides badly typed values + {map[string]interface{}{"a": 1, "b": "2", "c": 3}, &inlineMapInt{A: 1, M: map[string]int{"c": 3}}}, + + // Can't decode int into struct. + {bson.M{"a": bson.M{"b": 2}}, &struct{ A bool }{}}, + + // Would get decoded into a int32 too in the opposite direction. + {&shortIface{int64(1) << 30}, map[string]interface{}{"v": 1 << 30}}, +} + +func testCrossPair(c *C, dump interface{}, load interface{}) { + c.Logf("Dump: %#v", dump) + c.Logf("Load: %#v", load) + zero := makeZeroDoc(load) + data, err := bson.Marshal(dump) + c.Assert(err, IsNil) + c.Logf("Dumped: %#v", string(data)) + err = bson.Unmarshal(data, zero) + c.Assert(err, IsNil) + c.Logf("Loaded: %#v", zero) + c.Assert(zero, DeepEquals, load) +} + +func (s *S) TestTwoWayCrossPairs(c *C) { + for _, item := range twoWayCrossItems { + testCrossPair(c, item.obj1, item.obj2) + testCrossPair(c, item.obj2, item.obj1) + } +} + +func (s *S) TestOneWayCrossPairs(c *C) { + for _, item := range oneWayCrossItems { + testCrossPair(c, item.obj1, item.obj2) + } +} + +// -------------------------------------------------------------------------- +// ObjectId hex representation test. + +func (s *S) TestObjectIdHex(c *C) { + id := bson.ObjectIdHex("4d88e15b60f486e428412dc9") + c.Assert(id.String(), Equals, `ObjectIdHex("4d88e15b60f486e428412dc9")`) + c.Assert(id.Hex(), Equals, "4d88e15b60f486e428412dc9") +} + +func (s *S) TestIsObjectIdHex(c *C) { + test := []struct { + id string + valid bool + }{ + {"4d88e15b60f486e428412dc9", true}, + {"4d88e15b60f486e428412dc", false}, + {"4d88e15b60f486e428412dc9e", false}, + {"4d88e15b60f486e428412dcx", false}, + } + for _, t := range test { + c.Assert(bson.IsObjectIdHex(t.id), Equals, t.valid) + } +} + +// -------------------------------------------------------------------------- +// ObjectId parts extraction tests. + +type objectIdParts struct { + id bson.ObjectId + timestamp int64 + machine []byte + pid uint16 + counter int32 +} + +var objectIds = []objectIdParts{ + objectIdParts{ + bson.ObjectIdHex("4d88e15b60f486e428412dc9"), + 1300816219, + []byte{0x60, 0xf4, 0x86}, + 0xe428, + 4271561, + }, + objectIdParts{ + bson.ObjectIdHex("000000000000000000000000"), + 0, + []byte{0x00, 0x00, 0x00}, + 0x0000, + 0, + }, + objectIdParts{ + bson.ObjectIdHex("00000000aabbccddee000001"), + 0, + []byte{0xaa, 0xbb, 0xcc}, + 0xddee, + 1, + }, +} + +func (s *S) TestObjectIdPartsExtraction(c *C) { + for i, v := range objectIds { + t := time.Unix(v.timestamp, 0) + c.Assert(v.id.Time(), Equals, t, Commentf("#%d Wrong timestamp value", i)) + c.Assert(v.id.Machine(), DeepEquals, v.machine, Commentf("#%d Wrong machine id value", i)) + c.Assert(v.id.Pid(), Equals, v.pid, Commentf("#%d Wrong pid value", i)) + c.Assert(v.id.Counter(), Equals, v.counter, Commentf("#%d Wrong counter value", i)) + } +} + +func (s *S) TestNow(c *C) { + before := time.Now() + time.Sleep(1e6) + now := bson.Now() + time.Sleep(1e6) + after := time.Now() + c.Assert(now.After(before) && now.Before(after), Equals, true, Commentf("now=%s, before=%s, after=%s", now, before, after)) +} + +// -------------------------------------------------------------------------- +// ObjectId generation tests. + +func (s *S) TestNewObjectId(c *C) { + // Generate 10 ids + ids := make([]bson.ObjectId, 10) + for i := 0; i < 10; i++ { + ids[i] = bson.NewObjectId() + } + for i := 1; i < 10; i++ { + prevId := ids[i-1] + id := ids[i] + // Test for uniqueness among all other 9 generated ids + for j, tid := range ids { + if j != i { + c.Assert(id, Not(Equals), tid, Commentf("Generated ObjectId is not unique")) + } + } + // Check that timestamp was incremented and is within 30 seconds of the previous one + secs := id.Time().Sub(prevId.Time()).Seconds() + c.Assert((secs >= 0 && secs <= 30), Equals, true, Commentf("Wrong timestamp in generated ObjectId")) + // Check that machine ids are the same + c.Assert(id.Machine(), DeepEquals, prevId.Machine()) + // Check that pids are the same + c.Assert(id.Pid(), Equals, prevId.Pid()) + // Test for proper increment + delta := int(id.Counter() - prevId.Counter()) + c.Assert(delta, Equals, 1, Commentf("Wrong increment in generated ObjectId")) + } +} + +func (s *S) TestNewObjectIdWithTime(c *C) { + t := time.Unix(12345678, 0) + id := bson.NewObjectIdWithTime(t) + c.Assert(id.Time(), Equals, t) + c.Assert(id.Machine(), DeepEquals, []byte{0x00, 0x00, 0x00}) + c.Assert(int(id.Pid()), Equals, 0) + c.Assert(int(id.Counter()), Equals, 0) +} + +// -------------------------------------------------------------------------- +// ObjectId JSON marshalling. + +type jsonType struct { + Id *bson.ObjectId +} + +func (s *S) TestObjectIdJSONMarshaling(c *C) { + id := bson.ObjectIdHex("4d88e15b60f486e428412dc9") + v := jsonType{Id: &id} + data, err := json.Marshal(&v) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, `{"Id":"4d88e15b60f486e428412dc9"}`) +} + +func (s *S) TestObjectIdJSONUnmarshaling(c *C) { + data := []byte(`{"Id":"4d88e15b60f486e428412dc9"}`) + v := jsonType{} + err := json.Unmarshal(data, &v) + c.Assert(err, IsNil) + c.Assert(*v.Id, Equals, bson.ObjectIdHex("4d88e15b60f486e428412dc9")) +} + +func (s *S) TestObjectIdJSONUnmarshalingError(c *C) { + v := jsonType{} + err := json.Unmarshal([]byte(`{"Id":"4d88e15b60f486e428412dc9A"}`), &v) + c.Assert(err, ErrorMatches, `Invalid ObjectId in JSON: "4d88e15b60f486e428412dc9A"`) + err = json.Unmarshal([]byte(`{"Id":"4d88e15b60f486e428412dcZ"}`), &v) + c.Assert(err, ErrorMatches, `Invalid ObjectId in JSON: "4d88e15b60f486e428412dcZ" .*`) +} + +// -------------------------------------------------------------------------- +// Some simple benchmarks. + +type BenchT struct { + A, B, C, D, E, F string +} + +func BenchmarkUnmarhsalStruct(b *testing.B) { + v := BenchT{A: "A", D: "D", E: "E"} + data, err := bson.Marshal(&v) + if err != nil { + panic(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = bson.Unmarshal(data, &v) + } + if err != nil { + panic(err) + } +} + +func BenchmarkUnmarhsalMap(b *testing.B) { + m := bson.M{"a": "a", "d": "d", "e": "e"} + data, err := bson.Marshal(&m) + if err != nil { + panic(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = bson.Unmarshal(data, &m) + } + if err != nil { + panic(err) + } +} diff --git a/Godeps/_workspace/src/labix.org/v2/mgo/bson/decode.go b/Godeps/_workspace/src/labix.org/v2/mgo/bson/decode.go new file mode 100644 index 0000000..1ec034e --- /dev/null +++ b/Godeps/_workspace/src/labix.org/v2/mgo/bson/decode.go @@ -0,0 +1,795 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. + +package bson + +import ( + "fmt" + "math" + "net/url" + "reflect" + "sync" + "time" +) + +type decoder struct { + in []byte + i int + docType reflect.Type +} + +var typeM = reflect.TypeOf(M{}) + +func newDecoder(in []byte) *decoder { + return &decoder{in, 0, typeM} +} + +// -------------------------------------------------------------------------- +// Some helper functions. + +func corrupted() { + panic("Document is corrupted") +} + +func settableValueOf(i interface{}) reflect.Value { + v := reflect.ValueOf(i) + sv := reflect.New(v.Type()).Elem() + sv.Set(v) + return sv +} + +// -------------------------------------------------------------------------- +// Unmarshaling of documents. + +const ( + setterUnknown = iota + setterNone + setterType + setterAddr +) + +var setterStyle map[reflect.Type]int +var setterIface reflect.Type +var setterMutex sync.RWMutex + +func init() { + var iface Setter + setterIface = reflect.TypeOf(&iface).Elem() + setterStyle = make(map[reflect.Type]int) +} + +func getSetter(outt reflect.Type, out reflect.Value) Setter { + setterMutex.RLock() + style := setterStyle[outt] + setterMutex.RUnlock() + if style == setterNone { + return nil + } + if style == setterUnknown { + setterMutex.Lock() + defer setterMutex.Unlock() + if outt.Implements(setterIface) { + setterStyle[outt] = setterType + } else if reflect.PtrTo(outt).Implements(setterIface) { + setterStyle[outt] = setterAddr + } else { + setterStyle[outt] = setterNone + return nil + } + style = setterStyle[outt] + } + if style == setterAddr { + if !out.CanAddr() { + return nil + } + out = out.Addr() + } else if outt.Kind() == reflect.Ptr && out.IsNil() { + out.Set(reflect.New(outt.Elem())) + } + return out.Interface().(Setter) +} + +func clearMap(m reflect.Value) { + var none reflect.Value + for _, k := range m.MapKeys() { + m.SetMapIndex(k, none) + } +} + +func (d *decoder) readDocTo(out reflect.Value) { + var elemType reflect.Type + outt := out.Type() + outk := outt.Kind() + + for { + if outk == reflect.Ptr && out.IsNil() { + out.Set(reflect.New(outt.Elem())) + } + if setter := getSetter(outt, out); setter != nil { + var raw Raw + d.readDocTo(reflect.ValueOf(&raw)) + err := setter.SetBSON(raw) + if _, ok := err.(*TypeError); err != nil && !ok { + panic(err) + } + return + } + if outk == reflect.Ptr { + out = out.Elem() + outt = out.Type() + outk = out.Kind() + continue + } + break + } + + var fieldsMap map[string]fieldInfo + var inlineMap reflect.Value + start := d.i + + origout := out + if outk == reflect.Interface { + if d.docType.Kind() == reflect.Map { + mv := reflect.MakeMap(d.docType) + out.Set(mv) + out = mv + } else { + dv := reflect.New(d.docType).Elem() + out.Set(dv) + out = dv + } + outt = out.Type() + outk = outt.Kind() + } + + docType := d.docType + keyType := typeString + convertKey := false + switch outk { + case reflect.Map: + keyType = outt.Key() + if keyType.Kind() != reflect.String { + panic("BSON map must have string keys. Got: " + outt.String()) + } + if keyType != typeString { + convertKey = true + } + elemType = outt.Elem() + if elemType == typeIface { + d.docType = outt + } + if out.IsNil() { + out.Set(reflect.MakeMap(out.Type())) + } else if out.Len() > 0 { + clearMap(out) + } + case reflect.Struct: + if outt != typeRaw { + sinfo, err := getStructInfo(out.Type()) + if err != nil { + panic(err) + } + fieldsMap = sinfo.FieldsMap + out.Set(sinfo.Zero) + if sinfo.InlineMap != -1 { + inlineMap = out.Field(sinfo.InlineMap) + if !inlineMap.IsNil() && inlineMap.Len() > 0 { + clearMap(inlineMap) + } + elemType = inlineMap.Type().Elem() + if elemType == typeIface { + d.docType = inlineMap.Type() + } + } + } + case reflect.Slice: + switch outt.Elem() { + case typeDocElem: + origout.Set(d.readDocElems(outt)) + return + case typeRawDocElem: + origout.Set(d.readRawDocElems(outt)) + return + } + fallthrough + default: + panic("Unsupported document type for unmarshalling: " + out.Type().String()) + } + + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + name := d.readCStr() + if d.i >= end { + corrupted() + } + + switch outk { + case reflect.Map: + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + k := reflect.ValueOf(name) + if convertKey { + k = k.Convert(keyType) + } + out.SetMapIndex(k, e) + } + case reflect.Struct: + if outt == typeRaw { + d.dropElem(kind) + } else { + if info, ok := fieldsMap[name]; ok { + if info.Inline == nil { + d.readElemTo(out.Field(info.Num), kind) + } else { + d.readElemTo(out.FieldByIndex(info.Inline), kind) + } + } else if inlineMap.IsValid() { + if inlineMap.IsNil() { + inlineMap.Set(reflect.MakeMap(inlineMap.Type())) + } + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + inlineMap.SetMapIndex(reflect.ValueOf(name), e) + } + } else { + d.dropElem(kind) + } + } + case reflect.Slice: + } + + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } + d.docType = docType + + if outt == typeRaw { + out.Set(reflect.ValueOf(Raw{0x03, d.in[start:d.i]})) + } +} + +func (d *decoder) readArrayDocTo(out reflect.Value) { + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + i := 0 + l := out.Len() + for d.in[d.i] != '\x00' { + if i >= l { + panic("Length mismatch on array field") + } + kind := d.readByte() + for d.i < end && d.in[d.i] != '\x00' { + d.i++ + } + if d.i >= end { + corrupted() + } + d.i++ + d.readElemTo(out.Index(i), kind) + if d.i >= end { + corrupted() + } + i++ + } + if i != l { + panic("Length mismatch on array field") + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } +} + +func (d *decoder) readSliceDoc(t reflect.Type) interface{} { + tmp := make([]reflect.Value, 0, 8) + elemType := t.Elem() + + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + for d.i < end && d.in[d.i] != '\x00' { + d.i++ + } + if d.i >= end { + corrupted() + } + d.i++ + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + tmp = append(tmp, e) + } + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } + + n := len(tmp) + slice := reflect.MakeSlice(t, n, n) + for i := 0; i != n; i++ { + slice.Index(i).Set(tmp[i]) + } + return slice.Interface() +} + +var typeSlice = reflect.TypeOf([]interface{}{}) +var typeIface = typeSlice.Elem() + +func (d *decoder) readDocElems(typ reflect.Type) reflect.Value { + docType := d.docType + d.docType = typ + slice := make([]DocElem, 0, 8) + d.readDocWith(func(kind byte, name string) { + e := DocElem{Name: name} + v := reflect.ValueOf(&e.Value) + if d.readElemTo(v.Elem(), kind) { + slice = append(slice, e) + } + }) + slicev := reflect.New(typ).Elem() + slicev.Set(reflect.ValueOf(slice)) + d.docType = docType + return slicev +} + +func (d *decoder) readRawDocElems(typ reflect.Type) reflect.Value { + docType := d.docType + d.docType = typ + slice := make([]RawDocElem, 0, 8) + d.readDocWith(func(kind byte, name string) { + e := RawDocElem{Name: name} + v := reflect.ValueOf(&e.Value) + if d.readElemTo(v.Elem(), kind) { + slice = append(slice, e) + } + }) + slicev := reflect.New(typ).Elem() + slicev.Set(reflect.ValueOf(slice)) + d.docType = docType + return slicev +} + +func (d *decoder) readDocWith(f func(kind byte, name string)) { + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + name := d.readCStr() + if d.i >= end { + corrupted() + } + f(kind, name) + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } +} + +// -------------------------------------------------------------------------- +// Unmarshaling of individual elements within a document. + +var blackHole = settableValueOf(struct{}{}) + +func (d *decoder) dropElem(kind byte) { + d.readElemTo(blackHole, kind) +} + +// Attempt to decode an element from the document and put it into out. +// If the types are not compatible, the returned ok value will be +// false and out will be unchanged. +func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { + + start := d.i + + if kind == '\x03' { + // Special case for documents. Delegate to readDocTo(). + switch out.Kind() { + case reflect.Interface, reflect.Ptr, reflect.Struct, reflect.Map: + d.readDocTo(out) + default: + switch out.Interface().(type) { + case D: + out.Set(d.readDocElems(out.Type())) + case RawD: + out.Set(d.readRawDocElems(out.Type())) + default: + d.readDocTo(blackHole) + } + } + return true + } + + var in interface{} + + switch kind { + case 0x01: // Float64 + in = d.readFloat64() + case 0x02: // UTF-8 string + in = d.readStr() + case 0x03: // Document + panic("Can't happen. Handled above.") + case 0x04: // Array + outt := out.Type() + for outt.Kind() == reflect.Ptr { + outt = outt.Elem() + } + switch outt.Kind() { + case reflect.Array: + d.readArrayDocTo(out) + return true + case reflect.Slice: + in = d.readSliceDoc(outt) + default: + in = d.readSliceDoc(typeSlice) + } + case 0x05: // Binary + b := d.readBinary() + if b.Kind == 0x00 || b.Kind == 0x02 { + in = b.Data + } else { + in = b + } + case 0x06: // Undefined (obsolete, but still seen in the wild) + in = Undefined + case 0x07: // ObjectId + in = ObjectId(d.readBytes(12)) + case 0x08: // Bool + in = d.readBool() + case 0x09: // Timestamp + // MongoDB handles timestamps as milliseconds. + i := d.readInt64() + if i == -62135596800000 { + in = time.Time{} // In UTC for convenience. + } else { + in = time.Unix(i/1e3, i%1e3*1e6) + } + case 0x0A: // Nil + in = nil + case 0x0B: // RegEx + in = d.readRegEx() + case 0x0D: // JavaScript without scope + in = JavaScript{Code: d.readStr()} + case 0x0E: // Symbol + in = Symbol(d.readStr()) + case 0x0F: // JavaScript with scope + d.i += 4 // Skip length + js := JavaScript{d.readStr(), make(M)} + d.readDocTo(reflect.ValueOf(js.Scope)) + in = js + case 0x10: // Int32 + in = int(d.readInt32()) + case 0x11: // Mongo-specific timestamp + in = MongoTimestamp(d.readInt64()) + case 0x12: // Int64 + in = d.readInt64() + case 0x7F: // Max key + in = MaxKey + case 0xFF: // Min key + in = MinKey + default: + panic(fmt.Sprintf("Unknown element kind (0x%02X)", kind)) + } + + outt := out.Type() + + if outt == typeRaw { + out.Set(reflect.ValueOf(Raw{kind, d.in[start:d.i]})) + return true + } + + if setter := getSetter(outt, out); setter != nil { + err := setter.SetBSON(Raw{kind, d.in[start:d.i]}) + if err == SetZero { + out.Set(reflect.Zero(outt)) + return true + } + if err == nil { + return true + } + if _, ok := err.(*TypeError); !ok { + panic(err) + } + return false + } + + if in == nil { + out.Set(reflect.Zero(outt)) + return true + } + + outk := outt.Kind() + + // Dereference and initialize pointer if necessary. + first := true + for outk == reflect.Ptr { + if !out.IsNil() { + out = out.Elem() + } else { + elem := reflect.New(outt.Elem()) + if first { + // Only set if value is compatible. + first = false + defer func(out, elem reflect.Value) { + if good { + out.Set(elem) + } + }(out, elem) + } else { + out.Set(elem) + } + out = elem + } + outt = out.Type() + outk = outt.Kind() + } + + inv := reflect.ValueOf(in) + if outt == inv.Type() { + out.Set(inv) + return true + } + + switch outk { + case reflect.Interface: + out.Set(inv) + return true + case reflect.String: + switch inv.Kind() { + case reflect.String: + out.SetString(inv.String()) + return true + case reflect.Slice: + if b, ok := in.([]byte); ok { + out.SetString(string(b)) + return true + } + } + case reflect.Slice, reflect.Array: + // Remember, array (0x04) slices are built with the correct + // element type. If we are here, must be a cross BSON kind + // conversion (e.g. 0x05 unmarshalling on string). + if outt.Elem().Kind() != reflect.Uint8 { + break + } + switch inv.Kind() { + case reflect.String: + slice := []byte(inv.String()) + out.Set(reflect.ValueOf(slice)) + return true + case reflect.Slice: + switch outt.Kind() { + case reflect.Array: + reflect.Copy(out, inv) + case reflect.Slice: + out.SetBytes(inv.Bytes()) + } + return true + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch inv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetInt(inv.Int()) + return true + case reflect.Float32, reflect.Float64: + out.SetInt(int64(inv.Float())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetInt(1) + } else { + out.SetInt(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON?") + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + switch inv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetUint(uint64(inv.Int())) + return true + case reflect.Float32, reflect.Float64: + out.SetUint(uint64(inv.Float())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetUint(1) + } else { + out.SetUint(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON.") + } + case reflect.Float32, reflect.Float64: + switch inv.Kind() { + case reflect.Float32, reflect.Float64: + out.SetFloat(inv.Float()) + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetFloat(float64(inv.Int())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetFloat(1) + } else { + out.SetFloat(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON?") + } + case reflect.Bool: + switch inv.Kind() { + case reflect.Bool: + out.SetBool(inv.Bool()) + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetBool(inv.Int() != 0) + return true + case reflect.Float32, reflect.Float64: + out.SetBool(inv.Float() != 0) + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON?") + } + case reflect.Struct: + if outt == typeURL && inv.Kind() == reflect.String { + u, err := url.Parse(inv.String()) + if err != nil { + panic(err) + } + out.Set(reflect.ValueOf(u).Elem()) + return true + } + } + + return false +} + +// -------------------------------------------------------------------------- +// Parsers of basic types. + +func (d *decoder) readRegEx() RegEx { + re := RegEx{} + re.Pattern = d.readCStr() + re.Options = d.readCStr() + return re +} + +func (d *decoder) readBinary() Binary { + l := d.readInt32() + b := Binary{} + b.Kind = d.readByte() + b.Data = d.readBytes(l) + if b.Kind == 0x02 && len(b.Data) >= 4 { + // Weird obsolete format with redundant length. + b.Data = b.Data[4:] + } + return b +} + +func (d *decoder) readStr() string { + l := d.readInt32() + b := d.readBytes(l - 1) + if d.readByte() != '\x00' { + corrupted() + } + return string(b) +} + +func (d *decoder) readCStr() string { + start := d.i + end := start + l := len(d.in) + for ; end != l; end++ { + if d.in[end] == '\x00' { + break + } + } + d.i = end + 1 + if d.i > l { + corrupted() + } + return string(d.in[start:end]) +} + +func (d *decoder) readBool() bool { + if d.readByte() == 1 { + return true + } + return false +} + +func (d *decoder) readFloat64() float64 { + return math.Float64frombits(uint64(d.readInt64())) +} + +func (d *decoder) readInt32() int32 { + b := d.readBytes(4) + return int32((uint32(b[0]) << 0) | + (uint32(b[1]) << 8) | + (uint32(b[2]) << 16) | + (uint32(b[3]) << 24)) +} + +func (d *decoder) readInt64() int64 { + b := d.readBytes(8) + return int64((uint64(b[0]) << 0) | + (uint64(b[1]) << 8) | + (uint64(b[2]) << 16) | + (uint64(b[3]) << 24) | + (uint64(b[4]) << 32) | + (uint64(b[5]) << 40) | + (uint64(b[6]) << 48) | + (uint64(b[7]) << 56)) +} + +func (d *decoder) readByte() byte { + i := d.i + d.i++ + if d.i > len(d.in) { + corrupted() + } + return d.in[i] +} + +func (d *decoder) readBytes(length int32) []byte { + start := d.i + d.i += int(length) + if d.i > len(d.in) { + corrupted() + } + return d.in[start : start+int(length)] +} diff --git a/Godeps/_workspace/src/labix.org/v2/mgo/bson/encode.go b/Godeps/_workspace/src/labix.org/v2/mgo/bson/encode.go new file mode 100644 index 0000000..6ba383a --- /dev/null +++ b/Godeps/_workspace/src/labix.org/v2/mgo/bson/encode.go @@ -0,0 +1,462 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. + +package bson + +import ( + "fmt" + "math" + "net/url" + "reflect" + "strconv" + "time" +) + +// -------------------------------------------------------------------------- +// Some internal infrastructure. + +var ( + typeBinary = reflect.TypeOf(Binary{}) + typeObjectId = reflect.TypeOf(ObjectId("")) + typeSymbol = reflect.TypeOf(Symbol("")) + typeMongoTimestamp = reflect.TypeOf(MongoTimestamp(0)) + typeOrderKey = reflect.TypeOf(MinKey) + typeDocElem = reflect.TypeOf(DocElem{}) + typeRawDocElem = reflect.TypeOf(RawDocElem{}) + typeRaw = reflect.TypeOf(Raw{}) + typeURL = reflect.TypeOf(url.URL{}) + typeTime = reflect.TypeOf(time.Time{}) + typeString = reflect.TypeOf("") +) + +const itoaCacheSize = 32 + +var itoaCache []string + +func init() { + itoaCache = make([]string, itoaCacheSize) + for i := 0; i != itoaCacheSize; i++ { + itoaCache[i] = strconv.Itoa(i) + } +} + +func itoa(i int) string { + if i < itoaCacheSize { + return itoaCache[i] + } + return strconv.Itoa(i) +} + +// -------------------------------------------------------------------------- +// Marshaling of the document value itself. + +type encoder struct { + out []byte +} + +func (e *encoder) addDoc(v reflect.Value) { + for { + if vi, ok := v.Interface().(Getter); ok { + getv, err := vi.GetBSON() + if err != nil { + panic(err) + } + v = reflect.ValueOf(getv) + continue + } + if v.Kind() == reflect.Ptr { + v = v.Elem() + continue + } + break + } + + if v.Type() == typeRaw { + raw := v.Interface().(Raw) + if raw.Kind != 0x03 && raw.Kind != 0x00 { + panic("Attempted to unmarshal Raw kind " + strconv.Itoa(int(raw.Kind)) + " as a document") + } + e.addBytes(raw.Data...) + return + } + + start := e.reserveInt32() + + switch v.Kind() { + case reflect.Map: + e.addMap(v) + case reflect.Struct: + e.addStruct(v) + case reflect.Array, reflect.Slice: + e.addSlice(v) + default: + panic("Can't marshal " + v.Type().String() + " as a BSON document") + } + + e.addBytes(0) + e.setInt32(start, int32(len(e.out)-start)) +} + +func (e *encoder) addMap(v reflect.Value) { + for _, k := range v.MapKeys() { + e.addElem(k.String(), v.MapIndex(k), false) + } +} + +func (e *encoder) addStruct(v reflect.Value) { + sinfo, err := getStructInfo(v.Type()) + if err != nil { + panic(err) + } + var value reflect.Value + if sinfo.InlineMap >= 0 { + m := v.Field(sinfo.InlineMap) + if m.Len() > 0 { + for _, k := range m.MapKeys() { + ks := k.String() + if _, found := sinfo.FieldsMap[ks]; found { + panic(fmt.Sprintf("Can't have key %q in inlined map; conflicts with struct field", ks)) + } + e.addElem(ks, m.MapIndex(k), false) + } + } + } + for _, info := range sinfo.FieldsList { + if info.Inline == nil { + value = v.Field(info.Num) + } else { + value = v.FieldByIndex(info.Inline) + } + if info.OmitEmpty && isZero(value) { + continue + } + e.addElem(info.Key, value, info.MinSize) + } +} + +func isZero(v reflect.Value) bool { + switch v.Kind() { + case reflect.String: + return len(v.String()) == 0 + case reflect.Ptr, reflect.Interface: + return v.IsNil() + case reflect.Slice: + return v.Len() == 0 + case reflect.Map: + return v.Len() == 0 + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Struct: + if v.Type() == typeTime { + return v.Interface().(time.Time).IsZero() + } + for i := v.NumField()-1; i >= 0; i-- { + if !isZero(v.Field(i)) { + return false + } + } + return true + } + return false +} + +func (e *encoder) addSlice(v reflect.Value) { + vi := v.Interface() + if d, ok := vi.(D); ok { + for _, elem := range d { + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + if d, ok := vi.(RawD); ok { + for _, elem := range d { + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + l := v.Len() + et := v.Type().Elem() + if et == typeDocElem { + for i := 0; i < l; i++ { + elem := v.Index(i).Interface().(DocElem) + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + if et == typeRawDocElem { + for i := 0; i < l; i++ { + elem := v.Index(i).Interface().(RawDocElem) + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + for i := 0; i < l; i++ { + e.addElem(itoa(i), v.Index(i), false) + } +} + +// -------------------------------------------------------------------------- +// Marshaling of elements in a document. + +func (e *encoder) addElemName(kind byte, name string) { + e.addBytes(kind) + e.addBytes([]byte(name)...) + e.addBytes(0) +} + +func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { + + if !v.IsValid() { + e.addElemName('\x0A', name) + return + } + + if getter, ok := v.Interface().(Getter); ok { + getv, err := getter.GetBSON() + if err != nil { + panic(err) + } + e.addElem(name, reflect.ValueOf(getv), minSize) + return + } + + switch v.Kind() { + + case reflect.Interface: + e.addElem(name, v.Elem(), minSize) + + case reflect.Ptr: + e.addElem(name, v.Elem(), minSize) + + case reflect.String: + s := v.String() + switch v.Type() { + case typeObjectId: + if len(s) != 12 { + panic("ObjectIDs must be exactly 12 bytes long (got " + + strconv.Itoa(len(s)) + ")") + } + e.addElemName('\x07', name) + e.addBytes([]byte(s)...) + case typeSymbol: + e.addElemName('\x0E', name) + e.addStr(s) + default: + e.addElemName('\x02', name) + e.addStr(s) + } + + case reflect.Float32, reflect.Float64: + e.addElemName('\x01', name) + e.addInt64(int64(math.Float64bits(v.Float()))) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + u := v.Uint() + if int64(u) < 0 { + panic("BSON has no uint64 type, and value is too large to fit correctly in an int64") + } else if u <= math.MaxInt32 && (minSize || v.Kind() <= reflect.Uint32) { + e.addElemName('\x10', name) + e.addInt32(int32(u)) + } else { + e.addElemName('\x12', name) + e.addInt64(int64(u)) + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch v.Type() { + case typeMongoTimestamp: + e.addElemName('\x11', name) + e.addInt64(v.Int()) + + case typeOrderKey: + if v.Int() == int64(MaxKey) { + e.addElemName('\x7F', name) + } else { + e.addElemName('\xFF', name) + } + + default: + i := v.Int() + if (minSize || v.Type().Kind() != reflect.Int64) && i >= math.MinInt32 && i <= math.MaxInt32 { + // It fits into an int32, encode as such. + e.addElemName('\x10', name) + e.addInt32(int32(i)) + } else { + e.addElemName('\x12', name) + e.addInt64(i) + } + } + + case reflect.Bool: + e.addElemName('\x08', name) + if v.Bool() { + e.addBytes(1) + } else { + e.addBytes(0) + } + + case reflect.Map: + e.addElemName('\x03', name) + e.addDoc(v) + + case reflect.Slice: + vt := v.Type() + et := vt.Elem() + if et.Kind() == reflect.Uint8 { + e.addElemName('\x05', name) + e.addBinary('\x00', v.Bytes()) + } else if et == typeDocElem || et == typeRawDocElem { + e.addElemName('\x03', name) + e.addDoc(v) + } else { + e.addElemName('\x04', name) + e.addDoc(v) + } + + case reflect.Array: + et := v.Type().Elem() + if et.Kind() == reflect.Uint8 { + e.addElemName('\x05', name) + e.addBinary('\x00', v.Slice(0, v.Len()).Interface().([]byte)) + } else { + e.addElemName('\x04', name) + e.addDoc(v) + } + + case reflect.Struct: + switch s := v.Interface().(type) { + + case Raw: + kind := s.Kind + if kind == 0x00 { + kind = 0x03 + } + e.addElemName(kind, name) + e.addBytes(s.Data...) + + case Binary: + e.addElemName('\x05', name) + e.addBinary(s.Kind, s.Data) + + case RegEx: + e.addElemName('\x0B', name) + e.addCStr(s.Pattern) + e.addCStr(s.Options) + + case JavaScript: + if s.Scope == nil { + e.addElemName('\x0D', name) + e.addStr(s.Code) + } else { + e.addElemName('\x0F', name) + start := e.reserveInt32() + e.addStr(s.Code) + e.addDoc(reflect.ValueOf(s.Scope)) + e.setInt32(start, int32(len(e.out)-start)) + } + + case time.Time: + // MongoDB handles timestamps as milliseconds. + e.addElemName('\x09', name) + e.addInt64(s.Unix() * 1000 + int64(s.Nanosecond() / 1e6)) + + case url.URL: + e.addElemName('\x02', name) + e.addStr(s.String()) + + case undefined: + e.addElemName('\x06', name) + + default: + e.addElemName('\x03', name) + e.addDoc(v) + } + + default: + panic("Can't marshal " + v.Type().String() + " in a BSON document") + } +} + +// -------------------------------------------------------------------------- +// Marshaling of base types. + +func (e *encoder) addBinary(subtype byte, v []byte) { + if subtype == 0x02 { + // Wonder how that brilliant idea came to life. Obsolete, luckily. + e.addInt32(int32(len(v) + 4)) + e.addBytes(subtype) + e.addInt32(int32(len(v))) + } else { + e.addInt32(int32(len(v))) + e.addBytes(subtype) + } + e.addBytes(v...) +} + +func (e *encoder) addStr(v string) { + e.addInt32(int32(len(v) + 1)) + e.addCStr(v) +} + +func (e *encoder) addCStr(v string) { + e.addBytes([]byte(v)...) + e.addBytes(0) +} + +func (e *encoder) reserveInt32() (pos int) { + pos = len(e.out) + e.addBytes(0, 0, 0, 0) + return pos +} + +func (e *encoder) setInt32(pos int, v int32) { + e.out[pos+0] = byte(v) + e.out[pos+1] = byte(v >> 8) + e.out[pos+2] = byte(v >> 16) + e.out[pos+3] = byte(v >> 24) +} + +func (e *encoder) addInt32(v int32) { + u := uint32(v) + e.addBytes(byte(u), byte(u>>8), byte(u>>16), byte(u>>24)) +} + +func (e *encoder) addInt64(v int64) { + u := uint64(v) + e.addBytes(byte(u), byte(u>>8), byte(u>>16), byte(u>>24), + byte(u>>32), byte(u>>40), byte(u>>48), byte(u>>56)) +} + +func (e *encoder) addBytes(v ...byte) { + e.out = append(e.out, v...) +}