#!/usr/bin/env bash

# Copyright (c) 2015 Alexandra Figlovskaya <fglval@gmail.com>
# Copyright (c) 2015 Aleksey Cheusov <vle@gmx.net>
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

# variables settable by user
: ${SVM_TRAIN_CMD:=svm-train}
: ${SVM_PREDICT_CMD:=svm-predict}
: ${TMPDIR:=/tmp}

############################################################
set -e
export LC_ALL=C

indent2 (){
    sed '/./ s/^/  /' "$@"
}

sig_handler (){
    on_exit
    trap - "$1"
    kill -"$1" $$
}

on_exit(){
    show_stderr
    if test -z "$keep_tmp"; then
	if test -n "$tmp_dir"; then
	    rm -rf "$tmp_dir"
	fi
    else
	echo "Temporary files are here $tmp_dir" 1>&2
    fi
}

results_from_testing_sets (){
    if ! test -s "$tmp_dir/testing_fold.txt"; then
	cat "$tmp_dir/result_single1.txt"
	return
    fi

    awk '
    FNR == NR {
	# reading testing_fold.txt
	++obj_num[$1]
	testobj[$1,obj_num[$1]] = NR
	next
    }

    # reading results on testing folds
    FNR == 1 {
	++fold_num
    }

    {
	idx = testobj[fold_num, FNR]
	result [idx] = $0
    }

    END {
	if ((NR % 2) != 0){
	    print "internal error!" > "/dev/stderr"
	    exit 12
	}
	count = NR/2
	for (i=1; i <= count; ++i){
	    print result [i]
	}
    }' "$tmp_dir/testing_fold.txt" $result_all
}

show_stderr (){
    if test -z "$last"; then
	return
    fi
    for i in `seq $last`; do
	#
	fn="$tmp_dir/train_stderr${i}"
	if test -s "$fn"; then
	    echo "---- train stderr $i ----" 1>&2
	    cat -- "$fn" 1>&2
	fi
	#
	fn="$tmp_dir/predict_stderr${i}"
	if test -s "$fn"; then
	    echo "---- predict stderr $i ----" 1>&2
	    cat -- "$fn" 1>&2
	fi
    done
}

wait_all (){
    local i
    local ex
    ex=0
    for i in `seq $last`; do
	if wait ${pid[$i]}; then
	    :
	else
	    ex=$?
	fi
    done
    return "$ex"
}

#    heri-eval -t10 -n 5 dataset.libsvm          # 10*5-fold cross-validation
usage(){
    cat 1>&2 <<'EOF'
usage: heri-eval [OPTIONS] training_set [-- SVM_TRAIN_OPTIONS]
Examples:
    heri-eval -n5 dataset.libsvm                # 5-fold cross-validation
    heri-eval -e testing.libsvm dataset.libsvm  # testing on testing.libsvm

OPTIONS:
      -h                   help message

      -n N                 N-fold cross validation mode (mandatory option)
      -t T                 T*N-fold cross validation mode (1 by default)

      -e testing_set       testing set for hold-out method

      -o <filename>        save results from testing sets
                           to the specified file
                           (golden_tag result_tag [score])
      -O <filename>        save incorrectly classified objects
                           to the specified file
                           (#object_number: golden_tag result_tag [score])
      -m <filename>        save confusion matrix to the specified file
                           (frequency : golden_tag result_tag)

      -f                   Enable output of per-fold statistics (see -Mf)
      -M <chars>           output mode:
                              t -- output total statistics,
                              f -- output per-fold statistics,
                              c -- output cross-fold statistics.
                           The default is "-M tc".
      -p <stat_opts>       options passed to heri-stat(1)
      -S <seed>            seed pseudo-random generator used for splitting
                           dataset into traing and testing parts.
                           The default is empty, which means
                           'split dataset randomly every invocation'
      -K                   keep temporary directory after exiting
      -D                   debugging mode, implies -K

SVM_TRAIN_OPTIONS: options passed to svm-train(1) and alike

Environment variables:
  SVM_TRAIN_CMD   -- training utility, e.g., liblinear-train
                     (the default is svm-train)
  SVM_PREDICT_CMD -- predicting utility, e.g., liblinear-predict
                     (the default is svm-predict)
  TMPDIR          -- temporary directory (the default is /tmp)

Examples: 
  Ex1: heri-eval -e testing_set.libsvm training_set.libsvm -- -s 0 -t 0
  Ex1: export SVM_TRAIN_CMD='liblinear-train'
       export SVM_PREDICT_CMD='liblinear-predict'
       heri-eval -p '-mr' -v 5 training_set.libsvm -- -s 4 -q
EOF
}

runs=1
output_mode=tc
times=1
while getopts De:fhKm:M:n:o:O:p:S:t: f; do
    case "$f" in
	'?')
	    usage
	    exit 1;;
	h)
	    usage
	    exit 0;;
	n)
	    number_of_folds="$OPTARG";;
	e)
	    testing_set="$OPTARG";;
	t)
	    times="$OPTARG";;
	m)
	    confusion_matrix="$OPTARG";;
	o)
	    results="$OPTARG";;
	O)
	    incorrect_results="$OPTARG";;
	p)
	    heristat_args="$heristat_args $OPTARG";;
	f)
	    output_mode="f$output_mode";;
	M)
	    output_mode="$OPTARG";;
	S)
	    seed="$OPTARG";;
	K)
	    keep_tmp=1;;
	D)
	    keep_tmp=1
	    debug=1;;
    esac
