/*
 * nntest.cc
 * 
 * Neuronal network test program. 
 * 
 * Note: This is a quick test program. IF YOU INTEND TO SEE GOOD 
 *       PROGRAMMING, DO NOT LOOK AT THIS FILE. 
 * 
 * Copyright (c) 2004 by Wolfgang Wieser (wwieser@gmx.de) 
 * 
 * This file may be distributed and/or modified under the terms of the 
 * GNU General Public License version 2 as published by the Free Software 
 * Foundation. 
 * 
 * This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING THE
 * WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.
 * 
 */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <assert.h>

#include <Qt/qapplication.h>
#include <Qt/qpainter.h>
#include <Qt/qbrush.h>
#include <Qt/qlayout.h>
#include <Qt/qlabel.h>
#include <Qt/qcdestyle.h>
#include <Qt/qevent.h>

#include <QTX/xpainter.h>
#include <QTX/rgbcolorsource.h>

//gcc -W -Wall -O2 nntest.cc -o nntest -fno-rtti -fno-exceptions -I. -IQt -fmessage-length=$COLUMNS -lqt-mt -L/opt/Qt/lib libqtxlib.a


using namespace QTX;

//------------------------------------------------------------------------------

inline double SQR(double x)
	{  return(x*x);  }
inline float SQR(float x)
	{  return(x*x);  }
inline int SQR(int x)
	{  return(x*x);  }

inline double rand01()
	{  return(double(rand())/RAND_MAX);  }
inline double rand11()
	{  return(rand01()*2.0-1.0);  }


class NNetwork
{
	public:
		typedef float weight_t;
		struct ILT  // input layer type
		{
			float x;  // value
		};
		struct OLT  // output layer type
		{
			float x;  // value
		};
		
	public:
		// Input layer size: 
		int iw,ih;
		
		// Cohonen/output layer size: 
		int ow,oh;
		
		// Time counter: 
		int cyc;
		
	private:
		// Weights from any input node to any output node: 
		weight_t *_w;  // [ih*iw*oh*ow]
		
		// Access function: 
		// iy*iw*oh*ow + ix*oh*ow + oy*ow + ox
		weight_t &w(int ix,int iy,int ox,int oy)
			{  return(_w[((iy*iw+ix)*oh+oy)*ow+ox]);  }
		
		// Input, output layer current state: 
		ILT *_il;  // [iw*ih]
		OLT *_ol;  // [ow*oh]
		
	public:
		// Input/output layer access functions: 
		ILT *il(int ix,int iy)  {  return(&_il[iy*iw+ix]);  }
		OLT *ol(int ox,int oy)  {  return(&_ol[oy*ow+ox]);  }
		
		int winx,winy;  // coos of winner
		OLT *winner;
		OLT *loser;
		
	private:
		
	public:
		NNetwork();
		~NNetwork();
		
		// Set weights randomly: 
		void FillRandom();
		
		// Clear (zero) input/output layer: 
		void ZapIL();
		void ZapOL();
		
		// Feed net with input: 
		void Feed(const char *str);
		
		// Load input layer with passed string (used by Feed()): 
		void LoadInput(const char *str);
		
		// Calculate output from input (used by Feed()): 
		void CalcOutput(int print_winner);
};


void NNetwork::LoadInput(const char *str)
{
	// Load sring into input layer: 
	const char *s=str;
	for(int iy=0; iy<ih; iy++) for(int ix=0; ix<iw; ix++,s++)
		switch(*s)
		{
			case ' ':  il(ix,iy)->x=-1.0;  break;
			case '#':  il(ix,iy)->x=+1.0;  break;
			case '+':  il(ix,iy)->x= 0.0;  break;
			case '\0':  assert(0);  break;
			default:    assert(0);  break;
		}
}

