####################################################################
# #
# ensemble_md, #
# a python package for running GROMACS simulation ensembles #
# #
# Written by Wei-Tse Hsu <wehs7661@colorado.edu> #
# Copyright (c) 2022 University of Colorado Boulder #
# #
####################################################################
"""
The :obj:`.analyze_traj` module provides methods for analyzing trajectories of a REXEE simulation.
"""
import copy
import numpy as np
import matplotlib.pyplot as plt
from itertools import chain
from matplotlib.ticker import MaxNLocator
from alchemlyb.parsing.gmx import _get_headers as get_headers
from alchemlyb.parsing.gmx import _extract_dataframe as extract_dataframe
from ensemble_md.utils import utils
[docs]def stitch_time_series(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, save_npy=True):
"""
Stitches the state-space/CV-space trajectories for each starting configuration from DHDL files
or PLUMED output files generated at different iterations.
Parameters
----------
files : list
A list of lists of file paths to GROMACS DHDL files or general GROMACS XVG files or PLUMED ouptput files.
Specifically, :code:`files[i]` should be a list containing the files of interest from all iterations in
replica :code:`i`. The files should be sorted naturally.
rep_trajs : list
A list of lists that represents the replica-space trajectories for each starting configuration. For example,
:code:`rep_trajs[0] = [0, 2, 3, 0, 1, ...]` means that starting configuration 0 transitioned to replica 2, then
3, 0, 1, in iterations 1, 2, 3, 4, ..., respectively. This can be read from the :code:`rep_trajs.npy` file
generated by the REXEE simulation.
shifts : list, Optional
A list of values for shifting the local state indices to global indices for each replica. The length of the
list should be equal to the number of replicas. This is only needed when :code:`dhdl=True`.
dhdl : bool, Optional
Whether the input files are GROMACS dhdl files. If :code:`dhdl=False`, the input files must be readable
by :func:`numpy.loadtxt` assuming that the start of a comment is indicated by either the :code:`#` or :code:`@`
characters. Such files include any GROMACS XVG files or PLUMED output files (output by plumed driver,
for instance). In this case, trajectories of the configurational collective variable of interest are generated.
The default is :code:`True`.
col_idx : int, Optional
The index of the column to be extracted from the input files. This is only needed when :code:`dhdl=False`,
By default, we extract the last column.
save_npy : bool, Optional
Whether to save the output trajectories as an NPY file. The default is :code:`True`.
Returns
-------
trajs : list
A list that contains lists of state-space/CV-space trajectory (in global indices) for each starting
configuration. For example, :code:`trajs[i]` is the state-space/CV-space trajectory of starting
configuration :code:`i`.
Example
-------
>>> import glob
>>> import natsort
>>> import numpy as np
>>> from ensemble_md.analysis import analyze_traj
>>> n_sim = 4 # Assuming 4 replicas sampling states sets 0-3, 2-5, 4-7, and 6-9, respectively.
>>> files = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(n_sim)]
>>> shifts = [0, 2, 4, 6]
>>> rep_trajs = np.load('rep_trajs.npy') # rep_trajs.npy is generated by the REXEE simulation
>>> state_trajs = analyze_traj.stitch_time_series(files, rep_trajs, shifts, dhdl=True, save_npy=True)
See also
--------
:func:`.stitch_time_series_for_sim`
:func:`.stitch_xtc_trajs`
"""
n_configs = len(files) # number of starting configurations
n_iter = len(files[0]) # number of iterations per replica
# First figure out which dhdl/plumed output files each starting configuration corresponds to
# files_sorted[i] contains the dhdl/plumed output files for starting configuration i sorted
# based on iteration indices
files_sorted = [[] for i in range(n_configs)]
for i in range(n_configs):
for j in range(n_iter):
files_sorted[i].append(files[rep_trajs[i][j]][j])
# Then, stitch the trajectories for each starting configuration
# Unlike stitch_time_series_for_sim, there is no way to check the continuity.
trajs = [[] for i in range(n_configs)] # for each starting configuration
for i in range(n_configs):
for j in range(n_iter):
if dhdl:
traj, _ = extract_state_traj(files_sorted[i][j])
# Shift the indices so that global indices are used.
shift_idx = rep_trajs[i][j]
traj = list(np.array(traj) + shifts[shift_idx])
else:
traj = np.loadtxt(files_sorted[i][j], comments=['#', '@'])[:, col_idx]
if j != n_iter - 1:
traj = traj[:-1]
trajs[i].extend(traj)
if save_npy is True:
if dhdl:
np.save('state_trajs.npy', trajs)
else:
np.save('cv_trajs.npy', trajs)
return trajs
[docs]def stitch_time_series_for_sim(files, shifts=None, dhdl=True, col_idx=-1, save_npy=True):
"""
Stitches the state-space/CV-space time series in the same replica/simulation folder.
That is, the output time series is contributed by multiple different trajectories (initiated by
different starting configurations) to a certain state set.
Parameters
----------
files : list
A list of lists of file paths to GROMACS DHDL files or general GROMACS XVG files
or PLUMED output files. Specifically, :code:`files[i]` should be a list containing
the files of interest from all iterations in replica :code:`i`. The files should be sorted naturally.
shifts : list, Optional
A list of values for shifting the local state indices to global indices for each replica. The length of the
list should be equal to the number of replicas. This is only needed when :code:`dhdl=True`.
dhdl : bool, Optional
Whether the input files are GROMACS dhdl files. If :code:`dhdl=False`, the input files must be readable
by :func:`numpy.loadtxt` assuming that the start of a comment is indicated by either the :code:`#` or :code:`@`
characters. Such files include any GROMACS XVG files or PLUMED output files (output by plumed driver, for
instance). In this case, trajectories of the configurational collective variable of interest are generated.
The default is :code:`True`.
col_idx : int, Optional
The index of the column to be extracted from the input files. This is only needed when :code:`dhdl=False`,
By default, we extract the last column.
save_npy : bool, Optional
Whether to save the output trajectories as an NPY file. The default is :code:`True`.
Returns
-------
trajs : list
A list that contains lists of state-space/CV-space trajectory (in global indices) for each replica.
For example, :code:`trajs[i]` is the state-space/CV-space trajectory of replica :code:`i`.
Example
-------
>>> import glob
>>> import natsort
>>> from ensemble_md.analysis import analyze_traj
>>> n_sim = 4 # Assuming 4 replicas sampling states sets 0-3, 2-5, 4-7, and 6-9, respectively.
>>> files = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(n_sim)]
>>> shifts = [0, 2, 4, 6]
>>> state_trajs = analyze_traj.stitch_time_series(files, shifts, dhdl=True, save_npy=True)
See also
--------
:func:`.stitch_time_series`
:func:`.stitch_xtc_trajs`
"""
n_sim = len(files) # number of replicas
n_iter = len(files[0]) # number of iterations per replica
trajs = [[] for i in range(n_sim)]
t_last, val_last = None, None # just for checking the continuity of the trajectory
for i in range(n_sim):
for j in range(n_iter):
if dhdl:
traj, t = extract_state_traj(files[i][j])
else:
traj = np.loadtxt(files[i][j], comments=['#', '@'])[:, col_idx]
t = np.loadtxt(files[i][j], comments=['#', '@'])[:, 0]
if j != 0:
# Check the continuity of the trajectory
if traj[0] != val_last or t[0] != t_last:
err_str = f'The first frame of iteration {j} in replica {i} is not continuous with the last frame of the previous iteration. ' # noqa: E501
err_str += f'Please check files {files[i][j - 1]} and {files[i][j]}.'
raise ValueError(err_str)
t_last = t[-1]
val_last = traj[-1]
if j != n_iter - 1:
traj = traj[:-1] # remove the last frame, which is the same as the first of the next time series.
trajs[i].extend(traj)
# All segments for the same replica should have the same shift
if dhdl:
trajs[i] = list(np.array(trajs[i]) + shifts[i])
# Save the trajectories as an NPY file if desired
if save_npy:
np.save('state_trajs_for_sim.npy', trajs)
return trajs
[docs]def stitch_xtc_trajs(gmx_executable, files, rep_trajs):
"""
Demuxes GROMACS trajectories from different replicas into individual continuous trajectories.
Parameters
----------
gmx_executable : str
The path of the GROMACS executable.
files : list
A list of lists of file paths to GROMACS XTC files. Specifically, :code:`files[i]` should be a list containing
the paths to the files of interest from all iterations in replica :code:`i`. The files should be sorted
naturally.
rep_trajs : list
A list of lists that represents the replica space trajectories for each starting configuration. For example,
:code:`rep_trajs[0] = [0, 2, 3, 0, 1, ...]` means that starting configuration 0 transitioned to replica 2, then
3, 0, 1, in iterations 1, 2, 3, 4, ..., respectively. This can be read from the :code:`rep_trajs.npy` file
generated by the REXEE simulation.
See also
--------
:func:`.stitch_time_series`
:func:`.stitch_time_series_for_sim`
"""
n_sim = len(files) # number of replicas
n_iter = len(files[0]) # number of iterations per replica
# First figure out which xtc files each starting configuration corresponds to
# files_sorted[i] contains the xtc files for starting configuration i sorted
# based on iteration indices
files_sorted = [[] for i in range(n_sim)]
for i in range(n_sim):
for j in range(n_iter):
files_sorted[i].append(files[rep_trajs[i][j]][j])
# Then, stitch the trajectories for each starting configuration
for i in range(n_sim):
print(f'Recovering the continuous trajectory {i} by concatenating the XTC files ...')
arguments = [gmx_executable, 'trjcat', '-f']
arguments.extend(files_sorted[i])
arguments.extend(['-o', f'traj_{i}.xtc'])
returncode, stdout, stderr = utils.run_gmx_cmd(arguments)
if returncode != 0:
print(f'Error with return code: {returncode}):\n{stderr}')
[docs]def convert_npy2xvg(trajs, dt, subsampling=1):
"""
Convert a :code:`state_trajs.npy` or :code:`cv_trajs.npy` file to :math:`R` XVG files
that have two columns: time (ps) and state index/CV value. (:math:`R` is the number of replicas.)
Parameters
----------
trajs : numpy.ndarray
The state-space or CV-space trajectories read from :code:`state_trajs.npy` or :code:`cv_trajs.npy`.
dt : float
The time interval (in ps) between consecutive frames of the trajectories.
subsampling : int, Optional
The stride for subsampling the time series. The default is 1.
"""
n_configs = len(trajs)
for i in range(n_configs):
traj = trajs[i]
t = np.arange(len(traj)) * dt
headers = ['This file was created by ensemble_md']
if 'int' in str(traj.dtype):
headers.extend(['Time (ps) v.s. State index'])
np.savetxt(f'traj_{i}.xvg', np.transpose([t[::subsampling], traj[::subsampling]]), header='\n'.join(headers), fmt=['%-8.1f', '%4.0f']) # noqa: E501
else:
headers.extend(['Time (ps) v.s. CV'])
np.savetxt(f'traj_{i}.xvg', np.transpose([t[::subsampling], traj[::subsampling]]), header='\n'.join(headers), fmt=['%-8.1f', '%8.6f']) # noqa: E501
[docs]def traj2transmtx(traj, N, normalize=True):
"""
Computes the transition matrix given a trajectory. For example, if a state-space
trajectory from a EXE or HREX simulation is given, a state-space transition matrix is returned.
If a trajectory showing transitions between replicas in a REXEE simulation is given,
a replica-space transition matrix is returned.
Parameters
---------
traj : list
A list of state indices showing the trajectory in the state space. The index
is assumed to start from 0.
N : int
The size (N) of the expcted transition matrix (N by N).
normalize : bool
Whether to normalize the matrix so that each row sum to 1. If :code:`normalize=False`, then
the entries will be the counts of transitions.
Returns
-------
transmtx : numpy.ndarray
The transition matrix computed from the trajectory
"""
transmtx = np.zeros([N, N])
for i in range(1, len(traj)):
transmtx[traj[i - 1], traj[i]] += 1 # counts of transitions
if normalize is True:
transmtx /= np.sum(transmtx, axis=1)[:, None] # normalize the transition matrix
transmtx[np.isnan(transmtx)] = 0 # for non-sampled state, there could be nan due to 0/0
return transmtx
[docs]def plot_rep_trajs(trajs, fig_name, dt=None, stride=None):
"""
Plots the replica-space trajectories for a REXEE simulation.
Parameters
----------
trajs : list
A list of lists that represent the all replica-space trajectories.
fig_name : str
The file path to save the figure.
dt : float or None, Optional
One trajectory timestep in ps. If :code:`dt=None`, the function assumes there are no time frames but MC steps.
The default is :code:`None`.
stride : int, Optional
The stride for plotting the time series. The default is 100 if the length of
any trajectory has more than one million frames. Otherwise, it will be 1. Typically
plotting more than 10 million frames can take a lot of memory.
See also
--------
:func:`.plot_state_trajs`
"""
n_sim = len(trajs)
cmap = plt.cm.ocean # other good options are CMRmap, gnuplot, terrain, turbo, brg, etc.
colors = [cmap(i) for i in np.arange(n_sim) / n_sim]
if dt is None:
x = np.arange(len(trajs[0]))
else:
x = np.arange(len(trajs[0])) * dt
if max(x) >= 10000:
x = x / 1000 # convert to ns
units = 'ns'
else:
units = 'ps'
if stride is None:
if len(trajs[0]) > 1000000:
stride = 100
else:
stride = 1
fig = plt.figure()
ax = fig.add_subplot(111)
for i in range(n_sim):
if len(trajs[0]) >= 100: # don't show the markers
plt.plot(x[::stride], trajs[i][::stride], color=colors[i], label=f'Trajectory {i}')
else:
plt.plot(x[::stride], trajs[i][::stride], color=colors[i], label=f'Trajectory {i}', marker='o')
if dt is None:
plt.xlabel('MC moves')
else:
plt.xlabel(f'Time ({units})')
plt.ylabel('Replica')
plt.grid()
plt.legend()
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
plt.savefig(f'{fig_name}', dpi=600)
[docs]def plot_state_trajs(trajs, state_ranges, fig_name, dt=None, stride=None, title_prefix='Trajectory'):
"""
Plots the state-space trajectories for a REXEE simulation.
Parameters
----------
trajs : list
A list of lists of state indices generated either from different continuous trajectories or from different
state sets (i.e. from different simulation folders). This can be generated by either
:func:`.stitch_time_series` or :func:`.stitch_time_series_for_sim`.
state_ranges : list
A list of lists of showing the state indices sampled by each replica.
fig_name : str
The file path to save the figure.
dt : float or None, Optional
One trajectory timestep in ps. If :code:`dt=None`, the function assumes there are no time frames but MC steps.
The default is :code:`None`.
stride : int, Optional
The stride for plotting the time series. The default is 10 if the length of
any trajectory has more than 100,000 frames. Otherwise, it will be 1. Typically
plotting more than 10 million frames can take a lot of memory.
title_prefix : str, Optional
The prefix shared by the titles of the subplots. For example, if :code:`title_prefix` is
set to "Trajectory", then the titles of the subplots will be "Trajectory 0", "Trajectory 1", ..., etc.
The default is :code:`'Trajectory'`.
See also
--------
:func:`.plot_rep_trajs`
:func:`.plot_state_hist`
"""
n_sim = len(trajs)
cmap = plt.cm.ocean # other good options are CMRmap, gnuplot, terrain, turbo, brg, etc.
colors = [cmap(i) for i in np.arange(n_sim) / n_sim]
if dt is None:
x = np.arange(len(trajs[0]))
else:
x = np.arange(len(trajs[0])) * dt
if max(x) >= 10000:
x = x / 1000 # convert to ns
units = 'ns'
else:
units = 'ps'
if stride is None:
if len(trajs[0]) > 100000:
stride = 10
else:
stride = 1
# x_range = [-5, len(trajs[0]) - 1 + 5]
x_range = [np.min(x), np.max(x)]
y_range = [-0.2, np.max(trajs) + 0.2]
n_configs = len(trajs)
n_rows, n_cols = utils._get_subplot_dimension(n_configs)
_, ax = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(2.5 * n_cols, 2.5 * n_rows))
for i in range(n_configs):
plt.subplot(n_rows, n_cols, i + 1)
# First color different regions showing state sets
for j in range(n_configs):
bounds = [list(state_ranges[j])[0], list(state_ranges[j])[-1]]
if j == 0:
bounds[0] -= 0.5
if j == n_configs - 1:
bounds[1] += 0.5
plt.fill_between(x_range, y1=bounds[1], y2=bounds[0], color=colors[j], alpha=0.1)
if len(trajs[0]) > 100000:
linewidth = 0.01
else:
linewidth = 1 # this is the default
# Finally, plot the trajectories
plt.plot(x[::stride], trajs[i][::stride], color=colors[i], linewidth=linewidth)
if dt is None:
plt.xlabel('MC moves')
else:
plt.xlabel(f'Time ({units})')
plt.ylabel('State')
plt.title(f'{title_prefix} {i}', fontweight='bold')
if len(trajs[0]) >= 10000:
plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
plt.xlim(x_range)
plt.ylim(y_range)
plt.grid()
# Remove redundant subplots
n_rm = n_cols * n_rows - n_configs
for i in range(n_rm):
ax.flat[-1 * (i + 1)].set_visible(False)
plt.tight_layout()
plt.savefig(f'{fig_name}', dpi=600)
[docs]def plot_state_hist(trajs, state_ranges, fig_name, stack=True, figsize=None, prefix='Trajectory', subplots=False, save_hist=True): # noqa: E501
"""
Plots the histograms of state visitation for all replicas in a REXEE simulation.
Parameters
----------
trajs : list
A list of lists of state indices generated either from different continuous trajectories or from different
state sets (i.e. from different simulation folders). This can be generated by either
:func:`.stitch_time_series` or :func:`.stitch_time_series_for_sim`.
state_ranges : list
A list of lists of showing the state indices sampled by each replica.
fig_name : str
The file path to save the figure.
stack : bool, Optional
Whether to stack the histograms. This parameter is only relevant when :code:`subplots` is :code:`False`.
The default is :code:`True`.
figsize : tuple, Optional
A tuple specifying the length and width of the output figure. The
default is :code:`(6.4, 4.8)` for cases having less than 30 states and :code:`(10, 4.8)` otherwise.
prefix : str, Optional
The prefix shared by the titles of the subplots, or the labels shown in the same plot.
For example, if :code:`prefix` is set to "Trajectory", then the titles/labels of the
will be "Trajectory 0", "Trajectory 1", ..., etc. The default is :code:`'Trajectory'`.
subplots : bool, Optional
Whether to plot the histograms in multiple subplots, with the title of
each based on the value of :code:`prefix`. The default is :code:`False`.
save_hist : bool, Optional
Whether to save the histogram data. The default is :code:`True`.
Returns
-------
hist_data : list
The histogram data of the each state index time series.
See also
--------
:func:`.plot_state_trajs`
"""
n_configs = len(trajs)
n_states = max(max(state_ranges)) + 1
cmap = plt.cm.ocean # other good options are CMRmap, gnuplot, terrain, turbo, brg, etc.
colors = [cmap(i) for i in np.arange(n_configs) / n_configs]
hist_data = []
lower_bound, upper_bound = -0.5, n_states - 0.5
for traj in trajs:
# bins for different traj in trajs should be the same
hist, bins = np.histogram(traj, bins=np.arange(lower_bound, upper_bound + 1, 1))
hist_data.append(hist)
if save_hist is True:
np.save('hist_data.npy', hist_data)
# Use the same bins for all histograms
bins = bins[:-1] # Remove the last bin edge because there are n+1 bin edges for n bins
# Start plotting
if figsize is None:
if max(trajs[-1]) > 30:
figsize = (10, 4.8)
else:
figsize = (6.4, 4.8) # default
fig = plt.figure(figsize=figsize)
if subplots is False:
# Initialize the list of bottom (only matters for stack = True)
bottom = [0] * n_states
ax = fig.add_subplot(111)
y_max = 0
for i in range(n_configs):
max_count = np.max(bottom + hist_data[i])
if max_count > y_max:
y_max = max_count
plt.bar(
range(n_states),
hist_data[i],
align='center',
width=1,
color=colors[i],
edgecolor='black',
label=f'{prefix} {i}',
alpha=0.5,
bottom=bottom
)
if stack is True:
bottom = [b + c for b, c in zip(bottom, hist_data[i])]
plt.xticks(range(n_states))
# Here we color the different regions to show state sets
y_max *= 1.05
for i in range(n_configs):
bounds = [list(state_ranges[i])[0], list(state_ranges[i])[-1]]
if i == 0:
bounds[0] -= 0.5
if i == n_configs - 1:
bounds[1] += 0.5
plt.fill_betweenx([0, y_max], x1=bounds[1] + 0.5, x2=bounds[0] - 0.5, color=colors[i], alpha=0.1, zorder=0)
plt.xlim([lower_bound, upper_bound])
plt.ylim([0, y_max])
plt.xlabel('State index')
plt.ylabel('Count')
plt.grid()
plt.legend()
plt.tight_layout()
plt.savefig(f'{fig_name}', dpi=600)
else:
n_rows, n_cols = utils._get_subplot_dimension(n_configs)
_, ax = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(4 * n_cols, 3 * n_rows))
for i in range(n_configs):
plt.subplot(n_rows, n_cols, i + 1)
plt.bar(
state_ranges[i],
hist_data[i][state_ranges[i]],
align='center',
width=1,
edgecolor='black',
alpha=0.5
)
plt.xticks(state_ranges[i], fontsize=8)
plt.xlim([state_ranges[i][0] - 0.5, state_ranges[i][-1] + 0.5])
plt.xlabel('State index')
plt.ylabel('Count')
plt.title(f'{prefix} {i}')
plt.grid()
plt.tight_layout()
plt.savefig(f'{fig_name}', dpi=600)
return hist_data
[docs]def calc_hist_rmse(hist_data, state_ranges):
"""
Calculates the RMSE of accumulated histogram counts of the state index. The reference
is determined by assuming all alchemical states have equal chances to be visited, i.e.
the alchemical weights are perfect.
Parameters
----------
hist_data : list
The histogram data of the state index for each trajectory.
state_ranges : list
A list of lists of showing the state indices sampled by each replica.
Returns
-------
rmse : float
The RMSE value of accumulated histogram counts of the state index, with respect to the case
where equal sampling is reached for all states.
"""
N = np.max(state_ranges) + 1 # the number of states
n_accessible = np.histogram(state_ranges, bins=np.arange(-0.5, N + 0.5))[0]
n_samples = np.sum(hist_data) # Should be equal to (n_iter * nst_sim / nstdhdl + 1) * n_sim
n_states_sum = np.sum(n_accessible) # n_sub * n_sim
hist_ref = n_samples * (n_accessible / n_states_sum) # may not be all integers but should be fine
hist_acc = np.sum(hist_data, axis=0)
rmse = np.sqrt(np.sum((hist_acc - hist_ref) ** 2) / len(hist_ref))
return rmse
[docs]def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'):
"""
Calculates and plots the average transit times for each trajectory, including the time
it takes from states 0 to k, from k to 0 and from 0 to k back to 0 (i.e., round-trip time).
If there are more than 100 round-trips, 3 histograms corresponding to t_0k, t_k0 and t_roundtrip
will be generated.
Parameters
----------
trajs : list
A list of lists that represent the state-space trajectories of all continuous trajectories.
N : int
The total number of states in the whole range.
fig_prefix : str, Optional
A prefix to use for all generated figures. The default is :code:`None`, which means no prefix.
dt : float or None, Optional
One trajectory timestep in ps. If :code:`dt=None`, the function assumes there are no time frames but MC steps.
The default is :code:`None`.
folder : str, Optional
The directory where the figures will be saved. The default is the current directory.
Returns
-------
t_0k_list : list
A list of transit times from states 0 to k for each trajectory.
t_k0_list : list
A list of transit times from states k to 0 for each trajectory.
t_roundtrip_list : list
A list of round-trip times for each trajectory.
units : str
The units of the time.
"""
if dt is None:
x = np.arange(len(trajs[0]))
units = 'step'
else:
x = np.arange(len(trajs[0])) * dt
if max(x) >= 10000:
x = x / 1000 # convert to ns
units = 'ns'
else:
units = 'ps'
# The lists below are for storing data corresponding to different trajectories.
t_0k_list, t_k0_list, t_roundtrip_list = [], [], []
t_0k_avg, t_k0_avg, t_roundtrip_avg = [], [], []
sci = False # whether to use scientific notation in the y-axis in the plot
t_max = 0 # the maximum time across trajectories --> just for decideing the units
for i in range(len(trajs)):
traj = trajs[i]
last_visited = None # last visited end
k = N - 1
t_0, t_k = [], [] # time frames visting states 0 and k (k is the other end)
# time spent from statkes 0 to k, k to 0 and the round-trip time (from 0 to k to 0)
t_0k, t_k0, t_roundtrip = [], [], []
end_0_found, end_k_found = None, None
for t in range(len(traj)):
if traj[t] == 0:
end_0_found = True
if last_visited != 0:
t_0.append(t)
if last_visited == k:
t_k0.append(t - t_k[-1])
last_visited = 0
if traj[t] == k:
end_k_found = True
if last_visited != k:
t_k.append(t)
if last_visited == 0:
t_0k.append(t - t_0[-1])
last_visited = k
# Here we figure out the round-trip time from t_0k and t_k0.
t_0k_, t_k0_ = copy.deepcopy(t_0k), copy.deepcopy(t_k0)
if len(t_0k_) != len(t_k0_): # then it must be len(t_0k) = len(t_k0) + 1 or len(t_k0) = len(t_0k) + 1, so we drop the last element of the larger list # noqa: E501
if len(t_0k_) > len(t_k0_):
t_0k_.pop()
else:
t_k0_.pop()
t_roundtrip = list(np.array(t_0k_) + np.array(t_k0_))
if end_0_found is True and end_k_found is True:
if dt is not None:
units = 'ps'
t_0k = list(np.array(t_0k) * dt) # units: ps
t_k0 = list(np.array(t_k0) * dt) # units: ps
t_roundtrip = list(np.array(t_roundtrip) * dt) # units: ps
if len(t_0k) + len(t_k0) + len(t_roundtrip) > 0: # i.e. not all are empty
if np.max(list(chain.from_iterable([t_0k, t_k0, t_roundtrip]))) > t_max:
t_max = np.max(list(chain.from_iterable([t_0k, t_k0, t_roundtrip])))
if t_max >= 10000:
units = 'ns'
t_0k = list(np.array(t_0k) / 1000) # units: ns
t_k0 = list(np.array(t_k0) / 1000) # units: ns
t_roundtrip = list(np.array(t_roundtrip) / 1000) # units: ns
t_0k_list.append(t_0k)
t_0k_avg.append(np.mean(t_0k))
t_k0_list.append(t_k0)
t_k0_avg.append(np.mean(t_k0))
t_roundtrip_list.append(t_roundtrip)
t_roundtrip_avg.append(np.mean(t_roundtrip))
if len(t_0k) + len(t_k0) + len(t_roundtrip) > 0: # i.e. not all are empty
flattened_list = list(chain.from_iterable([t_0k, t_k0, t_roundtrip]))
if sci is False and np.max(flattened_list) >= 10000:
sci = True
else:
t_0k_list.append([])
t_k0_list.append([])
t_roundtrip_list.append([])
# Now we plot! (If there are no events, the figures will just be blank)
meta_list = [t_0k_list, t_k0_list, t_roundtrip_list]
y_labels = [
f'Average transit time from states 0 to k ({units})',
f'Average transit time from states k to 0 ({units})',
f'Average round-trip time ({units})',
]
fig_names = ['t_0k.png', 't_k0.png', 't_roundtrip.png']
for t in range(len(meta_list)):
t_list = meta_list[t]
if all(not x for x in t_list):
# If the nested list is empty, no plots will be generated.
pass
else:
len_list = [len(i) for i in t_list]
if np.max(len_list) <= 10:
marker = 'o'
else:
marker = ''
plt.figure()
for i in range(len(t_list)): # t_list[i] is the list for trajectory i
plt.plot(np.arange(len(t_list[i])) + 1, t_list[i], label=f'Trajectory {i}', marker=marker)
if np.max(list(chain.from_iterable(t_list))) >= 10000:
plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
plt.xlabel('Event index')
plt.ylabel(f'{y_labels[t]}')
plt.grid()
plt.legend()
if fig_prefix is None:
plt.savefig(f'{folder}/{fig_names[t]}', dpi=600)
else:
plt.savefig(f'{folder}/{fig_prefix}_{fig_names[t]}', dpi=600)
lens = [len(t_list[i]) for i in range(len(t_list))]
if np.min(lens) >= 100: # plot a histogram
counts, bins = np.histogram(t_list[i])
plt.figure()
for i in range(len(t_list)):
plt.hist(t_list[i], bins=int(len(t_list[i]) / 20), label=f'Trajectory {i}')
if max(counts) >= 10000:
plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
plt.xlabel(f'{y_labels[t]}')
plt.ylabel('Event count')
plt.grid()
plt.legend()
if fig_prefix is None:
plt.savefig(f'{folder}/hist_{fig_names[t]}', dpi=600)
else:
plt.savefig(f'{folder}/{fig_prefix}_hist_{fig_names[t]}', dpi=600)
return t_0k_list, t_k0_list, t_roundtrip_list, units
[docs]def plot_g_vecs(g_vecs, refs=None, refs_err=None, plot_rmse=True):
"""
For each alchemical intermediate state, plots the alchemical weight as a function of
the iteration index. Note that the alchemical weight of the first state (which is always 0)
is skipped. If the reference values are given, they will be plotted in the figure (as horizontoal lines)
and a final RMSE will be calculated. Note that this function is only meaningful for weight-updating
REXEE simulations.
Parameters
----------
g_vecs : numpy.ndarray
The alchemical weights of all states as a function of the iteration index. The shape should
be (n_iterations, n_states). Such an array can be directly read from :code:`g_vecs.npy`
generated by a REXEE simulation.
refs : numpy.ndarray
The reference values of the alchemical weights. The default is :code:`None`.
refs_err : list or numpy.ndarray, Optional
The errors of the reference values. The default is :code:`None`.
plot_rmse : bool, Optional
Whether to plot RMSE as a function of the iteration index. The default is :code:`True`.
"""
# n_iter, n_state = g_vecs.shape[0], g_vecs.shape[1]
g_vecs = np.transpose(g_vecs)
n_states = len(g_vecs)
cmap = plt.cm.ocean # other good options are CMRmap, gnuplot, terrain, turbo, brg, etc.
colors = [cmap(i) for i in np.arange(n_states) / n_states]
plt.figure()
for i in range(1, len(g_vecs)):
if len(g_vecs[0]) < 100:
plt.plot(range(len(g_vecs[i])), g_vecs[i], label=f'State {i}', c=colors[i], linewidth=0.8, marker='o', markersize=2) # noqa: E501
else: # plot without markers
plt.plot(range(len(g_vecs[i])), g_vecs[i], label=f'State {i}', c=colors[i], linewidth=0.8)
plt.xlabel('Iteration index')
plt.ylabel('Alchemical weight (kT)')
ax = plt.gca()
x_range = ax.get_xlim()
plt.xlim([0, x_range[1]])
plt.grid()
plt.legend(loc='center left', bbox_to_anchor=(1, 0.2))
if refs is not None:
for i in range(1, len(refs)):
plt.axhline(y=refs[i], c='black', linestyle='--', linewidth=0.5)
if refs_err is not None:
ax = plt.gca()
x_range = ax.get_xlim()
plt.fill_between(x_range, y1=refs[i] - refs_err[i], y2=refs[i] + refs_err[i], color='lightgreen')
# Calculate the RMSE as a function of the iteration index
RMSE_list = [np.sqrt(np.mean((g_vecs[:, i] - refs) ** 2)) for i in range(len(g_vecs[0]))]
plt.text(0.02, 0.95, f'Final RMSE: {RMSE_list[-1]:.3f} kT', transform=ax.transAxes)
print(f'Final RMSE: {RMSE_list[-1]: .3f} kT')
plt.tight_layout()
plt.savefig('g_vecs.png', dpi=600)
if refs is not None and plot_rmse is True:
plt.figure()
plt.plot(range(len(g_vecs[i])), RMSE_list)
plt.xlabel('Iteration index')
plt.ylabel('RMSE in the alchemical weights (kT)')
plt.grid()
plt.savefig('g_vecs_rmse.png', dpi=600)
[docs]def get_swaps(REXEE_log='run_REXEE_log.txt'):
"""
For each replica, identifies the states involved in proposed and accepted swaps.
Parameters
----------
REXEE_log : str, Optional
The output log file of the REXEE simulation. The default is :code:`'run_REXEE_log.txt'`.
Returns
-------
proposed_swaps : list
A list of dictionaries showing where the swaps were proposed in
each replica. Each dictionary (corresponding to one replica) have
keys being the global state indices and values being the number of
proposed swaps that involved the state indicated by the key.
accepted_swaps : list
A list of dictionaries showing where the swaps were accepted in
each replica. Each dictionary (corresponding to one replica) have
keys being the global state indices and values being the number of
accepted swaps that involved the state indicated by the key.
Example
-------
Below is an example based on a REXEE simulations having four replicas sampling states
0-4, 1-5, 2-6 and 3-7, respectively.
>>> from ensemble_md.analysis import analyze_traj
>>> proposed_swaps, accepted_swaps = analyze_traj.get_swaps('run_REXEE_log.txt')
>>> for i in range(len(proposed_swaps)):
>>> print(proposed_swaps[i])
{0: 0, 1: 3, 2: 1, 3: 0, 4: 0}
{1: 2, 2: 2, 3: 0, 4: 1, 5: 1}
{2: 3, 3: 3, 4: 2, 5: 0, 6: 0}
{3: 0, 4: 1, 5: 0, 6: 3, 7: 0}
todo
----
We should be able to only use :code:`rep_trajs.npy` and :code:`state_trajs.npy`
instead of parsing the REXEE log file to reach the same goal.
See also
--------
:func:`.plot_swaps`
"""
f = open(REXEE_log, 'r')
lines = f.readlines()
f.close()
state_list = []
for line in lines:
if 'Number of replicas: ' in line:
n_sim = int(line.split('Number of replicas: ')[-1])
if '- Replica' in line:
state_list.append(eval(line.split('States ')[-1]))
if 'Iteration' in line:
break
# Note that proposed_swaps and accepted_swaps are initialized in the same way
proposed_swaps = [{i: 0 for i in state_list[j]} for j in range(n_sim)] # Key: global state index; Value: The number of accepted swaps # noqa: E501
accepted_swaps = [{i: 0 for i in state_list[j]} for j in range(n_sim)] # Key: global state index; Value: The number of accepted swaps # noqa: E501
state_trajs = [[] for i in range(n_sim)] # the state-space trajectory for each REPLICA (not trajectory)
for line in lines:
if 'Simulation' in line and 'Global state' in line:
rep = int(line.split(':')[0].split()[-1])
state = int(line.split(',')[0].split()[-1])
state_trajs[rep].append(state)
if 'Proposed swap' in line:
swap = eval(line.split(': ')[-1])
proposed_swaps[swap[0]][state_trajs[swap[0]][-1]] += 1 # states_trajs[swap[0]][-1] is the last state sampled by swap[0] # noqa: E501
proposed_swaps[swap[1]][state_trajs[swap[1]][-1]] += 1 # states_trajs[swap[1]][-1] is the last state sampled by swap[1] # noqa: E501
if 'Swap accepted!' in line:
accepted_swaps[swap[0]][state_trajs[swap[0]][-1]] += 1 # states_trajs[swap[0]][-1] is the last state sampled by swap[0] # noqa: E501
accepted_swaps[swap[1]][state_trajs[swap[1]][-1]] += 1 # states_trajs[swap[1]][-1] is the last state sampled by swap[1] # noqa: E501
return proposed_swaps, accepted_swaps
[docs]def plot_swaps(swaps, swap_type='', stack=True, figsize=None):
"""
Plots the histogram of the proposed swaps or accepted swaps for each replica.
Parameters
----------
swaps : list
A list of dictionaries showing showing the number of swaps for each
state for each replica. This list could be either of the outputs from :obj:`.get_swaps`.
swap_type : str, Optional
The type of swaps to be plotted. Common options include :code:`'accepted'` and :code:`'proposed'`.
This value will only influence the name of y-axis and the output file name. The default is an empty string.
stack : bool, Optional
Whether to stack the histograms. The default is :code:`True`.
figsize : tuple, Optional
A tuple specifying the length and width of the output figure. The
default is :code:`(6.4, 4.8)` for cases having less than 30 states and :code:`(10, 4.8)` otherwise.
See also
--------
:func:`.get_swaps`
"""
n_sim = len(swaps)
n_states = max(max(d.keys()) for d in swaps) + 1
lower_bound, upper_bound = -0.5, n_states - 0.5
state_ranges = [list(swaps[i].keys()) for i in range(n_sim)]
cmap = plt.cm.ocean
colors = [cmap(i) for i in np.arange(n_sim) / n_sim]
# A new list of dictionaries, each of which consider all state indies
full_data = [{state: d.get(state, 0) for state in range(n_states)} for d in swaps] # d.get(state, 0) returns 0 if the state is unavilable # noqa: E501
# counts of swaps for all states
counts_list = [[d[state] for state in range(n_states)] for d in full_data]
# Initialize the list of bottom
bottom = [0] * n_states
if figsize is None:
if n_states > 30:
figsize = (10, 4.8)
else:
figsize = (6.4, 4.8) # default
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
for i in range(n_sim):
plt.bar(
range(n_states),
counts_list[i],
align='center',
width=1,
color=colors[i],
edgecolor='black',
label=f'Replica {i}',
alpha=0.5,
bottom=bottom
)
if stack is True:
bottom = [b + c for b, c in zip(bottom, counts_list[i])]
plt.xticks(range(n_states))
# Here we color the different regions to show state sets
y_min, y_max = ax.get_ylim()
for i in range(n_sim):
bounds = [list(state_ranges[i])[0], list(state_ranges[i])[-1]]
if i == 0:
bounds[0] -= 0.5
if i == n_sim - 1:
bounds[1] += 0.5
plt.fill_betweenx([y_min, y_max], x1=bounds[1] + 0.5, x2=bounds[0] - 0.5, color=colors[i], alpha=0.1, zorder=0)
plt.xlim([lower_bound, upper_bound])
# plt.ylim([y_min, y_max])
plt.xlabel('State')
if swap_type == '':
plt.ylabel('Number of swaps')
else:
plt.ylabel(f'Number of {swap_type} swaps')
plt.grid()
plt.legend()
plt.tight_layout()
if swap_type == '':
plt.savefig('swaps.png', dpi=600)
else:
plt.savefig(f'{swap_type}_swaps.png', dpi=600)
[docs]def get_g_evolution(log_files, start_state, end_state, avg_frac=0, avg_from_last_update=False):
"""
For a weight-updating simulation, gets the time series of the alchemical
weights of all states. Note that this funciton is only suitable for analyzing
either a single expanded ensemble simulation or a replica in a REXEE simulation.
For the latter case, all the log files for the replica should be provided.
Parameters
----------
log_files : list
The list of file paths to the log file(s). If multiple log files are provided (for a REXEE simulation),
please make sure the files are in the correct order such that the time series of the alchemical weights
are continuous.
start_state : int
The index of the first state of interest. The index starts from 0.
end_state : int
The index of the last state of interest. The index start from 0. For example, if :code:`start_state`
is set to 1 and :code:`end_state` is set to 3, then the weight evolution for
states 1, 2 and 3 will be extracted.
avg_frac : float, Optional
The fraction of the last part of the simulation to be averaged. The
default is 0, which means no averaging. Note that this parameter is
ignored if :code:`avg_from_last_update` is :code:`True`.
avg_from_last_update : bool, Optional
Whether to average from the last update of the Wang-Landau incrementor. If this option is set to
:code:`False`, or the option is set to :code:`True` but the Wang-Landau incrementor was not updated
in the provided log file(s), the all weights will be used for averging.
Returns
-------
g_vecs_all : list
The alchemical weights of all states as a function of time.
It should be a list of lists. For example, :code:`g_vecs_all[i]` should be the
alchemical weights of all states at time frame with index :code:`i`.
Weights after equilibration are not included.
g_vecs_avg : list
The alchemical weights of all states averaged over the last part of
the simulation. If :code:`avg_frac` is 0, :code:`None` will be returned.
Note that weights after equilibration are not considered.
g_vecs_err : list
The errors of the alchemical weights of all states averaged over the
last part of the simulation. If :code:`avg_frac` is 0 and :code:`avg_from_last_update`
is :code:`False`, :code:`None` will be returned. Note that weights after equilibration are not considered.
Example
-------
>>> from ensemble_md import analyze_traj
>>> log_files = ['EXE.log'] # For analyzing a single expanded ensemble simulation
>>> results = analyze_traj.get_g_evolution(log_files, start_state=0, end_state=6)
"""
g_vecs_all = []
idx_updates = [] # the indices of the data points corresponding to the updates of wl-delta
for log_file in log_files:
f = open(log_file, "r")
lines = f.readlines()
f.close()
n = -1
find_equil = False
for line in lines:
n += 1
if "Count G(in kT)" in line: # this line is lines[n]
w = [] # the list of weights at this time frame
for i in range(start_state + 1, end_state + 1):
if "<<" in lines[n + i]:
w.append(float(lines[n + i].split()[-3]))
else:
w.append(float(lines[n + i].split()[-2]))
if find_equil is False:
g_vecs_all.append(w)
if 'weights are now' in line:
idx_updates.append(len(g_vecs_all) - 1)
if "Weights have equilibrated" in line:
find_equil = True
# Usually, the line two lines above "Weights have been equilibrated" is the line
# "Step xxx: weights are now: xxx", but there could be exceptions, in which case
# we just do not append anything since the last fixed weights should have been alreayd appended.
# The exception happens when the change of the WL incrmentor and happened at the time when
# the log file is written, in which case one WL incrementor below than the cutoff will be printed,
# leading to different formats of the log file where "weights are now" is not in lines[n-2].
if "weights are now:" in lines[n-2]:
w = [float(i) for i in lines[n - 2].split(':')[-1].split()]
g_vecs_all.append(w)
break
if avg_from_last_update is True:
# If the weights are equilibrated, then the last occurrence of "weights are now"
# is right before the equilibration message, in which case we want to average
# from the second last occurrence of "weights are now".
if find_equil is True:
idx_updates = idx_updates[:-1]
if idx_updates == []:
print('Note: wl-delta was not updated in the provided log file(s) so all weights are used for averaging.')
idx_last_update = -1 # so that all weights are used for averaging
else:
idx_last_update = idx_updates[-1]
g_vecs_avg = np.mean(g_vecs_all[idx_last_update + 1:], axis=0)
g_vecs_err = np.std(g_vecs_all[idx_last_update + 1:], axis=0, ddof=1)
else:
if avg_frac != 0:
n_avg = int(avg_frac * len(g_vecs_all))
g_vecs_avg = np.mean(g_vecs_all[-n_avg:], axis=0)
g_vecs_err = np.std(g_vecs_all[-n_avg:], axis=0, ddof=1)
else:
g_vecs_avg = None
g_vecs_err = None
return g_vecs_all, g_vecs_avg, g_vecs_err
[docs]def get_dg_evolution(log_files, start_state, end_state):
"""
For a weight-updating simulation, gets the time series of the weight
difference (:math:`Δg = g_2-g_1`) between the states of interest.
Parameters
----------
log_files : list
The list of file paths to the log file(s). If multiple log files are provided (for a REXEE simulation),
please make sure the files are in the correct order such that the time series of the alchemical weights
are continuous.
start_state : int
The index of the state (starting from 0) whose weight is :math:`g_1`.
end_state : int
The index of the state (starting from 0) whose weight is :math:`g_2`.
Returns
-------
dg : list
The time series of :math:`Δg`.
"""
# N_states = end_state - start_state + 1 # number of states for the range of insterest
g_vecs, _, _ = get_g_evolution(log_files, start_state, end_state)
dg = [g_vecs[i][end_state] - g_vecs[i][start_state] for i in range(len(g_vecs))]
return dg
[docs]def plot_dg_evolution(log_files, start_state, end_state, start_idx=None, end_idx=None, dt_log=2):
"""
For a weight-updating simulation, plots the time series of the weight
difference (:math:`Δg = g_2-g_1`) between the states of interest.
Parameters
----------
log_files : list
The list of file paths to the log file(s). If multiple log files are provided (for a REXEE simulation),
please make sure the files are in the correct order such that the time series of the alchemical weights
are continuous.
start_state : int
The index of the state (starting from 0) whose weight is :math:`g_1`.
end_state : int
The index of the state (starting from 0) whose weight is :math:`g_2`.
start_idx : int, Optional
The index of the first frame to be plotted. The default is :code:`None`, which means the first frame.
end_idx : int, Optional
The index of the last frame to be plotted. The default is :code:`None`, which means the last frame.
dt_log : float, Optional
The time interval (in ps) between two consecutive frames in the log file. The
default is 2 ps.
"""
dg = get_dg_evolution(log_files, start_state, end_state)
# Now we plot
dg = dg[start_idx:end_idx]
t = np.arange(len(dg)) * dt_log
plt.figure()
if max(t) >= 10000:
t = t / 1000
units = 'ns'
else:
units = 'ps'
plt.plot(t, dg)
plt.xlabel(f'Time ({units})')
plt.ylabel(r'$\Delta g$')
plt.grid()
plt.savefig('dg_evolution.png', dpi=600)
return dg
[docs]def get_delta_w_updates(log_file, plot=False):
"""
Parses the log file of a weight-updating simulation and identifies the
time frames when the Wang-Landau incrementor was updated.
Parameters
----------
log_file : str
The file path of the LOG file.
plot : bool, Optional
Whether to plot the Wang-Landau incrementor as a function of time.
The default is :code:`False`.
Returns
-------
t_updates : list
A list of time frames (in ns) when the Wang-Landau incrementor was updated.
delta_w_updates : list
A list of the updated Wang-Landau incrementors. Should be the same
length as :code:`t_updates`.
equil : bool
Whether the weights got equilibrated during the simulation.
"""
f = open(log_file, "r")
lines = f.readlines()
f.close()
# Get the parameters
for l in lines: # noqa: E741
if ' dt ' in l:
dt = float(l.split('=')[-1])
if 'init-wl-delta ' in l:
init_wl_delta = float(l.split('=')[-1])
if 'wl-scale ' in l:
wl_scale = float(l.split('=')[-1])
if 'weight-equil-wl-delta ' in l:
wl_delta_cutoff = float(l.split('=')[-1])
if 'Started mdrun' in l:
break
# Start parsing the data
n = -1
t_updates, delta_w_updates = [0], [init_wl_delta]
for l in lines: # noqa: E741
n += 1
if 'weights are now' in l:
t_updates.append(int(l.split(':')[0].split('Step')[-1]) * dt / 1000) # in ns
# search the following 10 lines to find the Wang-Landau incrementor
for i in range(10):
if 'Wang-Landau incrementor is:' in lines[n + i]:
delta_w_updates.append(float(lines[n + i].split()[-1]))
break
if 'Weights have equilibrated' in l:
equil = True
break
if equil is True:
delta_w_updates.append(delta_w_updates[-1] * wl_scale)
# Plot the Wang-Landau incrementor as a function of time if requested
# Note that between adjacen entries in t_updates, a horizontal line should be drawn.
if plot is True:
plt.figure()
for i in range(len(t_updates) - 1):
plt.plot([t_updates[i], t_updates[i + 1]], [delta_w_updates[i], delta_w_updates[i]], c='C0')
plt.plot([t_updates[i + 1], t_updates[i + 1]], [delta_w_updates[i], delta_w_updates[i + 1]], c='C0')
plt.text(0.65, 0.95, f'init_wl_delta: {init_wl_delta}', transform=plt.gca().transAxes)
plt.text(0.65, 0.9, f'wl_scale: {wl_scale}', transform=plt.gca().transAxes)
plt.text(0.65, 0.85, f'wl_delta_cutoff: {wl_delta_cutoff}', transform=plt.gca().transAxes)
plt.xlabel('Time (ns)')
plt.ylabel(r'Wang-Landau incrementor ($k_{B}T$)')
plt.grid()
plt.savefig('delta_w_updates.png', dpi=600)
return t_updates, delta_w_updates, equil