/* Copyright (C) 2006 MySQL AB

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA */

#include <string.h>
#include <stdio.h>
#include <memory.h>
#include "Engine.h"
#include "StorageDatabase.h"
#include "StorageConnection.h"
#include "SyncObject.h"
#include "Sync.h"
#include "SQLError.h"
#include "Threads.h"
#include "StorageHandler.h"
#include "StorageTable.h"
#include "StorageTableShare.h"
#include "Sync.h"
#include "Threads.h"
#include "Configuration.h"
#include "Connection.h"
#include "Database.h"
#include "Table.h"
#include "Field.h"
#include "User.h"
#include "RoleModel.h"
#include "Value.h"
#include "Record.h"
#include "Transaction.h"
#include "Statement.h"
#include "Bitmap.h"
#include "PStatement.h"
#include "RSet.h"
#include "Sequence.h"
#include "StorageConnection.h"
#include "MySqlEnums.h"
#include "ScaledBinary.h"

#define ACCOUNT				"mysql"
#define PASSWORD			"mysql"

static Threads			*threads;
static Configuration	*configuration;

static const char *ddl [] = {
	"upgrade sequence mystorage.indexes",
	NULL
	};

#ifdef _DEBUG
#undef THIS_FILE
static char THIS_FILE[]=__FILE__;
#endif

//////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////


StorageDatabase::StorageDatabase(const char *dbName, const char* path)
{
	name = dbName;
	filename = path;
	memset(shares, 0, sizeof(shares));
	masterConnection = NULL;
	user = NULL;
	lookupIndexAlias = NULL;
	useCount = 1;
	
	if (!threads)
		threads = new Threads(NULL);
}

StorageDatabase::~StorageDatabase(void)
{
	if (lookupIndexAlias)
		lookupIndexAlias->release();

	if (masterConnection)
		masterConnection->release();
	
	if (user)
		user->release();
}

StorageTableShare* StorageDatabase::getTableShare(const char* name, int impureSize, bool tempTable)
{
	char tableName[256];
	char schemaName[256];
	StorageTableShare::cleanupTableName(name, tableName, sizeof(tableName), schemaName, sizeof(schemaName));
	if (tempTable)
		sprintf(schemaName, "%s", "TEMP");
	Sync sync(&syncObject, "StorageDatabase::getTableShare");
	sync.lock(Shared);
	int slot = JString::hash(tableName, shareHashSize);
	StorageTableShare *share;
	
	for (share = shares[slot]; share; share = share->collision)
		if (share->name == tableName && share->schemaName == schemaName)
			return share;
	
	sync.unlock();
	sync.lock(Exclusive);
	
	for (share = shares[slot]; share; share = share->collision)
		if (share->name == tableName && share->schemaName == schemaName)
			return share;
	
	share = new StorageTableShare(this, tableName, schemaName, impureSize);
	share->collision = shares[slot];
	shares[slot] = share;
	
	return share;
}

Connection* StorageDatabase::getConnection()
{
	if (!configuration)
		configuration = new Configuration(NULL);
		
	return new Connection(configuration);
}

Connection* StorageDatabase::getOpenConnection(void)
{
	try
		{
		if (!masterConnection)
			masterConnection = getConnection();
		
		if (!masterConnection->database)
			masterConnection->openDatabase(name, filename, ACCOUNT, PASSWORD, NULL, threads);
		}
	catch (...)
		{
		if (masterConnection)
			{
			masterConnection->close();
			masterConnection = NULL;
			}
		
		throw;
		}
	
	return masterConnection->clone();
}

Connection* StorageDatabase::createDatabase(void)
{
	try
		{
		masterConnection = getConnection();
		masterConnection->createDatabase(name, filename, ACCOUNT, PASSWORD, threads);
		Statement *statement = masterConnection->createStatement();
		
		for (const char **sql = ddl; *sql; ++sql)
			statement->execute(*sql);
		
		statement->release();
		}
	catch (...)
		{
		if (masterConnection)
			{
			masterConnection->close();
			masterConnection = NULL;
			}
		
		throw;
		}
	
	return masterConnection->clone();
}


