#!/usr/bin/env python

"""Convert dicom TimTrio dirs based on heuristic info

This function uses DicomStack and mri_convert to convert Siemens
TrioTim directories. It proceeeds by extracting dicominfo from each
subject and writing a config file $subject_id/$subject_id.auto.txt in
the output directory. Users can create a copy of the file called
$subject_id.edit.txt and modify it to change the files that are
converted. This edited file will always overwrite the original file. If
there is a need to revert to original state, please delete this edit.txt
file and rerun the conversion

"""

import argparse
from glob import glob
import inspect
import json
import os
import shutil
import sys
from tempfile import mkdtemp
import tarfile

import dicom as dcm
import dcmstack as ds


def save_json(filename, data):
    """Save data to a json file

    Parameters
    ----------
    filename : str
        Filename to save data in.
    data : dict
        Dictionary to save in json file.

    """

    fp = file(filename, 'w')
    json.dump(data, fp, sort_keys=True, indent=4)
    fp.close()


def load_json(filename):
    """Load data from a json file

    Parameters
    ----------
    filename : str
        Filename to load data from.

    Returns
    -------
    data : dict

    """

    fp = file(filename, 'r')
    data = json.load(fp)
    fp.close()
    return data


def process_dicoms(fl, dcmsession=None, basedir=None, dcmfilter=None):
    if dcmsession is None:
        multi_session = False
        dcmsession = [0] * len(fl)
    else:
        multi_session = True

    groups = [[], []]
    mwgroup = []
    for fidx, filename in enumerate(fl):
        if not basedir is None:
            filename = os.path.join(basedir, filename)
        mw = ds.wrapper_from_data(dcm.read_file(filename, force=True))
        try:
            del mw.series_signature['iop']
        except:
            pass
        try:
            del mw.series_signature['ICE_Dims']
        except:
            pass
        try:
            del mw.series_signature['SequenceName']
        except:
            pass
        try:
            series_id = (dcmsession[fidx],
                         mw.dcm_data.SeriesNumber,
                         mw.dcm_data.ProtocolName)
        except AttributeError:
            # not a normal DICOM -> ignore
            series_id = (dcmsession[fidx], -1, 'none')
        if not series_id[0] < 0:
            if not dcmfilter is None and dcmfilter(mw.dcm_data):
                series_id = (dcmsession[fidx], -1, mw.dcm_data.ProtocolName)
        if not groups:
            mwgroup.append(mw)
            groups[0].append(series_id)
            groups[1].append(len(mwgroup) - 1)
            continue

        N = len(mwgroup)
        # filter out unwanted non-image-data DICOMs by assigning
        # a series number < 0 (see test below)
        if not series_id[1] < 0 and mw.dcm_data[0x0008, 0x0016].repval in (
                'Raw Data Storage',
                'GrayscaleSoftcopyPresentationStateStorage'):
            series_id = (dcmsession[fidx], -1, mw.dcm_data.ProtocolName)
        #print fidx, N, filename
        ingrp = False
        for idx in range(N):
            same = mw.is_same_series(mwgroup[idx])
            #print idx, same, groups[idx][0]
            if same:
                ingrp = True
                if series_id[1] >= 0:
                    series_id = (dcmsession[fidx],
                                 mwgroup[idx].dcm_data.SeriesNumber,
                                 mwgroup[idx].dcm_data.ProtocolName)
                groups[0].append(series_id)
                groups[1].append(idx)
        if not ingrp:
            mwgroup.append(mw)
            groups[0].append(series_id)
            groups[1].append(len(mwgroup) - 1)

    group_map = dict(zip(groups[0], groups[1]))

    total = 0
    filegroup = {}
    seqinfo = []
    # for the next line to make any sense the series_id needs to
    # be sortable in a way that preserves the series order
    for series, mwidx in sorted(group_map.items()):
        if series[1] < 0:
            # skip our fake series with unwanted files
            continue
        mw = mwgroup[mwidx]
        if mw.image_shape is None:
            # this whole thing has now image data (maybe just PSg DICOMs)
            # nothing to see here, just move on
            continue
        dcminfo = mw.dcm_data
        files = [fl[i] for i, s in enumerate(groups[0]) if s == series]
        # turn the series_id into a human-readable string -- string is needed
        # for JSON storage later on
        if multi_session:
            series = '%i-%i-%s' % series
        else:
            series = '%i-%s' % series[1:]
        filegroup[series] = files
        size = list(mw.image_shape) + [len(files)]
        total += size[-1]
        if len(size) < 4:
            size.append(1)
        try:
            TR = float(dcminfo.RepetitionTime) / 1000.
        except AttributeError:
            TR = -1
        try:
            TE = float(dcminfo.EchoTime)
        except AttributeError:
            TE = -1
        info = [total, os.path.split(files[0])[1], series, '-', '-', '-'] + \
            size + [TR, TE, dcminfo.ProtocolName, 'MoCo' in dcminfo.SeriesDescription]
        seqinfo.append(info)
    return seqinfo, filegroup


