Lamellae Analysis#

[ ]:
%load_ext autoreload
%autoreload 2
[ ]:
import numpy as np
import pandas as pd
import warnings
from copy import deepcopy
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from cryocat import cryomotl
from cryocat import mathutils

Functions used in the analysis#

Functions to flatten lamellae#

[ ]:
def fit_single_plane(coord):
    # Fitting a plane through all coordinates
    # Solve the linear system A * x = coord_z for x
    A = np.column_stack((coord[:,0], coord[:,1], np.ones_like(coord[:,0])))
    fitted_plane = np.linalg.lstsq(A, coord[:,2], rcond=None)[0]

    return fitted_plane

def flatten_coordinates(coord, fitted_plane = None):

    if fitted_plane is None:
        fitted_plane = fit_single_plane(coord)

    fitted_plane = np.reshape(fitted_plane, (3,))
    flat_coord = coord[:,2]-(fitted_plane[0]*coord[:,0]+fitted_plane[1]*coord[:,1]+fitted_plane[2])

    return flat_coord

def fit_all_planes(input_motl, percentile_flattening, percentile_hist):

    fitted_planes = []
    boundaries = []
    cleaned_coord_stats = []

    for t in input_motl.get_unique_values("tomo_id"):
        tm = input_motl.get_motl_subset(t, feature_id="tomo_id", reset_index=True)
        coord = tm.get_coordinates()

        # Calculate "flat" z-coordinates, 1st iteration
        z_flattened_it1 = flatten_coordinates(coord)

        # Calculate the percentiles
        percentile = np.percentile(z_flattened_it1, percentile_flattening)

        # Get coordinates for 2nd iteration, removing outliers in Z < percentile, for more accurate straightening
        z_flattened_it2_idx = (z_flattened_it1 < percentile) & (z_flattened_it1 > -percentile)
        coord_percent = coord[z_flattened_it2_idx]

        cleaned_coord_stats.append([np.min(coord_percent[:,0]), np.max(coord_percent[:,0]), np.min(coord_percent[:,1]), np.max(coord_percent[:,1])])

        # Calculate "flat" z-coordinates, 2nd iteration
        fitted_plane = fit_single_plane(coord_percent)
        z_flattened_it1 = flatten_coordinates(coord, fitted_plane=fitted_plane)

        fitted_planes.append(fitted_plane)

        # Calculate histogram
        n_bins, edges = np.histogram(z_flattened_it1, bins=100)

        # Calculate lower bound
        lower_bound = edges[np.max(np.where(n_bins[0:np.argmax(n_bins)] < percentile_hist * np.max(n_bins)))+1]

        # Calculate upper bound
        upper_bound = edges[np.min(np.where(n_bins[np.argmax(n_bins):] < percentile_hist * np.max(n_bins))) + np.argmax(n_bins)+1]

        boundaries.append([lower_bound, upper_bound])

    plane_stats = pd.DataFrame()
    plane_stats["tomo_id"] = input_motl.get_unique_values("tomo_id")
    plane_stats[["fitted_plane_x","fitted_plane_y","fitted_plane_z"]] = np.stack(fitted_planes)
    plane_stats[["cleaned_x_min","cleaned_x_max","cleaned_y_min", "cleaned_y_max"]] = np.stack(cleaned_coord_stats)
    plane_stats[["lower_bound","upper_bound"]] = np.stack(boundaries)

    return plane_stats