Table* StorageDatabase::createTable(StorageConnection *storageConnection, const char* tableName, const char *schemaName, const char* sql, int64 autoIncrementValue)
{
	Database *database = masterConnection->database;
	
	if (!user)
		if ((user = database->findUser(ACCOUNT)))
			user->addRef();		
	
	Statement *statement = masterConnection->createStatement();
	
	try
		{
		statement->execute(sql);
		
		if (autoIncrementValue)
			{
			char buffer[1024];
			snprintf(buffer, sizeof(buffer), "create sequence  %s.\"%s\" start with " I64FORMAT, schemaName, tableName, autoIncrementValue - 1);
			statement->execute(buffer);
			}
			
		statement->release();
		}
	catch (SQLException& exception)
		{
		exception;
		statement->release();
		storageConnection->setErrorText(&exception);
		
		return NULL;
		}
		
	return findTable(tableName, schemaName);
}

Table* StorageDatabase::findTable(const char* tableName, const char *schemaName)
{
	return masterConnection->database->findTable(schemaName, tableName);
}

int StorageDatabase::insert(Connection* connection, Table* table, Stream* stream)
{
	return table->insert(connection->getTransaction(), stream);
}

int StorageDatabase::next(StorageTable* storageTable, int recordNumber)
{
	StorageConnection *storageConnection = storageTable->storageConnection;
	Connection *connection = storageConnection->connection;
	Table *table = storageTable->share->table;
	
	try
		{
		for (;; ++recordNumber)
			{
			Record *candidate = table->fetchNext(recordNumber);
			
			if (!candidate)
				return StorageErrorRecordNotFound;
			
			Record *record = candidate->fetchVersion(connection->getTransaction());
			
			if (!record)
				{
				candidate->release();
				continue;
				}
			
			if (candidate != record)
				{
				record->addRef();
				candidate->release();
				}
			
			recordNumber = record->recordNumber;
			storageTable->setRecord(record);
			
			return recordNumber;
			}
		}
	catch (SQLException& exception)
		{
		storageConnection->setErrorText(&exception);
		return StorageErrorRecordNotFound;
		}
}

int StorageDatabase::fetch(StorageConnection *storageConnection, StorageTable* storageTable, int recordNumber)
{
	Table *table = storageTable->share->table;
	Connection *connection = storageConnection->connection;
	
	try
		{
		Record *candidate = table->fetch(recordNumber);
		
		if (!candidate)
			return StorageErrorRecordNotFound;

		Record *record = candidate->fetchVersion(connection->getTransaction());
		
		if (!record)
			{
			candidate->release();
			return StorageErrorRecordNotFound;
			}
		
		if (record != candidate)
			{
			record->addRef();
			candidate->release();
			}
		
		storageTable->setRecord(record);
				
		return 0;
		}
	catch (SQLException& exception)
		{
		storageConnection->setErrorText(&exception);
		
		return StorageErrorRecordNotFound;
		}
}


int StorageDatabase::nextIndexed(StorageTable *storageTable, void* recordBitmap, int recordNumber)
{
	if (!recordBitmap)
		return StorageErrorRecordNotFound;

	StorageConnection *storageConnection = storageTable->storageConnection;
	Connection *connection = storageConnection->connection;
	Table *table = storageTable->share->table;

	try
		{
		Bitmap *bitmap = (Bitmap*) recordBitmap;
		
		for (;;)
			{
			recordNumber = bitmap->nextSet (recordNumber);
			
			if (recordNumber < 0)
				return StorageErrorRecordNotFound;
				
			Record *candidate = table->fetch (recordNumber);
			++recordNumber;
			
			if (candidate)
				{
				Record *record = candidate->fetchVersion (connection->getTransaction());
				
				if (record)
					{
					recordNumber = record->recordNumber;

					if (candidate != record)
						{
						record->addRef();
						candidate->release();
						}
						
					storageTable->setRecord(record);
					
					return recordNumber;
					}
					
				candidate->release();
				}
			}
		}
	catch (SQLException& exception)
		{
		storageConnection->setErrorText(&exception);
		return -2;
		}
}

int StorageDatabase::savepointSet(Connection* connection)
{
	Transaction *transaction = connection->getTransaction();
	
	return transaction->createSavepoint();
}

int StorageDatabase::savepointRollback(Connection* connection, int savePoint)
{
	Transaction *transaction = connection->getTransaction();
	transaction->rollbackSavepoint(savePoint);
	
	return 0;
}

