Source code for drizzlepac.tweakutils

"""
:Authors: Warren Hack

:License: :doc:`/LICENSE`

"""
import string
import os
import sys

import numpy as np
from scipy import signal, ndimage

from stsci.tools import asnutil, irafglob, parseinput, fileutil, logutil
from astropy.io import fits
import astropy.coordinates as coords
import astropy.units as u
from astropy.utils import deprecated

import stsci.imagestats as imagestats

from . import findobj
from . import cdriz

__all__ = [
    'parse_input', 'atfile_sci', 'parse_atfile_cat', 'ndfind',
    'get_configobj_root', 'isfloat', 'parse_skypos', 'make_val_float',
    'radec_hmstodd', 'parse_exclusions', 'parse_colname', 'readcols',
    'read_FITS_cols', 'read_ASCII_cols', 'write_shiftfile', 'createWcsHDU',
    'idlgauss_convolve', 'gauss_array', 'gauss', 'make_vector_plot',
    'apply_db_fit', 'write_xy_file', 'find_xy_peak', 'plot_zeropoint',
    'build_xy_zeropoint', 'build_pos_grid'
]

_ASCII_LETTERS = string.ascii_letters
_NASCII = len(string.ascii_letters)

log = logutil.create_logger(__name__, level=logutil.logging.NOTSET)


def _is_str_none(s):
    if s is None or s.strip().upper() in ['', 'NONE', 'INDEF']:
        return None
    return s


def parse_input(input, prodonly=False, sort_wildcards=True):
    catlist = None

    if not isinstance(input, list) and ('_asn' in input or '_asc' in input):
        # Input is an association table. Get the input files
        oldasndict = asnutil.readASNTable(input, prodonly=prodonly)
        filelist = [fileutil.buildRootname(fname) for fname in
                    oldasndict['order']]

    elif not isinstance(input, list) and input[0] == '@':
        # input is an @ file
        # Read the first line in order to determine whether
        # catalog files have been specified in a second column...
        with open(input[1:]) as f:
            line = f.readline()

        # Parse the @-file with irafglob to extract the input filename
        filelist = irafglob.irafglob(input, atfile=atfile_sci)
        print(line)

        # If there are additional columns for catalog files...
        if len(line.split()) > 1:
            # ...parse out the names of the catalog files as well
            catlist, catdict = parse_atfile_cat(input)

    elif isinstance(input, list):
        # input a python list
        filelist = []
        for fn in input:
            flist, output = parse_input(fn, prodonly=prodonly)
            # if wild-cards are given, sort for uniform usage:
            if fn.find('*') > -1 and sort_wildcards:
                flist.sort()
            filelist += flist

    else:
        # input is either a string or something unrecognizable,
        # so give it a try:
        filelist, output = parseinput.parseinput(input)
        # if wild-cards are given, sort for uniform usage:
        if input.find('*') > -1 and sort_wildcards:
            filelist.sort()

    return filelist, catlist


def atfile_sci(line):
    return '' if line is None or not line.strip() else line.split()[0]