def write_config(outfile, info):
    from pprint import PrettyPrinter
    with open(outfile, 'wt') as fp:
        fp.writelines(PrettyPrinter().pformat(info))


def read_config(infile):
    info = None
    with open(infile, 'rt') as fp:
        info = eval(fp.read())
    return info


def conversion_info(subject, outdir, info, filegroup, basedir=None):
    convert_info = []
    for key, items in info.items():
        if not items:
            continue
        template = key[0]
        outtype = key[1]
        outpath = outdir
        for idx, itemgroup in enumerate(items):
            if not isinstance(itemgroup, list):
                itemgroup = [itemgroup]
            for subindex, item in enumerate(itemgroup):
                try:
                    files = filegroup[item]
                except KeyError:
                    files = filegroup[unicode(item)]
                if not basedir is None:
                    files = [os.path.join(basedir, f) for f in files]
                outprefix = template.format(item=idx + 1, subject=subject, seqitem=item, subindex=subindex + 1)
                convert_info.append((os.path.join(outpath, outprefix), outtype, files))
    return convert_info


def embed_nifti(dcmfiles, niftifile, infofile, force=False):
    import dcmstack as ds
    import nibabel as nb
    import os
    stack = ds.parse_and_stack(dcmfiles, force=force).values()
    if len(stack) > 1:
        raise ValueError('Found multiple series')
    stack = stack[0]

    #Create the nifti image using the data array
    if not os.path.exists(niftifile):
        nifti_image = stack.to_nifti(embed_meta=True)
        nifti_image.to_filename(niftifile)
        return ds.NiftiWrapper(nifti_image).meta_ext.to_json()
    orig_nii = nb.load(niftifile)
    #orig_hdr = orig_nii.get_header()
    aff = orig_nii.get_affine()
    ornt = nb.orientations.io_orientation(aff)
    axcodes = nb.orientations.ornt2axcodes(ornt)
    new_nii = stack.to_nifti(voxel_order=''.join(axcodes), embed_meta=True)
    #new_hdr = new_nii.get_header()
    #orig_hdr.extensions = new_hdr.extensions
    #orig_nii.update_header()
    #orig_nii.to_filename(niftifile)
    meta = ds.NiftiWrapper(new_nii).meta_ext.to_json()
    with open(infofile, 'wt') as fp:
        fp.writelines(meta)
    return niftifile, infofile


