diff --git a/db/db.go b/db/db.go new file mode 100644 index 0000000..cddd2da --- /dev/null +++ b/db/db.go @@ -0,0 +1,19 @@ +package db + +import "gopkg.in/mgo.v2" + +type Mongo struct { + URL string + Database string + CollectionName string +} + +func (db *Mongo) Collection() (*mgo.Collection, error) { + session, err := mgo.Dial(db.URL) + if err != nil { + return nil, err + } + c := session.DB(db.Database).C(db.CollectionName) + + return c, nil +} diff --git a/main.go b/main.go index c391aba..77b5e5e 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,7 @@ import ( "time" "github.com/gojp/goreportcard/check" - "gopkg.in/mgo.v2" + "github.com/gojp/goreportcard/db" "labix.org/v2/mgo/bson" ) @@ -22,15 +22,6 @@ var ( mongoCollection = "reports" ) -func getMongoCollection() (*mgo.Collection, error) { - session, err := mgo.Dial(mongoURL) - if err != nil { - return nil, err - } - c := session.DB(mongoDatabase).C(mongoCollection) - return c, nil -} - func homeHandler(w http.ResponseWriter, r *http.Request) { log.Println("Serving home page") if r.URL.Path[1:] == "" { @@ -103,7 +94,8 @@ type checksResp struct { func getFromCache(repo string) (checksResp, error) { // try and fetch from mongo - coll, err := getMongoCollection() + db := db.Mongo{URL: mongoURL, Database: mongoDatabase, CollectionName: mongoCollection} + coll, err := db.Collection() if err != nil { return checksResp{}, fmt.Errorf("Failed to get mongo collection during GET: %v", err) } @@ -248,7 +240,8 @@ func checkHandler(w http.ResponseWriter, r *http.Request) { w.Write(b) // write to mongo - coll, err := getMongoCollection() + db := db.Mongo{URL: mongoURL, Database: mongoDatabase, CollectionName: mongoCollection} + coll, err := db.Collection() if err != nil { log.Println("Failed to get mongo collection: ", err) } else {