{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# imports\n", "import pathlib\n", "import math\n", "from PIL import Image, ImageOps\n", "import numpy as np\n", "import scipy\n", "from scipy.fft import fft2, ifft2, fftshift, next_fast_len\n", "from scipy.ndimage import gaussian_filter\n", "from scipy.signal import correlate2d\n", "from scipy.interpolate import griddata\n", "from skimage.transform import radon\n", "from skimage.feature import match_template\n", "import cv2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pth = pathlib.Path(\"/Users/margareteminizer/Desktop/coda/registration_testing_python\")\n", "IHC = False # whether there are IHC images\n", "zc = None # center of the z-stack, automatically calculated if None\n", "regE = {\n", " \"szE\": 251, # size of registration tiles\n", " \"bfE\": 200, # size of buffer on registration tiles\n", " \"diE\": 100, # distance between tiles\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# get list of images\n", "imlist = sorted(list(pth.glob(\"*.tif\")), key=lambda x: x.name)\n", "tp = imlist[0].suffix" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# calculate center image and order\n", "zc = len(imlist) // 2\n", "rf = list(range(zc, 0, -1)) + list(range(zc, len(imlist)-1)) + [0]\n", "mv = list(range(zc-1, -1, -1)) + list(range(zc+1, len(imlist)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# find max size of images in list\n", "szz = [0, 0]\n", "for i in range(len(imlist)):\n", " with Image.open(imlist[i]) as im:\n", " width, height = im.size\n", " szz[0] = max(szz[0], height)\n", " szz[1] = max(szz[1], width)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# global registration settings\n", "padall = 250 # padding around all images\n", "rsc = 6\n", "iternum = 5 # max iterations of registration calculation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define outputs\n", "outpthG = pth / \"registered\"\n", "outpthE = outpthG / \"elastic registration\"\n", "outpthE2 = outpthE / \"check\"\n", "matpth = outpthE / \"save_warps\"\n", "matpthD = matpth / \"D\"\n", "for p in (outpthG, outpthE, outpthE2, matpth, matpthD):\n", " p.mkdir(exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_ims(pth, nm, tp):\n", " im = np.array(Image.open(pth / f\"{nm}{tp}\"))\n", " TA = np.array(Image.open(pth / \"TA\" / f\"{nm}{tp}\"))\n", " TA = (TA > 0).astype(np.uint8)\n", " return im, TA" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def pad_im_both2(im, sz, ext, fillval):\n", " szim = np.array([sz[0]-im.shape[0], sz[1]-im.shape[1]])\n", " szA = szim // 2\n", " szB = szim - szA+ext\n", " szA = szA+ext\n", " if len(im.shape) == 2:\n", " padded_im = np.pad(\n", " im, \n", " ((szA[0], szB[0]), (szA[1], szB[1])), \n", " mode=\"constant\", \n", " constant_values=fillval,\n", " )\n", " elif len(im.shape) == 3:\n", " padded_channels = [\n", " np.pad(\n", " im[..., i], \n", " ((szA[0], szB[0]), (szA[1], szB[1])), \n", " mode=\"constant\", \n", " constant_values=fillval[i],\n", " )\n", " for i in range(im.shape[-1])\n", " ]\n", " padded_im = np.stack(padded_channels, axis=-1)\n", " else:\n", " raise ValueError(\"Image must be 2D or 3D\")\n", " return padded_im\n", "\n", "def preprocessing(im, TA, szz, padall):\n", " # pad image\n", " fillvals, _ = scipy.stats.mode(im, axis=(1, 0))\n", " im = pad_im_both2(im, szz, padall, fillvals)\n", " if TA.shape!=im.shape[:-1]:\n", " TA = pad_im_both2(TA, szz, padall, 0)\n", " # remove noise and complement images\n", " TA = TA > 0\n", " im = im.astype(np.uint8)\n", " impg = np.copy(im)\n", " impg[TA==0] = 255\n", " impg = 255-np.array(ImageOps.grayscale(Image.fromarray(impg)))\n", " # apply gaussian filter\n", " impg = gaussian_filter(impg, sigma=2)\n", " return im, impg, TA, fillvals\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# set up center image\n", "nm = imlist[zc-1].stem\n", "imzc, TAzc = get_ims(pth, nm, tp)\n", "imzc, imzcg, TAzc, fillvals = preprocessing(imzc, TAzc, szz, padall)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img = imzcg\n", "TA = TAzc\n", "img0 = imzcg\n", "TA0 = TAzc\n", "krf0 = zc\n", "img00 = imzcg\n", "TA00 = TAzc\n", "krf00 = zc" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def bpassW(arr, lnoise, lobject):\n", " b = 1.0 * lnoise\n", " w = lobject\n", " N = 2 * w + 1\n", " arrg = cv2.GaussianBlur(\n", " arr, (N, N), b * math.sqrt(2), borderType=cv2.BORDER_REFLECT\n", " )\n", " ha = np.ones((N, N), np.float32) / (N**2)\n", " arra = cv2.filter2D(arr, -1, ha, borderType=cv2.BORDER_REFLECT)\n", " return np.maximum(arrg - arra, 0)\n", "\n", "\n", "def xcorrf2(a, b=None, pad=True):\n", " \"\"\"\n", " Two-dimensional cross-correlation using Fourier transforms.\n", " XCORRF2(A, B) computes the cross-correlation of matrices A and B.\n", " XCORRF2(A) is the autocorrelation function.\n", " \"\"\"\n", " if b is None:\n", " b = a\n", " ma, na = a.shape\n", " mb, nb = b.shape\n", " # Make reverse conjugate of one array\n", " b = np.conj(b[::-1, ::-1])\n", " if pad:\n", " # Use power of 2 transform lengths\n", " mf = next_fast_len(ma + mb)\n", " nf = next_fast_len(na + nb)\n", " at = fft2(a, s=(mf, nf))\n", " bt = fft2(b, s=(mf, nf))\n", " else:\n", " at = fft2(a)\n", " bt = fft2(b)\n", " # Multiply transforms then inverse transform\n", " c = ifft2(at * bt)\n", " # Make real output for real input\n", " if np.isrealobj(a) and np.isrealobj(b):\n", " c = np.real(c)\n", " # Trim to standard size\n", " if pad:\n", " return c[: ma + mb - 1, : na + nb - 1]\n", " return fftshift(c)[: ma + mb - 1, : na + nb - 1]\n", "\n", "\n", "def mskcircle2_rect(sz, da=None):\n", " \"\"\"\n", " Generate the circle mask based on size of input matrix.\n", " \"\"\"\n", " if da is None:\n", " da = min(sz)\n", " msk = np.zeros(sz, dtype=np.uint8)\n", " center = (sz[1] // 2, sz[0] // 2)\n", " radius = (da - 1) // 2\n", " # Draw the circle\n", " cv2.circle(msk, center, radius, 1, thickness=-1)\n", " return msk\n", "\n", "\n", "def gcnt(im, mx, sz, th=0, rsq=0):\n", " \"\"\"\n", " Calculates the centroid of bright spots to sub-pixel accuracy.\n", " \"\"\"\n", " nr, nc = im.shape\n", " mx = np.array(mx)\n", " # Remove all potential locations within distance sz from edges of image\n", " valid_indices = (\n", " (mx[:, 1] > 1.5 * sz / 2)\n", " & (mx[:, 1] < nr - 1.5 * sz / 2)\n", " & (mx[:, 0] > 1.5 * sz / 2)\n", " & (mx[:, 0] < nc - 1.5 * sz / 2)\n", " )\n", " mx = mx[valid_indices]\n", " nmx = mx.shape[0]\n", " # Inside of the window, assign an x and y coordinate for each pixel\n", " dimm = sz\n", " x1, y1 = np.meshgrid(np.arange(dimm), np.arange(dimm))\n", " x1 = x1 - (dimm + 1) / 2\n", " y1 = y1 - (dimm + 1) / 2\n", " kk = (dimm - 1) // 2\n", " # Generate the mask matrix\n", " msk = np.ones((sz, sz))\n", " AAt = np.column_stack(\n", " (x1.ravel() ** 2 + y1.ravel() ** 2, x1.ravel(), y1.ravel(), np.ones(x1.size))\n", " )\n", " pts = []\n", " for i in range(nmx):\n", " # Create a small working array around each candidate location, and apply the window function\n", " tmp = (\n", " msk\n", " * im[mx[i, 1] - kk : mx[i, 1] + kk + 1, mx[i, 0] - kk : mx[i, 0] + kk + 1]\n", " )\n", " tmp = np.maximum(tmp - th, 1)\n", " intensity = np.log(tmp.ravel())\n", " cci = intensity > 1\n", " intensity = intensity[cci]\n", " AA = AAt[cci]\n", " AT = AA.T\n", " maa = AT @ AA\n", " bb = AT @ intensity\n", " c = np.linalg.solve(maa, bb)\n", " inthat = AA @ np.linalg.inv(maa) @ AT @ intensity\n", " mint = np.mean(intensity)\n", " eint = np.sum((inthat - mint) ** 2)\n", " rint = np.sum((intensity - mint) ** 2)\n", " rsquare = eint / rint\n", " if rsquare > rsq and np.isreal(c).all() and np.isreal(rsquare) and c[0] < 0:\n", " xavg = -c[1] / c[0] / 2\n", " yavg = -c[2] / c[0] / 2\n", " peak = np.exp(c[3] - (xavg**2 + yavg**2) * c[0])\n", " x0 = [peak, xavg, yavg, np.sqrt(-1 / c[0] / 2)]\n", " xa = x0\n", " pts.append([mx[i, 0] + xa[1], mx[i, 1] + xa[2], xa[0], xa[3], rsquare, i])\n", " pts = np.array(pts)\n", " outid = pts[:, -1].astype(int)\n", " out = pts[:, :-1]\n", " return out, outid\n", "\n", "\n", "def calculate_transform(im, imnxt, xy=None, p=None):\n", " imly = min(im.shape[0], imnxt.shape[0])\n", " imlx = min(im.shape[1], imnxt.shape[1])\n", " if xy is None:\n", " xy = (\n", " np.array([(imly - 1) // 2, (imlx - 1) // 2]) + 1\n", " ) # use center point of image\n", " if p is None:\n", " def_rm = np.round(0.95 * (np.array([imly - 1, imlx - 1]) // 2))\n", " p = {\n", " \"rm\": def_rm,\n", " \"rs\": def_rm,\n", " \"tm\": 3,\n", " \"rg\": max(imly, imlx), # search range\n", " }\n", " # location of image pattern\n", " x0 = xy[1]\n", " y0 = xy[0]\n", " x_range = slice(x0 - p[\"rm\"][1], x0 + p[\"rm\"][1] + 1)\n", " y_range = slice(y0 - p[\"rm\"][0], y0 + p[\"rm\"][0] + 1)\n", " imptn = im[y_range, x_range]\n", " x_range = slice(x0 - p[\"rs\"][1], x0 + p[\"rs\"][1] + 1)\n", " y_range = slice(y0 - p[\"rs\"][0], y0 + p[\"rs\"][0] + 1)\n", " imgrid = imnxt[y_range, x_range]\n", " # intensity normalization (may help to take off scale effect,\n", " # expecially for fft-based transformation)\n", " imptn = (imptn - np.mean(imptn)) / np.std(imptn)\n", " imgrid = (imgrid - np.mean(imgrid)) / np.std(imgrid)\n", " # decide pattern recognition method\n", " if p[\"tm\"] == 1:\n", " y1 = correlate2d(imgrid, imptn, mode=\"full\")\n", " elif p[\"tm\"] == 2:\n", " raise NotImplementedError(\"'patrecog' method not implemented\")\n", " elif p[\"tm\"] == 3:\n", " y1 = xcorrf2(imptn, imgrid)\n", " elif p[\"tm\"] == 4:\n", " raise NotImplementedError(\"'xcor2d_nmrd' method not implemented\")\n", " elif p[\"tm\"] == 5:\n", " y1 = match_template(imgrid, imptn)\n", " y1 = y1[::-1, ::-1] * 100 # rescale to larger than 1\n", " else:\n", " raise ValueError(f\"Invalid pattern recognition method: {p['tm']}\")\n", " msk = mskcircle2_rect(y1.shape, p[\"rg\"])\n", " y1m = y1 * msk\n", " # Find the indices of the maximum value in y1m\n", " my, mx = np.unravel_index(np.argmax(y1m), y1m.shape)\n", " cnt = gcnt(y1, [mx, my], 3, 0, 0)\n", " # Extract yx and calculate res\n", " yx = cnt[[1, 0]]\n", " res = yx - p[\"rs\"] - p[\"rm\"] - 1\n", " return res[[1, 0]] # Swap to xy\n", "\n", "\n", "def reg_ims_com(amv0, arf, count, sz, rf, deg0, xy0, r, th):\n", " tform = []\n", " rsft = 0\n", " xyt = np.array([[0], [0]])\n", " RR = 0\n", " amv = amv0\n", " mm, _ = scipy.stats.mode(amv, axis=(0, 1))\n", " mmr, _ = scipy.stats.mode(arf, axis=(0, 1))\n", " theta = np.arange(-90, 90.5, 0.5)\n", " thetaout = 2\n", " brf = arf.astype(np.float32)\n", " rsft = np.zeros(count + 1)\n", " rsft[0] = deg0\n", " (h, w) = amv0.shape[:2]\n", " rotation_matrix = cv2.getRotationMatrix2D((w // 2, h // 2), np.sum(rsft), 1.0)\n", " amvr = cv2.warpAffine(\n", " amv0,\n", " rotation_matrix,\n", " (w, h),\n", " flags=cv2.INTER_LINEAR,\n", " borderMode=cv2.BORDER_CONSTANT,\n", " )\n", " if np.sum(amvr) == 0:\n", " rsft = np.sum(rsft)\n", " return tform, amv0, rsft, xyt, RR\n", " # first translation dictated by center of mass\n", " if np.any(r):\n", " amv2 = bpassW(amvr, 2, 50)\n", " amv2 = amv2 > 0\n", " arf2 = bpassW(brf, 2, 50)\n", " arf2 = arf2 > 0\n", " cx = np.cumsum(np.sum(amv2, axis=0))\n", " cy = np.cumsum(np.sum(amv2, axis=1))\n", " cmamv = [np.argmax(cx > cx[-1] / 2), np.argmax(cy > cy[-1] / 2)]\n", " cx = np.cumsum(np.sum(arf2, axis=0))\n", " cy = np.cumsum(np.sum(arf2, axis=1))\n", " cmarf = [np.argmax(cx > cx[-1] / 2), np.argmax(cy > cy[-1] / 2)]\n", " xy = cmarf - cmamv\n", " xyt = np.array([[xy[0]], [xy[1]]]) * rf\n", " xyt += xy0\n", " shift_values = (xyt / rf).flatten()\n", " translation_matrix = np.float32([[1, 0, shift_values[0]], [0, 1, shift_values[1]]])\n", " amv = cv2.warpAffine(\n", " amvr,\n", " translation_matrix,\n", " (w, h),\n", " flags=cv2.INTER_LINEAR,\n", " borderMode=cv2.BORDER_REFLECT,\n", " )\n", " a = arf > 0\n", " RR0 = np.corrcoef(amv[a].flatten(), arf[a].flatten())[0, 1]\n", " # iterate \"count\" times to achieve sufficient global registration on resized, blurred image\n", " for kk in range(count):\n", " bmv = amv.astype(np.float32)\n", " # use radon for rotational registration\n", " R0 = radon(arf, theta, circle=True)\n", " Rn = radon(amv, theta, circle=True)\n", " R0 = bpassW(R0, 1, 3)\n", " Rn = bpassW(Rn, 1, 3)\n", " try:\n", " rsf1 = calculate_transform(R0, Rn)\n", " except:\n", " print(\"WARNING: caught exception from calculate_transform\")\n", " rsf1 = np.array([0, 0]) # differs from line 57 in reg_ims_com.m\n", " rsf = rsf1[0] / thetaout\n", " # rotate image then calculate translational registration\n", " rotation_matrix = cv2.getRotationMatrix2D((w // 2, h // 2), rsf[0], 1.0)\n", " bmvr = cv2.warpAffine(\n", " bmv,\n", " rotation_matrix,\n", " (w, h),\n", " flags=cv2.INTER_LINEAR,\n", " borderMode=cv2.BORDER_CONSTANT,\n", " )\n", " bmvr[bmvr == 0] = mm\n", " try:\n", " xy1 = calculate_transform(brf, bmvr)\n", " except:\n", " print(\"WARNING: caught exception from calculate_transform\")\n", " xy1 = np.array([0, 0])\n", " try:\n", " xy2 = calculate_transform(bmvr, brf)\n", " except:\n", " print(\"WARNING: caught exception from calculate_transform\")\n", " xy2 = np.array([0, 0])\n", " xy = np.mean(np.array([xy1, -xy2]), axis=0)\n", " # keep old transform in case update is bad\n", " rsft0 = rsft\n", " xyt0 = xyt\n", " # update total rsf\n", " rsft[kk + 1] = rsf\n", " # update rotation\n", " if rsf > 0: # clockwise\n", " rotation_matrix = cv2.getRotationMatrix2D((0, 0), rsf, 1.0)[:2, :2]\n", " else: # or counterclockwise\n", " rotation_matrix = cv2.getRotationMatrix2D((0, 0), -rsf, 1.0)[:2, :2]\n", " xyt = rotation_matrix @ xyt\n", " # and translation\n", " cyt = xyt + (np.array([[xy[0]], [xy[1]]]) * rf)\n", " # update registration image\n", " rotation_matrix = cv2.getRotationMatrix2D((w // 2, h // 2), np.sum(rsft), 1.0)\n", " amv = cv2.warpAffine(\n", " amv0,\n", " rotation_matrix,\n", " (w, h),\n", " flags=cv2.INTER_LINEAR,\n", " borderMode=cv2.BORDER_CONSTANT,\n", " )\n", " shift_values = (xyt / rf).flatten()\n", " translation_matrix = np.float32(\n", " [[1, 0, shift_values[0]], [0, 1, shift_values[1]]]\n", " )\n", " amv = cv2.warpAffine(\n", " amv,\n", " translation_matrix,\n", " (w, h),\n", " flags=cv2.INTER_LINEAR,\n", " borderMode=cv2.BORDER_REFLECT,\n", " )\n", " amv[amv == 0] = mm\n", " a = arf > 0\n", " RR = np.corrcoef(amv[a].flatten(), arf[a].flatten())[0, 1]\n", " # if iteration hasn't improved correlation of images, then stop\n", " if RR + 0.02 < RR0 and count > 2:\n", " rsft = rsft0\n", " xyt = xyt0\n", " rotation_matrix = cv2.getRotationMatrix2D(\n", " (w // 2, h // 2), np.sum(rsft), 1.0\n", " )\n", " amv = cv2.warpAffine(\n", " amv0,\n", " rotation_matrix,\n", " (w, h),\n", " flags=cv2.INTER_LINEAR,\n", " borderMode=cv2.BORDER_CONSTANT,\n", " )\n", " shift_values = (xyt / rf).flatten()\n", " translation_matrix = np.float32(\n", " [[1, 0, shift_values[0]], [0, 1, shift_values[1]]]\n", " )\n", " amv = cv2.warpAffine(\n", " amv,\n", " translation_matrix,\n", " (w, h),\n", " flags=cv2.INTER_LINEAR,\n", " borderMode=cv2.BORDER_REFLECT,\n", " )\n", " amv[amv == 0] = mm\n", " RR = np.corrcoef(amv[a].flatten(), arf[a].flatten())[0, 1]\n", " break\n", " # maximum distance a point in the image moves\n", " x1 = round(amv.shape[1] / 2)\n", " y1 = round(amv.shape[0] / 2)\n", " angle_rad = np.radians(rsft[kk + 1])\n", " x2 = x1 * np.cos(angle_rad) - y1 * np.sin(angle_rad) + xy[1] - x1\n", " y2 = x1 * np.sin(angle_rad) + y1 * np.cos(angle_rad) + xy[0] - y1\n", " rff = np.sqrt(x2**2 + y2**2)\n", " if rff < 0.75 or RR > 0.9:\n", " break\n", " # apply calculated registration to fullscale image\n", " # (account for translation 'sz' due to cropping tissue from full images)\n", " rsft = np.sum(rsft)\n", " angle_rad = np.radians(rsft)\n", " cos_angle = np.cos(angle_rad)\n", " sin_angle = np.sin(angle_rad)\n", " translation_x = xyt[0] + sz[0]\n", " translation_y = xyt[1] + sz[1]\n", " # OpenCV uses a 2x3 matrix for affine transformations (MATLAB is 3x3)\n", " tform = np.array(\n", " [[cos_angle, -sin_angle, translation_x], [sin_angle, cos_angle, translation_y]],\n", " dtype=np.float32,\n", " )\n", " return tform, amv, rsft, xyt, RR\n", "\n", "\n", "def group_of_reg(amv0, arf, iternum0, sz, rf, bb):\n", " T = (\n", " [-2, 177, 87, 268, -1, 88, 269, 178]\n", " + list(range(-7, 8, 2))\n", " + list(range(179, 184))\n", " + list(range(89, 94))\n", " + list(range(270, 273))\n", " )\n", " R = 0.2\n", " rs = 0\n", " xy = 0\n", " aa = arf == 0\n", " ab = amv0 == 0\n", " arf = arf.astype(np.float64)\n", " amv0 = amv0.astype(np.float64)\n", " arf = (arf - np.mean(arf)) / np.std(amv0)\n", " amv0 = (amv0 - np.mean(amv0)) / np.std(amv0)\n", " arf -= np.min(arf)\n", " arf[aa] = 0\n", " amv0 -= np.min(amv0)\n", " amv0[ab] = 0\n", " amv = amv0\n", " RR0 = np.sum((amv > 0)) + np.sum((arf > 0))\n", " for kp in range(len(T)):\n", " try:\n", " _, amv1, rs1, xy1, RR = reg_ims_com(\n", " amv0, arf, iternum0, sz, rf, T[kp], np.array([[0], [0]]), 1\n", " )\n", " if RR != 0:\n", " aa = (arf > 0) + (amv1 > 0)\n", " RR = np.sum(aa == 2) / np.sum(aa > 0)\n", " except:\n", " print(\"WARNING: caught exception from reg_ims_com\")\n", " RR = 0\n", " if RR > R:\n", " R = RR\n", " rs = rs1\n", " xy = xy1\n", " amv = amv1\n", " if RR > bb and kp > 16:\n", " break\n", " return R, rs, xy, amv\n", "\n", "\n", "def calculate_global_reg(imrf, immv, rf, iternum):\n", " bb = 0.9\n", " amv = cv2.resize(\n", " immv, (immv.shape[1] // rsc, immv.shape[0] // rsc), interpolation=cv2.INTER_AREA\n", " )\n", " amv = gaussian_filter(amv, sigma=2)\n", " arf = cv2.resize(\n", " imrf, (imrf.shape[1] // rsc, imrf.shape[0] // rsc), interpolation=cv2.INTER_AREA\n", " )\n", " arf = gaussian_filter(arf, sigma=2)\n", " sz = np.array([0, 0])\n", " cent = np.array([0, 0])\n", " # calculate registration, flipping image if necessary\n", " iternum0 = 2\n", " R, rs, xy, amv1out = group_of_reg(amv, arf, iternum0, sz, rf, bb)\n", " f = 0\n", " ct = 0.8\n", " if R < ct:\n", " print(\"try flipping image\")\n", " amv2 = amv[::-1, :, :]\n", " R2, rs2, xy2, amv2out = group_of_reg(amv2, arf, iternum0, sz, rf, bb)\n", " if R2 > R:\n", " rs = rs2\n", " xy = xy2\n", " f = 1\n", " amv = amv2\n", " tform, amvout, _, _, Rout = reg_ims_com(\n", " amv, arf, iternum - iternum0, sz, rf, rs, xy, 0\n", " )\n", " aa = (arf > 0) + (amvout > 0)\n", " Rout = np.sum(aa == 2) / np.sum(aa > 0)\n", " # create output image\n", " Rin = {\n", " \"ImageSize\": immv.shape,\n", " \"XWorldLimits\": np.array([0.5, immv.shape[1] + 0.5]),\n", " \"YWorldLimits\": (0.5, immv.shape[0] + 0.5),\n", " }\n", " if np.sum(np.abs(cent)) == 0:\n", " mx = np.mean(Rin[\"XWorldLimits\"])\n", " my = np.mean(Rin[\"XWorldLimits\"])\n", " cent = np.array([mx, my])\n", " Rin[\"XWorldLimits\"] = Rin[\"XWorldLimits\"] - cent[0]\n", " Rin[\"YWorldLimits\"] = Rin[\"YWorldLimits\"] - cent[1]\n", " if f == 1:\n", " immv = immv[::-1, :, :]\n", " # register\n", " output_size = (Rin[\"ImageSize\"][1], Rin[\"ImageSize\"][0])\n", " imout = cv2.warpAffine(\n", " immv,\n", " tform,\n", " output_size,\n", " flags=cv2.INTER_NEAREST,\n", " borderMode=cv2.BORDER_CONSTANT,\n", " borderValue=0,\n", " )\n", " return imout, tform, cent, f, Rout" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def register_global_im(im, tform, cent, f, fillval):\n", " # set up rotation point of reference\n", " Rin = {\n", " \"ImageSize\": im.shape,\n", " \"XWorldLimits\": np.array([0.5, im.shape[1] + 0.5])-cent[0],\n", " \"YWorldLimits\": (0.5, im.shape[0] + 0.5)-cent[1],\n", " }\n", " tform[0, 2]+=Rin[\"XWorldLimits\"][0]\n", " tform[1, 2]+=Rin[\"YWorldLimits\"][0]\n", " # flip if necessary\n", " if f==1:\n", " im = im[::-1, :, :]\n", " # register\n", " return cv2.warpAffine(\n", " im,\n", " tform,\n", " (Rin[\"ImageSize\"][1], Rin[\"ImageSize\"][0]),\n", " flags=cv2.INTER_NEAREST,\n", " borderMode=cv2.BORDER_CONSTANT,\n", " borderValue=fillval,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def getImLocalWindowInd_rf(xy, imsz, wndra, skipstep):\n", " \"\"\"\n", " Convert the locations in images to index matrices of regional windows.\n", " \"\"\"\n", " # Identify the locations that are within the range of the image\n", " # ccx = (xy[:, 0] < imsz[1] - wndra - 1) & (xy[:, 0] > wndra + 1)\n", " # ccy = (xy[:, 1] < imsz[0] - wndra - 1) & (xy[:, 1] > wndra + 1)\n", " # cc = ccx & ccy\n", " xmin = xy[:, 0] - wndra\n", " ymin = xy[:, 1] - wndra\n", " indmin = np.ravel_multi_index((ymin, xmin), imsz)\n", " # indmax = imsz[0] * imsz[1]\n", " if isinstance(skipstep, int): # if it's a value\n", " gx, gy = np.meshgrid(\n", " np.arange(0, 2 * wndra + 1, skipstep) * imsz[1],\n", " np.arange(0, 2 * wndra + 1, skipstep),\n", " )\n", " else: # when skipstep is a vector\n", " gx0 = np.arange(0, 2 * wndra + 1) * imsz[1]\n", " gy0 = np.arange(0, 2 * wndra + 1)\n", " gx0 = gx0[skipstep]\n", " gy0 = gy0[skipstep]\n", " gx, gy = np.meshgrid(gx0, gy0)\n", " gxy = gx + gy\n", " gxy = gxy.ravel()\n", " ind = indmin[:, np.newaxis] + gxy\n", " return ind\n", "\n", "\n", "def reg_ims_ELS(amv0, arf0, rf, v=None):\n", " if v is None:\n", " v = 0\n", " imout = np.array([])\n", " RR = np.array([])\n", " arf = cv2.resize(\n", " arf0, (arf0.shape[1] // rf, arf0.shape[0] // rf), interpolation=cv2.INTER_LINEAR\n", " )\n", " amv = cv2.resize(\n", " amv0, (arf0.shape[1] // rf, arf0.shape[0] // rf), interpolation=cv2.INTER_LINEAR\n", " )\n", " try:\n", " xy1 = calculate_transform(arf, amv)\n", " except:\n", " print(\"WARNING: caught exception from calculate_transform\")\n", " xy1 = np.array([0, 0])\n", " try:\n", " xy2 = calculate_transform(amv, arf)\n", " except:\n", " print(\"WARNING: caught exception from calculate_transform\")\n", " xy2 = np.array([0, 0])\n", " xyt = np.mean(np.vstack([xy1, -xy2]), axis=0)\n", " X = -(xyt[0] * rf)\n", " Y = -(xyt[1] * rf)\n", " if v is not None and v != 0:\n", " translation_matrix = np.float32([[1, 0, -X], [0, 1, -Y]])\n", " imout = cv2.warpAffine(\n", " amv0,\n", " translation_matrix,\n", " (amv0.shape[1], amv0.shape[0]),\n", " flags=cv2.INTER_LINEAR,\n", " borderMode=cv2.BORDER_CONSTANT,\n", " borderValue=0,\n", " )\n", " a = imout[(imout > 0) & (arf0 > 0)]\n", " b = arf0[(imout > 0) & (arf0 > 0)]\n", " RR = np.corrcoef(a, b)[0, 1]\n", " return X, Y, imout, RR\n", "\n", "def get_nn_grids(xgg):\n", " filters = [\n", " np.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]]),\n", " np.array([[0, 1, 0], [0, 0, 0], [0, 0, 0]]),\n", " np.array([[0, 0, 1], [0, 0, 0], [0, 0, 0]]),\n", " np.array([[0, 0, 0], [1, 0, 0], [0, 0, 0]]),\n", " np.array([[0, 0, 0], [0, 0, 1], [0, 0, 0]]),\n", " np.array([[0, 0, 0], [0, 0, 0], [1, 0, 0]]),\n", " np.array([[0, 0, 0], [0, 0, 0], [0, 1, 0]]),\n", " np.array([[0, 0, 0], [0, 0, 0], [0, 0, 1]])\n", " ]\n", " grids = [scipy.ndimage.convolve(xgg, f) for f in filters]\n", " gridX = np.stack(grids, axis=-1)\n", " return gridX\n", "\n", "def fill_vals(xgg, ygg, cc, xystd=None):\n", " if xystd is None:\n", " xystd = 0\n", " sxgg = np.array([])\n", " sygg = np.array([])\n", " denom = scipy.ndimage.convolve(~cc.astype(float), np.array([[1, 1, 1], [1, 0, 1], [1, 1, 1]]))\n", " if xystd !=0 :\n", " gridX = get_nn_grids(xgg)\n", " gridY = get_nn_grids(ygg)\n", " gridD = get_nn_grids(~cc)\n", " gridX = (gridX - xgg[..., np.newaxis]) ** 2 * gridD\n", " gridY = (gridY - ygg[..., np.newaxis]) ** 2 * gridD\n", " sxgg = np.sqrt(np.sum(gridX, axis=-1) / np.sum(denom, axis=-1))\n", " sygg = np.sqrt(np.sum(gridY, axis=-1) / np.sum(denom, axis=-1))\n", " sxgg[cc] = 0\n", " sygg[cc] = 0\n", " denom[denom == 0] = 1\n", " dxgg = scipy.ndimage.convolve(xgg, np.array([[1, 1, 1], [1, 0, 1], [1, 1, 1]])) / denom\n", " dygg = scipy.ndimage.convolve(ygg, np.array([[1, 1, 1], [1, 0, 1], [1, 1, 1]])) / denom\n", " xgg[cc] = dxgg[cc]\n", " ygg[cc] = dygg[cc]\n", " return xgg, ygg, dxgg, dygg, denom, sxgg, sygg\n", "\n", "def make_final_grids(xgg0, ygg0, bf, x, y, szim):\n", " xgg = np.copy(xgg0)\n", " ygg = np.copy(ygg0)\n", " mxy = 75 # 50 # allow no translation larger than this cutoff\n", " cmxy = (xgg > mxy) | (ygg > mxy) # non-continuous values\n", " xgg[cmxy] = -5000\n", " # find points where registration was calculated\n", " cempty = xgg == -5000\n", " xgg[cempty] = 0\n", " ygg[cempty] = 0\n", " # replace non-continuous values with mean of neighbors\n", " _, _, dxgg, dygg, _, sxgg, sygg = fill_vals(xgg, ygg, cempty, 1)\n", " m1 = np.divide(\n", " np.abs(xgg - dxgg), np.abs(dxgg), out=np.zeros_like(xgg), where=dxgg != 0\n", " ) # percent difference between x and mean of surrounding\n", " m2 = np.divide(\n", " np.abs(ygg - dygg), np.abs(dygg), out=np.zeros_like(ygg), where=dygg != 0\n", " )\n", " dds = (sxgg > 50) | (sygg > 50)\n", " ddm = (m1 > 5) | (m2 > 5)\n", " ddp = (np.abs(xgg) > 80) | (np.abs(ygg) > 80)\n", " dd = (dds | ddm | ddp) & ~cempty\n", " xgg[dd] = dxgg[dd]\n", " ygg[dd] = dygg[dd]\n", " # fill in values outside tissue region with mean of neighbors\n", " cc = cempty\n", " count = 1\n", " while np.sum(cc) > 0 and count < 500:\n", " _, _, dxgg, dygg, denom = fill_vals(xgg, ygg, cc)\n", " cfill = (denom > 2) & cc # touching 3+ numbers and needs to be filled\n", " xgg[cfill] = dxgg[cfill]\n", " ygg[cfill] = dygg[cfill]\n", " cc = cc & ~cfill # needs to be filled and has not been filled\n", " count += 1\n", " print(f\"count = {count}/500\")\n", " xgg = gaussian_filter(xgg, 1)\n", " ygg = gaussian_filter(ygg, 1)\n", " # add buffer to outline of displacement map to avoid discontinuity\n", " xgg = np.pad(xgg, ((1, 1), (1, 1)), mode=\"edge\")\n", " ygg = np.pad(ygg, ((1, 1), (1, 1)), mode=\"edge\")\n", " x = np.concatenate(([1], np.unique(x) - bf, [szim[1]]))\n", " y = np.concatenate(([1], np.unique(y) - bf, [szim[0]]))\n", " # get D\n", " xq, yq = np.meshgrid(np.arange(1, szim[1] + 1), np.arange(1, szim[0] + 1))\n", " xgq = griddata((x.flatten(), y.flatten()), xgg.flatten(), (xq, yq), method=\"cubic\")\n", " ygq = griddata((x.flatten(), y.flatten()), ygg.flatten(), (xq, yq), method=\"cubic\")\n", " D = np.stack((xgq, ygq), axis=-1)\n", " return D, xgg, ygg, x, y\n", "\n", "def calculate_elastic_registration(imrfR, immvR, TArf, TAmv, sz, bf, di, cutoff=0.15):\n", " cc = 10\n", " cc2 = cc + 1\n", " szim = immvR.shape\n", " m = (sz - 1) / 2 + 1\n", " # pad and blur images and pad masks\n", " immvR = np.pad(\n", " immvR,\n", " pad_width=bf,\n", " mode=\"constant\",\n", " constant_values=scipy.stats.mode(immvR).mode[0],\n", " )\n", " immvR = cv2.GaussianBlur(immvR, (0, 0), 3)\n", " imrfR = np.pad(\n", " imrfR,\n", " pad_width=bf,\n", " mode=\"constant\",\n", " constant_values=scipy.stats.mode(imrfR).mode[0],\n", " )\n", " imrfR = cv2.GaussianBlur(imrfR, (0, 0), 3)\n", " TAmv = np.pad(TAmv, pad_width=bf, mode=\"constant\", constant_values=0)\n", " TArf = np.pad(TArf, pad_width=bf, mode=\"constant\", constant_values=0)\n", " # make grid for registration points\n", " n1 = np.random.randint(1, round(di / 2) + 1) + bf + m\n", " n2 = np.random.randint(1, round(di / 2) + 1) + bf + m\n", " x, y = np.meshgrid(\n", " np.arange(n1, immvR.shape[1] - m - bf, di),\n", " np.arange(n2, immvR.shape[0] - m - bf, di),\n", " )\n", " x = x.ravel()\n", " y = y.ravel()\n", " # get percentage of tissue in each registration ROI\n", " checkS = np.zeros(len(x))\n", " numb = 200\n", " for b in range(0, len(x), numb):\n", " b2 = min(b + numb - 1, len(x))\n", " ii = getImLocalWindowInd_rf(\n", " [x[b : b2 + 1], y[b : b2 + 1]], TAmv.shape, m - 1, 1\n", " )\n", " imcheck = np.reshape(\n", " np.transpose(TAmv.ravel()[ii], (1, 0)), (sz, sz, ii.shape[0])\n", " )\n", " imcheck2 = np.zeros(imcheck.shape)\n", " imcheck2[cc2:-cc, cc2:-cc, :] = imcheck[cc2:-cc, cc2:-cc, :]\n", " mvS = np.sum(imcheck2, axis=(0, 1))\n", " imcheck = np.reshape(\n", " np.transpose(TArf.ravel()[ii], (1, 0)), (sz, sz, ii.shape[0])\n", " )\n", " rfS = np.sum(imcheck2, axis=(0, 1))\n", " checkS[b : b2 + 1] = min(mvS, rfS, 2)\n", " del ii, imcheck, imcheck2\n", " checkS /= sz**2\n", " yg = ((y - np.min(y)) / di) + 1\n", " xg = ((x - np.min(x)) / di) + 1\n", " unique_x_len = len(np.unique(x))\n", " unique_y_len = len(np.unique(y))\n", " xgg0 = -5000 * np.ones((unique_y_len, unique_x_len))\n", " ygg0 = -5000 * np.ones((unique_y_len, unique_x_len))\n", " for kk in np.where(checkS > cutoff)[0]:\n", " ii = getImLocalWindowInd_rf([x[kk], y[kk]], TAmv.shape, m - 1, 1)\n", " ii[ii == -1] = 1\n", " immvS = immvR.ravel()[ii]\n", " imrfS = imrfR.ravel()[ii]\n", " immvS = np.reshape(np.transpose(immvS, (1, 0)), (sz, sz))\n", " imrfS = np.reshape(np.transpose(imrfS, (1, 0)), (sz, sz))\n", " X, Y, imoutS = reg_ims_ELS(immvS, imrfS, 2, 1)\n", " xgg0[yg[kk].astype(int) - 1, xg[kk].astype(int) - 1] = X\n", " ygg0[yg[kk].astype(int) - 1, xg[kk].astype(int) - 1] = Y\n", " # smooth registration grid and make interpolated displacement map\n", " if np.max(szim) > 2000:\n", " szimout = np.round(szim / 5)\n", " x /= 5\n", " y /= 5\n", " bf /= 5\n", " else:\n", " szimout = szim\n", " D, xgg, ygg, xx, yy = make_final_grids(xgg0, ygg0, bf, x, y, szimout)\n", " return D, xgg, ygg, xx, yy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for kk in range(len(mv)):\n", " # create moving image\n", " immv0 = np.array(Image.open(imlist[mv[kk]]))\n", " TAmv = np.array(Image.open(pth / \"TA\" / imlist[mv[kk]].name))\n", " immv, immvg, TAmv,fillvals = preprocessing(immv0, TAmv, szz, padall)\n", " # reset reference images when at center\n", " if rf[kk]==zc:\n", " imrfgA=img\n", " TArfA=TA\n", " krfA=zc\n", " imrfgB=img0\n", " TArfB=TA0\n", " krfB=krf0\n", " imrfgC=img00\n", " TArfC=TA00\n", " krfC=krf00\n", " imvEold=imzc\n", " rc=0\n", " # skipping a conditional that's here in the matlab to load a previously-calculated\n", " # registration, lines below should be in an \"else\" block\n", " rc = 1\n", " RB = 0.4\n", " RC = 0.4\n", " immvGgB = immvg\n", " immvGgC = immvg\n", " ct = 0.945\n", " # try with registration pairs 1\n", " immvGg, tform, cent, f, R = calculate_global_reg(imrfgA, immvg, rsc, iternum)\n", " # try with registration pairs 2\n", " if R