/* Module support.
 *
 * IRC Services is copyright (c) 1996-2009 Andrew Church.
 *     E-mail: <achurch@achurch.org>
 * Parts written by Andrew Kempe and others.
 * This program is free but copyrighted software; see the file COPYING for
 * details.
 */

#include "services.h"
#include "modules.h"
#include "conffile.h"
#undef use_module
#undef unuse_module

#ifndef STATIC_MODULES
# include <dlfcn.h>
#endif

/*************************************************************************/

/* Set this to nonzero to allow use_module() and unuse_module() to be used
 * on a module by itself. */
int modules_allow_use_self = 0;

/*************************************************************************/

/* Internal structure for callback lists. */
typedef struct callbacklist_ CallbackList;
struct callbacklist_ {
    char *name;
    int calling;  /* used by {call,remove}_callback() for safe callback
		   * removal from inside the callback */
    struct {
	callback_t func;
	int pri;
    } *funcs;
    int funcs_count;
};

/* Structure for module data. */
struct Module_ {
    struct Module_ *next, *prev;
    const char *path;		/* Path passed to load_module() */
    const char *name;		/* Name of the module from `module_name' */
    ConfigDirective *modconfig;	/* `module_config' in this module */
    char *module_name_ptr;	/* `module_name' in this module */
    const int32 *module_version_ptr;  /* `module_version' in this module */
    void *dllhandle;		/* Handle used by dynamic linker */
    CallbackList *callbacks;
    int callbacks_count;
    const Module **user;	/* Array of module's users (use_module()) */
    int user_count;
};


/* Module data for Services core. */
static Module coremodule = { NULL, NULL, "", "core", NULL, NULL, NULL, 0, 0 };


/* Global list of modules. */
static Module *modulelist = &coremodule;


/* Callbacks for loading, unloading, and reconfiguring modules. */
static int cb_load_module = -1;
static int cb_unload_module = -1;
static int cb_reconfigure = -1;


/*************************************************************************/

#ifdef STATIC_MODULES

/* Structure for a module symbol. */
struct modsym {
    const char *symname;
    void *value;
};

/* Static module information (from modules/modules.a): */
struct modinfo {
    const char *modname;
    struct modsym *modsyms;
};
extern struct modinfo modlist[];

#else  /* !STATIC_MODULES */

static void *program_handle;	/* Handle for the main program */

#endif  /* STATIC_MODULES */

/*************************************************************************/
/********************* Initialization and cleanup ************************/
/*************************************************************************/

int modules_init(int ac, char **av)
{
#ifndef STATIC_MODULES
    program_handle = dlopen(NULL, 0);
#endif
    cb_load_module   = register_callback(NULL, "load module");
    cb_unload_module = register_callback(NULL, "unload module");
    cb_reconfigure   = register_callback(NULL, "reconfigure");
    if (cb_load_module < 0 || cb_unload_module < 0 || cb_reconfigure < 0) {
	log("modules_init: register_callback() failed\n");
	return 0;
    }
    return 1;
}

/*************************************************************************/

void modules_cleanup(void)
{
    int i;

    unload_all_modules();
    unregister_callback(NULL, cb_reconfigure);
    unregister_callback(NULL, cb_unload_module);
    unregister_callback(NULL, cb_load_module);
    ARRAY_FOREACH(i, coremodule.callbacks) {
	if (coremodule.callbacks[i].name) {
	    log("modules: Core forgot to unregister callback `%s'",
		coremodule.callbacks[i].name);
	    free(coremodule.callbacks[i].name);
	    free(coremodule.callbacks[i].funcs);
	}
    }
    free(coremodule.callbacks);
}

/*************************************************************************/

void unload_all_modules(void)
{
    Module *mod, *mod2;

    LIST_FOREACH_SAFE (mod, modulelist, mod2) {
	if (*mod->path && !unload_module(mod))  /* don't try to unload core */
	    log("modules: Failed to unload `%s' on exit", mod->name);
    }
}

/*************************************************************************/
/*********************** Low-level module routines ***********************/
/*************************************************************************/

/* These low-level routines take care of all changes in processing with
 * regard to dynamic vs. static modules and different platforms. */

/* Common variables: */

#ifdef STATIC_MODULES
static const char *dl_last_error;
#endif

/*************************************************************************/

/* Low-level routine to open a module and return a handle. */

