# Init and code

In [None]:
from datetime import datetime
import os
import re
import numpy as np

%matplotlib ipympl
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (9, 7)

from IPython.display import display, HTML
display(HTML("<style>.container { width:99% !important; }</style>"))

In [None]:
import lsst.daf.persistence as dafPersist
import lsst.afw.image as afwImage
import lsst.afw.display as afwDisplay

from pfs.drp.stella.utils import addPfsCursor, showDetectorMap

afwDisplay.setDefaultBackend("matplotlib")
afwDisplay.setDefaultMaskPlaneColor("BAD_FLAT", afwDisplay.CYAN)
afwDisplay.setDefaultMaskPlaneColor("REFLINE", afwDisplay.IGNORE)

In [None]:
def butlerName(_butler=None):
    if _butler is None:
        _butler = butler
    root = _butler._repos.inputs()[0]._repoArgs.root
    root = re.sub(r"^.*/rerun/", "", root)
    
    return root

# butlers

In [None]:
butlers = {}

repo = "/projects/HSC/PFS/Subaru" if os.path.exists("/projects/HSC/PFS/Subaru") else "/work/drp" 

reruns = [
    "rhl/eng-2024-08",
]
for rerun in reruns:
    kwargs = {}
    if rerun.startswith('/'):
        dataDir = rerun
    else:
        dataDir = os.path.join(repo, "rerun", rerun)

    if "tanaka/fiberProfiles/FIBERPROFILES" in rerun:
        kwargs.update(calibName="/work/drp/rerun/hassans/tanaka/fiberProfiles/FIBERPROFILES")

    if not os.path.exists(dataDir):
        continue

    print(dataDir)
    butlers[rerun] = dafPersist.Butler(dataDir, **kwargs)

if os.path.exists("/work/drp"):
    kwargs = {}
    repoRoot = "/work/drp"

    calibName = "CALIB"
    kwargs.update(calibRoot=os.path.join(repoRoot, calibName))
    rerun = f'drpActor/{calibName}'

    butlers["drp"] = dafPersist.Butler(os.path.join(repoRoot, 'rerun', rerun), **kwargs)

butler = butlers[list(butlers)[0]]   # default

butlerName()

# Read and process CR splits

## Read data
N.b. needn't reset `calexps` and `pfsArms` for new visits, but saves memory and makes rerunning pipeline easier

In [None]:
visits = [111488, 111489, 111490]

def dataIdToKey(did, **kwargs):
    did = did.copy()
    did.update(**kwargs)
    
    return (did["visit"], did["arm"], did["spectrograph"])

def skipDetector(dataId):
    """Return True to skip this dataId"""
    if False:
        return not (arm == 'r' and spectrograph == 2)
    elif True:
        return not (arm == 'r' and spectrograph == 2) or (arm == 'n' and spectrograph in [1, 4])
    else:      # process all detectors
        return False

try:
    detMaps
except NameError:
    detMaps = {}
    pfsConfigs = {}

postISRCCDs = {}
pfsArms = {}

for v in visits:
    for spectrograph in (1, 2, 3, 4):
        for arm in "brn":
            dataId = dict(visit=v, arm=arm, spectrograph=spectrograph)
            if skipDetector(dataId):
                continue
            kk = dataIdToKey(dataId)

            if kk not in detMaps: 
                detMaps[kk] = butler.get("detectorMap_used", dataId)  # only needed for display
                pfsConfigs[kk] = butler.get("pfsConfig", dataId)  # only needed for display

            postISRCCDs[kk] = butler.get('postISRCCD', dataId)
            pfsArms[kk] = None # butler.get("pfsArm", dataId)

## Estimate a "clean" version for each `(visit, arm, spectrograph)`

We'd like to use a median, but with only two visits use a minimum.  Could choose in code, but should think about
what to do for e.g. four or six visits

In [None]:
try:
    clean
except NameError:
    clean = {}
    fluxes = {}
        
dataId = {}
for spectrograph in (1, 2, 3, 4):
    dataId.update(spectrograph=spectrograph)
    for arm in "brn":
        dataId.update(arm=arm)
                    
        if skipDetector(dataId):
            continue

        kk0 = dataIdToKey(dataId, visit=visits[0])
        stack = np.empty((len(visits), postISRCCDs[kk0].getHeight(), postISRCCDs[kk0].getWidth()), dtype=np.float32)

        for i, v in enumerate(visits):
            dataId.update(visit=v)
            kk = dataIdToKey(dataId)

            fluxes[kk] = np.nanmedian(postISRCCDs[kk].image.array)

            stack[i] = postISRCCDs[kk].image.array

            stack[i] *= fluxes[kk0]/fluxes[kk]

        clean[kk0] = afwImage.MaskedImageF(postISRCCDs[kk0].getDimensions())
        clean[kk0].image.array = (np.nanmin if len(visits) < 5 else np.nanmedian)(stack, axis=0)

### Display the junk in the exposure, i.e. the postISRCCD with the clean estimate of the signal removed

This will have more than cosmic rays in it, as there are at least three components in the data (objects, OH, and O_2) and they vary independently depending on atmospheric chemistry and the seeing

In [None]:
v = visits[1]
dataId.update(visit=v, arm='r', spectrograph=2)

kk = dataIdToKey(dataId)
kk0 = dataIdToKey(dataId, visit=visits[0])