def flatten_z_on_plane(input_motl, plane_stats, feature_id_z="geom4", feature_id_edge_dist="geom5"):

    for t in input_motl.get_unique_values("tomo_id"):
        tm = input_motl.get_motl_subset(t, feature_id="tomo_id", reset_index=True)
        coord = tm.get_coordinates()
        fitted_plane = plane_stats.loc[plane_stats["tomo_id"]==t, ["fitted_plane_x","fitted_plane_y","fitted_plane_z"]].values

        # in case there is no entry for this tomogram skip it and print warning
        if fitted_plane.shape==(0,3):
            warnings.warn(f"Tomogram #{t} does not have any entry for the fitted plane and will be skipped.")
            continue

        flattened_z = flatten_coordinates(coord, fitted_plane)
        top_dist = plane_stats.loc[plane_stats["tomo_id"]==t, "upper_bound"].values-flattened_z         # Distance to top edge
        bottom_dist = flattened_z-plane_stats.loc[plane_stats["tomo_id"]==t, "lower_bound"].values    # Distance to bottom edge

        input_motl.df.loc[input_motl.df["tomo_id"]==t,feature_id_z] = flattened_z
        input_motl.df.loc[input_motl.df["tomo_id"]==t,feature_id_edge_dist] = np.min([top_dist, bottom_dist], axis=0)

Function to select particles within given distance from the edge#

[ ]:
def get_particles_within_distance(input_motl, distance, number_of_particles, feature_id_edge_dist="geom5"):

    # copy the input motive list
    new_motl = deepcopy(input_motl)

    # Remove negative values (outliers)
    new_motl.df = new_motl.df.loc[new_motl.df[feature_id_edge_dist] >= 0.0]

    # compute the difference based on the distance
    new_motl.df["temp_column"] = np.abs(new_motl.df[feature_id_edge_dist] - distance)

    # sort the entries from min to max difference
    sorted_df = new_motl.df.sort_values(by="temp_column")

    # take only first N particles
    sorted_df = sorted_df.head(number_of_particles)
    sorted_df.drop(columns=['temp_column'], inplace=True)

    # sort based on the original index and store in the new motive list
    new_motl.df = sorted_df.sort_index()

    return new_motl

Functions to plot the results#