static void *my_dlopen(const char *name)
{
#if !defined(STATIC_MODULES)

    char pathname[PATH_MAX];
    snprintf(pathname, sizeof(pathname), "%s/modules/%s.so",
	     services_dir, name);
    return dlopen(pathname, RTLD_NOW | RTLD_GLOBAL);

#else  /* STATIC_MODULES */

    int i;

    for (i = 0; modlist[i].modname; i++) {
	if (strcmp(modlist[i].modname, name) == 0)
	    break;
    }
    if (!modlist[i].modname) {
	dl_last_error = "Module not found";
	return NULL;
    }
    return &modlist[i];

#endif /* STATIC_MODULES */
} /* my_dlopen() */

/*************************************************************************/

/* Low-level routine to close a module. */

static void my_dlclose(void *handle)
{
#if !defined(STATIC_MODULES)

    dlclose(handle);

#else  /* STATIC_MODULES */

    /* nothing */

#endif /* STATIC_MODULES */
} /* my_dlclose() */

/*************************************************************************/

/* Low-level routine to retrieve a symbol from a module given its handle. */

static void *my_dlsym(void *handle, const char *symname)
{
#if !defined(STATIC_MODULES)

# ifdef SYMS_NEED_UNDERSCORES
    char buf[256];
    if (strlen(symname) > sizeof(buf)-2) {  /* too long for buffer */
	log("modules: symbol name too long in my_dlsym(): %s", symname);
	return NULL;
    }
    snprintf(buf, sizeof(buf), "_%s", symname);
    symname = buf;
# endif
    if (handle) {
	return dlsym(handle, symname);
    } else {
	Module *mod;
	void *ptr;
	LIST_FOREACH (mod, modulelist) {
	    ptr = dlsym(mod->dllhandle?mod->dllhandle:program_handle, symname);
	    if (ptr)
		return ptr;
	}
	return NULL;
    }

#else  /* STATIC_MODULES */

    int i;

    if (handle) {
	struct modsym *syms = ((struct modinfo *)handle)->modsyms;
	for (i = 0; syms[i].symname; i++) {
	    if (strcmp(syms[i].symname, symname) == 0)
		break;
	}
	if (!syms[i].symname)
	    return NULL;
	return syms[i].value;
    } else {
	for (i = 0; modlist[i].modname; i++) {
	    void *value = my_dlsym(&modlist[i], symname);
	    if (value)
		return value;
	}
	return NULL;
    }

#endif /* STATIC_MODULES */
} /* my_dlsym() */

/*************************************************************************/

/* Low-level routine to return the error message (if any) from the previous
 * call. */

static const char *my_dlerror(void)
{
#if !defined(STATIC_MODULES)

    return dlerror();

#else  /* STATIC_MODULES */

    const char *str = dl_last_error;
    dl_last_error = NULL;
    return str;

#endif /* STATIC_MODULES */
} /* my_dlerror() */

/*************************************************************************/
/************************ Module-level functions *************************/
/*************************************************************************/

/* Internal routine to load a module.  Returns the module pointer or NULL
 * on error.
 */

static Module *internal_load_module(const char *modulename)
{
    void *handle;
    Module *module, *mptr;
    const char *name;
    const int32 *verptr;
    ConfigDirective *confptr;

    if (strstr(modulename, "../")) {
	log("modules: Attempt to load bad module name: %s", modulename);
	goto err_return;
    }
    LIST_SEARCH(modulelist, path, modulename, strcmp, mptr);
    if (mptr) {
	log("modules: Attempt to load module `%s' twice", modulename);
	goto err_return;
    }

    handle = my_dlopen(modulename);
    if (!handle) {
	const char *error = my_dlerror();
	if (!error)
	    error = "Unknown error";
	log("modules: Unable to load module `%s': %s", modulename, error);
	goto err_return;
    }

    module = scalloc(sizeof(*module), 1);
    module->dllhandle = handle;
    module->path = sstrdup(modulename);

    name = get_module_symbol(module, "module_name");
    /* The above will return the first instance of `module_name' found in
     * _any_ module (though giving priority to the given module), so we
     * need to check if we've seen this address before. */
    LIST_SEARCH_SCALAR(modulelist, module_name_ptr, name, mptr);
    if (mptr)
	name = NULL;
    module->module_name_ptr = (char *)name;
    if (!name)
	name = modulename;
    LIST_SEARCH(modulelist, name, name, strcmp, mptr);
    if (mptr) {
	log("modules: Unable to load module `%s': Name `%s' already in use"
	    " (by module `%s')", modulename, name, mptr->path);
	goto err_freemod;
    }
    module->name = sstrdup(name);

    verptr = get_module_symbol(module, "module_version");
    /* With static modules, module_version is known to be defined at
     * compile time; furthermore, some linkers (hey there, MacOS) merge
     * common constants into a single value, so all these would end up
     * pointing to the same value. */
#if !defined(STATIC_MODULES)
    if (verptr) {
	LIST_SEARCH_SCALAR(modulelist, module_version_ptr, verptr, mptr);
	if (mptr)
	    verptr = NULL;
    }
#endif
    if (!verptr) {
	log("modules: Unable to load module `%s': No `module_version'"
	    " symbol found", modulename);
	goto err_freename;
    } else if (*verptr != MODULE_VERSION_CODE) {
	log("modules: Unable to load module `%s': Version mismatch"
	    " (module version = %06X, core version = %06X)",
	    modulename, *verptr, MODULE_VERSION_CODE);
	goto err_freename;
    }
    module->module_version_ptr = verptr;

    confptr = get_module_symbol(module, "module_config");  /* as above */
    if (confptr) {
	LIST_SEARCH_SCALAR(modulelist, modconfig, confptr, mptr);
	if (mptr)
	    confptr = NULL;
    }
    module->modconfig = confptr;

    return module;

  err_freename:
    free((char *)module->name);
  err_freemod:
    free((char *)module->path);
    free(module);
    my_dlclose(handle);
  err_return:
    return NULL;
}

