/* This file is part of cqual.
   Copyright (C) 2000-2002 The Regents of the University of California.

cqual is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2, or (at your option)
any later version.

cqual is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with cqual; see the file COPYING.  If not, write to
the Free Software Foundation, 59 Temple Place - Suite 330,
Boston, MA 02111-1307, USA. */

#include "parser.h"
#include "AST_utils.h"
#include "expr.h"
#include "builtins.h"
#include "c-parse.h"
#include "constants.h"
#include "analyze.h"
#include "cqual.h"
#include "hash.h"
#include "qerror.h"
#include "qtype.h"
#include "utils.h"
#include "pam.h"
#include "effect.h"

typedef enum { lpos, rpos, apos } context;
/* What kind of position we're in. apos = & context */

static void confine_inf_declaration(declaration d, compound_stmt cs);
static void confine_inf_init(expression rhs);
static void confine_inf_statement(statement s);
static void confine_inf_expression(expression e, context context);
static void confine_inf_unary_expression(unary e, context context);
static void confine_inf_binary_expression(binary e, context context);
static void confine_this_expression(expression e);

/* Some handy functions defined elsewhere */
void unparse_start(FILE *);                  /* unparse.c */
void prt_variable_decl(variable_decl);
void prt_expression(expression, int);
const char *binary_op_name(ast_kind);
bool is_void_parms(declaration);             /* semantics.c */
bool equal_expressions(expression e1, expression e2); /* eq_expressions.c */
bool isassignment(binary e); /* analyze.c */

/**************************************************************************
 *                                                                        *
 * Globals                                                                *
 *                                                                        *
 **************************************************************************/

static compound_stmt cur_cs = NULL; /* current enclosing compound stmt */

/**************************************************************************
 *                                                                        *
 * Utilities                                                              *
 *                                                                        *
 **************************************************************************/

/**************************************************************************
 *                                                                        *
 * Tree Traversal                                                         *
 *                                                                        *
 * The code that walks over the AST and generates the constraint          *
 * graph, plus whatever info we'll need to do linking.  Think of this     *
 * as a big pattern-match.                                                *
 *                                                                        *
 **************************************************************************/

void confine_inf(declaration program)
{
  declaration d;

  unparse_start(stdout);
  AST_set_parents(CAST(node, program));
  cur_cs = NULL;
  
  scan_declaration(d, program)
    {
      /* Assume left-to-right order of evaluation */
      confine_inf_declaration(d, NULL);
    }
}

static void confine_inf_declaration(declaration d, compound_stmt cs)
{
  switch (d->kind)
    {
    case kind_asm_decl:
      {
	return;
      };
    case kind_data_decl:
      {
	data_decl dd = CAST(data_decl, d);
	declaration decl;

	scan_declaration(decl, dd->decls) {
	  confine_inf_declaration(decl, cs);
	}
	return;
      };
    case kind_variable_decl:
      {
	variable_decl vd = CAST(variable_decl, d);
	data_declaration root;

	root = root_ddecl(vd->ddecl);

	root->cs = cs;

	if (vd->arg1)
	  {
	    confine_inf_init(vd->arg1);
	  }
	return;
      }
      break;
    case kind_function_decl:
      {
	function_decl fd = CAST(function_decl, d);
	data_declaration root;

	root = root_ddecl(fd->ddecl);
	root->cs = cs;


	/*** Scan argument list ***/
	{
	  declaration arg;

	  if (!is_void_parms(fd->fdeclarator->parms))
	    scan_declaration(arg, fd->fdeclarator->parms)
	      /* Ignore ...'s */
	      if (!is_ellipsis_decl(arg))
		{
		  data_declaration root_vd;
		  variable_decl argvd;
		  oldidentifier_decl oid;
		  
		  /* Construct type of parameter */
		  if (arg->kind == kind_data_decl)
		    {
		      data_decl argd = CAST(data_decl, arg);

		      argvd = CAST(variable_decl, argd->decls);
		      assert(!argvd->next);    /* Only var_decl per data_decl.
						   multi var_decls are for
						   things like int a, b. */
		      oid = NULL;
		      root_vd = root_ddecl(argvd->ddecl);
		    }
		  else
		    {
		      oid = CAST(oldidentifier_decl, arg);
		      argvd = NULL;
		      root_vd = root_ddecl(oid->ddecl);
		    }

		  assert(is_compound_stmt(fd->stmt));
		  root_vd->cs = CAST(compound_stmt, fd->stmt);
	
		}
	}

	/*** Evaluate body ***/

	{
	  assert(is_compound_stmt(fd->stmt));
	  confine_inf_statement(fd->stmt);
	}

	return;
      };
    case kind_extension_decl:
      {
	extension_decl ed = CAST(extension_decl, d);
	return confine_inf_declaration(ed->decl, cs); 
	/* Ignore __extension__ */
      };
    default:
      fail_loc(d->loc,
	       "Unexpected decl kind 0x%x\n", d->kind);
    }
}