int StorageDatabase::savepointRelease(Connection* connection, int savePoint)
{
	Transaction *transaction = connection->getTransaction();
	transaction->releaseSavepoint(savePoint);
	
	return 0;
}

int StorageDatabase::deleteTable(StorageConnection* storageConnection, const char* tableName, const char *schemaName)
{
	Connection *connection = storageConnection->connection;
	char buffer[512];
	snprintf(buffer, sizeof(buffer), "drop table %s.\"%s\"", schemaName, tableName);
	Statement *statement = connection->createStatement();
	
	try
		{
		statement->execute(buffer);
		}
	catch (SQLException& exception)
		{
		int errorCode = exception.getSqlcode();
		storageConnection->setErrorText(&exception);
		
		switch (errorCode)
			{
			case NO_SUCH_TABLE:
				return StorageErrorTableNotFound;
				
			case UNCOMMITTED_UPDATES:
				return StorageErrorUncommittedUpdates;
			}
			
		return 200 - errorCode;
		}

	// Drop sequence, if any.  If none, this will throw an exception.  Ignore it
	
	int res = 0;
	
	if (connection->findSequence(schemaName, tableName))
		try
			{
			snprintf(buffer, sizeof(buffer), "drop sequence %s.\"%s\"", schemaName, tableName);
			statement->execute(buffer);
			}
		catch (SQLException& exception)
			{
			storageConnection->setErrorText(&exception);
			res = 200 - exception.getSqlcode();
			}
	
	statement->release();
	
	return res;
}

int StorageDatabase::deleteRow(StorageConnection *storageConnection, Table* table, int recordNumber)
{
	Connection *connection = storageConnection->connection;

	try
		{
		Record *candidate = table->fetch(recordNumber);
		
		if (!candidate)
			return StorageErrorRecordNotFound;
		
		Record *record = candidate->fetchVersion(connection->getTransaction());

		if (record != candidate)
			{
			record->addRef();
			candidate->release();
			}
		
		table->deleteRecord(connection->getTransaction(), record);
		record->release();
		
		return 0;
		}
	catch (SQLException& exception)
		{
		int code;
		int sqlCode = exception.getSqlcode();

		switch (sqlCode)
			{
			case DEADLOCK:
				code = StorageErrorDeadlock;
				break;

			case UPDATE_CONFLICT:
				code = StorageErrorUpdateConflict;
				break;

			default:
				code = StorageErrorRecordNotFound;
			}

		storageConnection->setErrorText(&exception);

		return code;
		}
}

int StorageDatabase::updateRow(StorageConnection* storageConnection, Table* table, int recordNumber, Stream* stream)
{
	Connection *connection = storageConnection->connection;
	table->update (connection->getTransaction(), recordNumber, stream);
	
	return 0;
}


int StorageDatabase::createIndex(StorageConnection *storageConnection, Table* table, const char* indexName, const char* sql)
{
	Connection *connection = storageConnection->connection;
	Database *database = connection->database;
	Statement *statement = connection->createStatement();
	
	try
		{
		statement->execute(sql);
		//database->flush();
		}
	catch (SQLException& exception)
		{
		storageConnection->setErrorText(&exception);
		statement->release();
		return -2;
		}
	
	statement->release();
	
	return 0;
}

int StorageDatabase::renameTable(StorageConnection* storageConnection, Table* table, const char* tableName, const char *schemaName)
{
	Connection *connection = storageConnection->connection;

	try
		{
		Database *database = connection->database;
		Sequence *sequence = connection->findSequence(schemaName, table->name);
		int numberIndexes = 0;
		int firstIndex = 0;
		Index *index;

		for (index = table->indexes; index; index = index->next)
			{
			if (index->type == PrimaryKey)
				firstIndex = 1;

			++numberIndexes;
			}

		for (int n = firstIndex; n < numberIndexes; ++n)
			{
			char indexName[256];
			sprintf(indexName, "%s$%d", (const char*) table->name, n);
			Index *index = table->findIndex(indexName);
			
			if (index)
				{
				sprintf(indexName, "%s$%d", tableName, n);
				index->rename(indexName);
				}
			}
				
			
		table->rename(schemaName, tableName);
		
		if (sequence)
			sequence->rename(tableName);
		
		database->commitSystemTransaction();
		//database->flush();
		
		return 0;
		}
	catch (SQLException& exception)
		{
		storageConnection->setErrorText(&exception);
		return -2;
		}
}