void NNetwork::CalcOutput(int print_winner)
{
	// For every output cell, calculate activity: 
	// Find most active cell ("winner"): 
	int wx=-1,wy=-1;
	int loserx=-1,losery=-1;
	double wact=+1e100;
	double loseract=-1e100;
	for(int oy=0; oy<oh; oy++) for(int ox=0; ox<ow; ox++)
	{
		OLT *oc=ol(ox,oy);
		double act=0.0;
		
		#if 0
		// Scalar product w * x: 
		for(int iy=0; iy<ih; iy++) for(int ix=0; ix<iw; ix++)
			act+=w(ix,iy,ox,oy)*il(ix,iy)->x;
		#else
		// Euclidic distance |w-x|
		for(int iy=0; iy<ih; iy++) for(int ix=0; ix<iw; ix++)
			act+=SQR(w(ix,iy,ox,oy)-il(ix,iy)->x);
		act=sqrt(act);
		#endif
		
		oc->x=act;
		if(wact>act)  wx=ox,wy=oy,wact=act;
		if(loseract<act)  loserx=ox,losery=oy,loseract=act;
	}
	assert(wx>=0);
	
	winx=wx;  winy=wy;
	winner=ol(wx,wy);
	loser=ol(loserx,losery);
	
	if(print_winner)
		printf("Winner=%d,%d, act=%g; Loser=%d,%d, act=%g\n",
			winx,winy,wact,loserx,losery,loseract);
}

void NNetwork::Feed(const char *str)
{
	LoadInput(str);
	
	CalcOutput(1);
	
	// Update weights: 
	for(int oy=0; oy<oh; oy++) for(int ox=0; ox<ow; ox++)
	{
		OLT *oc=ol(ox,oy);
		
		const weight_t eta0=0.3;
		weight_t eta=eta0;
		if(oc!=winner)
		{
			// Calc distance from winner on output plane: 
			float wdist2=SQR(ox-winx)+SQR(oy-winy);
			
			#if 0
			const weight_t a2=1.0/( 2.0 );
			const weight_t sigma2=1.0/( 3.0*a2 );   // 3.0*a2
			//const weight_t a2=1.0/( 4.0 );
			//const weight_t sigma2=1.0/( 3.5*3.5 - a2 );
			eta *= (1.0f-wdist2*a2)*expf(-wdist2*sigma2);
			#else
			const weight_t sigma2=1.0/( 3.0 );
			eta *= expf(-wdist2*sigma2);
			#endif
		}
		
		for(int iy=0; iy<ih; iy++) for(int ix=0; ix<iw; ix++)
			w(ix,iy,ox,oy)+=eta*(il(ix,iy)->x-w(ix,iy,ox,oy));
		//for(int iy=0; iy<ih; iy++) for(int ix=0; ix<iw; ix++)
		//	w(ix,iy,ox,oy)+=eta*();
		
		continue;
	}
	
	++cyc;
}


void NNetwork::FillRandom()
{
	for(int iy=0; iy<ih; iy++) for(int ix=0; ix<iw; ix++)
		for(int oy=0; oy<oh; oy++) for(int ox=0; ox<ow; ox++)
			w(ix,iy,ox,oy)=rand11();
	cyc=0;
}

void NNetwork::ZapIL()
{
	for(int iy=0; iy<ih; iy++) for(int ix=0; ix<iw; ix++)
		il(ix,iy)->x=0.0;
}

void NNetwork::ZapOL()
{
	for(int oy=0; oy<oh; oy++) for(int ox=0; ox<ow; ox++)
		ol(ox,oy)->x=0.0;
}


NNetwork::NNetwork()
{
	iw=10;
	ih=10;
	ow=10;
	oh=10;
	
	size_t size=ih*iw*oh*ow*sizeof(weight_t);
	printf("Weight array: %u kb\n",size>>10);
	_w=(weight_t*)malloc(size);
	assert(_w);
	
	_il=(ILT*)malloc(ih*iw*sizeof(ILT));
	_ol=(OLT*)malloc(oh*ow*sizeof(OLT));
	assert(_il && _ol);
	
	winx=-1;
	winy=-1;
	winner=NULL;
	loser=NULL;
	
	FillRandom();
	
	ZapIL();
	ZapOL();
}

NNetwork::~NNetwork()
{
	if(_w) free(_w);
	if(_il) free(_il);
	if(_ol) free(_ol);
}

//------------------------------------------------------------------------------

class MainWindow : public QWidget
{
	private:
		XPainter *black,*blue,*red,*any;
		RGBColorSource *rgb;
		