static void confine_inf_init(expression rhs)
{
  if (rhs->kind == kind_init_list)
    {
      init_list il = CAST(init_list, rhs);

      scan_expression(rhs, il->args)
	if (rhs->kind == kind_init_index)
	  {
	    init_index ii = CAST(init_index, rhs);

	    confine_inf_expression(ii->arg1, rpos);
	    if (ii->arg2)
	      {
		confine_inf_expression(ii->arg2, rpos);
	      }
	    confine_inf_init(ii->init_expr);
	  }
	else if (rhs->kind == kind_init_field)
	  {
	    init_field ifd = CAST(init_field, rhs);
	    
	    confine_inf_init(ifd->init_expr);
	  }
	else
	  {
	    confine_inf_init(rhs);
	  }
      return;
     }
  else
    {
      confine_inf_expression(rhs, rpos);    
      return;
    }
}

static void confine_inf_statement(statement s)
{
  switch (s->kind)
    {
    case kind_asm_stmt:
      return;
      break;
    case kind_compound_stmt:
      {
	compound_stmt cs = CAST(compound_stmt, s);
	declaration decl;
	statement stmt;

	cs->enclosing_cs = cur_cs;
	cur_cs = cs;

	/* Analyze the declarations. */
	scan_declaration(decl, cs->decls)
	  {
	    assert(decl->kind != kind_asm_decl); /*asm_decl only at toplevel */
	    confine_inf_declaration(decl, cs);
	  }

	scan_statement(stmt, cs->stmts)
	  {
	    confine_inf_statement(stmt);
	  }

	cur_cs = cs->enclosing_cs;

	return;
      };
      break;
    case kind_if_stmt:
      {
	if_stmt is = CAST(if_stmt, s);

	confine_inf_expression(is->condition, rpos);
	confine_inf_statement(is->stmt1);

	if (is->stmt2)
	  {
	    confine_inf_statement(is->stmt2);
	  }

	return;
      };
      break;
    case kind_labeled_stmt:
      {
	labeled_stmt ls = CAST(labeled_stmt, s);

	ls->label->enclosing_cs = cur_cs;

	confine_inf_statement(ls->stmt);
	return;
      };
      break;
    case kind_expression_stmt:
      {
	expression_stmt es = CAST(expression_stmt, s);

	confine_inf_expression(es->arg1, rpos);
	return;
      };
      break;
    case kind_while_stmt:
      {
	while_stmt ws = CAST(while_stmt, s);

	ws->enclosing_cs = cur_cs;

	confine_inf_expression(ws->condition, rpos);
	confine_inf_statement(ws->stmt);
	return;
      };
      break;
    case kind_dowhile_stmt:
      {
	dowhile_stmt dws = CAST(dowhile_stmt, s);

	dws->enclosing_cs = cur_cs;

	confine_inf_statement(dws->stmt);

	if (!definite_zero(dws->condition))
	  {
	    /* Catch do { x } while(0); case -- used in macro expansions */
	    confine_inf_expression(dws->condition, rpos);
	  }
	return;
      };
      break;
    case kind_switch_stmt:
      {
	switch_stmt ss = CAST(switch_stmt, s);

	ss->enclosing_cs = cur_cs;

	confine_inf_expression(ss->condition, rpos);
	confine_inf_statement(ss->stmt);

	return;
      };
      break;
    case kind_for_stmt:
      {
	for_stmt fs = CAST(for_stmt, s);

	fs->enclosing_cs = cur_cs;

	if (fs->arg1)
	  {
	    confine_inf_expression(fs->arg1, rpos);
	  }

	if (fs->arg2)
	  {
	    confine_inf_expression(fs->arg2, rpos);
	  }

	confine_inf_statement(fs->stmt);

	if (fs->arg3)
	  {
	    confine_inf_expression(fs->arg3, rpos);
	  }

	return;
      };
      break;
    case kind_return_stmt:
      {
	return_stmt rs = CAST(return_stmt, s);

	if (rs->arg1)
	  {
	    confine_inf_expression(rs->arg1, rpos);
	  }
	return;
      };
      break;
    case kind_computed_goto_stmt:
      {
	computed_goto_stmt cgs = CAST(computed_goto_stmt, s);

	confine_inf_expression(cgs->arg1, rpos);
	return;
      };
      break;
    case kind_break_stmt:
      {
	break_stmt bs = CAST(break_stmt, s);
	bs->enclosing_cs = cur_cs;
	
	return;
      }
    case kind_continue_stmt:
      {
	continue_stmt cs = CAST(continue_stmt, s);
	cs->enclosing_cs = cur_cs;
	
	return;
      }
    case kind_goto_stmt:
      {
	goto_stmt gs = CAST(goto_stmt, s);
	gs->enclosing_cs = cur_cs;

	return;
      }
    case kind_empty_stmt:
      return;
    case kind_change_type_stmt:
      {
	change_type_stmt ct = CAST(change_type_stmt, s);
	
	confine_inf_expression(ct->arg1, lpos);
	confine_this_expression(ct->arg1);

	return;
      }
      break;
    case kind_assert_type_stmt:
      {
	assert_type_stmt at = CAST(assert_type_stmt, s);

	confine_inf_expression(at->arg1, rpos);

	return;
      }
      break;
    case kind_deep_restrict_stmt:
      {
	deep_restrict_stmt dr = CAST(deep_restrict_stmt, s);
	
	confine_inf_expression(dr->arg1, lpos);
	confine_inf_statement(dr->stmt);
	return;
      }
    default:
      fail_loc(s->loc, "Unexpected statement kind 0x%x\n", s->kind);
      break;
    }
}