Index* StorageDatabase::findIndex(Table* table, const char* indexName)
{
	return table->findIndex(indexName);	
}

Bitmap* StorageDatabase::indexScan(Index* index, StorageKey *lower, StorageKey *upper, bool partial, StorageConnection* storageConnection, Bitmap *bitmap)
{
	if (!index)
		return NULL;

	if (lower)
		lower->indexKey.index = index;
		
	if (upper)
		upper->indexKey.index = index;
		
	return index->scanIndex((lower) ? &lower->indexKey : NULL,
							(upper) ? &upper->indexKey : NULL, partial, 
							storageConnection->connection->getTransaction(), bitmap);
}

int StorageDatabase::makeKey(StorageIndexDesc* indexDesc, const UCHAR* key, int keyLength, StorageKey* storageKey)
{
	int segmentNumber = 0;
	Value vals [MAX_KEY_SEGMENTS];
	Value *values[MAX_KEY_SEGMENTS];
	Index *index = indexDesc->index;
	
	if (!index)
		return StorageErrorBadKey;
	
	try
		{
		for (const UCHAR *p = key, *end = key + keyLength; p < end && segmentNumber < indexDesc->numberSegments; ++segmentNumber)
			{
			StorageSegment *segment = indexDesc->segments + segmentNumber;
			int nullFlag = (segment->nullBit) ? *p++ : 0;
			values[segmentNumber] = vals + segmentNumber;
			int len = getSegmentValue(segment, p, values[segmentNumber], index->fields[segmentNumber]);
			
			if (nullFlag)
				values[segmentNumber]->setNull();

			p += len;
			}

		index->makeKey(segmentNumber, values, &storageKey->indexKey);
		storageKey->numberSegments = segmentNumber;
		return 0;
		}
	catch (SQLError&)
		{
		return StorageErrorBadKey;
		}
}


int StorageDatabase::isKeyNull(StorageIndexDesc* indexDesc, const UCHAR* key, int keyLength)
{
	int segmentNumber = 0;
	Index *index = indexDesc->index;
	
	for (const UCHAR *p = key, *end = key + keyLength; p < end && segmentNumber < indexDesc->numberSegments; ++segmentNumber)
		{
		StorageSegment *segment = indexDesc->segments + segmentNumber;
		int nullFlag = (segment->nullBit) ? *p++ : 0;
	
		if (!nullFlag)
			return false;
		
		switch (segment->type)
			{
			case HA_KEYTYPE_VARBINARY1:
			case HA_KEYTYPE_VARBINARY2:
			case HA_KEYTYPE_VARTEXT1:
			case HA_KEYTYPE_VARTEXT2:
				p += segment->length + 2;
				break;
			
			default:
				p += segment->length;
			}
		}
	
	return true;
}

int StorageDatabase::storeBlob(Connection* connection, Table* table, StorageBlob *blob)
{
	return table->storeBlob(connection->getTransaction(), blob->length, blob->data);
}

void StorageDatabase::getBlob(Table* table, int recordNumber, StorageBlob* blob)
{
	Stream stream;
	table->getBlob(recordNumber, &stream);
	blob->length = stream.totalLength;
	blob->data = new UCHAR[blob->length];
	stream.getSegment(0, blob->length, blob->data);
}

void StorageDatabase::remove(StorageTableShare* share)
{
	Sync sync(&syncObject, "StorageDatabase::remove");
	sync.lock(Exclusive);
	int slot = share->name.hash(shareHashSize);
	
	for (StorageTableShare **ptr = shares + slot; *ptr; ptr = &(*ptr)->collision)
		if (*ptr == share)
			{
			*ptr = share->collision;
			break;
			}
}

void StorageDatabase::add(StorageTableShare* share)
{
	Sync sync(&syncObject, "StorageDatabase::add");
	sync.lock(Exclusive);
	int slot = share->name.hash(shareHashSize);
	share->collision = shares[slot];
	shares[slot] = share;
}

Sequence* StorageDatabase::findSequence(const char* name, const char *schemaName)
{
	return masterConnection->findSequence(schemaName, name);
}

void StorageDatabase::addRef(void)
{
	++useCount;
}

void StorageDatabase::release(void)
{
	if (--useCount == 0)
		delete this;
}

