Commit ca033279 authored by Boers, Frank's avatar Boers, Frank
Browse files

update for path cls

parent 0c63907b
......@@ -16,10 +16,10 @@ import collections,pprint
import mne
from mne.viz import topomap
from dcnn_utils import logger, isFile, rescale, read_raw, get_raw_filename
from dcnn_utils import apply_noise_reduction_4d, fig2rgb_array
from dcnn_utils import logger, isFile, isPath, rescale, read_raw, get_raw_filename,expandvars
from dcnn_utils import apply_noise_reduction_4d,fig2rgb_array
__version__ = "2020.06.23.002"
__version__ = "2020.07.06.001"
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++#
# SLOTS class
......@@ -53,8 +53,24 @@ class _SLOTS(object):
for k in self.__slots__:
if k.startswith("_"): continue
data[k]= getattr(self,k)
return data
return data
def _get_slots_attr_hidden(self,lstrip=True):
data = dict()
for k in self.__slots__:
if k.startswith("_"):
if lstrip:
data[ k.lstrip("_") ]= getattr(self,k)
else:
data[k] = getattr(self,k)
return data
def _set_slots_attr_hidden(self,**kwargs):
for k in kwargs:
slot = "_"+ k
if slot in self.__slots__:
self.__setattr__(slot,kwargs.get(k) )
def _update_from_kwargs(self, **kwargs):
if not kwargs: return
for k in kwargs:
......@@ -78,15 +94,33 @@ class _SLOTS(object):
def update(self,**kwargs):
self._update_from_kwargs(**kwargs)
def get_info(self):
msg=["Info => {}".format(self._cls_name)]
def get_info(self,hidden=False,msg=None):
"""
log class parameter in slots
Parameters
----------
hidden: if True log hidden slots e.g. _data
msg: string or list, message to append
Returns
-------
"""
_msg=["Info => {}".format(self._cls_name)]
for k in self.__slots__:
if k.startswith("_"): continue
msg.append( " -> {} : {}".format(k,getattr(self,k)) )
_msg.append( " -> {} : {}".format(k,getattr(self,k)) )
if hidden:
_msg.append("-"*20)
for k,v in self._get_slots_attr_hidden().items():
_msg.append( " -> {} : {}".format(k,v) )
if isinstance(msg,(list)):
_msg.extend(msg)
elif msg:
_msg.append(msg)
try:
logger.info("\n".join(msg))
logger.info("\n".join(_msg))
except:
print("--->"+"\n".join(msg))
print("--->"+"\n".join(_msg))
def dump(self):
"""
......@@ -156,8 +190,6 @@ class DCNN_CONFIG(object):
self._missing_keys = None
self.fname = None
self.path_list = ['data_meg','data_train','report']
self.defaults = {
'version': 'v0.2',
......@@ -218,7 +250,7 @@ class DCNN_CONFIG(object):
self._cfg = None #deepcopy( self.defaults )
if not v: return
self._cfg,self._missing_keys = self._merge_and_check(v)
self._add_base_dir()
# self._add_base_dir()
@property
def missing_keys(self): return self._missing_keys
......@@ -228,9 +260,7 @@ class DCNN_CONFIG(object):
self.defaults = kwargs.get("defaults")
if "keys_to_check" in kwargs:
self.keys_to_chek = kwargs.get("keys_to_check")
if "path_list" in kwargs:
self.path_list = kwargs.get("path_list")
if "fname" in kwargs:
if "fname" in kwargs:
self.fname = kwargs.get("fname")
if "config" in kwargs:
self.config = kwargs.get("config")
......@@ -359,29 +389,6 @@ class DCNN_CONFIG(object):
'''
pp = pprint.PrettyPrinter(indent=2)
return ''.join(map(str,pp.pformat(d)))
def _add_base_dir(self):
'''
update path settings
Returns
-------
None.
'''
if not self.path_list: return
path = self.config.get('path')
if not path: return
basedir = path.get('basedir')
if not basedir: return
# update path settings
for p in self.path_list:
if not self.config['path'][p].startswith(basedir):
self.config['path'][p] = op.join(basedir,self.config['path'][p])
def load(self,**kwargs):
'''
......@@ -434,6 +441,107 @@ class DCNN_CONFIG(object):
logger.info("\n".join(msg))
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# PATH class
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
class DCNN_PATH(_SLOTS):
__slots__= ("expandvars","exit_on_error","logmsg","mkdir","overwrite","add_basedir",
"_basedir","_data_meg","_data_train","_report")
def __init__(self,**kwargs):
super().__init__(**kwargs)
self._cls_name = "DCNN_PATH"
self.init(**kwargs)
def init(self,**kwargs): # set overwrite=True
self.clear()
self._update_from_kwargs(**kwargs)
#-- set overwrite to True for real init if not set
self.overwrite = kwargs.get("overwrite",True)
self.add_basedir = kwargs.get("add_basedir",True)
self.expandvars = kwargs.get("expandvars",True)
self.exit_on_error = kwargs.get("exit_on_error",True)
self.logmsg = kwargs.get("logmsg",True)
self.mkdir = kwargs.get("mkdir",True)
# --
self.basedir = kwargs.get("basedir")
self.data_meg = kwargs.get("data_meg")
self.data_train = kwargs.get("data_train")
self.report = kwargs.get("report")
#-- set overwrite to False for update if not set
self.overwrite = kwargs.get("overwrite",False)
#self.__set_slots_attr_hidden(**kwargs) # sets hidden slots e.g. _data-xyz
@property
def basedir(self): return self._basedir
# self._get_path(self._basedir)
@basedir.setter
def basedir(self,v):
self._set_path(v,"basedir")
@property
def data_meg(self):
return self._get_path(self._data_meg)
@data_meg.setter
def data_meg(self,v):
return self._set_path(v,"data_meg")
@property
def data_train(self):
return self._get_path(self._data_train)
@data_train.setter
def data_train(self,v):
self._set_path(v,"data_train")
@property
def report(self):
self._get_path(self._report)
@report.setter
def report(self,v):
return self._set_path(v,"report")
def _get_path(self,v):
if self.expandvars:
return expandvars(v)
return v
def _set_path(self,path,label):
if not path:
logger.warning("Warning {} can not set path, not defined for: {}".format(self._cls_name,label))
return None
if self.add_basedir and self.basedir:
if not path.startswith(self.basedir):
path = op.join(self.basedir,path)
logger.warning(" ADD basedir: {} => {}".format(self.basedir,path))
is_path = isPath(path,head="{} check path: {}".format(self._cls_name,label),
exit_on_error=self.exit_on_error,logmsg=self.logmsg,mkdir=self.mkdir)
if not getattr(self,"_" +label) or self.overwrite:
self.__setattr__("_" +label,path)
else:
pass
# TEST
msg = ["Warning in {} not allowed to overwrite path settings for: {}".format(self._cls_name,label),
" -> path orig: {}".format(getattr(self,"_"+label)),
" -> path : {}".format(path),
" -> overwrite : {}".format(self.overwrite),
" -> add basedir: {}".format(self.add_basedir),
" -> basedir : {}".format(self.basedir)]
logger.warning("\n".join(msg))
def dump(self):
d1 = super().dump()
d2 = self._get_slots_attr_hidden()
return {**d1,**d2}
def get_info(self):
super().get_info(hidden=True,msg="basedir expand: {}\n".format( expandvars(self.basedir) ))
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# PICKS class
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -517,12 +625,13 @@ class PICKS(_SLOTS):
'''
error_msg = "\nERROR => Wrong or missing aux channel labels in config!\n" \
" -> aux labels: {}".format(self.aux_labels)
" -> aux labels in config: {}\n".format(self.aux_labels)+\
" -> aux labels in raw : {}".format(self.get_aux_labels_in_raw())
try:
picks = self.labels2picks(self.aux_labels)
assert picks is not None, error_msg
assert (len(picks) == len(self._aux_labels)), error_msg + "\n -> picks found: {}".format(picks)
assert (len(picks) == len(self._aux_labels)), error_msg + "\n -> picks found: {}\n".format(picks)
self._aux_types = self.raw.get_channel_types(picks, unique=False, only_data_chs=False)
except ValueError:
......@@ -531,7 +640,7 @@ class PICKS(_SLOTS):
return picks
def _get_all_aux(self):
def get_aux_picks_in_raw(self):
'''
get all aux picks
......@@ -542,6 +651,10 @@ class PICKS(_SLOTS):
return mne.pick_types(self.info, ecg=True, eog=True, meg=False, eeg=False,
ref_meg=False,exclude=[])
def get_aux_labels_in_raw( self ):
picks = self.get_aux_picks_in_raw()
return self.picks2labels(picks)
def _get_chan(self):
'''
get meg,ecg,eog picks
......
......@@ -23,7 +23,7 @@ import inspect
from distutils.dir_util import mkpath
import logging
__version__="2020.06.23.001"
__version__="2020.07.06.001"
try:
# https://github.com/borntyping/python-colorlog
......
......@@ -12,12 +12,12 @@ import numpy as np
import mne
from mne.preprocessing import ICA
from dcnn_utils import (logger, get_chop_times_indices, auto_label_cardiac,
from dcnn_utils import (logger,get_chop_times_indices, auto_label_cardiac,expandvars,
auto_label_ocular, get_unique_list, transform_mne_ica2data)
from dcnn_base import DCNN_MEG, DCNN_ICA, DCNN_SOURCES
from dcnn_utils import collect_source_info
from dcnn_base import DCNN_PATH,DCNN_MEG, DCNN_ICA, DCNN_SOURCES
__version__= "2020.06.30.001"
__version__= "2020.07.06.001"
# -----------------------------------------------------
# compute downsampling frequency and final chop length
......@@ -116,7 +116,7 @@ class DCNN(object):
"""
def __init__(self, version=None, n_jobs=1, path=None, meg={}, ica={}, sources={}, res={},
def __init__(self, version=None, n_jobs=1, path={}, meg={}, ica={}, sources={}, res={},
info=None, verbose=False):
"""
init via config dict:
......@@ -126,7 +126,7 @@ class DCNN(object):
----------
version : string, optional, default: None
n_jobs : int, optional, default: 1.
path : dict, optional, default: None
path : dict, optional, default: {}
meg : dict, optional, default: {}
init MEG CLS
......@@ -144,7 +144,7 @@ class DCNN(object):
"""
self.version = version
self.path = path
#self.path = path
self.config_info = info
if res:
......@@ -153,12 +153,16 @@ class DCNN(object):
self.model = {'res_time': None ,'res_space': None}
# the following will later be saved in one file for each experiment
self._PATH = DCNN_PATH(**path)
self._MEG = DCNN_MEG(**meg) # MEG system and data settings
self._ICA = DCNN_ICA(**ica)
self._SOURCES = DCNN_SOURCES(**sources)
self.n_jobs = n_jobs # set n_jobs for MEG not ICA
self.verbose = verbose
@property
def path(self): return self._PATH
@property
def meg(self): return self._MEG
......@@ -184,9 +188,10 @@ class DCNN(object):
@verbose.setter
def verbose(self ,v):
self._verbose = v
self.meg.verbose = v
self.ica.verbose = v
self._verbose = v
self.path.verbose = v
self.meg.verbose = v
self.ica.verbose = v
self.sources.verbose = v
def _meg_filter_resample(self,raw=None):
......@@ -706,6 +711,7 @@ class DCNN(object):
# data = pickle.load(f)
# load gDCNN data
fname= expandvars(fname)
logger.info('Start loading gDCNN results from disk ....\n -> {}'.format(fname))
npz = np.load(fname, allow_pickle=True)
logger.info('Done loading gDCNN results from disk\n')
......@@ -715,7 +721,7 @@ class DCNN(object):
logger.info("gDCNN data: {}\n{}".format(k, v))
# get details
self.path = npz.get('path').item() # load dict
self.path.init(**npz.get('path').item()) # load dict
logger.info("path:\n -> {}".format(self.path))
self.model = npz.get('model').item() # load dict
......@@ -747,7 +753,7 @@ class DCNN(object):
logger.info("input fnout: ".format(fnout))
name = op.basename(self.meg.fname).rsplit('.')[0]
if not path_out:
path_out = self.path['data_train']
path_out = self.path.data_train
if not op.exists(path_out):
makedirs(path_out)
fnout = os.path.join(path_out, name + '-gdcnn.npz')
......@@ -760,7 +766,7 @@ class DCNN(object):
self.model['fname_gdcnn'] = fnout
np.savez(fnout, path=self.path, model=self.model,
np.savez(fnout,model=self.model, path=self.path.dump(),
meg=self.meg.dump(), ica=self.ica.dump(), sources=self.sources.dump())
logger.info('Done saving results to disk: {}'.format(fnout))
......@@ -771,15 +777,12 @@ class DCNN(object):
# get details
# --- init dicts
msg = ["DCNN Info",
" -> path :",
" -> {}".format(self.path),
" -> sources:",
" -> {}".format(self.sources),
" -> model:",
" -> {}".format(self.model)]
logger.info("\n".join(msg))
# -- objs
self.path.get_info()
self.meg.get_info()
self.ica.get_info()
self.sources.get_info()
......
......@@ -6,7 +6,7 @@ import os,sys
from dcnn_base import DCNN_CONFIG
from dcnn_main import DCNN
from dcnn_logger import setup_script_logging,init_logfile
__version__= "2020-06-05-001"
__version__= "2020-07-06-001"
####################################################
......@@ -41,16 +41,16 @@ def run(fnconfig=None,basedir=None,data_meg=None,pattern='-raw.fif',
dcnn.verbose = True
if basedir: # FB test
cfg.config['path']['basedir'] = basedir
dcnn.path.basedir = basedir
if data_meg:
cfg.config['path']['data_meg']= data_meg # input directory
dcnn.path.data_meg = data_meg # input directory
# ==========================================================
# run ICA auto labelling
# ==========================================================
if do_label_ica:
# get file list to process
path_in = os.path.join( cfg.config['path']['basedir'],cfg.config['path']['data_meg'])
path_in = os.path.join( dcnn.path.basedir,dcnn.path.data_meg)
fnames = find_files(path_in, pattern=pattern)
if not fnames:
......@@ -86,7 +86,7 @@ def run(fnconfig=None,basedir=None,data_meg=None,pattern='-raw.fif',
# check ICA auto labelling
# ==========================================================
if do_label_check:
path_in = cfg.config['path']['data_train']
path_in = dcnn.path.data_train
fnames = find_files(path_in, pattern= '*.npz')
fname = fnames[0] # use first file for testing
......@@ -99,8 +99,6 @@ def run(fnconfig=None,basedir=None,data_meg=None,pattern='-raw.fif',
logger.info("dcnn ica n_chop.\n{}\n\n".format(dcnn.ica.chop.n_chop))
if __name__ == "__main__":
# -- get parameter / flags from cmd line
argv = sys.argv
......
......@@ -6,7 +6,7 @@ from dcnn_main import DCNN
from dcnn_logger import setup_script_logging #FB
from dcnn_utils import file_looper,get_args,dict2str,expandvars #FB
__version__= "2020-06-23-001"
__version__= "2020-07-06-001"
####################################################
......@@ -19,7 +19,7 @@ fnconfig = 'config_CTF_Philly.yaml'
pattern = '-raw.fif'
verbose = True
do_label_ica = True
do_label_ica = False
do_label_check = True
do_performance_check = False
# --
......@@ -34,27 +34,31 @@ def run(fnconfig=None,basedir=None,data_meg=None,data_train=None,pattern='-raw.f
# -- init config CLS
cfg = DCNN_CONFIG(verbose=verbose)
cfg.load(fname=fnconfig)
#-- init dcnn CLS
dcnn = DCNN(**cfg.config) # init object with config details
dcnn.verbose = True
#---
if basedir: # FB test
cfg.config['path']['basedir'] = expandvars(basedir)
dcnn.path.basedir = basedir
if data_meg:
cfg.config['path']['data_meg']= expandvars(data_meg) # input directory
dcnn.path.data_meg = data_meg # input directory
if data_train:
cfg.config['path']['data_train']= expandvars(data_train) # input directory
dcnn.path.data_train = data_train # input directory
dcnn.verbose = True
dcnn.get_info()
# ==========================================================
# run ICA auto labelling
# ==========================================================
if do_label_ica:
path_in = os.path.join(cfg.config['path']['basedir'],cfg.config['path']['data_meg'])
# -- looper catch error via try/exception setup log2file
#path_in = os.path.join(cfg.config['path']['basedir'],cfg.config['path']['data_meg'])
path_in = dcnn.path.data_meg
# -- looper catch error via try/exception setup log2file
for fnraw in file_looper(rootdir=path_in,pattern=pattern,version=__version__,verbose=verbose,logoverwrite=True,log2file=log2file):
logger.info(fnraw)
#logger.info(fnraw)
# -- read raw data and apply noise reduction
dcnn.meg.update(fname=fnraw)
......@@ -70,15 +74,28 @@ def run(fnconfig=None,basedir=None,data_meg=None,data_train=None,pattern='-raw.f
# check ICA auto labelling
# ==========================================================
if do_label_check:
path_in = cfg.config['path']['data_train']
path_in = dcnn.path.data_train
for fname in file_looper(rootdir=path_in, pattern='*.npz',version=__version__,verbose=verbose,log2file=log2file,logoverwrite=False):
dcnn.load_gdcnn(fname)
# check IC labels (and apply corrections)
dcnn.check_labels(save=True)
#if verbose: # ToDo set verbose level True,2,3
#dcnn.check_labels(save=True)
# check IC labels (and apply corrections)
name = os.path.basename(fname[:-4])
print ('>>> working on %s' % name)
figs, captions = dcnn.plot_ica_traces(fname)
report.add_figs_to_section(figs, captions=captions, section=name, replace=True)
report.save(fnreport + '.h5', overwrite=True)
report.save(fnreport + '.html', overwrite=True, open_browser=True)
#dcnn.check_labels(save=True, path_out=cfg.config['path']['data_train'])
#if verbose: # ToDo set verbose level True,2,3
# logger.debug("dcnn ica chop dump.\n{}\n{}\n\n".format( dcnn.ica.chop, dict2str(dcnn.ica.chop.dump()) ))
# logger.debug("dcnn ica n_chop.\n{}\n\n".format(dcnn.ica.chop.n_chop))
# logger.debug("dcnn ica topo data.\n{}\n\n".format(dcnn.ica.topo.data))
......@@ -116,12 +133,12 @@ if __name__ == "__main__":
opt.do_label_check = do_label_check
opt.log2file = do_log2file
elif opt.fb: # call from shell
opt.basedir = "$JUMEG_LOCAL_DATA"+"/dcnn"
opt.data_meg = "data_examples"
opt.data_train = "$JUMEG_LOCAL_DATA"+"/dcnn/ica_labeled/Juelich"
#opt.basedir = "$JUMEG_PATH_LOCAL_DATA"+"/gDCNN"
#opt.data_meg = "data_examples"
#opt.data_train = "$JUMEG_PATH_LOCAL_DATA"+"/exp/dcnn/ica_labeled/Juelich"
#-- 4D
opt.pattern = "*nr-raw.fif" #"*.c,rfDC_bcc,nr-raw.fif"
opt.config = "config_4D.yaml"
opt.pattern = "*-raw.fif" #"*.c,rfDC_bcc,nr-raw.fif"
opt.config = "config_4D_Juelich.yaml"
opt.ica = do_label_ica
opt.check = do_label_check
......
......@@ -19,7 +19,7 @@ import pprint
from dcnn_logger import get_logger,init_logfile
logger = get_logger()
__version__= "2020.06.23.001"
__version__= "2020.07.06.001"
#--- FB
def setup_logfile(fname,version=__version__,verbose=False,overwrite=True,level=None):
......@@ -151,9 +151,11 @@ def expandvars(v):
"""
if not v: return None
if isinstance(v,(list)):
for i in range(len(v)):
v[i] = os.path.expandvars(os.path.expanduser( str(v[i]) ))
return v
for i in range(len(v)):
v[i] = os.path.expandvars(os.path.expanduser( str(v[i]) ))
return v
else:
return os.path.expandvars(os.path.expanduser(str(v)))
return os.path.expandvars(os.path.expanduser( str(v) ))
......@@ -176,15 +178,15 @@ def isPath(pin,head="check path exist",exit_on_error=False,logmsg=False,mkdir=Fa
p = os.path.abspath(expandvars(pin))
if os.path.isdir(p):
if logmsg:
logger.info(head+"\n --> dir exist: {}\n -> abs dir{:>18} {}".format(pin,':',p))
logger.info(head+"\n --> dir exist: {}\n -> abs dir: {}".format(pin,p))
return p
elif mkdir:
os.makedirs(p)
if logmsg:
logger.info(head+"\n --> make dirs: {}\n -> abs dir{:>18} {}".format(pin,':',p))
logger.info(head+"\n --> make dirs: {}\n -> abs dir: {}".format(pin,p))
return p
#--- error no such file
logger.error(head+"\n --> no such directory: {}\n -> abs dir{:>18} {}".format(pin,':',p))
logger.error(head+"\n --> no such directory: {}\n -> abs dir: {}".format(pin,p))
if exit_on_error:
raise SystemError()
return False
......@@ -442,9 +444,9 @@ def find_files(rootdir='.', pattern='*', recursive=False):
"""
import os
import fnmatch
files = []
for root, dirnames, filenames in os.walk(rootdir):
rootdir = expandvars(rootdir)
files = []
for root, dirnames, filenames in os.walk( rootdir ):
if not