done
shift `expr $OPTIND - 1`

while test "$#" -gt 0; do
    case "$1" in
	--)
	    shift
	    break;;
	*)
	    print_sh=`printf '%q' "$1"`
	    files="$files $print_sh"
	    shift;;
    esac
done

trap "sig_handler INT"  INT
trap "on_exit" 0

if test -z "$number_of_folds" -a -z "$testing_set"; then
    echo 'Either -v or -e must be specified, run heri-eval -h for details' 1>&2
    exit 1
fi

if test -z "$files"; then
    echo 'Training set is mandatory, run heri-eval -h for details' 1>&2
    exit 1
fi

tmp_dir=`mktemp -d $TMPDIR/svm.XXXXXX`

training_testing (){
    if test -n "$number_of_folds"; then
	heri-split -c "$number_of_folds" -d "$tmp_dir" -s "$seed" $files
	if test -n "$seed"; then
	    seed="${seed}9876"
	fi
	last="$number_of_folds"
    else
	eval "cat -- $files" > "$tmp_dir/train1.txt"
	cp "$testing_set" "$tmp_dir/test1.txt"
	last=1
    fi

    for i in `seq $last`; do
	${SVM_TRAIN_CMD} "$@" "$tmp_dir/train$i.txt" "$tmp_dir/svm$i.bin" \
	    2> "$tmp_dir/train_stderr${i}" \
	    >  "$tmp_dir/train_stdout${i}" &
	pid[$i]=$!
    done

    wait_all

    for i in `seq $last`; do
	${SVM_PREDICT_CMD} "$tmp_dir/test$i.txt" "$tmp_dir/svm$i.bin" \
	    "$tmp_dir/result${i}.txt" \
	    2> "$tmp_dir/predict_stderr${i}" \
	    >  "$tmp_dir/predict_stdout${i}" &
	pid[$i]=$!
    done

    wait_all

    rm -f "$tmp_dir/golden_tags" "$tmp_dir/result.txt"
}

show_stat (){
    for t in `seq $times`; do
	result_all=''
	for i in `seq $last`; do
	    awk '{print $1}' "$tmp_dir/test${t}_$i.txt" > "$tmp_dir/golden_tags${t}_${i}"
	    if [[ "_$output_mode" =~ f ]]; then
		echo "Fold ${t}x$i statistics"
		heri-stat $heristat_args \
		    "$tmp_dir/golden_tags${t}_${i}" "$tmp_dir/result${t}_${i}.txt" |
		indent2
		echo ''
	    fi
	    heri-stat -R "$tmp_dir/golden_tags${t}_${i}" "$tmp_dir/result${t}_${i}.txt" \
		> "$tmp_dir/evaluation${t}_${i}.txt"
	    paste "$tmp_dir/golden_tags${t}_${i}" "$tmp_dir/result${t}_${i}.txt" | \
		tr '	' ' '  > "$tmp_dir/result_single${t}_${i}.txt"

	    ln -f "$tmp_dir/result_single${t}_${i}.txt" "$tmp_dir/result_single${i}.txt"
	    result_all="$result_all $tmp_dir/result_single${i}.txt"
	done
    done
}

for t in `seq $times`; do
    training_testing "$@"
#    ls -l "$tmp_dir/"
    for i in `seq $last`; do
	ln "$tmp_dir/test${i}.txt" "$tmp_dir/test${t}_$i.txt"
	ln "$tmp_dir/result${i}.txt" "$tmp_dir/result${t}_${i}.txt"
    done
#    rm "$tmp_dir/test${i}.txt" "$tmp_dir/result${i}.txt"
done

#echo before test
#ls -l "$tmp_dir"
show_stat
#echo after test

results_from_testing_sets > "$tmp_dir/result.txt"

# -o
if test -n "$results"; then
    cp "$tmp_dir/result.txt" "$results"
fi

# -O
if test -n "$incorrect_results"; then
    awk '$1 != $2 {print "#" NR, $0}' "$tmp_dir/result.txt" \
	> "$incorrect_results"
fi

# -m
if test -n "$confusion_matrix"; then
    awk '$1 != $2' "$tmp_dir/result.txt" |
    sort | uniq -c | sort -rn |
	awk '{print $1, ":", $2, $3}' > "$confusion_matrix"
fi

#
if [[ "_$output_mode" =~ t ]]; then
    echo 'Total statistics'
    heri-stat -1 $heristat_args "$tmp_dir"/result_single*_*.txt | indent2
    echo ''
fi

if test -n "$number_of_folds" && [[ "_$output_mode" =~ c ]]; then
    echo 'Total cross-folds statistics'
    heri-stat-addons "$tmp_dir"/evaluation*.txt | indent2
fi