		NNetwork *nn;
		
		// Draw/display parameters for the input & output layer: 
		struct
		{
			int x0,y0;  // offset in window
			int w,h;    // pixel size of a cell
		} dpi,dpo;
		
		char drawn;
		
		void _Redraw(int all_white=0);
		void _DrawWinnerPattern(const char *str,int cnt);
		void _DrawStat();
		
		void timerEvent(QTimerEvent *);
		void keyPressEvent(QKeyEvent *);
		void mousePressEvent(QMouseEvent *);
		void mouseMoveEvent(QMouseEvent *);
		void paintEvent(QPaintEvent *);
	public:
		MainWindow(QWidget *parent=NULL,const char *name=NULL);
		~MainWindow();
		
		void RandomFeed(int do_draw=1);
		void DrawPatterns();
};

void MainWindow::_Redraw(int all_white)
{
	if(all_white)
	{  any->SetForeground(rgb->color(1,1,1));  }
	
	// Draw input layer: 
	NNetwork::weight_t wmax=1.0,wmin=-1.0;
	for(int iy=0; iy<nn->ih; iy++) for(int ix=0; ix<nn->iw; ix++)
	{
		XPainter *pnt;
		if(all_white)
		{  pnt=any;  }
		else
		{
			NNetwork::weight_t a=nn->il(ix,iy)->x;
			if(a<wmin || a>wmax)
			{  pnt=blue;  }
			else
			{
				// 1.0f-XXX -> active=1.0=black
				float col=1.0f-(a-wmin)/(wmax-wmin);
				any->SetForeground(rgb->color(col,col,col));
				pnt=any;
			}
		}
		
		pnt->FillRectangle(
			ix*dpi.w+dpi.x0+1,iy*dpi.h+dpi.y0+1,
			dpi.w-1,dpi.h-1);
		blue->DrawRectangle(
			ix*dpi.w+dpi.x0,iy*dpi.h+dpi.y0,
			dpi.w,dpi.h);
	}
	
	// Draw output layer: 
	wmax=nn->loser ? nn->loser->x : 1.0;
	wmin=nn->winner ? nn->winner->x : -1.0;
	if(wmax<wmin)
	{  NNetwork::weight_t tmp=wmax; wmax=wmin; wmin=tmp;  }
	for(int oy=0; oy<nn->oh; oy++) for(int ox=0; ox<nn->ow; ox++)
	{
		XPainter *pnt;
		if(all_white)
		{  pnt=any;  }
		else
		{
			NNetwork::weight_t a=nn->ol(ox,oy)->x;
			if(nn->ol(ox,oy)==nn->winner)
			{  pnt=red;  }
			else if(a<wmin || a>wmax)
			{  pnt=blue;  }
			else
			{
				// 1.0f-XXX -> active=wmin=white
				//      XXX -> active=wnin=black
				float col=1.0-(a-wmin)/(wmax-wmin);
				any->SetForeground(rgb->color(col,col,col));
				pnt=any;
			}
		}
		
		pnt->FillRectangle(
			ox*dpo.w+dpo.x0+1,oy*dpo.h+dpo.y0+1,
			dpo.w-1,dpo.h-1);
		blue->DrawRectangle(
			ox*dpo.w+dpo.x0,oy*dpo.h+dpo.y0,
			dpo.w,dpo.h);
	}
	
	drawn='f';
	
	_DrawStat();
}

void MainWindow::_DrawWinnerPattern(const char *str,int cnt)
{
	nn->LoadInput(str);
	
	// Draw input layer into output layer cell: 
	int x0=nn->winx*dpo.w+dpo.x0+1;
	int y0=nn->winy*dpo.h+dpo.y0+1;
	switch(cnt%4)
	{
		case 0: break;
		case 1: x0+=9; break;
		case 2: y0+=9; break;
		case 3: x0+=9; y0+=9; break;
	}
	for(int iy=0; iy<nn->ih; iy++) for(int ix=0; ix<nn->iw; ix++)
	{
		if(nn->il(ix,iy)->x<=0.0) continue;
		blue->DrawPoint(x0+ix,y0+iy);
	}
}