[ ]:
def plot_single_tomogram(coord_list, label_list, flattened_z_list, plane_stats, percentile_hist = 0.15, show_plot = True, output_file = None):

    def add_info_to_subplot(subplot, s_title):
        subplot.set_title(s_title)
        subplot.set_xlabel('x (nm)')
        subplot.set_ylabel('y (nm)')
        subplot.set_zlabel('z (nm)')
        subplot.legend(loc='upper left')

    fig = plt.figure(figsize=(16, 14))

    # create first subplot with original coordinates and fitted plane
    ax1 = fig.add_subplot(221, projection='3d', computed_zorder=False)

    # create second subplot with flattened z coordinates
    ax2 = fig.add_subplot(222, projection='3d', computed_zorder=False)

    # plot coordinates (both original and flattened)
    for i, coord in enumerate(coord_list):
        ax1.scatter(coord[:,0], coord[:,1], coord[:,2], s=5, label=label_list[i], alpha=1)
        ax2.scatter(coord[:,0], coord[:,1], flattened_z_list[i], s=5, label=label_list[i], alpha=1)


    # create grid for surface plotting of the plane
    xv = np.linspace(plane_stats["cleaned_x_min"], plane_stats["cleaned_x_max"], 100)
    yv = np.linspace(plane_stats["cleaned_y_min"], plane_stats["cleaned_y_max"], 100)
    xm, ym = np.meshgrid(xv, yv)

    # Reshape and compute z values for the plane
    fitted_plane=np.reshape(plane_stats[["fitted_plane_x","fitted_plane_y","fitted_plane_z"]].values, (3,))
    points = np.column_stack([xm.ravel(), ym.ravel(), np.ones_like(xm.ravel())])
    plane_z = np.dot(points, fitted_plane).reshape(xm.shape)
    ax1.plot_surface(xm, ym, plane_z, alpha=0.5, cmap='viridis')

    # add info to subplots
    add_info_to_subplot(ax1,"Original coordinates with fitted plane")
    add_info_to_subplot(ax2,"Flattened z coordinates")

    # create histogram of flattened z coordinates from the raw TM data
    ax3 = fig.add_subplot(223)
    ax3.hist(flattened_z_list[0], bins=100, alpha=0.7)
    ax3.axvline(plane_stats["lower_bound"].values[0], color='red', linestyle='dashed', linewidth=2, label='Lower Bound')
    ax3.axvline(plane_stats["upper_bound"].values[0], color='green', linestyle='dashed', linewidth=2, label='Upper Bound')
    ax3.axhline(percentile_hist * plt.gca().get_ylim()[1], color='orange', linestyle='dashed', linewidth=2, label='Percentile')
    ax3.set_xlabel('Z coordinates (nm)')
    ax3.set_title('Histogram of all flattened z-coordinates of Raw TM')
    ax3.legend(loc='upper right')

    # Fourth subplot - 2x2 grid
    ax4 = fig.add_subplot(224)
    ax4.axis('off')
    ax4.set_title("Coordinates in 2D")
    sub_gs = GridSpec(2, 2, figure=fig, left=ax4.get_position().x0, right=ax4.get_position().x1,
                    bottom=ax4.get_position().y0, top=ax4.get_position().y1-0.02, wspace=0.35, hspace=0.5)

    # Create subplots with 2x2 grid
    ax4_axes = []
    for i in [0,1]:
        for j in [0,1]:
            ax4_axes.append(fig.add_subplot(sub_gs[i, j]))

    # plot coordinates (both original and flattened) in 2D
    for i, coord in enumerate(coord_list):
        ax4_axes[0].scatter(coord_list[i][:,0], coord_list[i][:,2], s=3, alpha=1)
        ax4_axes[1].scatter(coord_list[i][:,1], coord_list[i][:,2], s=3, alpha=1)
        ax4_axes[2].scatter(coord_list[i][:,0], flattened_z_list[i], s=3, alpha=1)
        ax4_axes[3].scatter(coord_list[i][:,1], flattened_z_list[i], s=3, alpha=1)

    for i in [2,3]:
        ax4_axes[i].axhline(plane_stats["lower_bound"].values[0], color='red', linestyle='dashed', linewidth=2, label='Lower Bound')
        ax4_axes[i].axhline(plane_stats["upper_bound"].values[0], color='green', linestyle='dashed', linewidth=2, label='Upper Bound')

    for a in zip(ax4_axes, ["Original coord in ZX", "Original coord in ZY","Flattened coord in ZX","Flattened coord in ZY"]):
        a[0].set_title(a[1])
        a[0].set_ylabel('z (nm)')
        if a[1].endswith("ZX"):
             a[0].set_xlabel('x (nm)')
        else:
             a[0].set_xlabel('y (nm)')

    if output_file is not None:
        plt.savefig(output_file, transparent=True, bbox_inches='tight', pad_inches=0)

    if not show_plot:
        plt.close()

def plot_all(input_motl_raw, input_motl_cleaned, plane_stats, feature_id_z="geom4", percentile_hist = 0.15, show_plots = True, output_file_base = None, output_suffix=".png"):

    for t in input_motl_raw.get_unique_values("tomo_id"):
        tm_raw = input_motl_raw.get_motl_subset(t, feature_id="tomo_id", reset_index=True)
        tm_cleaned = input_motl_cleaned.get_motl_subset(t, feature_id="tomo_id", reset_index=True)
        tm_plane_stats = plane_stats.loc[plane_stats["tomo_id"]==t,:]

        if output_file_base is not None:
            output_file = output_file_base + str(int(t)) + output_suffix
        else:
            output_file = None

        plot_single_tomogram([tm_raw.get_coordinates(), tm_cleaned.get_coordinates()],
                             ["Raw TM Coord", "Relion Coord"],
                             [tm_raw.get_feature(feature_id_z), tm_cleaned.get_feature(feature_id_z)],
                             tm_plane_stats,
                             percentile_hist = percentile_hist,
                             show_plot = show_plots,
                             output_file = output_file)

General workflow#

  1. Setup inputs

  2. Prepare particle lists

  3. Fit plane into the coordinates

  4. Flatten the z coordinates

  5. Plot the results

  6. Select particles in certain distance from the edges

1. Setup inputs#