void StorageDatabase::close(void)
{
	if (masterConnection)
		{
		Database *database = masterConnection->database;
		
		if (user)
			{
			user->release();
			user = NULL;
			}
			
		masterConnection->close();
		database->shutdown();
		delete database;
		database = NULL;
		//masterConnection->release();
		masterConnection = NULL;
		}
}

void StorageDatabase::dropDatabase(void)
{
	if (user)
		{
		user->release();
		user = NULL;
		}
			
	masterConnection->dropDatabase();
	masterConnection->release();
	masterConnection = NULL;
}

void StorageDatabase::freeBlob(StorageBlob *blob)
{
	delete [] blob->data;
	blob->data = NULL;
}

void StorageDatabase::validateCache(void)
{
	if (masterConnection)
		masterConnection->database->validateCache();
}

int StorageDatabase::getSegmentValue(StorageSegment* segment, const UCHAR* ptr, Value* value, Field *field)
{
	int length = segment->length;
	
	switch (segment->type)
		{
		case HA_KEYTYPE_LONG_INT:
			{
			int32 temp;
			memcpy(&temp, ptr, sizeof(temp));
			value->setValue(temp);
			}
			break;

		case HA_KEYTYPE_SHORT_INT:
			{
			short temp;
			memcpy(&temp, ptr, sizeof(temp));
			value->setValue(temp);
			}
			break;

		case HA_KEYTYPE_ULONGLONG:
		case HA_KEYTYPE_LONGLONG:
			{
			int64 temp;
			memcpy(&temp, ptr, sizeof(temp));
			value->setValue(temp);
			}
			break;
				
		case HA_KEYTYPE_FLOAT:
			{
			float temp;
			memcpy(&temp, ptr, sizeof(temp));
			value->setValue(temp);
			}
			break;
		
		case HA_KEYTYPE_DOUBLE:
			{
			double temp;
			memcpy(&temp, ptr, sizeof(temp));
			value->setValue(temp);
			}
			break;
		
		case HA_KEYTYPE_VARBINARY1:
		case HA_KEYTYPE_VARBINARY2:
		case HA_KEYTYPE_VARTEXT1:
		case HA_KEYTYPE_VARTEXT2:
			{
			unsigned short len;
			memcpy(&len, ptr, sizeof(len));
			value->setString(len, (const char*) ptr + 2, false);
			length += 2;
			}
			break;
		
		case HA_KEYTYPE_BINARY:
			if (field->isString())
				value->setString(length, (const char*) ptr, false);
			else if (segment->isUnsigned)
				{
				int64 number = 0;
				
				for (int n = 0; n < length; ++n)
					number = number << 8 | *ptr++;
					
				value->setValue(number);
				}
			else if (field->precision < 19 && field->scale == 0)
				{
				int64 number = (signed char) (*ptr++ ^ 0x80);
				
				for (int n = 1; n < length; ++n)
					number = (number << 8) | *ptr++;
				
				if (number < 0)
					++number;
					
				value->setValue(number);
				}
			else
				{
				BigInt bigInt;
				ScaledBinary::getBigIntFromBinaryDecimal((const char*) ptr, field->precision, field->scale, &bigInt);
				value->setValue(&bigInt);
				}

			break;
			
		case HA_KEYTYPE_TEXT:
			value->setString(length, (const char*) ptr, false);
			break;
		
		case HA_KEYTYPE_ULONG_INT:
			{
			uint32 temp;
			memcpy(&temp, ptr, sizeof(temp));
			
			if (field && field->type == Timestamp)
				value->setValue((int64) temp * 1000);
			else
				value->setValue((int64) temp);
			}
			break;
		
		case HA_KEYTYPE_INT8:
			value->setValue(*(signed char*) ptr);
			break;
		
		case HA_KEYTYPE_USHORT_INT:
			{
			unsigned short temp;
			memcpy(&temp, ptr, sizeof(temp));
			value->setValue(*ptr);
			}
			break;
			
		case HA_KEYTYPE_UINT24:
			{
			uint32 temp;
			memcpy(&temp, ptr, 3);
			value->setValue((int) (temp & 0xffffff));
			}
			break;
			
		case HA_KEYTYPE_INT24:
			{
			int32 temp;
			memcpy(&temp, ptr, 3);
			value->setValue((temp << 8) >> 8);
			}
			break;
		
		default:
			NOT_YET_IMPLEMENTED;
		}
	
	return length;
}