static void confine_inf_expression(expression e, context context)
{
  switch(e->kind) {
  case kind_comma:
    {
      comma c = CAST(comma, e);
      expression e2;

      scan_expression (e2, c->arg1)
	{
	  confine_inf_expression(e2, e2->next ? rpos : context);
	}
    };
    break;
  case kind_sizeof_type:
    {
      assert(context == rpos);
    }
    break;
  case kind_alignof_type:
    {
    }
    break;
  case kind_label_address:
    {
    }
    break;
  case kind_cast:
    {
      cast c = CAST(cast, e);
      
      confine_inf_expression(c->arg1, context);
    };
    break;
  case kind_cast_list:
    {
      /* XXX Fix! */
      cast_list cl = CAST(cast_list, e);

      confine_inf_init(cl->init_expr);
    };
    break;
  case kind_conditional:
    {
      conditional c = CAST(conditional, e);

      confine_inf_expression(c->condition, rpos);

      if (c->arg1)
	{
	  confine_inf_expression(c->arg1, context);
	}
      confine_inf_expression(c->arg2, context);
    };
    break;
  case kind_identifier:
    {
    };
    break;
  case kind_compound_expr:
    {
      compound_expr ce = CAST(compound_expr, e);
      compound_stmt cs = CAST(compound_stmt, ce->stmt);
      statement cur_stmt;
      declaration d;

      cs->enclosing_cs = cur_cs;
      cur_cs = cs;

      if (cs->id_labels)
	fail_loc(cs->loc, "Unimplemented: id_labels\n", 0);

      /* Analyze the declarations in the block */
      scan_declaration(d, cs->decls)
	{
	  assert(d->kind != kind_asm_decl); /*asm_decl only at toplevel */
	  confine_inf_declaration(d, cs);
	}

      /* Analyze the statements in the block.  Analyze all but the
         last one. */
      cur_stmt = cs->stmts;
      while (cur_stmt && cur_stmt->next)
	{
	  confine_inf_statement(cur_stmt);
	  cur_stmt = CAST(statement, cur_stmt->next);
	}

      /* Now analyze the last statement (if there is one), and
         compute the type of the expression. */
      if (cur_stmt && is_expression_stmt(cur_stmt))
	{
	  confine_inf_expression(CAST(expression_stmt, cur_stmt)->arg1,
				 context);
	}
      else
	{
	  /* Type is void */
	  if (cur_stmt)
	    {
	      confine_inf_statement(cur_stmt);
	    }
	}

      cur_cs = cs->enclosing_cs;
    };
    break;
  case kind_function_call:
    {
      function_call fc = CAST(function_call, e);
      expression arg;

      assert(context == rpos);
      if (fc->va_arg_call)
	{
	  break;
	}

      confine_inf_expression(fc->arg1, rpos);

      scan_expression(arg, fc->args)
	{
	  confine_inf_expression(arg, rpos);
	}
    };
    break;
  case kind_array_ref:
    {
      array_ref ar = CAST(array_ref, e);
      expression array, plus, star_plus;

      if (type_array(ar->arg1->type))
	array = ar->arg1;
      else
	array = ar->arg2;

      array->lvalue = TRUE; /* XXX: Hack to fix problem
				  w/default_conversion */
      array->cst = NULL; /* XXX: Hack to fix problem w/default_conversion */

      plus = make_binary(ar->loc, kind_plus, ar->arg1, ar->arg2);
      star_plus = make_dereference(ar->loc, plus);
      assert(!ar->alt);
      ar->alt = star_plus;

      confine_inf_expression(star_plus, context);
    };
    break;
  case kind_field_ref:
    {
      field_ref fr = CAST(field_ref, e);
      confine_inf_expression(fr->arg1, context);
    };
    break;
  case kind_init_list:
    {
      /*    init_list il = CAST(init_list, e);*/
      fail_loc(e->loc, "Unexpected init list\n", 0);
    };
    break;
  case kind_init_index:
    {
      /*    init_index ii = CAST(init_index, e);*/
      fail_loc(e->loc, "Unexpected init index\n", 0);
    };
    break;
  case kind_init_field:
    {
      fail_loc(e->loc, "Unexpected init field\n", 0);
    };
  case kind_lexical_cst:
    {
      assert(context == rpos);
    };
    break;
  case kind_string:
    {
      /* Could use string name -- see string_to_charp */

      assert(context == rpos);
    };
    break;
  default:
    if (is_unary(e))
      confine_inf_unary_expression(CAST(unary, e), context);
    else if (is_binary(e))
      confine_inf_binary_expression(CAST(binary, e), context);
  }

  return;
}

