The law of total variance in practice

7 minute read

Published:

In this post, I’ll solve an example that requires the use the law of total variance in RL.

Background

Background on some technical tools used in the main section.

The law of total variance

Let $(\Omega, \mathcal{F}, \mathbb{P})$ be a probability space, $\mathcal{G}_1 \subseteq \mathcal{G}_2 \subseteq \mathcal{F}$ two sub-$\sigma$-algebras of $\mathcal{F}$, $X$ an integrable random variable on $(\Omega, \mathcal{F}, \mathbb{P})$ with finite variance. The following holds true:

\[\mathbb{V}(X \mid \mathcal{G}_1) = \mathbb{\mathbb{E}}[\mathbb{V}(X \mid \mathcal{G}_2) \mid \mathcal{G}_1] - \mathbb{\mathbb{V}}[\mathbb{E}(X \mid \mathcal{G}_2) \mid \mathcal{G}_1].\]

The law of the unconscious statistician (LOTUS)

Let $\mathbb{P}_X$ be the pushforward of the random element $X \in \mathcal{X}$. For any real-valued, $f: \mathcal{X} \to \mathbb{R}$ measurable function,

\[\mathbb{E}[f(X)] = \sum_xf(x)\mathbb{P}_X(x),\]

or

\[\mathbb{E}[f(X)] = \int_\mathcal{X} f(x)\mathrm{d}\mathbb{P}_X(x),\]

provided that either the right-hand side, or the left-hand side exist. This is known as the “law of the unconscious statistician”, or LOTUS.

Jensen’s inequality

Let $\bar{\mathbb{R}} = \mathbb{R} \cup${$ -\infty, +\infty$}, and $\mathrm{dom}(f) = {x \in \mathbb{R}^d: f(x) < \infty }$ for a real-valued function $f$ on $\mathbb{R}^d$.

Jensen’s inequality: Let $f: \mathbb{R}^d \to \bar{\mathbb{R}}$ be a measurable convex function and $X$ be an $\mathbb{R}^d$-valued random element on some probability space such that $\mathbb{E}[X]$ exists and $X \in \mathrm{dom}(f)$ holds almost surely. Then,

\[\mathbb{E}[f(X)] \geq f(\mathbb{E}[X]).\]

An example the law of total variance in practice

We want to compare the variance of the target for SARSA and expected SARSA. The update rule of is

\[Q_{t + 1}(S_t, A_t) = Q_t(S_t, A_t) + \alpha \left[R_{t + 1} + \gamma Q_t(S_{t + 1}, A_{t + 1}) - Q_t(S_t, A_t) \right].\]

And the update rule for expected SARSA is

\[Q_{t + 1}(S_t, A_t) = Q_t(S_t, A_t) + \alpha \left[R_{t + 1} + \gamma \sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) - Q_t(S_t, A_t) \right],\]

where $\pi$ is the fixed policy that was used to generate the data $S_0, A_0, R_1, \, \dots\;$.

Let $H_t = (S_0, A_0, R_1, \dots, S_t)$, and $H’_t = (S_0, A_0, R_1, \dots, S_t, A_t)$.

  1. Show that
\[\mathbb{V}\left[Q_t(S_{t + 1}, A_{t + 1}) \middle\vert H_{t + 1}\right] \geq \mathbb{V}\left[\sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) \middle\vert H_{t + 1}\right].\]

First, note that using Markov property, we can replace $H_{t + 1}$ in the above expressions with $S_{t + 1}$. Second,

\[\mathbb{V}\left[\sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) \middle\vert S_{t + 1}\right] = 0.\]

Because given $S_{t + 1}$ (it is not random anymore), the randomness in the actions is averaged out using the expectation, so there is no randomness remaining more. Third,

\[\begin{align*} \mathbb{V}\left[Q_t(S_{t + 1}, A_{t + 1}) \middle\vert S_{t + 1}\right] & = \mathbb{E}\left[Q_t(S_{t + 1}, A_{t + 1})^2 \middle\vert S_{t + 1}\right] - \mathbb{E}\left[Q_t(S_{t + 1}, A_{t + 1}) \middle\vert S_{t + 1}\right]^2 \quad \text{(def of variance)} \\ & = \sum_{a \in \mathcal{A}} \pi(a \mid S_{t + 1}) Q_t(a, S_{t + 1})^2 - \left(\sum_{a \in \mathcal{A}} \pi(a \mid S_{t + 1}) Q_t(a, S_{t + 1}) \right)^2 \quad \text{(def of expectation and LOTUS)} \\ &\geq \sum_{a \in \mathcal{A}} \pi(a \mid S_{t + 1}) Q_t(a, S_{t + 1})^2 - \sum_{a \in \mathcal{A}} \pi(a \mid S_{t + 1})^2 Q_t(a, S_{t + 1})^2 \quad \text{(Cauchy-Swhartz)} \\ &\geq \sum_{a \in \mathcal{A}} \pi(a \mid S_{t + 1}) Q_t(a, S_{t + 1})^2 - \sum_{a \in \mathcal{A}} \pi(a \mid S_{t + 1}) Q_t(a, S_{t + 1})^2 \\ &\geq 0. \end{align*}\]

Right off the bat, we already knew that the variance is always non-negative anyway. 😅

  1. When would have the equality happened?

Well you’re asking when the variance of a random variable is zero. Then, the answer is when it’s deterministic. For a deterministic random variable $X$ we have that $\mathbb{E}[X] = X$. Hence, we want

