Source code for ensemble_md.analysis.clustering

####################################################################
#                                                                  #
#    ensemble_md,                                                  #
#    a python package for running GROMACS simulation ensembles     #
#                                                                  #
#    Written by Wei-Tse Hsu <wehs7661@colorado.edu>                #
#    Copyright (c) 2022 University of Colorado Boulder             #
#                                                                  #
####################################################################
import numpy as np
import matplotlib.pyplot as plt
from ensemble_md.utils.utils import run_gmx_cmd
from ensemble_md.analysis import analyze_traj


[docs]def cluster_traj(gmx_executable, inputs, grps, coupled_only=True, method='linkage', cutoff=0.1, suffix=None): """ Performs clustering analysis on a trajectory using the GROMACS command :code:`gmx cluster`. Note that this function encompasses the use of all the other functions in the module, including :func:`get_cluster_info`, :func:`get_cluster_members`, and :func:`analyze_transitions`. Parameters ---------- gmx_executable : str The path of the GROMACS executable. inputs : dict A dictionary that contains the different input files required for the clustering analysis. The dictionary must have the following four keys: :code:`traj` (input trajectory file in XTC or TRR format), :code:`config` (the configuration file in TPR or GRO format), :code:`xvg` (a GROMACS XVG file), and :code:`index` (an index/NDX file), with the values being the paths to the files. Note that the value of the key :code:`index` can be :code:`None`,in which case the function will use a default index file generated by :code:`gmx make_ndx`. If the parameter :code:`coupled_only` is set to :code:`True`, an XVG file that contains the time series of the state index (e.g., :code:`dhdl.xvg`) must be provided with the key :code:`xvg`. Otherwise, the key :code:`xvg` can be set to :code:`None`. grps : dict A dictionary that contains the names of the groups in the index file (NDX) for centering the system, calculating the RMSD, and outputting. The corresponding keys are :code:`center`, :code:`rmsd`, and :code:`output`. coupled_only : bool, Optional Whether to only consider the fully coupled configurations. The default is :code:`True`. method : str, Optional The method for clustering available for the GROMACS command :code:`gmx cluster`. The default is :code:`'linkage'`. Check the `GROMACS documentation <https://manual.gromacs.org/current/onlinehelp/gmx-cluster.html>`_ for other available options. cutoff : float, Optional The RMSD cutoff for clustering in nm. The default is 0.1. suffix : str, Optional The suffix for the output files. The default is :code:`None`, which means no suffix will be added. Example ------- Below is an example of performing a cluster analysis for all 4 replicas that compose of a REXEE simulation of a host-guest binding complex. >>> import glob >>> import natsort >>> from ensemble_md.analysis.clustering import cluster_traj >>> from ensemble_md.analysis.analyze_traj import stitch_trajs, convert_npy2xvg >>> rep_trajs = np.load('rep_trajs.npy') # Usually genrated by the REXEE simulation >>> state_trajs = np.load('state_trajs.npy') # Usually generated by analyze_traj.stitch_time_series >>> files = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*xtc')) for i in range(4)] >>> stitch_trajs('gmx', files, rep_trajs) >>> convert_npy2xvg(state_trajs, 0.2, subsampling=10) >>> for i in range(4): >>> print() >>> print(f'Performing clustering analysis for traj_{i}.xtc ...') >>> inputs = { >>> 'traj': f'traj_{i}.xtc', >>> 'config': 'complex.gro', >>> 'xvg': f'traj_{i}.xvg', >>> 'index': 'complex.ndx' >>> } >>> grps = { >>> 'center': 'HOS_MOL', >>> 'rmsd': 'complex_heavy', >>> 'output': 'HOS_MOL' >>> } >>> cluster_traj('gmx', inputs, grps, coupled_only=False, cutoff=0.13, suffix=f'{i}') """ # Check input parameters required_keys_1 = ['traj', 'config', 'xvg', 'index'] for key in required_keys_1: if key not in inputs: raise ValueError(f'The key "{key}" is missing in the inputs dictionary.') required_keys_2 = ['center', 'rmsd', 'output'] for key in required_keys_2: if key not in grps: raise ValueError(f'The key "{key}" is missing in the grps dictionary.') if coupled_only and inputs['xvg'] is None: raise ValueError('The parameter "coupled_only" is set to True but no XVG file is provided.') # Check if the index file is provided if inputs['index'] is None: print('Running gmx make_ndx to generate an index file ...') args = [ gmx_executable, 'make_ndx', '-f', inputs['config'], '-o', 'index.ndx', ] returncode, stdout, stderr = run_gmx_cmd(args, prompt_input='q\n') inputs['index'] = 'index.ndx' # Check if the groups are present in the index file with open(inputs['index'], 'r') as f: content = f.read() for key in grps: if grps[key] not in content: raise ValueError(f'The group "{grps[key]}" is not present in the provided/generated index file.') outputs = { 'nojump': 'nojump.xtc', 'center': 'center.xtc', 'rmsd-clust': 'rmsd_clust.xpm', 'rmsd-dist': 'rmsd_dist.xvg', 'cluster-log': 'cluster.log', 'cluster-pdb': 'clusters.pdb', 'rmsd': 'rmsd.xvg', # inter-medoid RMSD } if suffix is not None: for key in outputs: outputs[key] = outputs[key].replace('.', f'_{suffix}.') # Check if there is any fully coupled state in the trajectory lambda_data = np.transpose(np.loadtxt(inputs['xvg'], comments=['#', '@']))[1] if coupled_only is True and 0 not in lambda_data: print('Terminating clustering analysis since no fully decoupled state is present in the input trajectory while coupled_only is set to True.') # noqa: E501 else: # Either coupled_only is False or coupled_only is True but there are coupled configurations. print('Eliminating jumps across periodic boundaries for the input trajectory ...') args = [ gmx_executable, 'trjconv', '-f', inputs['traj'], '-s', inputs['config'], '-n', inputs['index'], '-o', outputs['nojump'], '-center', 'yes', '-pbc', 'nojump', ] if coupled_only: args.extend([ '-drop', inputs['xvg'], '-dropover', '0' ]) returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["center"]}\n{grps["output"]}\n') if returncode != 0: raise ValueError(f'Error with return code {returncode}:\n{stderr}') print('Centering the system ...') args = [ gmx_executable, 'trjconv', '-f', outputs['nojump'], '-s', inputs['config'], '-n', inputs['index'], '-o', outputs['center'], '-center', 'yes', '-pbc', 'mol', '-ur', 'compact', ] returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["center"]}\n{grps["output"]}\n') if returncode != 0: raise ValueError(f'Error with return code {returncode}:\n{stderr}') if coupled_only is True: N_coupled = np.count_nonzero(lambda_data == 0) print(f'Number of fully coupled configurations: {N_coupled}') print('Performing clustering analysis ...') args = [ gmx_executable, 'cluster', '-f', outputs['center'], '-s', inputs['config'], '-n', inputs['index'], '-o', outputs['rmsd-clust'], '-dist', outputs['rmsd-dist'], '-g', outputs['cluster-log'], '-cl', outputs['cluster-pdb'], '-cutoff', str(cutoff), '-method', method, ] returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["rmsd"]}\n{grps["output"]}\n') if returncode != 0: raise ValueError(f'Error with return code {returncode}:\n{stderr}') rmsd_range, rmsd_avg, n_clusters = get_cluster_info(outputs['cluster-log']) print(f'Range of RMSD values: from {rmsd_range[0]:.3f} to {rmsd_range[1]:.3f} nm') print(f'Average RMSD: {rmsd_avg:.3f} nm') print(f'Number of clusters: {n_clusters}') if n_clusters > 1: clusters, sizes = get_cluster_members(outputs['cluster-log']) for i in range(1, n_clusters + 1): print(f' - Cluster {i} accounts for {sizes[i] * 100:.2f}% of the total configurations.') if n_clusters == 2: transmtx, _, t_transitions = analyze_transitions(clusters, normalize=False) # Note that this is a 2D count matrix. # noqa: E501 n_transitions = np.sum(transmtx) - np.trace(transmtx) # This is the sum of all off-diagonal elements. np.trace calculates the sum of the diagonal elements. # noqa: E501 print(f'Number of transitions between the two clusters: {n_transitions}') if n_transitions > 0: print(f'Time frames of the transitions (ps): {t_transitions[(1, 2)]}') print('Calculating the inter-medoid RMSD between the two biggest clusters ...') # Note that we pass outputs['cluster-pdb'] to -s so that the first medoid will be used as the reference args = [ gmx_executable, 'rms', '-f', outputs['cluster-pdb'], '-s', outputs['cluster-pdb'], '-o', outputs['rmsd'], ] if inputs['index'] is not None: args.extend(['-n', inputs['index']]) # Here we simply assume same groups for least-squares fitting and RMSD calculation returncode, stdout, stderr = run_gmx_cmd(args, prompt_input=f'{grps["rmsd"]}\n{grps["rmsd"]}\n') if returncode != 0: print(f'Error with return code: {returncode}):\n{stderr}') rmsd = np.transpose(np.loadtxt(outputs['rmsd'], comments=['@', '#']))[1][1] # inter-medoid RMSD print(f'Inter-medoid RMSD between the two biggest clusters: {rmsd:.3f} nm')
[docs]def get_cluster_info(cluster_log): """ Extracts basic results from the clustering analysis by parsing the LOG file generated by the GROMACS :code:`gmx cluster` command. Parameters ---------- cluster_log : str The LOG file generated by the GROMACS :code:`gmx cluster` command. Returns ------- rmsd_range: list The range of RMSD values rmsd_avg: float The average RMSD value. n_clusters : int The number of clusters. """ f = open(cluster_log, 'r') lines = f.readlines() f.close() rmsd_range = [] for line in lines: if 'The RMSD ranges from' in line: rmsd_range.append(float(line.split('from')[-1].split('to')[0])) rmsd_range.append(float(line.split('from')[-1].split('to')[-1].split('nm')[0])) if 'Average RMSD' in line: rmsd_avg = float(line.split('is')[-1]) if 'Found' in line: n_clusters = int(line.split()[1]) break return rmsd_range, rmsd_avg, n_clusters
[docs]def get_cluster_members(cluster_log): """ Gets the members of each cluster from the LOG file generated by the GROMACS :code:`gmx cluster` command. Parameters ---------- cluster_log : str The LOG file generated by the GROMACS :code:`gmx cluster` command. Returns ------- clusters : dict A dictionary that contains the cluster indices (starting from 1) as the keys and the lists of members (represented by time frames) as the values. sizes : dict A dictionary that contains the cluster indices (starting from 1) as the keys and the sizes of the cluster (in fraction) as the values. """ clusters = {} current_cluster = 0 start_processing = False f = open(cluster_log, 'r') lines = f.readlines() f.close() for line in lines: # Start processing when we reach the line that starts with "cl." if line.strip().startswith("cl."): start_processing = True continue # Skip this line and continue to the next iteration if start_processing: parts = line.split('|') try: current_cluster = int(parts[0].strip()) clusters[current_cluster] = [] except ValueError: pass # This is either a new cluster or continuation of it, add members members = parts[-1].split() clusters[current_cluster].extend([int(i) for i in members]) sizes_list = [len(clusters[i]) for i in clusters] sizes = {i: sizes_list[i - 1] / sum(sizes_list) for i in clusters} return clusters, sizes
[docs]def analyze_transitions(clusters, normalize=True, plot_type=None): """ Analyzes transitions between clusters, including estimating the transition matrix, generating/plotting a trajectory showing which cluster each configuration belongs to, and/or plotting the distribution of the clusters. Parameters ---------- clusters : dict A dictionary that contains the cluster indices (starting from 1) as the keys and the lists of members (represented by time frames) as the values. normalize : bool, Optional Whether to normalize the output transition matrix. The default is :code:`True`. plot_type : str, Optional The type of the figure to be plotted. The default is :code:`None`, which means no figure will be plotted. The other options are :code:`'bar'` and :code:`'xy'`. The former plots the distribution of the clusters, while the latter plots the trajectory showing which cluster each configuration belongs to. Returns ------- transmtx: numpy.ndarray The transition matrix. traj: numpy.ndarray The trajectory showing which cluster each configuration belongs to. t_transitions: dict A dictionary with keys being pairs of cluster indices and values being the time frames of transitions between the two clusters. If there was no transition, an empty dictionary will be returned. """ # Combine all cluster members and sort them all_members = [] for key in clusters: all_members.extend([(member, key) for member in clusters[key]]) all_members.sort() # Generate the trajectory t = np.array([member[0] for member in all_members]) traj = np.array([member[1] for member in all_members]) # Generate the transition matrix # Since traj2transmtx assumes an index starting from 0, we subtract 1 from the trajectory transmtx = analyze_traj.traj2transmtx(traj - 1, len(clusters), normalize=normalize) # Generate the dictionary of transitions t_transitions = {} for i in range(len(traj) - 1): if traj[i] != traj[i + 1]: pair = tuple(sorted((traj[i], traj[i + 1]))) if pair not in t_transitions: t_transitions[pair] = [t[i + 1]] else: t_transitions[pair].append(t[i + 1]) if plot_type is not None: if plot_type == 'bar': fig = plt.figure() ax = fig.add_subplot(111) plt.bar(clusters.keys(), [len(clusters[i]) for i in clusters], width=0.35) plt.xlabel('Cluster index') plt.ylabel('Number of configurations') plt.grid() ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True)) ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True)) plt.savefig('cluster_distribution.png', dpi=600) elif plot_type == 'xy': fig = plt.figure() ax = fig.add_subplot(111) if len(t) > 1000: t = t / 1000 # convert to ns units = 'ns' else: units = 'ps' plt.plot(t, traj) plt.xlabel(f'Time frame ({units})') plt.ylabel('Cluster index') ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True)) plt.grid() plt.savefig('cluster_traj.png', dpi=600) else: raise ValueError(f'Invalid plot type: {plot_type}. The plot type must be either "bar" or "xy" or unspecified.') # noqa: E501 return transmtx, traj, t_transitions