static void confine_inf_unary_expression(unary e, context context)
{
  switch (e->kind)
    {
    case kind_dereference:
      {
        confine_inf_expression(e->arg1, rpos);

	return;
      }
      break;
    case kind_address_of:
      {
	assert(context == rpos);
	if (type_function(e->arg1->type))
	  return confine_inf_expression(e->arg1, context);
	else
	  return confine_inf_expression(e->arg1, apos);
      }
      break;
    case kind_extension_expr:
      {
	return confine_inf_expression(e->arg1, context);
      }
      break;
    case kind_sizeof_expr:
      {
	assert(context == rpos);
	return;
      }
      break;
    case kind_alignof_expr:
      {
	assert(context == rpos);
	return;
      }
      break;
    case kind_realpart:
    case kind_imagpart:
      {
	confine_inf_expression(e->arg1, rpos);
	return;
      }
      break;
    case kind_unary_minus:
    case kind_unary_plus:
    case kind_conjugate:
    case kind_bitnot:
      {
	confine_inf_expression(e->arg1, rpos);
	return;
      }
      break;
    case kind_not:
      {
	assert(context == rpos);
	confine_inf_expression(e->arg1, rpos);
	return;
      }
      break;
    case kind_preincrement:
    case kind_postincrement:
    case kind_predecrement:
    case kind_postdecrement:
      {
	assert(context == rpos);
	confine_inf_expression(e->arg1, lpos);
	return;
      }
      break;
    default:
      fail_loc(e->loc, "Unexpected unary op kind 0x%x\n", e->kind);
    }
}

