/** \file
  training a neural network

  it uses
  Fast artificial neural network library (http://fann.sourceforge.net)...

  load existing network, add new batch, outputs the learn-error, saves the network
  */

#include <stdio.h>
#include <stdlib.h>
#include <getopt.h>
#include <sys/time.h>
#include "local.h"

/* Prototypes */
static int help(char *argv[]);
static int calcoptions(int argc, char * argv[]);
static struct fann * create_network();
static void save_network(struct fann * ann);
static void destroy_network(struct fann * ann);
static void read_datasets();
int main(int argc, char * argv[]);

/** calculates and checks the given options and sets the struct net */
static int calcoptions(int argc, char* argv[]) {
	int opt_case;
	optind=1;
	opt_case=0;
	/* defaults for network parameters */
	net.connection_rate=1.0f; /* full connected MLP network */
	net.learning_rate=0.07f; /* 0.07f ;*/
	net.num_layers=3; /* one input, one hidden, one output layer */
	net.num_input=30; /* count of input neurons */
	net.num_neurons_hidden=30; /* how many hidden neurons, 3 is a good choice */
	net.num_output=1;
	net.desired_error=0.0001f; /* training is  aborted if learning error is less than... */
	net.max_steps=10000; /* how many training steps */
	net.steps_between_reports=1; /* how often reports will be generated */
	net.net_infile=NULL;
	net.outfilenames=NULL;
	net.net_outfile=NULL;
	net.verbose=0;
	net.net_writeperiod=net.steps_between_reports;
	net.infilename=NULL;
	net.trainfilename=NULL;
	net.testfilename=NULL;
	while (TRUE) { /* true */
		opt_case=getopt(argc,argv,"c:e:hi:l:m:o:r:v");
		if (opt_case == -1 ) {break;}
		switch (opt_case) {
			case 'e': net.desired_error=(float) atof(optarg);
				  if(net.desired_error<=0) {
					  fprintf(stderr, "learning error should be greater than zero\n");
					  (void) help(argv);
					  exit(ERRARG);
				  }
				  break;
			case 'i': net.net_infile=optarg; break;
			case 'l': net.learning_rate=(float) atof(optarg);
				  if(net.learning_rate<=0) {
					  fprintf(stderr, "learning rate should be greater than zero\n");
					  (void) help(argv);
					  exit(ERRARG);
				  }
				  break;
			case 'm': net.max_steps=(unsigned int) atoi(optarg);
				  if(net.max_steps<1) {
					  fprintf(stderr, "max_steps should be greater than zero\n");
					  (void) help(argv);
					  exit(ERRARG);
				  }
				  break;
			case 'o': net.outfilenames=optarg;
				  net.net_outfile=malloc((strlen(net.outfilenames)+5)*sizeof(char));
				  if (NULL==net.net_outfile) {
					  fprintf(stderr, "could not alloc mem for buf, %s\n", strerror(errno));
					  exit(ERRMEM);
				  }
				  strcpy(net.net_outfile, net.outfilenames);
				  strcat(net.net_outfile, ".net");
				  net.net_mqlefile=malloc((strlen(net.outfilenames)+6)*sizeof(char));
				  if (NULL==net.net_mqlefile) {
					  fprintf(stderr, "could not alloc mem for buf, %s\n", strerror(errno));
					  exit(ERRMEM);
				  }
				  strcpy(net.net_mqlefile, net.outfilenames);
				  strcat(net.net_mqlefile, ".mqle");
				  net.net_mqgefile=malloc((strlen(net.outfilenames)+6)*sizeof(char));
				  if (NULL==net.net_mqgefile) {
					  fprintf(stderr, "could not alloc mem for buf, %s\n", strerror(errno));
					  exit(ERRMEM);
				  }
				  strcpy(net.net_mqgefile, net.outfilenames);
				  strcat(net.net_mqgefile, ".mqge");
				  net.net_cfgfile=malloc((strlen(net.outfilenames)+5)*sizeof(char));
				  if (NULL==net.net_cfgfile) {
					  fprintf(stderr, "could not alloc mem for buf, %s\n", strerror(errno));
					  exit(ERRMEM);
				  }
				  strcpy(net.net_cfgfile, net.outfilenames);
				  strcat(net.net_cfgfile, ".cfg");
				  break;
			case 'r': net.steps_between_reports=(unsigned int) atoi(optarg);
				  if(net.steps_between_reports<1) {
					  fprintf(stderr, "steps_between_reports should be greater than zero\n");
					  (void) help(argv);
					  exit(ERRARG);
				  }
				  break;
			case 'v': net.verbose=1; break;
			case '?': fprintf(stderr, "there was an unknown parameter: %i \n", optopt); (void) help(argv); exit(0);
			case ':': fprintf(stderr, "there was a required parameter missed: %s \n", optarg); (void) help(argv); exit(0);
			case 'h': (void) help(argv); exit(0);
			default: /*printf("arg was a option: %s\n",optarg);*/ break;
		}
	}
	if ((argc-optind)!=2) {
		/* <1 means too few args */
		fprintf(stderr, "wrong count of arguments, <trainfile>/<testfile> missed\n");
		(void) help(argv);
		exit(0);
	} else { 
		net.trainfilename = argv[optind];
		net.testfilename  = argv[optind+1];
	}
	return (0);
}