[ ]:
pixel_size = 0.1223                 # pixel size of the raw data in nm/pixel
bin_raw_tm = 6                      # binning of the raw template-matching starfile
bin_clean_tm = 6                    # binning of the ribosome starfile, after 3D classification
percentile_flattening = 95          # used for 2nd iteration of coordinate flattening
percentile_hist = 0.15              # cut-off for histogram edges in percentage of maximum value
feature_id_flatten_z = "geom4"      # name of the column to store the flattened z in the motls
feature_id_edge_dist = "geom5"      # name of the column to store distance from the edge of a lamella

2. Prepare particle lists#

  • Note that this should give you one warning about gimbal lock present in the data. That often happens with orientations estimaned on discrete set of sampling points, e.g. during template matching.

[ ]:
raw_motl = cryomotl.RelionMotl("./inputs/bin6_Raw_TM_starfile.star")
cleaned_motl = cryomotl.RelionMotl("./inputs/bin6_TM_starfile.star")

# Scale the coordinates so they correspond to the physical coordinates of unbinned data
raw_motl.scale_coordinates(bin_raw_tm*pixel_size)
cleaned_motl.scale_coordinates(bin_clean_tm*pixel_size)

3. Fit plane into the coordinates#

[ ]:
plane_stats = fit_all_planes(raw_motl, percentile_flattening, percentile_hist)

4. Flatten the z coordinates#

  • This will store the “flattened” z coordinates based on the fitted planes estimated in the previous step

[ ]:
flatten_z_on_plane(raw_motl, plane_stats, feature_id_z=feature_id_flatten_z)
flatten_z_on_plane(cleaned_motl, plane_stats, feature_id_z=feature_id_flatten_z)

5. Plot the results#

[ ]:
# Plot the results for each tomogram in the dataset
plot_all(raw_motl,cleaned_motl,plane_stats,feature_id_z=feature_id_flatten_z, percentile_hist=percentile_hist)
[ ]:
# To create the results and save them without displaying use following parameters
# It creates one image per tomogram with name results_#tomoNumber.png
plot_all(raw_motl,cleaned_motl,plane_stats,feature_id_z=feature_id_flatten_z, percentile_hist=percentile_hist,
         show_plots=False, output_file_base="results_")

# To have different format use output_suffix, note the "." in there
plot_all(raw_motl,cleaned_motl,plane_stats,feature_id_z=feature_id_flatten_z, percentile_hist=percentile_hist,
         show_plots=False, output_file_base="results_", output_suffix=".svg")
  • For the test dataset you should get following figures:

results_243.png

results_360.png

6. Select particles in certain distance from the edges#

[ ]:
n_particles_to_pick = 10
output_path = "./"
output_name_base = "new_list_dist_"
output_suffix = "nm.star"
depth_range = range(5, 52, 5)

plot_histograms = True  # sanity check to see if particles from the correct depth were selected

if plot_histograms:
    rows, columns = mathutils.get_similar_size_factors(len(depth_range))
    fig, axs = plt.subplots(rows, columns, figsize=(columns*4, rows*4))
    r = 0
    c = 0

# Create new particle lists with always n_particles_to_pick that are closest to the given lamellae depth
# and write them out to the specified path, adding the distance to the filename
for i in range(5, 52, 5):
    new_motl = get_particles_within_distance(cleaned_motl, i, n_particles_to_pick)

    # write out the seleceted particles as a starfile - use_original_entries and keep_all_entries ensure
    # the lists to have same entries as in the input
    new_motl.write_out(output_path+output_name_base+str(i)+output_suffix, use_original_entries=True, keep_all_entries=True)

    if plot_histograms:
        axs[r,c].hist(new_motl.df[feature_id_edge_dist], bins=10, edgecolor='black')
        axs[r,c].set_title("Depth " + str(i) + " nm")
        axs[r,c].set_xlabel("Distance to the edge (nm)")
        c += 1
        if c==columns:
            c=0
            r+=1
        plt.tight_layout()
  • For the test dataset you should get following histogram figure:

histograms.png