#!/usr/bin/perl

$|=1;
srand;
$#input = 1;
$#hidden = 1;
$#output = 0;
$eta = .5;

$weightsa = &random_weights($#hidden, $#input);
$weightsb = &random_weights($#output, $#hidden);

for $ct (1..15000){
    $nch = 0;
    &train([0, 0], [0]); $nch += $chisq;
    &train([0, 1], [1]); $nch += $chisq;
    &train([1, 0], [1]); $nch += $chisq;
    &train([1, 1], [0]); $nch += $chisq;
    print("Training loop $ct error $nch\n");
    &burn_weights();
}
print("---------------------\n");
&show_weights($weightsa, $#hidden, $#input);
&show_weights($weightsb, $#output, $#hidden);
print("---------------------\n");
$show = 1;
print "-> ", join(" ", &forward(0, 0)), "\n";
print "-> ", join(" ", &forward(0, 1)), "\n";
print "-> ", join(" ", &forward(1, 0)), "\n";
print "-> ", join(" ", &forward(1, 1)), "\n";

sub burn_weights {
    foreach $j (0.. $#hidden){
	foreach $i (0..$#output){
	    $weightsb->[$i]->[$j] += $dweightsb->[$i]->[$j];
	    $dweightsb->[$i]->[$j] = 0;
	}

	foreach $k (0..$#input){
	    $weightsa->[$j]->[$k] += $dweightsa->[$j]->[$k];
	    $dweightsa->[$j]->[$k] = 0;
	}
    }
}

sub train {
    my($input, $dout) = @_;
    my @out = &forward(@$input);
    $chisq = 0;
    foreach $i (0..$#output){
	$chisq+=($$dout[$i] - $out[$i])**2;
	foreach $j (0..$#hidden){
#		print "Delta (i) = ", (($$dout[$i] - $out[$i])*&sech($H[$i])), 
#		"=", ($$dout[$i] - $out[$i]), " * ", &sech($H[$i]), "($o[$j])\n";
#		print("Adjusting second layer weight $i, $j by ", ($eta*$o[$j]*(($$dout[$i] - $out[$i])*&sech($H[$i]))), "\n");
	    $dweightsb->[$i]->[$j] += 
		($eta*$o[$j]*(($$dout[$i] - $out[$i])*&sech($H[$i])));
	    }
    }
    foreach $j (0..$#hidden){
	foreach $k (0..$#input){
	    $delta = 0;
	    foreach $i (0..$#output){
		$delta += 
		    $weightsb->[$i]->[$j]*
			($$dout[$i] - $out[$i])*&sech($H[$i]);
	    }
#	    $delta*=&sech($h[$j]);
	    $delta*=&sech($o[$j]);
#	    print "Adjusting first layer weight $j $k by ", ($eta*$delta*$$input[$k]), "\n";
	    $dweightsa->[$j]->[$k] += ($eta*$delta*$$input[$k]);
	}
    }
#    print("Chi square = ", $chisq/2, "\n");
}

sub random_weights {
    my($x, $y) = @_;
    foreach $a (0..$x) {
	foreach $b (0..$y){
#	    $w->[$a]->[$b] = rand();
	    $w->[$a]->[$b] = .9;
	} 
    }
    $w;
}

sub show_weights {
    my($w, $x, $y) = @_;
    foreach $a (0..$x) {
	foreach $b (0..$y){
	    print "$a $b: ", $w->[$a]->[$b], "\n";
	} 
    }
    print("\n");
}

sub forward {
    my(@input) = @_;

    foreach $i (0..$#hidden){
	$h[$i] = 0;
    }

    foreach $i (0..$#hidden){
	foreach $j (0..$#input){
	    $h[$i] += $weightsa->[$i]->[$j]*$input[$j];
	}
    }
    @o = ();
    @o = &squish(\@h);
    foreach $i (0..$#output){
	$H[$i] = 0;
    }

    foreach $i (0..$#output){
	foreach $j (0..$#hidden){
	    $H[$i] += $weightsb->[$i]->[$j]*$o[$j];
	}
    }
    print("Hidden...\n") if $show;
    print join("\n", @o) if $show;
    print("\n") if $show;
    @O = ();
    @O = &squish(\@H);
#    print(join(" ", @input), " -> ", join(" ", @O), "\n");
    @O;
}

sub squish {
    my $h = shift;
    @o = ();
    foreach $i (0..$#$h){
	$o[$i] = &tanh($$h[$i]);
    }
    return @o;
}

sub tanh {
    my($h) = shift;

    return 1 if $h > 100;
    return 0 if $h < -100;
    my $tanh = (exp($h)-exp(-1*$h))/(exp($h)+exp(-1*$h));
    ($tanh+1)/2;
}

sub tanh2 {
    my $h = shift;

    return 0 if $h < -30;

    return (1/(1+exp(-1*$h)));
}


sub cosh {
    my($h) = shift;
    
    my $cosh = (exp($h)+exp(-1*$h))/2;
    $cosh;
}

sub sech2 {
    my $h = shift;
    if($h >1 || $h < -1){
	return 0;
    }
    return .1 + $h*(1-$h);
}

sub sech {
    # Actually, sech^2
    my $h = shift;

    return 0 if $h >100;
    return 0 if $h < -100;
    my $cosh = &cosh($h);;
    my $sech = 1/$cosh;
    if($sech > 1){
	print "Whoa baby\n";
    }
    $sech*$sech;
}

sub printary {
    foreach $u (@_){
	print "-> $u\n";
    }
}