/** prints the help-page if there were errors on the command-line or if explicitely ordered */
static int help(char * argv[]) {
	fprintf(stderr,"train net <train data> <test data> -o <outputprefix>\n");
	fprintf(stderr,"  reads fann DAT-files 'train' and 'test', and trains a\n");
	fprintf(stderr,"  neural network. The network will be saved,\n");
	fprintf(stderr,"  if max_steps will be reached or actual\n");
	fprintf(stderr,"  error below desired_error.\n");
	fprintf(stderr,"call\n %s [options] <train data> <test data>\n", argv[0]);
	fprintf(stderr,"or\n %s -h\n",argv[0]);
	fprintf(stderr,"options are:\n");
	fprintf(stderr, "\t-a\tadaptr learnrate, means per epoch the learningrate decrease (default:%u)\n", net.adapt_learnrate);
	fprintf(stderr, "\t-e VAL\tlearning error over an epoch (float, default: %03f)\n", net.desired_error);
	fprintf(stderr, "\t-h    \thelp\n");
	fprintf(stderr, "\t-i VAL\tnet file to load (string, default: %s)\n", net.net_infile);
	fprintf(stderr, "\t-l VAL\tlearning_rate (float, default:%0.10f)\n", net.learning_rate);
	fprintf(stderr, "\t-m VAL\tmax_steps (unsigned int, default: %u)\n", net.max_steps);
	fprintf(stderr, "\t-o VAL\toutfiles (outfile names without extensions) to store (string, default:%s)\n",net.outfilenames);
	fprintf(stderr, "\t-r VAL\tsteps_between_reports (unsigned int, default: %u)\n", net.steps_between_reports);
	fprintf(stderr, "\t-v\t verbose (default: %u)\n", net.verbose);
	fprintf(stderr, "\n");
	return(0);
}

/** small function to encapsulate the creation/loading of ANN */
static struct fann * create_network() {
	struct fann * ann;
	/* create or load neural network */
	if (NULL != net.net_infile) { /* load network-data from file */
		/* TODO */
		ann = fann_create_from_file(net.net_infile);
		if (ann == NULL) {
			fprintf(stderr, "could not create network from file %s: %s\n", net.net_infile, strerror(errno));
			exit(ERRFANN);
		}
	} else { /* build network from data in net-structure */
		ann = fann_create(net.connection_rate, net.learning_rate, net.num_layers,net.num_input, net.num_neurons_hidden, net.num_output);
		if (ann == NULL) {
			fprintf(stderr, "could not create network: %s\n", strerror(errno));
			exit(ERRFANN);
		}
		fann_randomize_weights(ann, -0.9f, 0.9f);
		/* only with fann > 1.2, speeds up learning: 	fann_set_momentum(ann, 0.9); */
		fann_set_activation_function_output(ann, FANN_SIGMOID_SYMMETRIC);
		fann_set_activation_function_hidden(ann, FANN_SIGMOID_SYMMETRIC);
	}
	return(ann);
}

/** small function to encapsulate the peridocal backup mechanism to store the actual trained ANN */
static void save_network(struct fann * ann) {
	assert(ann !=NULL);
	/* save network */
	if (net.net_outfile!=NULL) { /* only if filename specified... */
		fann_save(ann,net.net_outfile);
	}
}