void MainWindow::_DrawStat()
{
	char tmp[32];
	snprintf(tmp,32,"t=%d",nn->cyc);
	erase(0,130,98,30);
	black->DrawString(4,150,tmp);
}


void MainWindow::mousePressEvent(QMouseEvent * /*ev*/)
{
	//int x0=ev->x();
	//int x1=ev->y();
}

void MainWindow::mouseMoveEvent(QMouseEvent *ev)
{
	int mx=ev->x();
	int my=ev->y();
	
	char tmp[64];
	*tmp='\0';
	
	// See if we're in input layer: 
	int cx=(mx-dpi.x0)/dpi.w;
	int cy=(my-dpi.y0)/dpi.h;
	if(cx>=0 && cy>=0 && cx<nn->iw && cy<nn->ih)
	{  snprintf(tmp,64,"I[%d,%d]=%f",cx,cy,nn->il(cx,cy)->x);  }
	
	// Of if we're in the output layer: 
	cx=(mx-dpo.x0)/dpo.w;
	cy=(my-dpo.y0)/dpo.h;
	if(cx>=0 && cy>=0 && cx<nn->ow && cy<nn->oh)
	{  snprintf(tmp,64,"O[%d,%d]=%f",cx,cy,nn->ol(cx,cy)->x);  }
	
	erase(0,460,200,30);
	if(*tmp)
	{  black->DrawString(4,480,tmp);  }
}

void MainWindow::keyPressEvent(QKeyEvent *ev)
{
	int iter=1;
	bool is_s=ev->state() & ShiftButton;
	bool is_c=ev->state() & ControlButton;
	if(is_s && is_c)  iter=1000;
	else if(is_c)  iter=100;
	else if(is_s)  iter=10;
	
	switch(ev->key())
	{
		case Key_Return:
		case Key_Enter:  break;
		case Key_Q:  close();  break;
		case Key_F:
			for(int i=0; i<iter; i++)
				RandomFeed();
			break;
		case Key_D:
			DrawPatterns();
			break;
		case Key_R:
			nn->ZapIL();
			nn->ZapOL();
			nn->FillRandom();
			_Redraw();
			break;
		case Key_V:
			for(int i=0; i<iter; i++)
				RandomFeed(0);
			DrawPatterns();
			break;
		case Key_C:
			nn->ZapIL();
			nn->ZapOL();
			nn->FillRandom();
			DrawPatterns();
			break;
	}
	
	ev->accept();
}

void MainWindow::paintEvent(QPaintEvent *)
{
	if(drawn=='p')
	{  DrawPatterns();  }
	else
	{  _Redraw();  }
}

void MainWindow::timerEvent(QTimerEvent *)
{
}