def convert(items, anonymizer=None, symlink=True, converter=None,
        scaninfo_suffix='_scaninfo.json', custom_callable=None):
    prov_files = []
    tmpdir = mkdtemp(prefix='heudiconvtmp')
    for item in items:
        if isinstance(item[1], (list, tuple)):
            outtypes = item[1]
        else:
            outtypes = [item[1]]
        prefix = item[0]
        print('Converting %s' % prefix)
        dirname = os.path.dirname(prefix + '.ext')
        print(dirname)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        for outtype in outtypes:
            print(outtype)
            if outtype == 'dicom':
                dicomdir = prefix + '_dicom'
                if os.path.exists(dicomdir):
                    shutil.rmtree(dicomdir)
                os.mkdir(dicomdir)
                for filename in item[2]:
                    outfile = os.path.join(dicomdir, os.path.split(filename)[1])
                    if not os.path.islink(outfile):
                        if symlink:
                            os.symlink(filename, outfile)
                        else:
                            os.link(filename, outfile)
            elif outtype in ['nii', 'nii.gz']:
                outname = prefix + '.' + outtype
                scaninfo = prefix + scaninfo_suffix
                if not os.path.exists(outname):
                    from nipype import config
                    config.enable_provenance()
                    from nipype import Function, Node
                    from nipype.interfaces.base import isdefined
                    print converter
                    if converter == 'mri_convert':
                        from nipype.interfaces.freesurfer.preprocess import MRIConvert
                        convertnode = Node(MRIConvert(), name='convert')
                        convertnode.base_dir = tmpdir
                        if outtype == 'nii.gz':
                            convertnode.inputs.out_type = 'niigz'
                        convertnode.inputs.in_file = item[2][0]
                        convertnode.inputs.out_file = os.path.abspath(outname)
                        #cmd = 'mri_convert %s %s' % (item[2][0], outname)
                        #print(cmd)
                        #os.system(cmd)
                        res = convertnode.run()
                    elif converter == 'dcm2nii':
                        from nipype.interfaces.dcm2nii import Dcm2nii
                        convertnode = Node(Dcm2nii(), name='convert')
                        convertnode.base_dir = tmpdir
                        convertnode.inputs.source_names = item[2]
                        convertnode.inputs.gzip_output = outtype == 'nii.gz'
                        convertnode.inputs.terminal_output = 'allatonce'
                        res = convertnode.run()
                        if isinstance(res.outputs.converted_files, list):
                            print("Cannot convert dicom files - series likely has multiple orientations: ", item[2])
                            continue
                        else:
                            shutil.copyfile(res.outputs.converted_files, outname)
                        if isdefined(res.outputs.bvecs):
                            outname_bvecs = prefix + '.bvecs'
                            outname_bvals = prefix + '.bvals'
                            shutil.copyfile(res.outputs.bvecs, outname_bvecs)
                            shutil.copyfile(res.outputs.bvals, outname_bvals)
                    prov_file = prefix + '_prov.ttl'
                    shutil.copyfile(os.path.join(convertnode.base_dir,
                                                 convertnode.name,
                                                 'provenance.ttl'),
                                    prov_file)
                    prov_files.append(prov_file)
                    embedfunc = Node(Function(input_names=['dcmfiles',
                                                           'niftifile',
                                                           'infofile',
                                                           'force'],
                                              output_names=['outfile',
                                                            'meta'],
                                              function=embed_nifti),
                                     name='embedder')
                    embedfunc.inputs.dcmfiles = item[2]
                    embedfunc.inputs.niftifile = os.path.abspath(outname)
                    embedfunc.inputs.infofile = os.path.abspath(scaninfo)
                    embedfunc.inputs.force = True
                    embedfunc.base_dir = tmpdir
                    res = embedfunc.run()
                    g = res.provenance.rdf()
                    g.parse(prov_file,
                            format='turtle')
                    g.serialize(prov_file, format='turtle')
                    #out_file, meta_dict = embed_nifti(item[2], outname, force=True)
                    os.chmod(outname, 0440)
                    os.chmod(scaninfo, 0440)
                    os.chmod(prov_file, 0440)
        if not custom_callable is None:
            custom_callable(*item)
    shutil.rmtree(tmpdir)


def convert_dicoms(subjs, dicom_dir_template, outdir, heuristic_file, converter,
                   queue=None, anon_sid_cmd=None, anon_outdir=None):
    for sid in subjs:
        tmpdir = None
        if queue == 'slurm':
            # TODO This needs to be updated to better scale with additional args
            progname = os.path.abspath(inspect.getfile(inspect.currentframe()))
            convertcmd = ' '.join(['python', progname, '-d', dicom_dir_template,
                                   '-o', outdir, '-f', heuristic_file, '-s', sid,
                                   '-c', converter])
            script_file = 'sg-%s.sh' % sid
            with open(script_file, 'wt') as fp:
                fp.writelines(['#!/bin/bash\n', convertcmd])
            outcmd = 'sbatch -J sg-%s -p %s -N1 -c2 --mem=20G %s' % (sid, queue, script_file)
            os.system(outcmd)
            continue
        # expand the input template
        sdir = dicom_dir_template % sid
        # and see what matches
        fl = sorted(glob(sdir))
        dcmsessions = None
        # some people keep compressed tarballs around -- be nice!
        if len(fl) and tarfile.is_tarfile(fl[0]):
            # check if we got a list of tarfiles
            if not len(fl) == sum([tarfile.is_tarfile(i) for i in fl]):
                raise ValueError("some but not all input files are tar files")
            # tarfiles already know what they contain, and often the filenames
            # are unique, or at least in a unqiue subdir per session
            # strategy: extract everything in a temp dir and assemble a list
            # of all files in all tarballs
            content = []
            tmpdir = mkdtemp(prefix='heudiconvtmp')
            # needs sorting to keep the generated "session" label deterministic
            for i, t in enumerate(sorted(fl)):
                tf = tarfile.open(t)
                # check content and sanitize permission bits
                tmembers = tf.getmembers()
                for tm in tmembers:
                    tm.mode = 0o700
                # get all files, assemble full path in tmp dir
                tf_content = [m.name for m in tmembers if m.isfile()]
                content += tf_content
                if len(fl) > 1:
                    # more than one tarball, take care of "session" indicator
                    if dcmsessions is None:
                        dcmsessions = []
                    dcmsessions += ([i] * len(tf_content))
                # extract into tmp dir
                tf.extractall(path=tmpdir, members=tmembers)
            fl = content
        #dcmfile = dcm.read_file(fl[0], force=True)
        #print sid, 'Dicom: ', dcmfile.PatientName, sid == dcmfile.PatientName
        anon_sid = sid
        if not anon_sid_cmd is None:
            from subprocess import check_output
            anon_sid = check_output([anon_sid_cmd, sid]).strip()
        if anon_outdir is None:
            anon_outdir = outdir
        tdir = os.path.join(anon_outdir, anon_sid)
        idir = os.path.join(outdir, sid)
        if anon_outdir == outdir:
            # if all goes into a single dir, have a dedicated 'info' subdir
            idir = os.path.join(idir, 'info')
        if not os.path.exists(idir):
            os.makedirs(idir)
        shutil.copy(heuristic_file, idir)
        path, fname = os.path.split(heuristic_file)
        sys.path.append(path)
        mod = __import__(fname.split('.')[0])

        infofile = os.path.join(idir, '%s.auto.txt' % sid)
        editfile = os.path.join(idir, '%s.edit.txt' % sid)
        if os.path.exists(editfile):
            info = read_config(editfile)
            filegroup = load_json(os.path.join(idir, 'filegroup.json'))
        else:
            seqinfo, filegroup = process_dicoms(
                fl, dcmsessions, basedir=tmpdir,
                dcmfilter=getattr(mod, 'filter_dicom', None))
            save_json(os.path.join(idir, 'filegroup.json'), filegroup)
            with open(os.path.join(idir, 'dicominfo.txt'), 'wt') as fp:
                for seq in seqinfo:
                    fp.write('\t'.join([str(val) for val in seq]) + '\n')
            info = mod.infotodict(seqinfo)
            write_config(infofile, info)
            write_config(editfile, info)
        if converter != 'none':
            cinfo = conversion_info(anon_sid, tdir, info, filegroup, basedir=tmpdir)
            convert(cinfo, converter=converter,
                    scaninfo_suffix=getattr(
                        mod, 'scaninfo_suffix', '_scaninfo.json'),
                    custom_callable=getattr(
                        mod, 'custom_callable', None))
        if not tmpdir is None:
            # clean up tmp dir with extracted tarball
            shutil.rmtree(tmpdir)