/*************************************************************************/

/* Initialize a module.  Return the module's init_module() return value, or
 * 1 if the module does not have an init_module() function.
 */

static int internal_init_module(Module *module)
{
    int (*initfunc)(Module *module);

    initfunc = get_module_symbol(module, "init_module");
    if (initfunc)
	return initfunc(module);
    else
	return 1;
}

/*************************************************************************/

/* Load a new module and return the Module pointer, or NULL on error.
 * (External interface to the above functions.)
 */

Module *load_module(const char *modulename)
{
    Module *module;

    if (debug)
	log("debug: Loading module `%s'", modulename);

    module = internal_load_module(modulename);
    if (!module)
	return NULL;
    LIST_INSERT(module, modulelist);

    if (!configure(module->name, module->modconfig,
		   CONFIGURE_READ | CONFIGURE_SET)) {
	log("modules: configure() failed for %s", modulename);
	goto fail;
    }

    if (!internal_init_module(module)) {
	log("modules: init_module() failed for %s", modulename);
	deconfigure(module->modconfig);
	goto fail;
    }

    if (debug)
	log("debug: Successfully loaded module `%s'", modulename);
    call_callback_2(NULL, cb_load_module, module, module->name);
    return module;

  fail:
    free((char *)module->name);
    my_dlclose(module->dllhandle);
    LIST_REMOVE(module, modulelist);
    free(module);
    return NULL;
}

/*************************************************************************/

/* Remove a module from memory.  Return nonzero on success, zero on
 * failure.
 */

int unload_module(Module *module)
{
    int (*exit_module)(Module *module);
    Module *tmp;
    int i;

    if (module->user_count > 0) {
	log("modules: Attempt to unload in-use module `%s' (in use by %s%s)",
	    module->name, module->user[0]->name,
	    module->user_count>1 ? " and others" : "");
	return 0;
    }

    if (debug)
	log("debug: Unloading module `%s'", module->name);

    exit_module = get_module_symbol(module, "exit_module");
    if (exit_module && !exit_module(module))
	return 0;
    LIST_REMOVE(module, modulelist);
    call_callback_1(NULL, cb_unload_module, module);
    deconfigure(module->modconfig);
    LIST_FOREACH (tmp, modulelist) {
	ARRAY_FOREACH (i, tmp->user) {
	    if (tmp->user[i] == module) {
		log("modules: Module `%s' forgot to unuse_module() for"
		    " module `%s'", module->name, tmp->name);
		ARRAY_REMOVE(tmp->user, i);
		i--;
	    }
	}
    }
    ARRAY_FOREACH(i, module->callbacks) {
	if (module->callbacks[i].name) {
	    log("modules: Module `%s' forgot to unregister callback `%s'",
		module->name, module->callbacks[i].name);
	    free(module->callbacks[i].name);
	    free(module->callbacks[i].funcs);
	}
    }
    free(module->callbacks);
    free((char *)module->name);
    free((char *)module->path);
    my_dlclose(module->dllhandle);
    free(module);
    return 1;
}

/*************************************************************************/

/* Return the Module pointer for the named module, or NULL if no such
 * module exists.
 */