const int npatterns=16;
const char *pattern[npatterns]=
{
//------------------0
	"    #     "
	"    #     "
	"    #     "
	"    #     "
	" ######## "
	"    #     "
	"    #     "
	"    #     "
	"    #     "
	"    #     ",
//------------------1
	"          "
	"   ###    "
	" ##   ##  "
	"#       # "
	"#       # "
	"#       # "
	"#       # "
	" ##   ##  "
	"   ###    "
	"          ",
//------------------2
	"          "
	" #       #"
	" ##     ##"
	" # #   # #"
	" #  # #  #"
	" #   #   #"
	" #       #"
	" #       #"
	" #       #"
	" #       #",
//------------------3
	"          "
	"  ####### "
	"     #    "
	"     #    "
	"     #    "
	"     #    "
	"     #    "
	"     #    "
	"     #    "
	"  ####### ",
//------------------4
	"          "
	"  ####    "
	" #    #   "
	" #        "
	"  ####    "
	"      #   "
	" #    #   "
	"  ####    "
	"          "
	"          ",
//------------------5
	"          "
	"  ####    "
	" #    #   "
	" #    #   "
	"  ####    "
	" #    #   "
	" #    #   "
	"  ####    "
	"          "
	"          ",
//------------------6
	"          "
	" #        "
	" #        "
	" #        "
	" #        "
	" #        "
	" #        "
	" #        "
	" #######  "
	"          ",
//------------------7
	"          "
	" #######  "
	"  #     # "
	"  #     # "
	"  #     # "
	"  ######  "
	"  #       "
	"  #       "
	"  #       "
	" ###      ",
//------------------8
	"          "
	" #######  "
	"  #     # "
	"  #     # "
	"  #     # "
	"  ######  "
	"  #   #   "
	"  #    #  "
	"  #     # "
	" ##    ###",
//------------------9
	"          "
	" #      # "
	"  #    #  "
	"   #  #   "
	"    ##    "
	"    ##    "
	"   #  #   "
	"  #    #  "
	" #      # "
	"          ",
//------------------10
	"          "
	"   ###    "
	" ##   ##  "
	"#       # "
	"#   #   # "
	"#   #   # "
	"#       # "
	" ##   ##  "
	"   ###    "
	"          ",
//------------------11
	"          "
	" ######## "
	"       #  "
	"      #   "
	"     #    "
	"    #     "
	"   #      "
	"  #       "
	" ######## "
	"          ",
//------------------12
	"          "
	" ######## "
	"       #  "
	"      #   "
	"   ####   "
	"    #     "
	"   #      "
	"  #       "
	" ######## "
	"          ",
//------------------13
	" ######## "
	"#        #"
	"          "
	"  #   #   "
	"          "
	"    #     "
	"    #     "
	" #     #  "
	"  #####   "
	"          ",
//------------------14
	"##########"
	"##########"
	"### ## ###"
	"####  ####"
	"####  ####"
	"### ## ###"
	"##########"
	"##########"
	"##########"
	"##########",
//------------------15
	"##########"
	"##########"
	"### ## ###"
	"####  ####"
	"####  ####"
	"### ## ###"
	"##########"
	"##########"
	"####  ####"
	"##########",
};

void MainWindow::RandomFeed(int do_draw)
{
	int pi=rand()%npatterns;
	
	nn->Feed(pattern[pi]);
	if(do_draw) _Redraw();
	else _DrawStat();
}

void MainWindow::DrawPatterns()
{
	_Redraw(1);
	
	int cnt[nn->ow][nn->oh];
	for(int oy=0; oy<nn->oh; oy++) for(int ox=0; ox<nn->ow; ox++)
		cnt[ox][oy]=0;
	
	for(int i=0; i<npatterns; i++)
	{
		nn->LoadInput(pattern[i]);
		nn->CalcOutput(0);
		_DrawWinnerPattern(pattern[i],cnt[nn->winx][nn->winy]++);
	}
	
	drawn='p';
}


MainWindow::MainWindow(QWidget *parent,const char *name) : 
	QWidget(parent,name)
{
	QPainter pnt(this);
	pnt.setBrush(QBrush(Qt::black,Qt::SolidPattern));
	black=new XPainter(&pnt,true);
	any=new XPainter(&pnt,true);
	pnt.setBrush(QBrush(Qt::blue,Qt::SolidPattern));
	blue=new XPainter(&pnt,true);
	pnt.setBrush(QBrush(Qt::red,Qt::SolidPattern));
	red=new XPainter(&pnt,true);
	assert(black && blue && red && any);
	
	rgb=new RGBColorSource(black);
	assert(rgb);
	
	nn=new NNetwork();
	assert(nn);
	
	dpi.x0=0;
	dpi.y0=0;
	dpi.w=8;
	dpi.h=8;
	
	dpo.x0=99;
	dpo.y0=0;
	dpo.w=20;
	dpo.h=20;
	
	drawn='\0';
	
	resize(500,500);
	move(100,100);
	
	setBackgroundColor(Qt::white);
	setCaption("NN test");
	
	setMouseTracking(1);
	show();
}


MainWindow::~MainWindow()
{
	delete nn;
	
	delete rgb;
	
	delete black;
	delete blue;
	delete red;
	delete any;
}


char *prg_name="nntest";

int main(int argc,char **arg)
{
	QApplication qapp(argc,arg);
	qapp.setStyle(new QCDEStyle(TRUE));
	//qapp.setStyle(new QSGIStyle(TRUE));
	
	MainWindow mainwin(NULL);
	
	qapp.setMainWidget(&mainwin);
	
	int r=qapp.exec();
	return(r);
}