\[\mathbb{E}\left[Q_t(S_{t + 1}, A_{t + 1}) \middle\vert S_{t + 1}\right] = Q_t(S_{t + 1}, A_{t + 1}),\]

which can only happen when the policy is deterministic.

  1. Show that
\[\mathbb{V} \left[R_{t + 1} + \gamma Q_t(S_{t + 1}, A_{t + 1}) \middle\vert H'_t \right] \geq \mathbb{V}\left[R_{t + 1} + \gamma \sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) \middle\vert H'_t\right],\]

that is, the appropriate conditional variance of the SARSA target is always at least as large as that of for the expected SARSA target.

For convenience, let $Z_t = (S_t, A_t)$. We have,

\[\begin{align*} & \mathbb{V} \left[R_{t + 1} + \gamma Q_t(S_{t + 1}, A_{t + 1}) \middle\vert H'_t \right] - \mathbb{V}\left[R_{t + 1} + \gamma \sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) \middle\vert H'_t\right] = \\ & \mathbb{V} [R_{t + 1} \mid Z_t] + \gamma^2 \mathbb{V} \left[Q_t(S_{t + 1}, A_{t + 1}) \middle\vert Z_t \right] + 2\mathrm{Cov}\left[R_{t + 1}, \gamma Q_t(S_{t + 1}, A_{t + 1}) \middle\vert Z_t \right] - \mathbb{V} [R_{t + 1} \mid Z_t] - \gamma^2 \mathbb{V}\left[\sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) \middle\vert Z_t\right] - 2\mathrm{Cov}\left[R_{t + 1}, \gamma \sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) \middle\vert Z_t \right]= \\ & \gamma^2 \mathbb{V} \left[Q_t(S_{t + 1}, A_{t + 1}) \middle\vert Z_t \right] + 2\mathrm{Cov}\left[R_{t + 1}, \gamma Q_t(S_{t + 1}, A_{t + 1}) \middle\vert Z_t \right] - \gamma^2 \mathbb{V}\left[\sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) \middle\vert Z_t\right] - 2\mathrm{Cov}\left[R_{t + 1}, \gamma \sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) \middle\vert Z_t \right]. \end{align*}\]

For covariance terms, note that although $R_{t + 1}$ is not independent of $Q_t$ but given $Z_t$, it is (it is naturally independent of $S_{t + 1}$ and $A_{t + 1}$). So, covariance terms are zero. Hence, we end up with

\[\mathbb{V} \left[R_{t + 1} + \gamma Q_t(S_{t + 1}, A_{t + 1}) \middle\vert H'_t \right] - \mathbb{V}\left[R_{t + 1} + \gamma \sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) \middle\vert H'_t\right] = \gamma^2 \left(\mathbb{V} \left[Q_t(S_{t + 1}, A_{t + 1}) \middle\vert Z_t \right] - \mathbb{V}\left[\sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) \middle\vert Z_t\right]\right).\]

Now we apply the law of total variance

\[\begin{align*} &\gamma^2 \left(\mathbb{V} \left[Q_t(S_{t + 1}, A_{t + 1}) \middle\vert Z_t \right] - \mathbb{V}\left[\sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) \middle\vert Z_t\right]\right) = \\ &\gamma^2 \left(\mathbb{E}\left[\mathbb{V}\left[Q_t(S_{t + 1}, A_{t + 1}) \middle\vert S_{t + 1} \right] \middle\vert Z_t \right] - \mathbb{V}\left[\mathbb{E}\left[Q_t(S_{t + 1}, A_{t + 1}) \middle\vert S_{t + 1} \right] \middle\vert Z_t \right] \right) - \\ &\gamma^2 \left( \mathbb{E}\left[\mathbb{V}\left[\sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) \middle\vert S_{t + 1}\right] \middle\vert Z_t \right] - \mathbb{V}\left[\mathbb{E}\left[\sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) \middle\vert S_{t + 1}\right] \middle\vert Z_t \right] \right) \end{align*}\]

First, $\mathbb{E}\left[\mathbb{V}\left[\sum_{a’ \in \mathcal{A}} \pi\left(a’ \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a’ \right) \middle\vert S_{t + 1} \right] \middle\vert Z_t \right] = 0$. Given $S_{t + 1}$ everything inside becomes deterministic. Also,

\[\mathbb{V}\left[\mathbb{E}\left[Q_t(S_{t + 1}, A_{t + 1}) \middle\vert S_{t + 1} \right] \middle\vert Z_t \right] = \mathbb{V}\left[ \sum_{a'} \pi(a' \mid S_{t + 1})Q_t(S_{t + 1}, a') \middle\vert Z_t \right]= \mathbb{V}\left[\mathbb{E}\left[ \sum_{a'} \pi(a' \mid S_{t + 1})Q_t(S_{t + 1}, a') \middle\vert S_{t + 1}\right] \middle\vert Z_t \right]\]

So, we have

\[\mathbb{V} \left[R_{t + 1} + \gamma Q_t(S_{t + 1}, A_{t + 1}) \middle\vert H'_t \right] - \mathbb{V}\left[R_{t + 1} + \gamma \sum_{a' \in \mathcal{A}} \pi\left(a' \mid S_{t + 1}\right)Q_t \left(S_{t + 1}, a' \right) \middle\vert H'_t\right] = \gamma^2 \mathbb{E}[\mathbb{V}[Q_t(S_{t + 1}, A_{t + 1}) \mid S_{t + 1}] \mid Z_t] \geq 0.\]

Since the variance is always positive.

Reference