[docs] def parse_atfile_cat(input): """ Return the list of catalog filenames specified as part of the input @-file """ with open(input[1:]) as f: catlist = [] catdict = {} for line in f.readlines(): if line[0] == '#' or not line.strip(): continue lspl = line.split() if len(lspl) > 1: catdict[lspl[0]] = lspl[1:] catlist.append(lspl[1:]) else: catdict[lspl[0]] = None catlist.append(None) return catlist, catdict
# functions to help work with configobj input def get_configobj_root(configobj): kwargs = {} for key in configobj: # Only copy in those entries which start with lower case letters # since sections are all upper-case for this task if key[0].islower(): kwargs[key] = configobj[key] return kwargs def ndfind(array, hmin, fwhm, skymode, sharplim=[0.2, 1.0], roundlim=[-1, 1], minpix=5, peakmin=None, peakmax=None, fluxmin=None, fluxmax=None, nsigma=1.5, ratio=1.0, theta=0.0, mask=None, use_sharp_round=False, nbright=None): star_list, fluxes = findobj.findstars( array, fwhm, hmin, skymode, peakmin=peakmin, peakmax=peakmax, fluxmin=fluxmin, fluxmax=fluxmax, ratio=ratio, nsigma=nsigma, theta=theta, use_sharp_round=use_sharp_round, mask=mask, sharplo=sharplim[0], sharphi=sharplim[1], roundlo=roundlim[0], roundhi=roundlim[1] ) if len(star_list) == 0: print('No valid sources found...') return tuple([[] for i in range(7 if use_sharp_round else 4)]) star_list = list(np.array(star_list).T) fluxes = np.array(fluxes, float) if nbright is not None: idx = np.argsort(fluxes)[::-1] fluxes = fluxes[idx] star_list = [s[idx] for s in star_list] if use_sharp_round: return (star_list[0], star_list[1], fluxes, np.arange(star_list[0].size), star_list[2], star_list[3], star_list[4]) else: return (star_list[0], star_list[1], fluxes, np.arange(star_list[0].size), None, None, None)
[docs] def isfloat(value): """ Return True if all characters are part of a floating point value """ try: float(value) return True except ValueError: return False
[docs] def parse_skypos(ra, dec): """ Function to parse RA and Dec input values and turn them into decimal degrees Input formats could be: ["nn","nn","nn.nn"] "nn nn nn.nnn" "nn:nn:nn.nn" "nnH nnM nn.nnS" or "nnD nnM nn.nnS" nn.nnnnnnnn "nn.nnnnnnn" """ rval = make_val_float(ra) dval = make_val_float(dec) if rval is None: rval, dval = radec_hmstodd(ra, dec) return rval, dval
def make_val_float(val): try: return float(val) except ValueError: return None
[docs] def radec_hmstodd(ra, dec): """ Function to convert HMS values into decimal degrees. This function relies on the astropy.coordinates package to perform the conversion to decimal degrees. Parameters ---------- ra : list or array List or array of input RA positions dec : list or array List or array of input Dec positions Returns ------- pos : arr Array of RA,Dec positions in decimal degrees Notes ----- This function supports any specification of RA and Dec as HMS or DMS; specifically, the formats:: ["nn","nn","nn.nn"] "nn nn nn.nnn" "nn:nn:nn.nn" "nnH nnM nn.nnS" or "nnD nnM nn.nnS" See Also -------- astropy.coordinates """ if sys.hexversion >= 196864: hmstrans = str.maketrans(_ASCII_LETTERS, _NASCII * ' ') else: hmstrans = string.maketrans(_ASCII_LETTERS, _NASCII * ' ') if isinstance(ra, list): rastr = ':'.join(ra) elif isinstance(ra, float): rastr = None pos_ra = ra elif ra.find(':') < 0: # convert any non-numeric characters to spaces # (we already know the units) rastr = ra.translate(hmstrans).strip() rastr = rastr.replace(' ', ' ') # convert 'nn nn nn.nn' to final 'nn:nn:nn.nn' string rastr = rastr.replace(' ', ':') else: rastr = ra if isinstance(dec, list): decstr = ':'.join(dec) elif isinstance(dec, float): decstr = None pos_dec = dec elif dec.find(':') < 0: decstr = dec.translate(hmstrans).strip() decstr = decstr.replace(' ', ' ') decstr = decstr.replace(' ', ':') else: decstr = dec if rastr is None: pos = (pos_ra, pos_dec) else: pos_coord = coords.SkyCoord(rastr + ' ' + decstr, unit=(u.hourangle, u.deg)) pos = (pos_coord.ra.deg, pos_coord.dec.deg) return pos
[docs] def parse_exclusions(exclusions): """ Read in exclusion definitions from file named by 'exclusions' and return a list of positions and distances """ fname = fileutil.osfn(exclusions) if os.path.exists(fname): with open(fname) as f: flines = f.readlines() else: print('No valid exclusions file "', fname, '" could be found!') print('Skipping application of exclusions files to source catalogs.') return None # Parse out lines which can be interpreted as positions and distances exclusion_list = [] units = None for line in flines: if line[0] == '#' or 'global' in line[:6]: continue # Only interpret the part of the line prior to the comment # if a comment has been attached to the line if '#' in line: line = line.split('#')[0].rstrip() if units is None: units = 'pixels' if line[:3] in ['fk4', 'fk5', 'sky']: units = 'sky' if line[:5] in ['image', 'physi', 'pixel']: units = 'pixels' continue if 'circle(' in line: nline = line.replace('circle(', '') nline = nline.replace(')', '') nline = nline.replace('"', '') vals = nline.split(',') if ':' in vals[0]: posval = vals[0] + ' ' + vals[1] else: posval = (float(vals[0]), float(vals[1])) else: # Try to interpret unformatted line if ',' in line: split_tok = ',' else: split_tok = ' ' vals = line.split(split_tok) if len(vals) == 3: if ':' in vals[0]: posval = vals[0] + ' ' + vals[1] else: posval = (float(vals[0]), float(vals[1])) else: continue exclusion_list.append( {'pos': posval, 'distance': float(vals[2]), 'units': units} ) return exclusion_list
[docs] def parse_colname(colname): """ Common function to interpret input column names provided by the user. This function translates column specification provided by the user into a column number. Notes ----- This function will understand the following inputs:: '1,2,3' or 'c1,c2,c3' or ['c1','c2','c3'] '1-3' or 'c1-c3' '1:3' or 'c1:c3' '1 2 3' or 'c1 c2 c3' '1' or 'c1' 1 Parameters ---------- colname : Column name or names to be interpreted Returns ------- cols : list The return value will be a list of strings. """ if isinstance(colname, list): cname = '' for c in colname: cname += str(c) + ',' cname = cname.rstrip(',') elif isinstance(colname, int) or colname.isdigit(): cname = str(colname) else: cname = colname if 'c' in cname[0]: cname = cname.replace('c', '') ctok = None cols = None if '-' in cname: ctok = '-' if ':' in cname: ctok = ':' if ctok is not None: cnums = cname.split(ctok) c = list(range(int(cnums[0]), int(cnums[1]) + 1)) cols = [str(i) for i in c] if cols is None: ctok = ',' if ',' in cname else ' ' cols = cname.split(ctok) return cols
[docs] def readcols(infile, cols=None): """ Function which reads specified columns from either FITS tables or ASCII files This function reads in the columns specified by the user into numpy arrays regardless of the format of the input table (ASCII or FITS table). Parameters ---------- infile : string Filename of the input file cols : string or list of strings Columns to be read into arrays Returns ------- outarr : array Numpy array or arrays of columns from the table """ if _is_str_none(infile) is None: return None if infile.endswith('.fits'): outarr = read_FITS_cols(infile, cols=cols) else: outarr = read_ASCII_cols(infile, cols=cols) return outarr
[docs] def read_FITS_cols(infile, cols=None): # noqa: N802 """ Read columns from FITS table """ with fits.open(infile, memmap=False) as ftab: extnum = 0 extfound = False for extn in ftab: if 'tfields' in extn.header: extfound = True break extnum += 1 if not extfound: print('ERROR: No catalog table found in ', infile) raise ValueError # Now, read columns from the table in this extension if no column names # were provided by user, simply read in all columns from table if _is_str_none(cols[0]) is None: cols = ftab[extnum].data.names # Define the output outarr = [ftab[extnum].data.field(c) for c in cols] return outarr
[docs] def read_ASCII_cols(infile, cols=[1, 2, 3]): # noqa: N802 """ Interpret input ASCII file to return arrays for specified columns. Notes ----- The specification of the columns should be expected to have lists for each 'column', with all columns in each list combined into a single entry. For example:: cols = ['1,2,3','4,5,6',7] where '1,2,3' represent the X/RA values, '4,5,6' represent the Y/Dec values and 7 represents the flux value for a total of 3 requested columns of data to be returned. Returns ------- outarr : list of arrays The return value will be a list of numpy arrays, one for each 'column'. """ # build dictionary representing format of each row # Format of dictionary: {'colname':col_number,...} # This provides the mapping between column name and column number coldict = {} with open(infile, 'r') as f: flines = f.readlines() for l in flines: # interpret each line from catalog file if l[0].lstrip() == '#' or l.lstrip() == '': continue else: # convert first row of data into column definitions using indices coldict = {str(i + 1): i for i, _ in enumerate(l.split())} break numcols = len(cols) outarr = [[] for _ in range(numcols)] convert_radec = False # Now, map specified columns to columns in file and populate output arrays for l in flines: # interpret each line from catalog file l = l.strip() lspl = l.split() # skip blank lines, comment lines, or lines with # fewer columns than requested by user if not l or len(lspl) < numcols or l[0] == '#' or "INDEF" in l: continue # For each 'column' requested by user, pull data from row for c, i in zip(cols, list(range(numcols))): cnames = parse_colname(c) if len(cnames) > 1: # interpret multi-column specification as one value outval = '' for cn in cnames: cnum = coldict[cn] cval = lspl[cnum] outval += cval + ' ' outarr[i].append(outval) convert_radec = True else: # pull single value from row for this column cnum = coldict[cnames[0]] if isfloat(lspl[cnum]): cval = float(lspl[cnum]) else: cval = lspl[cnum] # Check for multi-column values given as "nn:nn:nn.s" if ':' in cval: cval = cval.replace(':', ' ') convert_radec = True outarr[i].append(cval) # convert multi-column RA/Dec specifications if convert_radec: outra = [] outdec = [] for ra, dec in zip(outarr[0], outarr[1]): radd, decdd = radec_hmstodd(ra, dec) outra.append(radd) outdec.append(decdd) outarr[0] = outra outarr[1] = outdec # convert all lists to numpy arrays for c in range(len(outarr)): outarr[c] = np.array(outarr[c]) return outarr
[docs] def write_shiftfile(image_list, filename, outwcs='tweak_wcs.fits'): """ Write out a shiftfile for a given list of input Image class objects """ rows = '' nrows = 0 for img in image_list: row = img.get_shiftfile_row() if row is not None: rows += row nrows += 1 if nrows == 0: # If there are no fits to report, do not write out a file return # write out reference WCS now if os.path.exists(outwcs): os.remove(outwcs) p = fits.HDUList() p.append(fits.PrimaryHDU()) p.append(createWcsHDU(image_list[0].refWCS)) p.writeto(outwcs) # Write out shiftfile to go with reference WCS with open(filename, 'w') as f: f.write('# frame: output\n') f.write('# refimage: %s[wcs]\n' % outwcs) f.write('# form: delta\n') f.write('# units: pixels\n') f.write(rows) print('Writing out shiftfile :', filename)
[docs] def createWcsHDU(wcs): # noqa: N802 """ Generate a WCS header object that can be used to populate a reference WCS HDU. For most applications, stwcs.wcsutil.HSTWCS.wcs2header() will work just as well. """ header = wcs.to_header() header['EXTNAME'] = 'WCS' header['EXTVER'] = 1 # Now, update original image size information header['NPIX1'] = (wcs.pixel_shape[0], "Length of array axis 1") header['NPIX2'] = (wcs.pixel_shape[1], "Length of array axis 2") header['PIXVALUE'] = (0.0, "values of pixels in array") if hasattr(wcs, 'orientat'): orientat = wcs.orientat else: # find orientat from CD or PC matrix if wcs.wcs.has_cd(): cd12 = wcs.wcs.cd[0][1] cd22 = wcs.wcs.cd[1][1] elif wcs.wcs.has_pc(): cd12 = wcs.wcs.cdelt[0] * wcs.wcs.pc[0][1] cd22 = wcs.wcs.cdelt[1] * wcs.wcs.pc[1][1] else: raise ValueError("Invalid WCS: WCS does not contain neither " "a CD nor a PC matrix.") orientat = np.rad2deg(np.arctan2(cd12, cd22)) header['ORIENTAT'] = (orientat, "position angle of " "image y axis (deg. e of n)") return fits.ImageHDU(None, header)
# # Code used for testing source finding algorithms #
[docs] @deprecated(since='3.0.0', name='idlgauss_convolve', warning_type=Warning) def idlgauss_convolve(image, fwhm): sigmatofwhm = 2 * np.sqrt(2 * np.log(2)) radius = 1.5 * fwhm / sigmatofwhm # Radius is 1.5 sigma if radius < 1.0: radius = 1.0 fwhm = sigmatofwhm / 1.5 print("WARNING!!! Radius of convolution box smaller than one.") print("Setting the 'fwhm' to minimum value, %f." % fwhm) sigsq = (fwhm / sigmatofwhm)**2 # sigma squared nhalf = int(radius) # Center of the kernel nbox = 2 * nhalf + 1 # Number of pixels inside of convolution box # x,y coordinates of the kernel: kern_y, kern_x = np.ix_(np.arange(nbox), np.arange(nbox)) # Compute the square of the distance to the center: g = (kern_x - nhalf)**2 + (kern_y - nhalf)**2 # We make a mask to select the inner circle of radius "radius": mask = g <= radius**2 # The number of pixels in the mask within the inner circle: nmask = mask.sum() g = np.exp(-0.5 * g / sigsq) # We make the 2D gaussian profile # Convolving the image with a kernel representing a gaussian # (which is assumed to be the psf). # For the kernel, values further than "radius" are equal to zero c = g * mask # We normalize the gaussian kernel c[mask] = (c[mask] - c[mask].mean()) / (c[mask].var() * nmask) # c1 will be used to the test the roundness c1 = g[nhalf] c1 = (c1 - c1.mean()) / ((c1**2).sum() - c1.mean()) # Convolve image with kernel "c": h = signal.convolve2d(image, c, boundary='fill', mode='same', fillvalue=0) h[:nhalf, :] = 0 # Set the sides to zero in order to avoid border effects h[-nhalf:, :] = 0 h[:, :nhalf] = 0 h[:, -nhalf:] = 0 return h, c1
[docs] def gauss_array(nx, ny=None, fwhm=1.0, sigma_x=None, sigma_y=None, zero_norm=False): """ Computes the 2D Gaussian with size nx*ny. Parameters ---------- nx : int ny : int [Default: None] Size of output array for the generated Gaussian. If ny == None, output will be an array nx X nx pixels. fwhm : float [Default: 1.0] Full-width, half-maximum of the Gaussian to be generated sigma_x : float [Default: None] sigma_y : float [Default: None] Sigma_x and sigma_y are the stddev of the Gaussian functions. zero_norm : bool [Default: False] The kernel will be normalized to a sum of 1 when True. Returns ------- gauss_arr : array A numpy array with the generated gaussian function """ if ny is None: ny = nx if sigma_x is None: if fwhm is None: print('A value for either "fwhm" or "sigma_x" needs to be ' 'specified!') raise ValueError else: # Convert input FWHM into sigma sigma_x = fwhm / (2 * np.sqrt(2 * np.log(2))) if sigma_y is None: sigma_y = sigma_x xradius = nx // 2 yradius = ny // 2 # Create grids of distance from center in X and Y xarr = np.abs(np.arange(-xradius, xradius + 1)) yarr = np.abs(np.arange(-yradius, yradius + 1)) hnx = gauss(xarr, sigma_x) hny = gauss(yarr, sigma_y) hny = hny.reshape((ny, 1)) h = hnx * hny # Normalize gaussian kernel to a sum of 1 h = h / np.abs(h).sum() if zero_norm: h -= h.mean() return h
[docs] def gauss(x, sigma): """ Compute 1-D value of gaussian at position x relative to center.""" return (np.exp(-np.power(x, 2) / (2 * np.power(sigma, 2))) / (sigma * np.sqrt(2 * np.pi)))
# Plotting Utilities for drizzlepac
[docs] def make_vector_plot(coordfile, columns=[1, 2, 3, 4], data=None, figure_id=None, title=None, axes=None, every=1, labelsize=8, ylimit=None, limit=None, xlower=None, ylower=None, output=None, headl=4, headw=3, xsh=0.0, ysh=0.0, fit=None, scale=1.0, vector=True, textscale=5, append=False, linfit=False, rms=True, plotname=None): """ Convert a XYXYMATCH file into a vector plot or set of residuals plots. This function provides a single interface for generating either a vector plot of residuals or a set of 4 plots showing residuals. The data being plotted can also be adjusted for a linear fit on-the-fly. Parameters ---------- coordfile : string Name of file with matched sets of coordinates. This input file can be a file compatible for use with IRAF's geomap. columns : list [Default: [0,1,2,3]] Column numbers for the X,Y positions from each image data : list of arrays If specified, this can be used to input matched data directly title : string Title to be used for the generated plot axes : list List of X and Y min/max values to customize the plot axes every : int [Default: 1] Slice value for the data to be plotted limit : float Radial offset limit for selecting which sources are included in the plot labelsize : int [Default: 8] or str Font size to use for tick labels, either in font points or as a string understood by tick_params(). ylimit : float Limit to use for Y range of plots. xlower : float ylower : float Limit in X and/or Y offset for selecting which sources are included in the plot output : string Filename of output file for generated plot headl : int [Default: 4] Length of arrow head to be used in vector plot headw : int [Default: 3] Width of arrow head to be used in vector plot xsh : float ysh : float Shift in X and Y from linear fit to be applied to source positions from the first image scale : float Scale from linear fit to be applied to source positions from the first image fit : array Array of linear coefficients for rotation (and scale?) in X and Y from a linear fit to be applied to source positions from the first image vector : bool [Default: True] Specifies whether or not to generate a vector plot. If False, task will generate a set of 4 residuals plots instead textscale : int [Default: 5] Scale factor for text used for labelling the generated plot append : bool [Default: False] If True, will overplot new plot on any pre-existing plot linfit : bool [Default: False] If True, a linear fit to the residuals will be generated and added to the generated residuals plots rms : bool [Default: True] Specifies whether or not to report the RMS of the residuals as a label on the generated plot(s). plotname : str [Default: None] Write out plot to a file with this name if specified. """ from matplotlib import pyplot as plt if data is None: data = readcols(coordfile, cols=columns) xy1x = data[0] xy1y = data[1] xy2x = data[2] xy2y = data[3] numpts = xy1x.shape[0] if fit is not None: xy1x, xy1y = apply_db_fit(data, fit, xsh=xsh, ysh=ysh) dx = xy2x - xy1x dy = xy2y - xy1y else: dx = xy2x - xy1x - xsh dy = xy2y - xy1y - ysh # apply scaling factor to deltas dx *= scale dy *= scale print('Total # points: {:d}'.format(len(dx))) if limit is not None: indx = np.sqrt(dx**2 + dy**2) <= limit dx = dx[indx].copy() dy = dy[indx].copy() xy1x = xy1x[indx].copy() xy1y = xy1y[indx].copy() if xlower is not None: xindx = np.abs(dx) >= xlower dx = dx[xindx].copy() dy = dy[xindx].copy() xy1x = xy1x[xindx].copy() xy1y = xy1y[xindx].copy() print('# of points after clipping: {:d}'.format(len(dx))) dr = np.sqrt(dx**2 + dy**2) max_vector = dr.max() if output is not None: write_xy_file(output, [xy1x, xy1y, dx, dy]) fig = plt.figure(num=figure_id) if not append: plt.clf() if vector: dxs = imagestats.ImageStats(dx.astype(np.float32)) dys = imagestats.ImageStats(dy.astype(np.float32)) minx = xy1x.min() maxx = xy1x.max() miny = xy1y.min() maxy = xy1y.max() plt_xrange = maxx - minx plt_yrange = maxy - miny qplot = plt.quiver(xy1x[::every], xy1y[::every], dx[::every], dy[::every], units='y', headwidth=headw, headlength=headl) key_dx = 0.01 * plt_xrange key_dy = 0.005 * plt_yrange * textscale maxvec = max_vector / 2. key_len = round(maxvec + 0.005, 2) plt.xlabel('DX: %.4f to %.4f +/- %.4f' % (dxs.min, dxs.max, dxs.stddev)) plt.ylabel('DY: %.4f to %.4f +/- %.4f' % (dys.min, dys.max, dys.stddev)) plt.title(r"$Vector\ plot\ of\ %d/%d\ residuals:\ %s$" % (xy1x.shape[0], numpts, title)) plt.quiverkey(qplot, minx + key_dx, miny - key_dy, key_len, "%0.2f pixels" % (key_len), coordinates='data', labelpos='E', labelcolor='Maroon', color='Maroon') else: plot_defs = [[xy1x, dx, "X (pixels)", "DX (pixels)"], [xy1y, dx, "Y (pixels)", "DX (pixels)"], [xy1x, dy, "X (pixels)", "DY (pixels)"], [xy1y, dy, "Y (pixels)", "DY (pixels)"]] if axes is None: # Compute a global set of axis limits for all plots minx = min(xy1x.min(), xy1y.min()) maxx = max(xy1x.max(), xy1y.max()) miny = min(dx.min(), dy.min()) maxy = max(dx.max(), dy.max()) else: minx = axes[0][0] maxx = axes[0][1] miny = axes[1][0] maxy = axes[1][1] if ylimit is not None: miny = -ylimit maxy = ylimit rms_labelled = False if title is None: fig.suptitle("Residuals [%d/%d]" % (xy1x.shape[0], numpts), ha='center', fontsize=labelsize + 6) else: # This definition of the title supports math symbols in the title fig.suptitle(r"$" + title + "$", ha='center', fontsize=labelsize + 6) for pnum, p in enumerate(plot_defs): pn = pnum + 1 ax = fig.add_subplot(2, 2, pn) plt.plot( p[0], p[1], 'b.', label='RMS(X) = %.4f, RMS(Y) = %.4f' % (dx.std(), dy.std()) ) lx = [int((p[0].min() - 500) / 500) * 500, int((p[0].max() + 500) / 500) * 500] plt.plot(lx, [0.0, 0.0], 'k', linewidth=3) plt.axis([minx, maxx, miny, maxy]) if rms and not rms_labelled: leg_handles, leg_labels = ax.get_legend_handles_labels() fig.legend(leg_handles, leg_labels, loc='center left', fontsize='small', frameon=False, bbox_to_anchor=(0.33, 0.51), borderaxespad=0) rms_labelled = True ax.tick_params(labelsize=labelsize) # Fine-tune figure; hide x ticks for top plots and y ticks for # right plots if pn <= 2: plt.setp(ax.get_xticklabels(), visible=False) else: ax.set_xlabel(plot_defs[pnum][2]) if pn % 2 == 0: plt.setp(ax.get_yticklabels(), visible=False) else: ax.set_ylabel(plot_defs[pnum][3]) if linfit: lxr = int((lx[-1] - lx[0]) / 100) lyr = int((p[1].max() - p[1].min()) / 100) a = np.vstack([p[0], np.ones(len(p[0]))]).T m, c = np.linalg.lstsq(a, p[1])[0] yr = [m * lx[0] + c, lx[-1] * m + c] plt.plot([lx[0], lx[-1]], yr, 'r') plt.text( lx[0] + lxr, p[1].max() + lyr, "%0.5g*x + %0.5g [%0.5g,%0.5g]" % (m, c, yr[0], yr[1]), color='r' ) plt.draw() if plotname: suffix = plotname[-4:] if '.' not in suffix: output += '.png' format = 'png' else: if suffix[1:] in ['png', 'pdf', 'ps', 'eps', 'svg']: format = suffix[1:] plt.savefig(plotname, format=format)
def apply_db_fit(data, fit, xsh=0.0, ysh=0.0): xy1x = data[0] xy1y = data[1] if fit is not None: xy1 = np.zeros((xy1x.shape[0], 2), np.float64) xy1[:, 0] = xy1x xy1[:, 1] = xy1y xy1 = np.dot(xy1, fit) xy1x = xy1[:, 0] + xsh xy1y = xy1[:, 1] + ysh return xy1x, xy1y def write_xy_file(outname, xydata, append=False, format=["%20.6f"]): if not isinstance(xydata, list): xydata = list(xydata) if not append: if os.path.exists(outname): os.remove(outname) with open(outname, 'a+') as f: for row in range(len(xydata[0][0])): outstr = "" for cols, fmts in zip(xydata, format): for col in range(len(cols)): outstr += fmts % (cols[col][row]) f.write(outstr + "\n") print('wrote XY data to: ', outname)
[docs] @deprecated(since='3.0.0', name='find_xy_peak', warning_type=Warning) def find_xy_peak(img, center=None, sigma=3.0): """ Find the center of the peak of offsets """ # find level of noise in histogram istats = imagestats.ImageStats(img.astype(np.float32), nclip=1, fields='stddev,mode,mean,max,min') if istats.stddev == 0.0: istats = imagestats.ImageStats(img.astype(np.float32), fields='stddev,mode,mean,max,min') imgsum = img.sum() # clip out all values below mean+3*sigma from histogram imgc = img[:, :].copy() imgc[imgc < istats.mode + istats.stddev * sigma] = 0.0 # identify position of peak yp0, xp0 = np.where(imgc == imgc.max()) # Perform bounds checking on slice from img ymin = max(0, int(yp0[0]) - 3) ymax = min(img.shape[0], int(yp0[0]) + 4) xmin = max(0, int(xp0[0]) - 3) xmax = min(img.shape[1], int(xp0[0]) + 4) # take sum of at most a 7x7 pixel box around peak xp_slice = (slice(ymin, ymax), slice(xmin, xmax)) yp, xp = ndimage.measurements.center_of_mass(img[xp_slice]) if np.isnan(xp) or np.isnan(yp): xp = 0.0 yp = 0.0 flux = 0.0 zpqual = None else: xp += xp_slice[1].start yp += xp_slice[0].start # compute S/N criteria for this peak: flux/sqrt(mean of rest of array) flux = imgc[xp_slice].sum() delta_size = float(img.size - imgc[xp_slice].size) if delta_size == 0: delta_size = 1 delta_flux = float(imgsum - flux) if flux > imgc[xp_slice].max(): delta_flux = flux - imgc[xp_slice].max() else: delta_flux = flux zpqual = flux / np.sqrt(delta_flux / delta_size) if np.isnan(zpqual) or np.isinf(zpqual): zpqual = None if center is not None: xp -= center[0] yp -= center[1] flux = imgc[xp_slice].max() del imgc return xp, yp, flux, zpqual
[docs] def plot_zeropoint(pars): """ Plot 2d histogram. Pars will be a dictionary containing: data, figure_id, vmax, title_str, xp,yp, searchrad """ from matplotlib import pyplot as plt xp = pars['xp'] yp = pars['yp'] searchrad = int(pars['searchrad'] + 0.5) plt.figure(num=pars['figure_id']) plt.clf() if pars['interactive']: plt.ion() else: plt.ioff() plt.imshow(pars['data'], vmin=0, vmax=pars['vmax'], interpolation='nearest') plt.viridis() plt.colorbar() plt.title(pars['title_str']) plt.plot(xp + searchrad, yp + searchrad, color='red', marker='+', markersize=24) plt.plot(searchrad, searchrad, color='yellow', marker='+', markersize=120) plt.text(searchrad, searchrad, "Offset=0,0", verticalalignment='bottom', color='yellow') plt.xlabel("Offset in X (pixels)") plt.ylabel("Offset in Y (pixels)") if pars['interactive']: plt.show() if pars['plotname']: suffix = pars['plotname'][-4:] output = pars['plotname'] if '.' not in suffix: output += '.png' format = 'png' else: if suffix[1:] in ['png', 'pdf', 'ps', 'eps', 'svg']: format = suffix[1:] plt.savefig(output, format=format)
[docs] @deprecated(since='3.0.0', name='build_xy_zeropoint', warning_type=Warning) def build_xy_zeropoint(imgxy, refxy, searchrad=3.0, histplot=False, figure_id=1, plotname=None, interactive=True): """ Create a matrix which contains the delta between each XY position and each UV position. """ print('Computing initial guess for X and Y shifts...') # run C function to create ZP matrix zpmat = cdriz.arrxyzero(imgxy.astype(np.float32), refxy.astype(np.float32), searchrad) xp, yp, flux, zpqual = find_xy_peak(zpmat, center=(searchrad, searchrad)) if zpqual is not None: print('Found initial X and Y shifts of ', xp, yp) print(' with significance of ', zpqual, 'and ', flux, ' matches') else: # try with a lower sigma to detect a peak in a sparse set of sources xp, yp, flux, zpqual = find_xy_peak( zpmat, center=(searchrad, searchrad), sigma=1.0 ) if zpqual: print('Found initial X and Y shifts of ', xp, yp) print(' with significance of ', zpqual, 'and ', flux, ' matches') else: print('!' * 80) print('!') print('! WARNING: No valid shift found within a search radius of ', searchrad, ' pixels.') print('!') print('!' * 80) if histplot: zpstd = flux // 5 if zpstd < 10: zpstd = 10 if zpqual is None: zpstd = 10 title_str = ("Histogram of offsets: Peak has %d matches at " "(%0.4g, %0.4g)" % (flux, xp, yp)) plot_pars = {'data': zpmat, 'figure_id': figure_id, 'vmax': zpstd, 'xp': xp, 'yp': yp, 'searchrad': searchrad, 'title_str': title_str, 'plotname': plotname, 'interactive': interactive} plot_zeropoint(plot_pars) return xp, yp, flux, zpqual
[docs] @deprecated(since='3.0.0', name='build_pos_grid', warning_type=Warning) def build_pos_grid(start, end, nstep, mesh=False): """ Return a grid of positions starting at X,Y given by 'start', and ending at X,Y given by 'end'. The grid will be completely filled in X and Y by every 'step' interval. """ # Build X and Y arrays dx = end[0] - start[0] if dx < 0: nstart = end end = start start = nstart dx = -dx stepx = dx / nstep # Perform linear fit to find exact line that connects start and end xarr = np.arange(start[0], end[0] + stepx / 2.0, stepx) yarr = np.interp(xarr, [start[0], end[0]], [start[1], end[1]]) # create grid of positions if mesh: xa, ya = np.meshgrid(xarr, yarr) xarr = xa.ravel() yarr = ya.ravel() return xarr, yarr