im = postISRCCDs[kk].clone()
im.image.array -= clean[kk0].image.array*fluxes[kk]/fluxes[kk0]

fig = 1; plt.close(fig); fig = plt.figure(fig)
disp = afwDisplay.Display(fig)
disp.scale('asinh', 'zscale', Q=3)
disp.scale('asinh', 0, 10, Q=5)

disp.mtv(im, title=f"{'%(visit)d %(arm)s%(spectrograph)d' % dataId}")
addPfsCursor(disp, detMaps[kk0]);

disp.zoom(128, 1670, 2228)

## Write cleaned calexps.  
Note that we haven't done the full repair -- e.g. removing scattered light

In [None]:
from lsst.pex.config import makePropertySet
from lsst.pipe.tasks.repair import RepairConfig
from lsst.meas.algorithms import findCosmicRays
from lsst.afw.detection import setMaskFromFootprintList

repConfig = RepairConfig()
config = makePropertySet(repConfig.cosmicray)
config["nCrPixelMax"] = 1000000

keepCRs = True

modelPsfConfig = repConfig.interp.modelPsf
modelPsfConfig.defaultFwhm = 3.5 + 0.5
psf = modelPsfConfig.apply()

for spectrograph in (1, 2, 3, 4):
    dataId.update(spectrograph=spectrograph)
    for arm in "brn":
        dataId.update(arm=arm)
        if skipDetector(dataId):
            continue
        kk0 = dataIdToKey(dataId, visit=visits[0])

        for i, visit in enumerate(visits):
            dataId.update(visit=visit)
            kk = dataIdToKey(dataId)

            im = postISRCCDs[kk].clone()
            im.setPsf(psf)
            im.image.array -= clean[kk0].image.array*fluxes[kk]/fluxes[kk0]

            # Find CRs in the "junk" image
            im.mask &= ~im.mask.getPlaneBitMask("CR")
            
            if False:
                from lsst.meas.algorithms import SourceDetectionTask
                detectionConfig = SourceDetectionTask.ConfigClass
                detectionConfig.thresholdValue = 5.0
                detectionConfig.includeThresholdMultiplier = 10.0
                # do not deblend, as it makes a mess
                detectionConfig.doDeblend = False

                sourceDetectionTask = SourceDetectionTask(detectionConfig)

                import lsst.afw.table as afwTable
                schema = afwTable.SourceTable.makeMinimalSchema()
                table = afwTable.SourceTable.make(schema) # , sourceIdFactory)

                sourceDetectionTask.run(table, im)

            cosmicrays = findCosmicRays(im.maskedImage, psf, 0.0, config, keepCRs)

            # Set the CR plane in the postISRCCD mask
            num = 0
            numPixels = 0
            if cosmicrays is not None:
                mask = postISRCCDs[kk].mask
                setMaskFromFootprintList(mask, cosmicrays, mask.getPlaneBitMask("CR"))

                num = len(cosmicrays)
                numPixels = np.sum(mask.array & mask.getPlaneBitMask("CR") != 0)

            print(f"{kk} nCR {num} nPixel {numPixels}")

        # Look for CRs which overlap in the input visits
        from lsst.ip.isr.isrFunctions import growMasks

        nGrow = 3
        tmpMask = np.empty((len(visits), postISRCCDs[kk0].getHeight(), postISRCCDs[kk0].getWidth()), dtype=np.int32)
        for i, visit in enumerate(visits):
            kk = dataIdToKey(dataId, visit=visit)
            
            mask = postISRCCDs[kk].mask.clone()
            mask &= mask.getPlaneBitMask("CR")
            growMasks(mask, nGrow, "CR", "CR")

            tmpMask[i] = mask.array

        # set an extra bit (EDGE) where (grown) CRs are present in both images
        inAllVisits = np.where(np.min(tmpMask, axis=0) > 0, mask.getPlaneBitMask("EDGE"), 0x0)

        for i, visit in enumerate(visits):
            kk = dataIdToKey(dataId, visit=visit)
            
            postISRCCDs[kk].mask.array |= np.where(inAllVisits, mask.getPlaneBitMask("CR"), 0x0)

        # write calexps
        if True:
            for visit in visits:
                dataId.update(visit=visit)
                butler.put(postISRCCDs[dataIdToKey(dataId)], 'calexp', dataId)

### Display one of those calexps

In [None]:
fig = 2; plt.close(fig); fig = plt.figure(fig)
disp = afwDisplay.Display(fig)
disp.scale('asinh', 'zscale', Q=8)
disp.scale('asinh', -20, 1*30, Q=1*3)

visit = visits[1]
dataId = dict(visit=visit, arm='r', spectrograph=2)
kk = dataIdToKey(dataId)
exp = postISRCCDs[kk] if False else butler.get("calexp", dataId)
if False:
    exp = exp.clone()
    exp.maskedImage -= butler.get("calexp", dataId, visit=visits[0]).maskedImage; exp.maskedImage *= -1
disp.mtv(exp, title=f"{'%(visit)d %(arm)s%(spectrograph)d' % dataId}")
addPfsCursor(disp, detMaps[kk]);

#disp.zoom(128//2, 2276, 1128)   # horizontal CR which we missed
#disp.zoom(2*128, 1684, 3064)
disp.zoom(128, 1670, 2228)