Module *find_module(const char *modulename)
{
    Module *result;
    LIST_SEARCH(modulelist, name, modulename, strcmp, result);
    return result;
}

/*************************************************************************/

/* Increment the use count for the given module.  A module cannot be
 * unloaded while its use count is nonzero.
 */

static int use_module_loopcheck(const Module *module, const Module *check)
{
    /* Return whether `module' is used by `check' (self-references are
     * ignored). */

    int i;

    ARRAY_FOREACH (i, module->user) {
	if (module->user[i] != module) {
	    if (module->user[i] == check
	     || use_module_loopcheck(module->user[i], check))
		return 1;
	}
    }
    return 0;
}

void use_module(Module *module, const Module *user)
{
    if (module == user && !modules_allow_use_self) {
	log("modules: BUG: Module `%s' called use_module() for itself!",
	    module->name);
	return;
    }
    if (use_module_loopcheck(user, module)) {
	log("modules: BUG: use_module loop detected (called by `%s' for `%s')",
	    user->name, module->name);
	return;
    }
    ARRAY_EXTEND(module->user);
    module->user[module->user_count-1] = user;
}

/*************************************************************************/

/* Decrement the use count for the given module.  `module' may be NULL, in
 * which case this routine does nothing.
 */

void unuse_module(Module *module, const Module *user)
{
    int i;

    if (!module)
	return;
    if (module == user && !modules_allow_use_self) {
	log("modules: BUG: Module `%s' called unuse_module() for itself!",
	    module->name);
	return;
    }
    if (module->user_count == 0) {
	log("modules: BUG: trying to unuse module `%s' with use count 0"
	    " from module `%s'", module->name, user->name);
	return;
    }
    ARRAY_SEARCH_PLAIN_SCALAR(module->user, user, i);
    if (i >= module->user_count) {
	log("modules: BUG: trying to unuse module `%s' from module `%s' but"
	    " caller not found in user list!", module->name, user->name);
	return;
    }
    ARRAY_REMOVE(module->user, i);
}

/*************************************************************************/

/* Reconfigure all modules.  The "reconfigure" callback is called with an
 * `int' parameter of 0 before reconfiguration and 1 after.  Returns 1 on
 * success, 0 on failure (on failure, all modules' configuration data will
 * be left alone).
 */

int reconfigure_modules(void)
{
    Module *mod;

    call_callback_1(NULL, cb_reconfigure, 0);
    LIST_FOREACH (mod, modulelist) {
	if (!configure(mod->name, mod->modconfig, CONFIGURE_READ))
	    return 0;
    }
    LIST_FOREACH (mod, modulelist)
	configure(mod->name, mod->modconfig, CONFIGURE_SET);
    call_callback_1(NULL, cb_reconfigure, 1);
    return 1;
}

/*************************************************************************/
/****************** Module symbol/information retrieval ******************/
/*************************************************************************/

/* Retrieve the value of the named symbol in the given module.  Return NULL
 * if no such symbol exists.  Note that this function should not be used
 * for symbols whose value might be NULL, because there is no way to
 * distinguish a symbol value of NULL from an error return.
 */

void *get_module_symbol(Module *module, const char *symname)
{
    return my_dlsym(module ? module->dllhandle : NULL, symname);
}

/*************************************************************************/

/* Retrieve the name of the given module.  If NULL is given, returns the
 * string "core".
 */

const char *get_module_name(Module *module)
{
    return module ? module->name : "core";
}

/*************************************************************************/
/********************** Callback-related functions ***********************/
/*************************************************************************/

/* Translate the NULL module to &coremodule. */

#define NO_NULL_MOD(m)	(m = (m ? m : &coremodule))

/*************************************************************************/

/* Local function to look up a callback for a module.  Returns NULL if not
 * found.
 */

static CallbackList *find_callback(Module *module, const char *name)
{
    int i;

    ARRAY_FOREACH (i, module->callbacks) {
	if (module->callbacks[i].name
	 && strcmp(module->callbacks[i].name,name) == 0)
	    break;
    }
    if (i == module->callbacks_count)
	return NULL;
    return &module->callbacks[i];
}

/*************************************************************************/

/* Register a new callback list.  "module" is the calling module's own
 * Module pointer, or NULL for core Services callbacks.  Return the
 * callback identifier (a nonnegative integer) or -1 on error.
 */