/** small function to free the ANN structure */
static void destroy_network(struct fann * ann) {
	assert(ann!=NULL);
	fann_destroy(ann);
}

/** readin data into training and test set */
static void read_datasets () {
	net.train_data = fann_read_train_from_file(net.trainfilename);
	net.test_data = fann_read_train_from_file(net.testfilename);
	fann_shuffle_train_data(net.train_data);
}

/** main logic to prepare ANN and feed it with the right data to train an ANN */
int main(int argc, char * argv[]) {
	FILE * fh_mqle;
	FILE * fh_mqge;
	struct fann * ann;
	float mqle;
	float mqge;
	unsigned int i,j;
	struct timeval lasttimestamp;
	struct timeval timestamp;
	gettimeofday(&timestamp,NULL);
	lasttimestamp.tv_sec=timestamp.tv_sec;
	lasttimestamp.tv_usec=timestamp.tv_usec;
#ifndef FLOATFANN
	/* assertion that we use float as fann_type */
#error "to avoid type conflicts, the float-variant of libfann must be used"
#endif
	(void) calcoptions(argc, argv);
	if (net.outfilenames!=NULL) {
		fh_mqle=fopen(net.net_mqlefile, "w");
		if (fh_mqle==NULL) {
			fprintf(stderr, "File %s could not be open, %s (line %i)\n", net.net_mqlefile, strerror(errno), __LINE__);
			exit(ERRSTREAM);
		}
		fprintf(fh_mqle, "# sample, fh_mqle over %u steps, time in s\n", net.steps_between_reports);
		fh_mqge=fopen(net.net_mqgefile, "w");
		if (fh_mqge==NULL) {
			fprintf(stderr, "File %s could not be open, %s (line %i)\n", net.net_mqgefile, strerror(errno), __LINE__);
			exit(ERRSTREAM);
		}
		fprintf(fh_mqge, "# sample, fh_mqge over %i steps\n", net.max_steps);
	} else {
		fh_mqge=stdout;
		fh_mqle=stderr;
	}
	read_datasets();
	ann=create_network();
	/* train network */
	fprintf(stderr, "total neurons: %u, total connections: %u\n", fann_get_total_neurons(ann), fann_get_total_connections(ann));
	fprintf(stderr, "start with training\n"); (void) fflush(stderr);
	mqle=10000.0f;
	mqge=10000.0f;
	/* for every epoch */
	for (i=0; i< net.max_steps && mqge > net.desired_error; i++) {
		int rindex=i % net.train_data->num_data;
		fann_reset_MSE(ann); /* reset mean learning error */
		fann_train(ann, net.train_data->input[rindex],net.train_data->output[rindex]);
		mqle=fann_get_MSE(ann);
		if (0 == i%net.steps_between_reports) {
			/* calc MQGE */
			fann_reset_MSE(ann); /* reset mean learning error */
			for (j=0; j <net.test_data->num_data; j++) {
				fann_test(ann, net.test_data->input[j], net.test_data->output[j]);
			}
			mqge=fann_get_MSE(ann);
			if (net.outfilenames!=NULL) {
				gettimeofday(&timestamp,NULL);
				fprintf(fh_mqle, "%u %0.5f %0.5f\n", i, mqle, (float) (timestamp.tv_sec-lasttimestamp.tv_sec)+((float) (timestamp.tv_usec-lasttimestamp.tv_usec)/1000));
				fprintf(fh_mqge, "%u %0.5f\n", i, mqge);
				lasttimestamp.tv_sec=timestamp.tv_sec;
				lasttimestamp.tv_usec=timestamp.tv_usec;
				/* to trace learning with gnuplot, per epoch */
				(void) fflush(fh_mqle);
				(void) fflush(fh_mqge);
			}
			printf ("step:%i MQLE:%0.5f MQGE:%0.5f\n", i, mqle, mqge);
		}
		if ( /* save network periodically, only if filename specified... */
				(net.net_outfile!=NULL) &&
				(0 ==  i%net.net_writeperiod)
		   ) { 			save_network(ann);
		}
		(void) fflush(stderr);
	} /* end for loop, end of training */ 
	save_network(ann);
	destroy_network(ann);
	if (net.outfilenames!=NULL) {
		fclose(fh_mqle);
		fclose(fh_mqge);
	}
	return(0);
}


