/* -*-c++-*- OpenThreads - Copyright (C) 1998-2007 Robert Osfield
 *
 * This library is open source and may be redistributed and/or modified under
 * the terms of the OpenSceneGraph Public License (OSGPL) version 0.0 or
 * (at your option) any later version.  The full license is in LICENSE file
 * included with this distribution, and on the openscenegraph.org website.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * OpenSceneGraph Public License for more details.
*/

#ifndef _OPENTHREADS_READWRITEMUTEX_
#define _OPENTHREADS_READWRITEMUTEX_

#include <OpenThreads/Thread>
#include <OpenThreads/ReentrantMutex>

namespace OpenThreads {

class ReadWriteMutex
{
    public:

        ReadWriteMutex():
            _readCount(0) {}

        virtual ~ReadWriteMutex() {}

        virtual int readLock()
        {
            OpenThreads::ScopedLock<OpenThreads::Mutex> lock(_readCountMutex);
            int result = 0;
            if (_readCount==0)
            {
                result = _readWriteMutex.lock();
            }
            ++_readCount;
            return result;
        }


        virtual int readUnlock()
        {
            OpenThreads::ScopedLock<OpenThreads::Mutex> lock(_readCountMutex);
            int result = 0;
            if (_readCount>0)
            {
                --_readCount;
                if (_readCount==0)
                {
                    result = _readWriteMutex.unlock();
                }
            }
            return result;
        }

        virtual int writeLock()
        {
            return _readWriteMutex.lock();
        }

        virtual int writeUnlock()
        {
            return _readWriteMutex.unlock();
        }

    protected:

        ReadWriteMutex(const ReadWriteMutex&) {}
        ReadWriteMutex& operator = (const ReadWriteMutex&) { return *(this); }

#if 0
        ReentrantMutex  _readWriteMutex;
        ReentrantMutex  _readCountMutex;
#else
        OpenThreads::Mutex  _readWriteMutex;
        OpenThreads::Mutex  _readCountMutex;
#endif
        unsigned int    _readCount;

};

class ScopedReadLock
{
    public:

        ScopedReadLock(ReadWriteMutex& mutex):_mutex(mutex) { _mutex.readLock(); }
        ~ScopedReadLock() { _mutex.readUnlock(); }

    protected:
        ReadWriteMutex& _mutex;

        ScopedReadLock& operator = (const ScopedReadLock&) { return *this; }
};


class ScopedWriteLock
{
    public:

        ScopedWriteLock(ReadWriteMutex& mutex):_mutex(mutex) { _mutex.writeLock(); }
        ~ScopedWriteLock() { _mutex.writeUnlock(); }

    protected:
        ReadWriteMutex& _mutex;

        ScopedWriteLock& operator = (const ScopedWriteLock&) { return *this; }
};

}

#endif