static void confine_inf_binary_expression(binary e, context context)
{
  assert(context == rpos);

  if (isassignment(e))
    confine_inf_expression(e->arg1, lpos);
  else
    confine_inf_expression(e->arg1, rpos);
  confine_inf_expression(e->arg2, rpos);
}

/**************************************************************************
 *                                                                        *
 * Confine                                                                *
 *                                                                        *
 **************************************************************************/

/* Just restrict confine to non-side-effectful computation...for now */

static bool get_identifiers_expression(expression e, identifier_set * identifiers);
static bool get_identifiers_unary_expression(unary e, identifier_set * indentifiers);
static bool get_identifiers_binary_expression(binary e, identifier_set * indentifiers);
static compound_stmt find_largest_scope(identifier_set identifiers);
static bool is_new_expression(expression e, dd_list expressions);


static void confine_this_expression(expression e)
{
  identifier_set identifiers = empty_identifier_set(parse_region);
  compound_stmt cs;

  if (!get_identifiers_expression(e, &identifiers))
    return;
  
  if ((cs = find_largest_scope(identifiers)))
    {      
      /* Now add this to every scope leading to cs */
      compound_stmt cur;
      cur = cur_cs;
      if (flag_confine_inf_aggressive)
	{
	  while (1)
	    {
	      assert(cur);

	      if (!cur->confine_expressions)
		cur->confine_expressions = dd_new_list(parse_region);

	      if (is_new_expression(e, cur->confine_expressions))
		{
		  exprdrinfo ed = ralloc(parse_region, struct ExprDrinfoPair);
		  ed->e = e;
		  ed->drinfo = NULL;
		  dd_add_last(parse_region, cur->confine_expressions, ed);
		}

	      if (cur == cs)
		break;
	      else
		cur = cur->enclosing_cs;
	    }
	}
      else {
	if (!cs->confine_expressions)
	  cs->confine_expressions = dd_new_list(parse_region);
	if (is_new_expression(e, cs->confine_expressions))
	  {
	    exprdrinfo ed = ralloc(parse_region, struct ExprDrinfoPair);
	    ed->e = e;
	    ed->drinfo = NULL;
	    dd_add_last(parse_region, cs->confine_expressions, ed);
	  }
      }
    }
  return;
}

static bool is_new_expression(expression e, dd_list expressions)
{
  dd_list_pos cur;
  bool new;

  new = TRUE;
  dd_scan(cur, expressions)
    {
      exprdrinfo ed = DD_GET(exprdrinfo, cur);
      expression old = ed->e;
      new = new & !equal_expressions(e, old);
    }

  return new;
}