if __name__ == '__main__':
    docstr = '\n'.join((__doc__,
"""
           Example:

           heudiconv -d rawdata/%s -o . -f heuristic.py -s s1 s2
s3
"""))
    parser = argparse.ArgumentParser(description=docstr)
    parser.add_argument('-d', '--dicom_dir_template',
                        dest='dicom_dir_template',
                        required=True,
                        help='''location of dicomdir that can be indexed with
                        subject id. Tarballs (can be compressed) are supported
                        in additions to directory. All matching tarballs for a
                        subject are extracted and their content processed in
                        a single pass''')
    parser.add_argument('-s', '--subjects', dest='subjs', required=True,
                        type=str, nargs='+', help='list of subjects')
    parser.add_argument('-c', '--converter', dest='converter',
                        default='dcm2nii',
                        choices=('mri_convert', 'dcmstack', 'dcm2nii', 'none'),
                        help='''tool to use for dicom conversion. Setting to 
                        "none" disables the actual conversion step -- useful
                        for testing heuristics.''')
    parser.add_argument('-o', '--outdir', dest='outputdir',
                        default=os.getcwd(),
                        help=''''output directory for conversion setup (for
                        further customization and future reference. This
                        directory will refer to non-anonymized subject IDs''')
    parser.add_argument('-a', '--conv-outdir', dest='conv_outputdir',
                        default=None,
                        help='''output directory for converted files. By
                        default this is identical to --outdir. This option is
                        most useful in combination with --anon-cmd''')
    parser.add_argument('--anon-cmd', dest='anon_cmd',
                        default=None,
                        help='''command to run to convert subject IDs used for
                        DICOMs to anonymmized IDs. Such command must take a
                        single argument and return a single anonymized ID.
                        Also see --conv-outdir''')
    parser.add_argument('-f', '--heuristic', dest='heuristic_file', required=True,
                        help='python script containing heuristic')
    parser.add_argument('-q', '--queue', dest='queue', default=None,
                        choices=('slurm',),
                        help='''select batch system to submit jobs to instead
                        of running the conversion serially''')
    args = parser.parse_args()
    convert_dicoms(args.subjs, os.path.abspath(args.dicom_dir_template),
                   os.path.abspath(args.outputdir),
                   heuristic_file=os.path.realpath(args.heuristic_file),
                   converter=args.converter,
                   queue=args.queue,
                   anon_sid_cmd=args.anon_cmd,
                   anon_outdir=args.conv_outputdir)