int register_callback(Module *module, const char *name)
{
    int i;

    if (debug >= 2)
	log("debug: register_callback(%s, \"%s\")",
	    module ? module->name : "core", name);
    NO_NULL_MOD(module);
    if (find_callback(module, name)) {
	log("BUG: register_callback(%s,\"%s\"): callback already registered",
	    module ? module->name : "core", name);
	return -1;
    }
    i = module->callbacks_count;
    ARRAY_EXTEND(module->callbacks);
    module->callbacks[i].name = sstrdup(name);
    module->callbacks[i].calling = 0;
    module->callbacks[i].funcs_count = 0;
    module->callbacks[i].funcs = NULL;
    return i;
}

/*************************************************************************/

/* Call all functions in a callback list.  Return 1 if a callback returned
 * nonzero, 0 if all callbacks returned zero, or -1 on error.
 */

#undef call_callback_5
int call_callback_5(Module *module, int id, void *arg1, void *arg2,
		    void *arg3, void *arg4, void *arg5)
{
    CallbackList *cl;
    int res = 0;
    int i;

    NO_NULL_MOD(module);
    if (id < 0 || id >= module->callbacks_count)
	return -1;
    cl = &module->callbacks[id];
    cl->calling = 1;
    ARRAY_FOREACH (i, cl->funcs) {
	res = cl->funcs[i].func(arg1, arg2, arg3, arg4, arg5);
	if (res != 0)
	    break;
    }
    if (cl->calling == 2) {  /* flag indicating some callbacks were removed */
	ARRAY_FOREACH (i, cl->funcs) {
	    if (!cl->funcs[i].func) {
		ARRAY_REMOVE(cl->funcs, i);
		i--;
	    }
	}
    }
    cl->calling = 0;
    return res;
}

/*************************************************************************/

/* Delete a callback list. */

int unregister_callback(Module *module, int id)
{
    CallbackList *cl;

    NO_NULL_MOD(module);
    if (debug >= 2)
	log("debug: unregister_callback(%s, %d)", module->name, id);
    if (id < 0 || id >= module->callbacks_count) {
	log("unregister_callback(): BUG: invalid callback ID %d for module"
	    " `%s'", id, module->name);
	return 0;
    }
    cl = &module->callbacks[id];
    if (!cl->name) {
	log("unregister_callback(): BUG: callback ID %d for module `%s'"
	    " is unused (double unregister?)", id, module->name);
	return 0;
    }
    free(cl->funcs);
    free(cl->name);
    cl->funcs = NULL;
    cl->name = NULL;
    return 1;
}

/*************************************************************************/

/* Add a function to a callback list with the given priority (higher
 * priority value = called sooner).  Callbacks with the same priority are
 * called in the order they were added.
 */

int add_callback_pri(Module *module, const char *name, callback_t callback,
		     int priority)
{
    CallbackList *cl;
    int n;

    if (debug >= 2) {
	log("debug: add_callback_pri(%s, \"%s\", %p, %d)",
	    module ? module->name : "core", name ? name : "(null)",
	    callback, priority);
    }
    NO_NULL_MOD(module);
    cl = find_callback(module, name);
    if (!cl) {
	if (debug >= 2)
	    log("debug: -- callback not found");
	return 0;
    }
    if (priority < CBPRI_MIN || priority > CBPRI_MAX) {
	log("add_callback_pri(): priority (%d) out of range for callback"
	    " `%s' in module `%s'",
	    priority, name, module ? module->name : "core");
	return 0;
    }
    ARRAY_FOREACH (n, cl->funcs) {
	if (cl->funcs[n].pri < priority)
	    break;
    }
    ARRAY_INSERT(cl->funcs, n);
    cl->funcs[n].func = callback;
    cl->funcs[n].pri = priority;
    return 1;
}

/*************************************************************************/

/* Remove a function from a callback list. */
int remove_callback(Module *module, const char *name, callback_t callback)

{
    CallbackList *cl;
    int index;

    if (debug >= 2)
	log("debug: remove_callback(%s, \"%s\", %p)",
	    module ? module->name : "core", name, callback);
    NO_NULL_MOD(module);
    cl = find_callback(module, name);
    if (!cl)
	return 0;
    ARRAY_SEARCH_SCALAR(cl->funcs, func, callback, index);
    if (index == cl->funcs_count)
	return 0;
    if (cl->calling) {
	cl->funcs[index].func = NULL;
	cl->calling = 2;  /* flag to call_callback() indicating CB removed */
    } else {
	ARRAY_REMOVE(cl->funcs, index);
    }
    return 1;
}

/*************************************************************************/