/* Needs first to determine if hitting the root (NULL) is ok */
static bool mark_cs_upward(compound_stmt cs, bool first)
{
  if (!cs)
    return first;
  
  if (cs->visited == 0)
    {
      cs->visited = 1;
      return mark_cs_upward(cs->enclosing_cs, first);
    }      
  else if (cs->visited == 1)
    return FALSE;
  else 
    {
      assert(cs->visited == 2);
      return TRUE;
    }
}

static void clean_cs_upward(compound_stmt cs)
{
  if (!cs || cs->visited == 0)
    return;

  cs->visited = 0;
  clean_cs_upward(cs->enclosing_cs);
}

/* For each compound stmt referred by an identifier in the set, if it
is not marked, walk up the chain of compound stmts, marking them until
either 1.)  hits the compound stmt previously marked, or 2.) hits a
compound stmt referred by an identifier in the set.  If 1.) occurs,
then return failure (NULL).  Else return the last compound stmt that
walked up.  Note that we need an implicit root (NULL).  */
static compound_stmt find_largest_scope(identifier_set identifiers)
{
  identifier_set_scanner ss;
  identifier id;
  compound_stmt last_cs = NULL;
  bool first;
  
  first = TRUE;
  scan_identifier_set(id, ss, identifiers)
    {
      compound_stmt cs;
      data_declaration root;

      root = root_ddecl(id->ddecl);
      cs = root->cs;
      
      if (!cs)
	/* Undeclared identifier (or that we somehow missed it) */
	return NULL;

      if (cs->visited)
	continue;
      else 
	{
	  cs->visited = 2;
	  if (mark_cs_upward(cs->enclosing_cs, first))
	    last_cs = cs;
	  else
	    return NULL;
	}
      first = FALSE;
    }
  
  scan_identifier_set(id, ss, identifiers)
    {
      compound_stmt cs;
      data_declaration root;

      root = root_ddecl(id->ddecl);
      cs = root->cs;
      clean_cs_upward(cs);
    }

  return last_cs;
}

static bool get_identifiers_expression(expression e, identifier_set * identifiers)
{
  switch (e->kind) {
  case kind_cast:
    {
      cast c = CAST(cast, e);
      return get_identifiers_expression(c->arg1, identifiers);
    }
    break;
  case kind_conditional:
    {
      conditional c = CAST(conditional, e);
      
      if (!get_identifiers_expression(c->condition, identifiers))
	return FALSE;
      if (c->arg1)
	{
	  if (!get_identifiers_expression(c->arg1, identifiers))
	    return FALSE;
	}
      return get_identifiers_expression(c->arg1, identifiers);
    }
    break;
  case kind_identifier:
    {
      identifier id = CAST(identifier, e);
      
      identifier_set_insert(parse_region, identifiers, id);
      return TRUE;
    }
    break;
  case kind_field_ref:
    {
      field_ref fr = CAST(field_ref, e);
      
      return get_identifiers_expression(fr->arg1, identifiers);
    }
    break;
  case kind_array_ref:
    {
      array_ref ar = CAST(array_ref, e);
      
      return (get_identifiers_expression(ar->arg1, identifiers) &&
	      get_identifiers_expression(ar->arg2, identifiers));
    }
    break;
  case kind_lexical_cst:
    {
      return TRUE;
    }
  default:
    if (is_unary(e))
      return get_identifiers_unary_expression(CAST(unary, e), identifiers);
    if (is_binary(e))
      return get_identifiers_binary_expression(CAST(binary, e), identifiers);
    else
      return FALSE;
  }
}

static bool get_identifiers_unary_expression(unary e, identifier_set * identifiers)
{
  return get_identifiers_expression(e->arg1, identifiers);
}

static bool get_identifiers_binary_expression(binary e, identifier_set * identifiers)
{
  return (get_identifiers_expression(e->arg1, identifiers) &
	  get_identifiers_expression(e->arg2, identifiers));
}
