#!/usr/bin/python
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
#   See COPYING file distributed along with the PyMVPA package for the
#   copyright and license terms.
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Tiny tool to prepare a directory for a typical analysis of fMRI data with
PyMVPA. Tools from the FSL suite will be used for preprocessing. It takes a 4D
fMRI timeseries as input and performs the following steps:

- extract an example volume

- perform motion correction using the example volume as reference

- conservative skull-stripping and brain mask generation

- masking of the motion-corrected timeseries with the brain mask

All results will be stored either in the current directory, or in a
subdirectory with the subject ID (if specified)."""

import sys
import os
from subprocess import call

# Import mvpa2 first so it could augment matplotlib backend if necessary
import mvpa2

import numpy as np

from mvpa2.base import verbose, externals, error
from mvpa2.misc.cmdline import parser, opt
from mvpa2.misc.fsl import McFlirtParams

if __debug__:
    from mvpa2.base import debug


_EXFUNC_CONV_DICT = {'last' :  lambda x: x-1,
                     'first':  lambda x: 0,
                     'middle': lambda x: int(x/2)}
"""Dictionary to get exemplar volume given a literal string"""

##REF: Name was automagically refactored
def prep_parser(parser):
    # use module docstring for help output
    parser.usage = "%s [OPTIONS] <fmri-data>\n\n" % sys.argv[0] + __doc__
    parser.version = "%prog " + mvpa2.__version__

    parser.add_option(opt.verbose)
    parser.add_option(opt.help)

    parser.add_option("-s", "--subject-id",
                      action="store", type="string", dest="subj",
                      default=None,
                      help="Subject ID used as output path")

    parser.add_option("-e", "--example-func-vol",
                      action="store", type="string", dest="exfunc",
                      default='middle',
                      help="Volume (numeric ID or 'last', 'first', 'middle') "
                           "to be used as an example functional image. "
                           "Default: 10")

    parser.add_option("-m", "--mcflirt-options",
                      action="store", type="string", dest="mcflirt_opts",
                      default='',
                      help="Options for MCFLIRT. '-plots' is auto-added ")

    parser.add_option("-p", "--mcflirt-plots",
                      action="store_true", dest="mcflirt_plots",
                      help="Create a .pdf with plots of motion parameters")

    parser.add_option("-b", "--bet-options",
                      action="store", type="string", dest="bet_opts",
                      default='-f 0.3',
                      help="Options for BET. '-m' is auto-added. "
                           "Default: '-f 0.3' for a safe guess of the brain "
                           "outline")



def main():
    """
    """
    prep_parser(parser)
    (options, infiles) = parser.parse_args()

    # late import of pynifti to be able to get help output without a big
    # external dep.
    externals.exists('nibabel', raise_=True)
    import nibabel as nb

    if len(infiles) > 1 or not len(infiles):
        error("%s needs exactly one input fMRI image as argument. "
                "Got %s" % (sys.argv[0], str(infiles)))

    func_fname = infiles[0]

    # compressed or uncompressed? decide by input image
    # XXX maybe add override option
    if func_fname.lower().endswith('nii.gz'):
        nii_ext = '.nii.gz'
        verbose(2, "Output files will be compressed NIfTI images")
    else:
        nii_ext = '.nii'
        verbose(2, "Output files will be uncompressed NIfTI images")

    # determine output path
    if options.subj is not None:
        opath = options.subj
    else:
        opath = os.path.curdir

    if not os.path.exists(opath):
        verbose(1, "Create output directory '%s'" % opath)
        os.makedirs(opath)
    else:
        verbose(2, "Using output path '%s'" % opath)

    verbose(2, "Load image file from '%s'" % func_fname)
    func_nim = nb.load(func_fname)
    func_nim_hdr = func_nim.header
    func_nim_data = func_nim.get_data()
    func_nim_shape = func_nim_data.shape

    # process exfunc option
    exfunc = options.exfunc.lower()
    timepoints = len(func_nim_shape) > 3 and func_nim_shape[3] or 1
    if exfunc in _EXFUNC_CONV_DICT.keys():
        exfuncid = _EXFUNC_CONV_DICT[exfunc](timepoints)
    else:
        try:
            exfuncid = int(exfunc)
        except ValueError, e:
            error("Failed to convert '%s' into numerical id of "
                  "volume." % (exfunc))

    if exfuncid >= timepoints or exfuncid < 0:
        error("Example functional volume id must be in the "
              "range 0 .. %d. Got %d." % (timepoints-1, exfuncid))

    verbose(2, "Extract volume %i as example volume" % exfuncid)
    ef_nim = nb.Nifti1Image(func_nim_data[..., exfuncid], None, func_nim_hdr)
    ef_nim.to_filename(os.path.join(opath, 'example_func' + nii_ext))

    # close input file -- will operate on motion-corrected one later on
    del func_nim, func_nim_data, func_nim_hdr

    mcflirt_call = \
        ' '.join(
           ['mcflirt',
            '-in ' + func_fname,
            '-out ' + os.path.join(opath, 'func_mc'),
            '-reffile ' + os.path.join(opath, 'example_func'),
            '-verbose 0',
            '-plots',
            options.mcflirt_opts]).strip()

    verbose(2, "Perform motion correction ('%s')" % mcflirt_call)

    # run MCFLIRT (silence stderr; 5 being some random file descriptor)
    if call(mcflirt_call, shell=True, stderr=None):
        error("MCFLIRT failed to perform the motion correction.")

    if options.mcflirt_plots:
        verbose(2, "Plot motion parameters estimates")
        externals.exists('pylab', raise_=True)
        import pylab as pl

        mc = McFlirtParams(os.path.join(opath, 'func_mc.par'))

        for k, (title, fields, ylabel) in enumerate(
            (('Translation', ('x', 'y', 'z'), 'mm'),
             ('Rotation', ('rot1', 'rot2', 'rot3'), 'radians'))):
            pl.subplot(211+k)
            pl.title(title)
            pl.plot([0, timepoints], [0, 0], '0.6')
            for dim in fields:
                pl.plot(mc[dim], label=dim)
            pl.legend()
            pl.axis('tight')
            pl.ylabel(ylabel)

        pl.gcf().savefig(os.path.join(opath, 'func_mc.pdf'))

    bet_call = \
        ' '.join(
            ['bet',
             os.path.join(opath, 'example_func'),
             os.path.join(opath, 'example_func_brain'),
             '-m',
             options.bet_opts]).strip()

    verbose(2, "Determine brain mask in functional space ('%s')" % bet_call)

    # run BET (silence stderr; 5 being some random file descriptor)
    if call(bet_call, shell=True, stderr=None):
        error("BET failed to perform the skull stripping.")


    verbose(2, "Threshold image background using brain mask")
    mask_nim = nb.load(os.path.join(opath, 'example_func_brain_mask' + nii_ext))
    mask_nim_data = mask_nim.get_data()
    func_nim_filename = os.path.join(opath, 'func_mc' + nii_ext)
    func_nim = nb.load(func_nim_filename)
    func_nim_data = func_nim.get_data()
    # special case: single slice mask
    #if len(mask_nim_data.shape) < 3:
    #    func_nim_data[:, np.asarray([mask_nim.data]) == 0] = 0
    #else:
    # I think with xyzt ordering we don't really need special handling
    func_nim_data[mask_nim_data == 0] = 0
    # yoh: NO I don't like (or just don't know how) nibabel API ;)
    nb.Nifti1Image(
        func_nim_data, None, func_nim.header).to_filename(func_nim_filename)


if __name__ == '__main__':
    main()
