{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "fsuA88fu5HSV" }, "source": [ "# No dependency introduction to learned optimizers in JAX\n", "\n", "This notebook contains a self contained implementation of learned optimizers in JAX.\n", "It is minimal in the hopes that it is easier to follow and give readers a better understanding of what is involved. First we start with some background describing what learned optimizer are. We begin the implementation by implementing a simple MLP and train it with a hand designed optimizer. We then introduce a simple learned optimizer and discuss multiple ways to meta-train the weights of this learned optimizers including gradients, and evolution strategies.\n", "\n", "The design ideas and patterns are the same as that used by [`learned_optimization`](https://github.com/google/learned_optimization), but greatly stripped down and simplified." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "executionInfo": { "elapsed": 4826, "status": "ok", "timestamp": 1647716615044, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "AdcD2g_d5Gw4" }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import tensorflow_datasets as tfds\n", "import matplotlib.pylab as plt\n", "import numpy as onp\n", "import functools\n", "import os" ] }, { "cell_type": "markdown", "metadata": { "id": "TacG7U72It6c" }, "source": [ "## What is a learned optimizer?\n", "\n", "Learned optimizers are machine learning models which themselves optimize other machine learning models.\n", "\n", "To understand what exactly this means, consider first a simple hand designed optimizer: SGD. We can write the update equation as a single function of both parameter values, $x$, and gradients $\\nabla l$ computed on some loss $l$.\n", "\n", "$$U_{sgd}(x, \\nabla l; \\alpha) = - \\alpha \\nabla l $$\n", "\n", "This update can be applied us our next iterate:\n", "\n", "$$x' = x + U_{sgd}(x, \\nabla l; \\alpha)$$\n", "\n", "This update rule is simple, effective, and widely used. Can we do better?\n", "\n", "Framed in this way, this algorithm is simply a function. One idea to improve training is to switch out this hand designed function with a learned function parameterized by some set of weights, $\\theta$:\n", "\n", "$$U(x, \\nabla l; \\theta) = \\text{NN}(x, \\nabla l; \\theta)$$\n", "\n", "We call the weights of the optimizer, $\\theta$, the meta-parameters, or outer-parameters. The weights this optimizer is optimizing we refer to as the inner-parameters, or simply parameters.\n", "\n", "Now given this more flexible form, how do we set a particular value of the learned optimizer weights so that the learned optimizer \"performs well\"? To do this, we must first define what it means to perform well. In standard optimization, this could mean find some low loss solution after applying the optimizer many times. In machine learning, this could be finding a solution which generalizes. This objective / measurement of performance of the learned optimizer often goes by the name of a meta-loss, or outer loss.\n", "\n", "With this metric in hand, we can **optimize** the weights of the learned optimizer with respect to this meta-loss. If we have a flexible enough set of weights, and can solve this optimization problem, we will be left with a performant optimizer!\n", "\n", "\n", "In this notebook, we first start by defining the type of problem we seek our optimizer to perform well on. Next, we introduce optimizers, followed learned optimizers. Next we define our meta-objective, or our measurement of how well our optimizers perform. Finally, we discuss a variety of techniques, and tricks for meta-training including gradient based, evolutionary strategies based, and by leveraging truncations." ] }, { "cell_type": "markdown", "metadata": { "id": "XqSOLXZ-5SJ0" }, "source": [ "## The inner problem\n", "\n", "We seek to train a learned optimizer to perform well on some task. In this demo notebook, we will define our task to be a single MLP trained on resized Fashion Mnist.\n", "\n", "### Data iterators\n", "Data iterators are pretty standard, so we will not reinvent the wheel and use tensorflow datasets to create a python iterator which yields batches of data.\n", "\n", "To keep meta-training fast, we will be working with with images resized to 8x8." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "height": 592 }, "executionInfo": { "elapsed": 7501, "status": "ok", "timestamp": 1647716622662, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "UM2Yg-HP6LhO", "outputId": "56efd60d-6fc6-4f25-e820-99a375d0121d" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAI/CAYAAACS8BZlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAAAylElEQVR4nO3de4xc9XnG8ef17K7vBmJsg21uSQkpIg4Gh4YYSEouhZDGVZMq\npklU0oulNhDSRkmNolbqReo/beO0jZBcQpoWEKoMKCQiEESgKbR1bMAt+EZsx8SbxTb4is2u17v7\n9g//Ijl4Z87xe87czn4/EsK7nse/H8cPx6/PzJwxdxcAAACkSe3eAAAAQKdgMAIAAEgYjAAAABIG\nIwAAgITBCAAAIGEwAgAASHqa8YuaWVvuATBpUnzOu/DCC0O50dHR8JoDAwPh7PHjx8PZItzdWrFO\nuzp02WWXhbN9fX2hXLt+L1944YW2rNuqDknt65FZ/D/xiiuuCOWK9GjHjh3h7JEjR8LZIqp+Lor2\nQJJGRkZCud7e3vCaRfq3adOmcDb63yrV75A14z5G7SrStGnTwtl/+Zd/CeUOHToUXvMv//Ivw9ld\nu3aFs0V0w8moyID84x//OJxduHBhKLd3797wmkUG8+hfBoqaCIPR5MmTw9mhoaFQbvfu3eE1ly9f\nHs7+x3/8RzhbRDeci4oMyIODg+HsgQMHQrlzzjknvOaePXvC2UWLFoWzRc6f9TrEU2kAAAAJgxEA\nAECSazAysxvMbKuZbTOzlc3eFKqHDqEM9AhloEdoJHMwMrOapK9LulHSpZJuNrNLm70xVAcdQhno\nEcpAj5AlzxWjqyRtc/cd7j4s6X5Jy5q7LVQMHUIZ6BHKQI/QUJ7BaIGkk98C1Z++B+RFh1AGeoQy\n0CM0lOc+RuO9ne2Uty+a2QpJKwrvCFVEh1AGeoQyZPaIDk1seQajfknnnfT1Qkmn3JnQ3VdLWi21\n794h6Fh0CGWgRyhDZo/o0MSW56m0dZIuNrOLzKxP0nJJDzd3W6gYOoQy0COUgR6hocwrRu4+Yma3\nSnpMUk3S3e6+sek7Q2XQIZSBHqEM9AhZcn1Wmrs/IumRJu8FFUaHUAZ6hDLQIzTCna8BAACSXFeM\nusXixYvD2VqtFspdeeWV4TWnT58ezqK+sbGxcDb64YuS1NfXF8odO3YsvOaMGTPC2b/6q78KZ//s\nz/4snJ0IfvmXfzmcfeqpp0K52bNnh9f84z/+43C2XR8i2w3+4A/+IJx98MEHw9noB1pv3bo1vGaR\nD1R/4IEHwtlrr702nK2HK0YAAAAJgxEAAEDCYAQAAJAwGAEAACQMRgAAAAmDEQAAQMJgBAAAkDAY\nAQAAJAxGAAAACYMRAABAwmAEAACQMBgBAAAkDEYAAAAJgxEAAEDS0+4NlOmVV14JZ2u1Wig3derU\n8JojIyPhLOrr6YnXenBwMJydNCn294wi+508eXI4e8UVV4SzaKzI/9v/9V//Fcr9+q//enjNIudO\n1HfTTTeFs//zP/8Tzo6OjoZyP/nJT8JrTp8+PZy9/PLLw9non8FDQ0N1f44rRgAAAAmDEQAAQMJg\nBAAAkGQORmZ2npk9aWabzWyjmd3eio2hWugRiqJDKAM9QpY8r/ockfRFd3/OzGZKetbMHnf3TU3e\nG6qFHqEoOoQy0CM0lHnFyN1fcffn0o9fl7RZ0oJmbwzVQo9QFB1CGegRspzWa4zM7EJJiyWtbcpu\nMCHQIxRFh1AGeoTx5L6BipnNkPSApC+4++Fxfn6FpBUl7g0V1KhHdAh5cC5CGTgXoZ5cg5GZ9epE\nge519wfHe4y7r5a0Oj3eS9shKiOrR3QIWTgXoQyci9BInnelmaRvSNrs7n/f/C2hiugRiqJDKAM9\nQpY8rzFaKukzkq43sw3pn480eV+oHnqEougQykCP0FDmU2nu/rQka8FeUGH0CEXRIZSBHiELd74G\nAABIGIwAAACS3G/X7wZ79+4NZ+fPnx/KzZ49O7zmwYMHw1nUd80114Szc+bMCWdPvKazdTlJGh4e\nDmenTJkSzqIx9/gbma677rpQrsjv52uvvRbOor5ly5aFsz/60Y/C2X//938P5Xp7e8NrnnvuueHs\nr/3ar4WzIyMj4Ww9XDECAABIGIwAAAASBiMAAICEwQgAACBhMAIAAEgYjAAAABIGIwAAgITBCAAA\nIGEwAgAASBiMAAAAEgYjAACAhMEIAAAgYTACAABIetq9gTIdOXIknD377LNDuSKfxs4nWjfH+eef\nH84W+f0cGxsL5Wq1WnjN0dHRcHZ4eDicRWOvvvpqODtlypRQ7tZbbw2vuXjx4nAW9U2aFL/2MHPm\nzHA2+vvZ0xMfCfbv3x/OjoyMhLPNwBUjAACAhMEIAAAgYTACAABIGIwAAACS3IORmdXM7Hkz+24z\nN4TqokMoAz1CUXQIjZzOFaPbJW1u1kYwIdAhlIEeoSg6hLpyDUZmtlDSTZLuau52UFV0CGWgRyiK\nDiFL3itGqyR9WVLsRi0AHUI5VokeoZhVokNoIHMwMrOPStrr7s9mPG6Fma03s/Wl7Q6VQIdQBnqE\nougQ8shzxWippI+Z2U5J90u63szuefOD3H21uy9x9yUl7xHdjw6hDPQIRdEhZMocjNz9Dndf6O4X\nSlou6Qfu/umm7wyVQYdQBnqEougQ8uA+RgAAAMlpfWKcuz8l6amm7AQTAh1CGegRiqJDqIcrRgAA\nAMlpXTGqsqNHj7Z7CyjJL/3SL4Wzx44dC2enTp0ayk2aFP/7yWuvvRbOHjlyJJxFY0V+T3/2s5+F\ncocPHw6vOX369HAW9bl7OFvkz6RarRbKvf766+E1Z8+eHc52Gq4YAQAAJAxGAAAACYMRAABAwmAE\nAACQMBgBAAAkDEYAAAAJgxEAAEDCYAQAAJAwGAEAACQMRgAAAAmDEQAAQMJgBAAAkDAYAQAAJAxG\nAAAASU+7N9ApDh48GModOHCg3I2gsCNHjoSzfX194ayZhXKTJsX/flIke9FFF4WzaGxoaCicXbBg\nQSg3NjYWXjPaXTRW5Li241zU0xMfCYr0r9NwxQgAACBhMAIAAEgYjAAAAJJcg5GZnWlma8xsi5lt\nNrOrm70xVA89QlF0CGWgR2gk7yutvibpUXf/hJn1SZrWxD2huugRiqJDKAM9Ql2Zg5GZzZJ0naRb\nJMndhyUNN3dbqBp6hKLoEMpAj5Alz1Npb5X0qqRvmtnzZnaXmU1v8r5QPfQIRdEhlIEeoaE8g1GP\npCsk3enuiyUdlbTyzQ8ysxVmtt7M1pe8R1RDZo/oEDJwLkIZOBehoTyDUb+kfndfm75eoxOl+gXu\nvtrdl7j7kjI3iMrI7BEdQgbORSgD5yI0lDkYuftuSbvM7JL0rQ9I2tTUXaFy6BGKokMoAz1Clrzv\nSrtN0r3p1fs7JH22eVtChdEjFEWHUAZ6hLpyDUbuvkESlxRRCD1CUXQIZaBHaIQ7XwMAACQMRgAA\nAEne1xhV3vBw7P5e06Zxw9RO87d/+7fh7Fe+8pVwdurUqaFcb29veM0tW7aEs+9+97vDWTQ2MjIS\nzo6NjYVyPT3x0/mRI0fCWTTHO9/5zpavefDgwXB28uTJ5W2kzbhiBAAAkDAYAQAAJAxGAAAACYMR\nAABAwmAEAACQMBgBAAAkDEYAAAAJgxEAAEDCYAQAAJAwGAEAACQMRgAAAAmDEQAAQMJgBAAAkDAY\nAQAAJObu5f+iZq9KernOT58t6bXSF62eTjxOF7j7nFYslNEhqTOPT6fpxGPUsg5JnItK0onHiXNR\n9+m041S3Q00ZjBoxs/XuvqSli3YhjlNjHJ9sHKPGOD75cJwa4/jk003HiafSAAAAEgYjAACApB2D\n0eo2rNmNOE6NcXyycYwa4/jkw3FqjOOTT9ccp5a/xggAAKBT8VQaAABA0rLByMxuMLOtZrbNzFa2\nat1uY2Y7zewFM9tgZuvbvZ9OQ4/yoUeN0aN86FF9dCifbuxQS55KM7OapJckfUhSv6R1km52901N\nX7zLmNlOSUvcvZPu99AR6FF+9Kg+epQfPRofHcqvGzvUqitGV0na5u473H1Y0v2SlrVobVQHPUIZ\n6BGKokMV1qrBaIGkXSd93Z++h1O5pO+b2bNmtqLdm+kw9Cg/elQfPcqPHo2PDuXXdR3qadE6Ns73\neDvc+Ja6+4CZzZX0uJltcfcftntTHYIe5UeP6qNH+dGj8dGh/LquQ626YtQv6byTvl4oaaBFa3cV\ndx9I/94r6SGduGSLE+hRTvSoIXqUEz2qiw7l1I0datVgtE7SxWZ2kZn1SVou6eEWrd01zGy6mc38\n+Y8lfVjSi+3dVUehRznQo0z0KAd61BAdyqFbO9SSp9LcfcTMbpX0mKSapLvdfWMr1u4y8yQ9ZGbS\nid+b+9z90fZuqXPQo9zoUQP0KDd6VAcdyq0rO8SdrwEAABLufA0AAJAwGAEAACQMRgAAAAmDEQAA\nQMJgBAAAkDAYAQAAJAxGAAAACYMRAABA0pQ7X5tZW+4aOXv27HD2ggsuCOXSHT1DDhw4EM7u2LEj\nnC3C3eP/waehXR0q4p3vfGcoN2lS/O8nb7zxRjj74x//OJwtolUdktrXo/nz54ez8+bNK3En+ezf\nvz+cffnll0vcSX5VPxddeeWV4eyrr74ayo2OjobXLNLb5557Lpwtol6HmnLn63YV6Xd+53fC2bvu\nuiuUK/KH2kMPPRTOfuITnwhni6j6yainJ/53he3bt4dyM2bMCK/5ox/9KJy98cYbw9kiJsJg9Bd/\n8Rfh7Je+9KVQbmRkJLzmmjVrwtnf/d3fDWeL6IZzUZG/OI+NjYWzd955Zyh38ODB8JrR3kpSb29v\nOFtEvQ7xVBoAAECSazAysxvMbKuZbTOzlc3eFKqHDqEM9AhloEdoJHMwMrOapK9LulHSpZJuNrNL\nm70xVAcdQhnoEcpAj5AlzxWjqyRtc/cd7j4s6X5Jy5q7LVQMHUIZ6BHKQI/QUJ7BaIGkXSd93Z++\nB+RFh1AGeoQy0CM0lOctOOO9avuUV+mb2QpJKwrvCFVEh1AGeoQyZPaIDk1seQajfknnnfT1QkkD\nb36Qu6+WtFrqznvQoKnoEMpAj1CGzB7RoYktz1Np6yRdbGYXmVmfpOWSHm7utlAxdAhloEcoAz1C\nQ5lXjNx9xMxulfSYpJqku919Y9N3hsqgQygDPUIZ6BGy5LrNr7s/IumRJu8FFUaHUAZ6hDLQIzTC\nna8BAAASBiMAAICk4z5Edvr06eF1ox/iKRX7sL92eOGFF8LZD37wg+FsN3xwYxHRT6WWpMHBwVDu\nJz/5SXjNIh9ivGxZ/J52RT6NvVs+RPaJJ54Ir7to0aJwdmhoKJQrci4/duxYOFvkg0ff/e53h7Pd\ncC667LLLwuved9994Wy0C0V6UKvVwtmbb745nH3ppZfCWT5EFgAAIAODEQAAQMJgBAAAkDAYAQAA\nJAxGAAAACYMRAABAwmAEAACQMBgBAAAkDEYAAAAJgxEAAEDCYAQAAJAwGAEAACQMRgAAAAmDEQAA\nQNLT7g282be+9a1wdtKkiTPnvfe97233FjrWjBkzwtn9+/eXuJN8ivR2bGwsnP32t78dzl577bXh\nbLeYM2dOODtlypRw1sxCueHh4fCaRbJFjlPVfeQjHwlnd+/eHc7OnDkzlBscHAyvedZZZ4WzH//4\nx8PZv/mbvwln65k4kwQAAEAGBiMAAICEwQgAACDJHIzM7Dwze9LMNpvZRjO7vRUbQ7XQIxRFh1AG\neoQseV58PSLpi+7+nJnNlPSsmT3u7puavDdUCz1CUXQIZaBHaCjzipG7v+Luz6Ufvy5ps6QFzd4Y\nqoUeoSg6hDLQI2Q5rdcYmdmFkhZLWtuU3WBCoEcoig6hDPQI48l9HyMzmyHpAUlfcPfD4/z8Ckkr\nStwbKqhRj+gQ8uBchDJwLkI9uQYjM+vViQLd6+4PjvcYd18taXV6vJe2Q1RGVo/oELJwLkIZOBeh\nkTzvSjNJ35C02d3/vvlbQhXRIxRFh1AGeoQseV5jtFTSZyRdb2Yb0j/x+5xjoqJHKIoOoQz0CA1l\nPpXm7k9Lin14D5DQIxRFh1AGeoQs3PkaAAAgYTACAABIcr9dv1UWLIjfZ+vEa+omhjfeeKPdW+hY\nv/3bvx3OXnjhheHsli1bQrn+/v7wmmeddVY4O3/+/HB2Ipg6dWo4+8QTT4Sz1157bSjnHn/zVK1W\nC2dnzJgRzlbd5MmTw9ndu3eHs6+//noo19MTHwlGRkbC2be97W3hbDNwxQgAACBhMAIAAEgYjAAA\nABIGIwAAgITBCAAAIGEwAgAASBiMAAAAEgYjAACAhMEIAAAgYTACAABIGIwAAAASBiMAAICEwQgA\nACCJf5Ruk/T29oazY2Nj4Wz0U4WLfKJ1kf1OmsRMW89v/dZvhbOf+9znwtnbbrstlCvye3n06NFw\nlk9Fb8zMwtn/+7//C2eXLl0ayhXZ7/z588PZwcHBcLbqihzXdvx5VsQrr7wSzs6dO7fEnRTHn64A\nAAAJgxEAAEDCYAQAAJDkHozMrGZmz5vZd5u5IVQXHUIZ6BGKokNo5HSuGN0uaXOzNoIJgQ6hDPQI\nRdEh1JVrMDKzhZJuknRXc7eDqqJDKAM9QlF0CFnyXjFaJenLkuLvH8REt0p0CMWtEj1CMatEh9BA\n5mBkZh+VtNfdn8143AozW29m60vbHSqBDqEM9AhF0SHkkeeK0VJJHzOznZLul3S9md3z5ge5+2p3\nX+LuS0reI7ofHUIZ6BGKokPIlDkYufsd7r7Q3S+UtFzSD9z9003fGSqDDqEM9AhF0SHkwX2MAAAA\nktP6QBV3f0rSU03ZCSYEOoQy0CMURYdQD1eMAAAAEgYjAACA5LSeSmuFnp74lnp7e8NZdw9no2q1\nWjhrZiXupFrOOOOMcPaiiy4KZ6dNmxbKzZo1K7zm5ZdfHs5G9ztRHD9+PJxdsGBBODtpUuzvq0XO\nYdE1JWn69OnhbNW94x3vCGePHDkSzo6OjrY0JxX7M6nIObAZuGIEAACQMBgBAAAkDEYAAAAJgxEA\nAEDCYAQAAJAwGAEAACQMRgAAAAmDEQAAQMJgBAAAkDAYAQAAJAxGAAAACYMRAABAwmAEAACQMBgB\nAAAkPe3ewJuNjo6Gs//4j/8Yzn7+858P5Yrst6+vL5ydMWNGOFt1R44cCWff8573hLOTJ08O5WbN\nmhVe84033ghnzznnnHB2Iijy/+fu3bvD2Z6e2Gm5t7c3vObY2Fg4W+Q4Vd3cuXPD2ZGRkXB2cHAw\nlBsaGgqveezYsXC2VquFs83AFSMAAICEwQgAACBhMAIAAEhyDUZmdqaZrTGzLWa22cyubvbGUD30\nCEXRIZSBHqGRvK/y+5qkR939E2bWJ2laE/eE6qJHKIoOoQz0CHVlDkZmNkvSdZJukSR3H5Y03Nxt\noWroEYqiQygDPUKWPE+lvVXSq5K+aWbPm9ldZja9yftC9dAjFEWHUAZ6hIbyDEY9kq6QdKe7L5Z0\nVNLKNz/IzFaY2XozW1/yHlENmT2iQ8jAuQhl4FyEhvIMRv2S+t19bfp6jU6U6he4+2p3X+LuS8rc\nICojs0d0CBk4F6EMnIvQUOZg5O67Je0ys0vStz4gaVNTd4XKoUcoig6hDPQIWfK+K+02SfemV+/v\nkPTZ5m0JFUaPUBQdQhnoEerKNRi5+wZJXFJEIfQIRdEhlIEeoRHufA0AAJAwGAEAACTm7uX/ombh\nX/SRRx4Jr3vLLbeEs2vXrs1+0DjOOOOM8JobNmwIZ9/1rneFs7Nnzw5n3d3C4dNQpEPR30tJ+pVf\n+ZVwNqqnJ+9L/U41Ojoazu7bty+cfctb3hLOtqpDUrEe7d27N7zuvHnzwlmz2OEZGxtr+ZqS9L//\n+7/h7KJFi8LZbjgXTZ06NbzuSy+9FM5u3rw5lNuzZ094zZkzZ4azv/EbvxHOFlGvQ1wxAgAASBiM\nAAAAEgYjAACAhMEIAAAgYTACAABIGIwAAAASBiMAAICEwQgAACBhMAIAAEgYjAAAABIGIwAAgITB\nCAAAIGEwAgAASBiMAAAAEnP38n9Rs1clvVznp8+W9Frpi1ZPJx6nC9x9TisWyuiQ1JnHp9N04jFq\nWYckzkUl6cTjxLmo+3TacarboaYMRo2Y2Xp3X9LSRbsQx6kxjk82jlFjHJ98OE6NcXzy6abjxFNp\nAAAACYMRAABA0o7BaHUb1uxGHKfGOD7ZOEaNcXzy4Tg1xvHJp2uOU8tfYwQAANCpeCoNAAAgadlg\nZGY3mNlWM9tmZitbtW63MbOdZvaCmW0ws/Xt3k+noUf50KPG6FE+9Kg+OpRPN3aoJU+lmVlN0kuS\nPiSpX9I6STe7+6amL95lzGynpCXu3kn3e+gI9Cg/elQfPcqPHo2PDuXXjR1q1RWjqyRtc/cd7j4s\n6X5Jy1q0NqqDHqEM9AhF0aEKa9VgtEDSrpO+7k/fw6lc0vfN7FkzW9HuzXQYepQfPaqPHuVHj8ZH\nh/Lrug71tGgdG+d7vB1ufEvdfcDM5kp63My2uPsP272pDkGP8qNH9dGj/OjR+OhQfl3XoVZdMeqX\ndN5JXy+UNNCitbuKuw+kf++V9JBOXLLFCfQoJ3rUED3KiR7VRYdy6sYOtWowWifpYjO7yMz6JC2X\n9HCL1u4aZjbdzGb+/MeSPizpxfbuqqPQoxzoUSZ6lAM9aogO5dCtHWrJU2nuPmJmt0p6TFJN0t3u\nvrEVa3eZeZIeMjPpxO/Nfe7+aHu31DnoUW70qAF6lBs9qoMO5daVHeLO1wAAAAl3vgYAAEgYjAAA\nABIGIwAAgITBCAAAIGEwAgAASBiMAAAAEgYjAACAhMEIAAAgacqdr82Mu0ZWlLuP9+GJpWtXh6ZO\nnRrOvuMd7wjl0l1hQ37605+Gs6+99lo4W0SrOiQV69Hb3/728LpTpkwJZzdt2hTKjY2NhddcvHhx\nODs8PBzObt26NZQbGRnR6Ohopc9Fl112WTgbPacU6dDg4GA4u2PHjnC2iHrnoqbc+ZrBqLqqPhhd\nfvnl4ex//ud/hnI9PfG/n3z+858PZ//5n/85nC2iWwajJ598MrxudEiWpEWLFoVyx44dC6954MCB\ncHbXrl3h7Pvf//5QbmBgQMeOHav0uWj79u3h7KRJsSeDigw3L7zwQjj7yU9+Mpwtot65iKfSAAAA\nklyDkZndYGZbzWybma1s9qZQPXQIZaBHKAM9QiOZg5GZ1SR9XdKNki6VdLOZXdrsjaE66BDKQI9Q\nBnqELHmuGF0laZu773D3YUn3S1rW3G2hYugQykCPUAZ6hIbyDEYLJJ386rr+9D0gLzqEMtAjlIEe\noaE8b4cZ71Xbp7xK38xWSFpReEeoIjqEMtAjlCGzR3RoYsszGPVLOu+krxdKGnjzg9x9taTVEm/X\nxynoEMpAj1CGzB7RoYktz1Np6yRdbGYXmVmfpOWSHm7utlAxdAhloEcoAz1CQ5lXjNx9xMxulfSY\npJqku919Y9N3hsqgQygDPUIZ6BGy5Lrlrrs/IumRJu8FFUaHUAZ6hDLQIzTCna8BAAASBiMAAICE\nD5HFaemGD5GdO3dueN2vfvWr4Wz0E62L/D/Y29sbzh4+fDic/f3f//1wtls+RHbevHnhdaMfKCxJ\nU6ZMCeX27dsXXrPI/zP/+q//Gs7ecccd4Ww3nIuKeOmll8LZ6KfVFzkXnXnmmeHs1VdfHc4WwYfI\nAgAAZGAwAgAASBiMAAAAEgYjAACAhMEIAAAgYTACAABIGIwAAAASBiMAAICEwQgAACBhMAIAAEgY\njAAAABIGIwAAgITBCAAAIGEwAgAASHravYFOsWXLllDunnvuCa/513/91+Hs+eefH87+9Kc/DWe7\nwac+9alw9q1vfWs4+8orr4Ryx48fD685ZcqUcHbevHnh7ESwd+/ecLavry+cPe+880K5rVu3htfc\nvXt3OLt9+/ZwFvUNDQ2Fs/v27QvlLrnkkvCahw4dCmc7DVeMAAAAEgYjAACAhMEIAAAgyRyMzOw8\nM3vSzDab2UYzu70VG0O10CMURYdQBnqELHlefD0i6Yvu/pyZzZT0rJk97u6bmrw3VAs9QlF0CGWg\nR2go84qRu7/i7s+lH78uabOkBc3eGKqFHqEoOoQy0CNkOa3XGJnZhZIWS1rblN1gQqBHKIoOoQz0\nCOPJfR8jM5sh6QFJX3D3w+P8/ApJK0rcGyqoUY/oEPLgXIQycC5CPbkGIzPr1YkC3evuD473GHdf\nLWl1eryXtkNURlaP6BCycC5CGTgXoZE870ozSd+QtNnd/775W0IV0SMURYdQBnqELHleY7RU0mck\nXW9mG9I/H2nyvlA99AhF0SGUgR6hocyn0tz9aUnWgr2gwugRiqJDKAM9QhbufA0AAJAwGAEAACS5\n365/uk68vu30TZoUn9VGRkbC2ajf+73fC2dvvPHGcLbIf+vnPve5UG7btm3hNVtpwYL4vdre8573\nhLPf/va3Q7ne3t7wmgcPHgxn58+fH85OBO7xNyMV+f/ziSeeCOXmzZsXXnPnzp3h7JVXXhnO3nXX\nXeFs1fX0xP94Pvfcc0O5K664IrzmD3/4w3C203DFCAAAIGEwAgAASBiMAAAAEgYjAACAhMEIAAAg\nYTACAABIGIwAAAASBiMAAICEwQgAACBhMAIAAEgYjAAAABIGIwAAgITBCAAAIIl/fG8D559/vr7y\nla+EskU+afzpp58OZwcHB0O56667LrxmrVYLZxcuXBjOfvKTnwzl7rzzzvCarfT2t789nN2+fXs4\nOzY2Fs5GHT16NJwt8mnsaGzatGnhbPRT1Y8cORJec3h4OJw9//zzw1nUt2/fvnA2+ueomYXXPHTo\nUDjbabhiBAAAkDAYAQAAJAxGAAAASe7ByMxqZva8mX23mRtCddEhlIEeoSg6hEZO54rR7ZI2N2sj\nmBDoEMpAj1AUHUJduQYjM1so6SZJdzV3O6gqOoQy0CMURYeQJe8Vo1WSviyp9e9HRlWsEh1CcatE\nj1DMKtEhNJA5GJnZRyXtdfdnMx63wszWm9n6IvfTQPVEOtSiraGL0CMURYeQR54rRkslfczMdkq6\nX9L1ZnbPmx/k7qvdfYm7L5kxY0bJ20SXO+0OtXqD6Ar0CEXRIWTKHIzc/Q53X+juF0paLukH7v7p\npu8MlUGHUAZ6hKLoEPLgPkYAAADJaX0oj7s/JemppuwEEwIdQhnoEYqiQ6iHK0YAAAAJgxEAAEBy\nWk+l5TU0NKSNGzeGsgMDA+F1/+iP/iic7evrC+W+9KUvhddcu3ZtOPvMM8+Esy+++GIoNzg4GF6z\nlcwsnP2nf/qncPbaa68N5Wq1WnjNhQsXhrPHjx8PZyeCm266KZzdsmVLODtlypRQ7tChQ+E1R0dH\nw9menqb8MTLhRc/TkjR//vxQbtOmTeE19+7dG852Gq4YAQAAJAxGAAAACYMRAABAwmAEAACQMBgB\nAAAkDEYAAAAJgxEAAEDCYAQAAJAwGAEAACQMRgAAAAmDEQAAQMJgBAAAkDAYAQAAJAxGAAAASU8z\nftG9e/fqH/7hH5rxSze0Zs2alq/ZLr29ve3eQsc6ePBgOLtnz55wdtKk2N8zxsbGwmvOnTs3nJ0z\nZ044OxH85m/+Zjh79OjRcHb69OnhbNSsWbPCWXcvcSf4ue9973vh7I033hjK7du3L7zmY489Fs52\nGq4YAQAAJAxGAAAACYMRAABAkmswMrMzzWyNmW0xs81mdnWzN4bqoUcoig6hDPQIjeR98fXXJD3q\n7p8wsz5J05q4J1QXPUJRdAhloEeoK3MwMrNZkq6TdIskufuwpOHmbgtVQ49QFB1CGegRsuR5Ku2t\nkl6V9E0ze97M7jKz1r+fFN2OHqEoOoQy0CM0lGcw6pF0haQ73X2xpKOSVr75QWa2wszWm9n6kveI\nasjsER1CBs5FKAPnIjSUZzDql9Tv7mvT12t0olS/wN1Xu/sSd19S5gZRGZk9okPIwLkIZeBchIYy\nByN33y1pl5ldkr71AUmbmrorVA49QlF0CGWgR8iS911pt0m6N716f4ekzzZvS6gweoSi6BDKQI9Q\nV67ByN03SOKSIgqhRyiKDqEM9AiNcOdrAACAhMEIAAAgyfsaI6BrTJoUn/fXrFkTzn7ve98LZ6N6\neuL/C2/btq3EnVTP2NhYOPv+978/nN25c2coN3v27PCahw4dCmdnzpwZzkb7OzIyEl6zW+zZsyec\nPXjwYChXpEO7du0KZzsNV4wAAAASBiMAAICEwQgAACBhMAIAAEgYjAAAABIGIwAAgITBCAAAIGEw\nAgAASBiMAAAAEgYjAACAhMEIAAAgYTACAABIGIwAAAASBiMAAIDE3L38X9TsVUkv1/npsyW9Vvqi\n1dOJx+kCd5/TioUyOiR15vHpNJ14jFrWIYlzUUk68ThxLuo+nXac6naoKYNRI2a23t2XtHTRLsRx\naozjk41j1BjHJx+OU2Mcn3y66TjxVBoAAEDCYAQAAJC0YzBa3YY1uxHHqTGOTzaOUWMcn3w4To1x\nfPLpmuPU8tcYAQAAdCqeSgMAAEhaNhiZ2Q1mttXMtpnZylat223MbKeZvWBmG8xsfbv302noUT70\nqDF6lA89qo8O5dONHWrJU2lmVpP0kqQPSeqXtE7Sze6+qemLdxkz2ylpibt30v0eOgI9yo8e1UeP\n8qNH46ND+XVjh1p1xegqSdvcfYe7D0u6X9KyFq2N6qBHKAM9QlF0qMJaNRgtkLTrpK/70/dwKpf0\nfTN71sxWtHszHYYe5UeP6qNH+dGj8dGh/LquQz0tWsfG+R5vhxvfUncfMLO5kh43sy3u/sN2b6pD\n0KP86FF99Cg/ejQ+OpRf13WoVVeM+iWdd9LXCyUNtGjtruLuA+nfeyU9pBOXbHECPcqJHjVEj3Ki\nR3XRoZy6sUOtGozWSbrYzC4ysz5JyyU93KK1u4aZTTezmT//saQPS3qxvbvqKPQoB3qUiR7lQI8a\nokM5dGuHWvJUmruPmNmtkh6TVJN0t7tvbMXaXWaepIfMTDrxe3Ofuz/a3i11DnqUGz1qgB7lRo/q\noEO5dWWHuPM1AABAwp2vAQAAEgYjAACAhMEIAAAgYTACAABIGIwAAAASBiMAAICEwQgAACBhMAIA\nAEiacudrM2vLXSPPPvvscHbevHmh3MjISHjNgYH4R+u8/vrr4WwR7j7ehyeWrl0dQvO1qkNS+3p0\n5plnhrPnnHNOKDc8PBxe82c/+1k4e+zYsXC2CM5FKKpeh1rykSCnI906PGTZsmXh7J/8yZ+Ecvv3\n7w+v+ed//ufh7JNPPhnOAmiuX/3VXw1n//RP/zSUe/nll8NrFjkXbd26NZwFOhFPpQEAACS5BiMz\nu8HMtprZNjNb2exNoXroEMpAj1AGeoRGMgcjM6tJ+rqkGyVdKulmM7u02RtDddAhlIEeoQz0CFny\nXDG6StI2d9/h7sOS7pcUfzEPJiI6hDLQI5SBHqGhPIPRAkm7Tvq6P30PyIsOoQz0CGWgR2goz7vS\nxnub2ClvXzSzFZJWFN4RqogOoQz0CGXI7BEdmtjyDEb9ks476euFkk65AY+7r5a0WuK+DzgFHUIZ\n6BHKkNkjOjSx5XkqbZ2ki83sIjPrk7Rc0sPN3RYqhg6hDPQIZaBHaCjzipG7j5jZrZIek1STdLe7\nb2z6zlAZdAhloEcoAz1Cllx3vnb3RyQ90uS9oMLoEMpAj1AGeoRGuPM1AABAwmAEAACQdNyHyP7d\n3/1dOHvllVeGs+vWrQvlarVaeM077rgjnC3ySdrPPPNMONsqtVpNZ5xxRij74IMPhtcdGRkJZ9ux\n5htvvBHOzpo1K5y94YYbQrl2HN92KPKhrNFzysyZM8Nrvu997wtn+RBZVA1XjAAAABIGIwAAgITB\nCAAAIGEwAgAASBiMAAAAEgYjAACAhMEIAAAgYTACAABIGIwAAAASBiMAAICEwQgAACBhMAIAAEgY\njAAAABIGIwAAgKSn3Rt4s1tuuSWcPXz4cMuzfX194TUvuOCCcPaBBx4IZ88555xwtlUWLFiglStX\nhrL79+8Przs4OBjOmllLc5I0MjISzg4NDYWzixYtCuW2bNkSXrOb9PTET60bNmwI5S699NLwmrNn\nzw5nUR2rV68OZ1esWBHOLlu2LJyN/tm9fv36uj/HFSMAAICEwQgAACBhMAIAAEgyByMzO8/MnjSz\nzWa20cxub8XGUC30CEXRIZSBHiFLnlcIjkj6ors/Z2YzJT1rZo+7+6Ym7w3VQo9QFB1CGegRGsq8\nYuTur7j7c+nHr0vaLGlBszeGaqFHKIoOoQz0CFlO6zVGZnahpMWS1jZlN5gQ6BGKokMoAz3CeHLf\nbMPMZkh6QNIX3P2UGweY2QpJ8RsZYEJo1KOTO/SWt7ylDbtDN+BchDLkPRdh4sl1xcjMenWiQPe6\n+4PjPcbdV7v7EndfUuYGUR1ZPTq5QzNmzGj9BtHxOBehDKdzLmr97tBued6VZpK+IWmzu/9987eE\nKqJHKIoOoQz0CFnyXDFaKukzkq43sw3pn480eV+oHnqEougQykCP0FDma4zc/WlJ8Q9zAkSPUBwd\nQhnoEbJw52sAAICEwQgAACDJ/Xb9Vtm/f39b1p0/f34ot2vXrvCas2bNCmf7+/vD2W4wa9YsffCD\nHwxl//u//zu87ujoaDg7aVLs7xm9vb3hNYso8t96zTXXhHJV7+3Pfec73wlnr7/++lDuXe96V3jN\nb33rW+Es6jvxOu/Wc/dQbv369eE1t23bFs5ecskl4WyR81g9XDECAABIGIwAAAASBiMAAICEwQgA\nACBhMAIAAEgYjAAAABIGIwAAgITBCAAAIGEwAgAASBiMAAAAEgYjAACAhMEIAAAgYTACAABIetq9\ngTebPHlyODs2NhbORj8ZvcinJw8NDYWzRY5TNxgdHdXhw4dD2eHh4fC60U+llqRardbSnCT19fW1\nJbt169ZQrkjnu8l9990Xzn7qU58K5Yr06Jlnnglnq87M1NvbG8o+8MADJe8mnz/8wz8M5Yr0tsi5\n86tf/Wo4+773vS+Uu/nmm+v+HFeMAAAAEgYjAACAhMEIAAAgyT0YmVnNzJ43s+82c0OoLjqEMtAj\nFEWH0MjpXDG6XdLmZm0EEwIdQhnoEYqiQ6gr12BkZgsl3STpruZuB1VFh1AGeoSi6BCy5L1itErS\nlyXF3w+PiW6V6BCKWyV6hGJWiQ6hgczByMw+Kmmvuz+b8bgVZrbezNaXtjtUQqRDBw4caNHu0C04\nF6GoSIeK3J8H3SnPFaOlkj5mZjsl3S/pejO7580PcvfV7r7E3ZeUvEd0v9Pu0FlnndXqPaLzcS5C\nUafdoSI38UV3yhyM3P0Od1/o7hdKWi7pB+7+6abvDJVBh1AGeoSi6BDy4D5GAAAAyWl9Vpq7PyXp\nqabsBBMCHUIZ6BGKokOohytGAAAACYMRAABAclpPpbXC8PBwOHv06NFwtlarhXJjY/FbYezduzec\nnTp1ajjbDYaHh7Vr165Qdt68eeF1d+7cGc7OmTMnlFu+fHl4zW3btoWzvb294ezg4GAoV+T/l25y\n8ODBcHZgYCCUu+eeU95cldvo6Gg4W3UzZszQkiWxNzju2bMnvO53vvOdcPbqq68O5d72treF1yxy\nPrntttvC2RdffDGcrYcrRgAAAAmDEQAAQMJgBAAAkDAYAQAAJAxGAAAACYMRAABAwmAEAACQMBgB\nAAAkDEYAAAAJgxEAAEDCYAQAAJAwGAEAACQMRgAAAAmDEQAAQGLuXv4vahb+Rffs2RNet7+/P5w9\nduxYKHf48OHwmueee244O2fOnHB2/vz54ay7Wzh8GqZMmeILFy4MZf/t3/4tvO7x48fD2cmTJ4dy\nBw4cCK9ZxDXXXBPOLliwIJQ7cuSIRkdHW9Ihqdi5qIhoFyTp+eefD+W+8IUvhNccGBgIZ1988cVw\ntohWnYva1SE0X70OccUIAAAgYTACAABIGIwAAACSXIORmZ1pZmvMbIuZbTazq5u9MVQPPUJRdAhl\noEdopCfn474m6VF3/4SZ9Uma1sQ9obroEYqiQygDPUJdmYORmc2SdJ2kWyTJ3YclDTd3W6gaeoSi\n6BDKQI+QJc9TaW+V9Kqkb5rZ82Z2l5lNb/K+UD30CEXRIZSBHqGhPINRj6QrJN3p7oslHZW08s0P\nMrMVZrbezNaXvEdUQ2aPTu7Q6OhoO/aIzsa5CGU4rXNROzaI9sozGPVL6nf3tenrNTpRql/g7qvd\nfYm7Lylzg6iMzB6d3KFardbyDaLjcS5CGU7rXNTy3aHtMgcjd98taZeZXZK+9QFJm5q6K1QOPUJR\ndAhloEfIkvddabdJuje9en+HpM82b0uoMHqEougQykCPUFeuwcjdN0jikiIKoUcoig6hDPQIjXDn\nawAAgITBCAAAIMn7GqOWmTt3bjj78Y9/PJwdGhoK5Xp64oewyDuvnn766XC2Gxw7dkzbt28PZd/7\n3veWvBvg9Bw7diycXbt2bfaDxrFnz57wmvv27QtngarhihEAAEDCYAQAAJAwGAEAACQMRgAAAAmD\nEQAAQMJgBAAAkDAYAQAAJAxGAAAACYMRAABAwmAEAACQMBgBAAAkDEYAAAAJgxEAAEDCYAQAAJCY\nu5f/i5q9KunlOj99tqTXSl+0ejrxOF3g7nNasVBGh6TOPD6dphOPUcs6JHEuKkknHifORd2n045T\n3Q41ZTBqxMzWu/uSli7ahThOjXF8snGMGuP45MNxaozjk083HSeeSgMAAEgYjAAAAJJ2DEar27Bm\nN+I4NcbxycYxaozjkw/HqTGOTz5dc5xa/hojAACATsVTaQAAAEnLBiMzu8HMtprZNjNb2ap1u42Z\n7TSzF8xsg5mtb/d+Og09yoceNUaP8qFH9dGhfLqxQy15Ks3MapJekvQhSf2S1km62d03NX3xLmNm\nOyUtcfdOut9DR6BH+dGj+uhRfvRofHQov27sUKuuGF0laZu773D3YUn3S1rWorVRHfQIZaBHKIoO\nVVirBqMFknad9HV/+h5O5ZK+b2bPmtmKdm+mw9Cj/OhRffQoP3o0PjqUX9d1qKdF69g43+PtcONb\n6u4DZjZX0uNmtsXdf9juTXUIepQfPaqPHuVHj8ZHh/Lrug616opRv6TzTvp6oaSBFq3dVdx9IP17\nr6SHdOKSLU6gRznRo4boUU70qC46lFM3dqhVg9E6SReb2UVm1idpuaSHW7R21zCz6WY28+c/lvRh\nSS+2d1cdhR7lQI8y0aMc6FFDdCiHbu1QS55Kc/cRM7tV0mOSapLudveNrVi7y8yT9JCZSSd+b+5z\n90fbu6XOQY9yo0cN0KPc6FEddCi3ruwQd74GAABIuPM1AABAwmAEAACQMBgBAAAkDEYAAAAJgxEA\nAEDCYAQAAJAwGAEAACQMRgAAAMn/AxwYF5h7sI3UAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import tensorflow as tf\n", "\n", "ds = tfds.load(\"fashion_mnist\", split=\"train\")\n", "\n", "\n", "def resize_and_scale(batch):\n", " batch[\"image\"] = tf.image.resize(batch[\"image\"], (8, 8)) / 255.\n", " return batch\n", "\n", "\n", "ds = ds.map(resize_and_scale).cache().repeat(-1).shuffle(\n", " 64 * 10).batch(128).prefetch(5)\n", "data_iterator = ds.as_numpy_iterator()\n", "batch = next(data_iterator)\n", "fig, axs = plt.subplots(4, 4, figsize=(10, 10))\n", "for ai, a in enumerate(axs.ravel()):\n", " a.imshow(batch[\"image\"][ai][:, :, 0], cmap=\"gray\")\n", "\n", "input_size = onp.prod(batch[\"image\"].shape[1:])" ] }, { "cell_type": "markdown", "metadata": { "id": "ruw2lKT1I0Ez" }, "source": [ "### Inner problem loss function & initialization\n", "\n", "Next, we must define the inner problem with which we seek to train.\n", "One important note here is no parameters are stored in the task itself! See [this](https://jax.readthedocs.io/en/latest/jax-101/07-state.html) jax tutorial for more information on this.\n", "\n", "Our task will have 2 methods -- an init which constructs the initial values of the weights, and a loss which applies the MLP, and returns the average cross entropy loss." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "executionInfo": { "elapsed": 1547, "status": "ok", "timestamp": 1647716624386, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "SNUAeWO65TzL", "outputId": "69ff87c8-4b39-4cd3-c16c-b42edf249113" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray(2.3008037, dtype=float32)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class MLPTask:\n", "\n", " def init(self, key):\n", " key1, key2 = jax.random.split(key)\n", " w0 = jax.random.normal(key1, [input_size, 128]) * 0.02\n", " w1 = jax.random.normal(key2, [128, 10]) * 0.02\n", " b0 = jnp.zeros([128])\n", " b1 = jnp.ones([10])\n", " return (w0, b0, w1, b1)\n", "\n", " def loss(self, params, batch):\n", " data = batch[\"image\"]\n", " data = jnp.reshape(data, [data.shape[0], -1])\n", " w0, b0, w1, b1 = params\n", " logits = jax.nn.relu(data @ w0 + b0) @ w1 + b1\n", " labels = jax.nn.one_hot(batch[\"label\"], 10)\n", " vec_loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)\n", " return jnp.mean(vec_loss)\n", "\n", "\n", "task = MLPTask()\n", "key = jax.random.PRNGKey(0)\n", "params = task.init(key)\n", "task.loss(params, batch)" ] }, { "cell_type": "markdown", "metadata": { "id": "YL8XVt6JI3BI" }, "source": [ "### Inner training with SGD\n", "\n", "With our newly defined model, let's train it with SGD." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "height": 282 }, "executionInfo": { "elapsed": 2351, "status": "ok", "timestamp": 1647716626848, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "wTMYdlvO5RVC", "outputId": "9c39fa33-cf66-4a6b-84de-81478deeb2c3" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAAA4VElEQVR4nO3deXxU5bnA8d+TnYSwh30JIIIgiBhRQFY3wIXaelu0arUqpbZX\n7SrWurRWq/W2va1LLa7VWu3iUm8BxQVFBMGAgKwCAQHZAggkQJbJvPePOWdyZubMmgmTTJ7v55NP\nZs4y854sz3nnXZ5XjDEopZRKXxmpLoBSSqnGpYFeKaXSnAZ6pZRKcxrolVIqzWmgV0qpNJeV6gK4\n6dSpkykuLk51MZRSqtlYvnz5fmNMkdu+Jhnoi4uLKS0tTXUxlFKq2RCRz8Pt06YbpZRKcxrolVIq\nzWmgV0qpNKeBXiml0pwGeqWUSnMa6JVSKs1poFdKqTSXVoH+4Xc28ebaPWjqZaWUqtckJ0wl4suj\nNby4bDu7Dldx4ZAuTB3ajQsGd6VVTmaqi6aUUimVNjX69gU5vPGDcdw0oT8LP9vPLS+t5Kt/Wkx5\nRXWqi6aUUiklTbGZo6SkxDQkBUJVbR33z13Pc0t8M4I7F+Zyx0WnMG14j2QVUSmlmhQRWW6MKXHb\nF7VGLyK9RGSBiKwXkbUicovLMd8UkdXW12IROc2xb5uIfCoiK0XkhCSwycvO5JfTTuVP3xwBwL6K\nam55aSXPfxQ2FYRSSqWtWJpuPMCPjDGnAGcD3xORwUHHbAXGG2OGAfcCs4P2TzTGDA93t2ksU4Z2\nY+uvp/KYFfDvfG0N723cdyKLoJRSKRc10BtjdhtjVliPK4D1QI+gYxYbY760nn4E9Ex2QRMlIkwd\n2o2XvzsKgB/8fSX7KqpSXCqllDpx4uqMFZFi4HRgaYTDrgfmOZ4bYL6ILBeRGRFee4aIlIpIaXl5\neTzFiskZfTpw77QhHKnycMnDi/ji0HEdhqmUahFiDvQi0hp4GbjVGHMkzDET8QX62xybxxhjRgBT\n8DX7jHM71xgz2xhTYowpKSpyzZ3fYFePKua1m8ZwoLKGMQ+8y1+1zV4p1QLEFOhFJBtfkH/BGPNK\nmGOGAU8C04wxB+ztxphd1vd9wKvAyIYWuiGG9mzLI1f62uzvfn0tdV6t1Sul0lsso24EeApYb4z5\nXZhjegOvAFcbYz5zbC8QkUL7MXABsCYZBW+Iyad25ZZzB+A18Pj7W1JdHKWUalSx1OjHAFcDk6wh\nkitFZKqIzBSRmdYxdwEdgceChlF2ARaJyCpgGTDHGPNGsi8iETPG9QPgmQ+3sveIds4qpdJXWk6Y\nitWnOw9zySOL6NQ6lyW3TyI7M20mCiulWpgGTZhKZ0N7tqV/UQH7K6t5Z/3eVBdHKaUaRYsO9ABP\nfutMAGb+dQW7Dh1PcWmUUir5WnygL+6Y73+8frfrqFGllGrWWnygFxFKf34eAOt2aaBXSqWfFh/o\nATq1zmVkcQdeWLqdqtq6VBdHKaWSSgO95cZx/dhzpIpBd77Btv1HU10cpZRKGg30lvMHd/E//mTH\nlxGOVEqp5kUDvcPCn0wEYNchnUCllEofGugdenVoBcBDb27UzJZKqbShgd7Bl9bH53/mb0xhSZRS\nKnk00AdZbg21fHTBFrya2VIplQY00Afp2DqXS0/rDsBXHvswxaVRSqmG00Dv4rIRvpUSV+88nOKS\nKKVUw2mgd9GxIMf/+JPtOtRSKdW8aaB30SYv2/94nea/UUo1cxroXfTqkM9lp/uab8orqlNcGqWU\nahgN9C4yM4Tff2M4AP/79iYdU6+UatY00EdQ0qc9ACu2H0ptQZRSqgE00Edw32VDAXjygzJdV1Yp\n1WxFDfQi0ktEFojIehFZKyK3uBwjIvJHEdksIqtFZIRj32QR2Wjtm5XsC2hMHazRN/PW7OHaZz5O\ncWmUUioxsdToPcCPjDGnAGcD3xORwUHHTAEGWF8zgD8BiEgm8Ki1fzBwhcu5TVa7/PrRN2XllSks\niVJKJS5qoDfG7DbGrLAeVwDrgR5Bh00DnjM+HwHtRKQbMBLYbIwpM8bUAC9ZxzYL2Zn1P55qjzeF\nJVFKqcTF1UYvIsXA6cDSoF09gB2O5zutbeG2u732DBEpFZHS8vLyeIrVqP4+42z/40Wb9qewJEop\nlZiYA72ItAZeBm41xgTPIhKXU0yE7aEbjZltjCkxxpQUFRXFWqxGd1a/jtxy7gAArnoq+P6mlFJN\nX0yBXkSy8QX5F4wxr7gcshPo5XjeE9gVYXuzMnZAp1QXQSmlEhbLqBsBngLWG2N+F+aw14FrrNE3\nZwOHjTG7gY+BASLSV0RygOnWsc1KSXEHijvm0yo7UydPKaWanawYjhkDXA18KiIrrW0/A3oDGGMe\nB+YCU4HNwDHgOmufR0S+D7wJZAJPG2PWJvMCTpSMDOF4bR1zPt3NxcO6p7o4SikVs6iB3hizCPe2\nducxBvhemH1z8d0ImrWy8qMArNt1RAO9UqpZ0ZmxMZoxrh8QOLZeKaWaAw30MZo1eRAisGFPhbbT\nK6WaFQ30McrIEIyBV1Z8wVOLtqa6OEopFTMN9HHo16kAgA8368QppVTzoYE+Dq//9zkM7dGWBRvL\nKd12MNXFUUqpmGigj0Pr3CyG9mwLwOWPL0lxaZRSKjYa6ONUmFc/InXAHXPZtLcihaVRSqnoNNDH\nyVNXP+Kmts7wz+U7U1gapZSKTgN9nPYcDlxpyhn4lVKqKdJAH6dbzxsQ8NyrY+qVUk2cBvo4DehS\nyCs3jfY/93h1QRKlVNOmgT4Bp3Zv639c59UavVKqadNAn4CcrPof28GjNSksiVJKRaeBvoG2WFkt\nlVKqqdJAn6CXrLVkDx3TGr1SqmnTQJ+gs/t1ZMa4flRWe1JdFKWUikgDfQO0ys6kqtbLR2UHUl0U\npZQKSwN9A9TW+YZWTp/9kdbslVJNlgb6BjhWU+d/POKXb6WwJEopFV7UNWNF5GngYmCfMeZUl/0/\nAb7peL1TgCJjzEER2QZUAHWAxxhTkqyCNwVHHbX4mjovxhhEIi6vq5RSJ1wsNfpngcnhdhpjHjLG\nDDfGDAduB943xjiTtU+09qdVkAf47oT+Ac9nPL88RSVRSqnwogZ6Y8xCINZVNq4AXmxQiZqRfkWt\nA56/tW5vikqilFLhJa2NXkTy8dX8X3ZsNsB8EVkuIjOinD9DREpFpLS8vDxZxWp07/xoPAU5maku\nhlJKhZXMzthLgA+Dmm3GGGNGAFOA74nIuHAnG2NmG2NKjDElRUVFSSxW4+pf1JrTerVLdTGUUiqs\nZAb66QQ12xhjdlnf9wGvAiOT+H5NxqFjtakuglJKhZWUQC8ibYHxwL8d2wpEpNB+DFwArEnG+zU1\ne4/4FiPJ1yYcpVQTFDXQi8iLwBJgoIjsFJHrRWSmiMx0HHYZMN8Y48zw1QVYJCKrgGXAHGPMG8ks\nfFNhj75p1yo7xSVRSqlQUcfRG2OuiOGYZ/ENw3RuKwNOS7RgzckNY/ux7cBR5n66J9VFUUqpEDoz\nNkk65Odw6FiNLkSilGpyNNAnSVFhLl4Dv3/rM55bsi3VxVFKKb+oTTcqNkWFuQA8smAzANeMKk5h\naZRSqp7W6JNkcLe20Q9SSqkU0ECfJL075gc8r/F4U1QSpZQKpIE+ido6hlc+uagshSVRSql6GuiT\n6K/Xn4Wdpfg3b2xMbWGUUsqigT6JhvZsyws3nJXqYiilVAAN9ElW0qeD//GOg8d0iUGlVMrp8Mok\ny8mqv3eO/c0CALY9cFGqiqOUUlqjV0qpdKeBvhEM6loY8LzaUxfmSKWUanwa6BvB89cHdsj+4+Md\nHK/RYK+USg0N9I2gqDA3oFZ/57/Xcvnji1NYIqVUS6aBvpFs2FMR8HztriMpKolSqqXTQK+UUmlO\nA30jefiK00O2Ld6yPwUlUUq1dBroG8nFw7oxaVDngG1XPrE0RaVRSrVkGugbiYjw9LVnBiQ6U0qp\nVIhlcfCnRWSfiKwJs3+CiBwWkZXW112OfZNFZKOIbBaRWckseHPhNYFLCz7+/hYqqmpTVBqlVEsU\nS43+WWBylGM+MMYMt75+CSAimcCjwBRgMHCFiAxuSGGbI2/QGrIPzNvA/XM3pKg0SqmWKGqgN8Ys\nBA4m8Nojgc3GmDJjTA3wEjAtgddp1upM6GLhVbU6eUopdeIkq41+lIisEpF5IjLE2tYD2OE4Zqe1\nzZWIzBCRUhEpLS8vT1KxUm/Kqd1Ctr36yRe8uXZPCkqjlGqJkhHoVwB9jDGnAQ8Dr1nbxeXY0Oqt\nvcOY2caYEmNMSVFRURKK1TQ8+LVhrtu/8/zyE1wSpVRL1eBAb4w5YoyptB7PBbJFpBO+Gnwvx6E9\ngV0Nfb/mxpm2WCmlUqHBUUhEuor4FtATkZHWax4APgYGiEhfEckBpgOvN/T9mqOz+3WIfpBSSjWS\nWIZXvggsAQaKyE4RuV5EZorITOuQy4E1IrIK+CMw3fh4gO8DbwLrgX8YY9Y2zmU0bc9eN5KJA9On\nOUop1byIcRkVkmolJSWmtLQ01cVIqsPHa/nHxzu4b+56/zZdeUoplSwistwYU+K2TxuQT5C2rbK5\ncVy/gG0LNuxLUWmUUi2JrhmbQtc9+zEA144u5p5Lh0Q5WimlEqM1+ibg2cXbUl0EpVQa00B/gn1t\nRM9UF0Ep1cJooD/Bfvv107hpQv+Q7dc/+zFXPalpjJVSyadt9CmQlRl6f33HpWO2xuPleG2dpjpW\nSjWI1uhTwFPnjem4Gc+Xctov5jdyaZRS6U4DfQp4vLHNXXhvY/okd1NKpY4G+hSotWr0ednuP/7g\nSWxNcVKbUqr50ECfAp46X+DuX9Q6ZN+Cjfvoe/tctpRX+rfV1mmgV0olTgN9Cni8vhr9f50ROtTy\nb0u3A7BxT4V/W22MbfpKKeVGA30K3Di2Hyd1bs0lp3XniWsCU1O8tW4vAEerPf5tNR4N9EqpxGmg\nT4F+Ra15+4fj6dg6l25t81yP+cm/Vvsf1wTV6B9dsJniWXN0SUKlVEw00KdYYV70qQzBNfqnFm0F\noNJR61dKqXA00KdYYV70yVCHjtUGBPXMDN8qjR7tpFVKxUBnxqZYLDX6Sx5ZBMA/Z47ijN7tybIC\nvbbdK6VioTX6FMt2SYcQzn89voTnlmzz1+irPdpGr5SKTgN9E7DotoncOy22fPRl+4/6a/TVWqNX\nSsVAA30T0LN9Pp3buI++Cfbcks/9Ad4edXOkqpZjNdoxq5RyF8vi4E+LyD4RWRNm/zdFZLX1tVhE\nTnPs2yYin4rIShFJr0Vgk6yNS6fsqH4daePShr/7cBVQX6Mfds98Rv363cYtoFKq2YqlRv8sMDnC\n/q3AeGPMMOBeYHbQ/onGmOHhFq1VPtmZvuaY03u34/pz+gIwqFshrXPDd9au333E//jw8VqKZ81p\n3EIqpZqlqEM+jDELRaQ4wv7FjqcfAbqEUgLaF+QAMLK4AweP1gAwoHMhH+TuD3vOr+asp2f7Viek\nfEqp5ivZbfTXA/Mczw0wX0SWi8iMSCeKyAwRKRWR0vLylpeet39Ra/7z3+fwkwsHBmS3LIhQowfY\n4MiJE83uw8c1b45SLVDSAr2ITMQX6G9zbB5jjBkBTAG+JyLjwp1vjJltjCkxxpQUFRUlq1jNyqk9\n2pKVmeFPeZCTlUGO1aQTjj0CxxYupfGRqlpG/fpd7nl9bXIKq5RqNpIS6EVkGPAkMM0Yc8DebozZ\nZX3fB7wKjEzG+6W7Go8vWGdnZpAhkQN9ZkbgrzA4L47NTpL29vq9SSihUqo5aXCgF5HewCvA1caY\nzxzbC0Sk0H4MXAC4jtxRgezmlZwYAn1wYrNws2Xt19E1TJRqeaJ2xorIi8AEoJOI7ATuBrIBjDGP\nA3cBHYHHxBdMPNYImy7Aq9a2LOBvxpg3GuEa0o4drLMzM8iK0nTzh3c2BTwvr6h2zZ9jB/gYVzFU\nSqWRWEbdXBFl/w3ADS7by4DTQs9Q0bRt5QvUBbmZUWv0wa5+ahkfzpoUsr3OX5XXSK9US6NJzZqg\nB742lFH9OzK8Vzt/XptYfXHouOv2OivTpTbdKNXyaAqEJqhdfg7fGl2MiBBnnA/wwtLP2X7gGFC/\nfKHGeaVaHq3RN3E92+cD8MINZ+E1htU7D3PhkC6c97uFEc+rrPZwx6tr6NMxn/d/MpE6q3Heq1V6\npVocDfRN3Kwpgzi9dztG9++IiDB2QPQ5BsWz5vDOj8YDcOR4LQAerzbdKNVSaaBv4vKyM5k2vEfc\n523eVwlAq+xMAH+NPtyEKqVU+tI2+jRw18WDQ7Z95/nlABytqePfK7+goso3YaqmzsuByuqIr1dZ\n7aF41hxe/WQnG/Yc4fCx2uQXWil1wmigTwPftrJdujl8vJZbXlrJFU98BEBVrZczfvV2xNf74kvf\nyJ0/vbeFyf/7Ad+YvSR5hVVKnXAa6FuoWJpw7Hb9eBKnKaWaHg30zdycm88B4Pnr40sj5PGasMHe\nnqN1rFrXpFUqHWigb+aGdG8LENNoHKcBd8zjwTc2uu6zh+4ftZYnjHNyrlKqidFA30y99YNxfHT7\nuQ16jcff3+K63V+jr/HV6POyMhv0Pkqp1NJA30wN6FJI17bhFxT/4KcTY3qd+Wv3sHlfYBu83aJj\nD8nMzU78z0QXLlcq9XQcfZrq1SGffp0KKNt/NOJxM6xhmI9fdQbGGKYM7eZIgOaTm5V4oB92z3za\n52fzyV0XJPwaSqmG0Rp9OoujbX3mX5fz3RdWcPBojb8mbztWXcehYzUB255YWMZb62JbxORLHYev\nVEppoE8jwemJE+lDnfy/C/EGrV1SUe3hvx4PHEt/39z13PhcaQLvoJQ60TTQp5Ee7VrROjeL9vm+\nfPaSwHCZfRXV/kyXTpv2Vfo7bz0uyxXe+591lESZiPXC0s8pnjUn7CpYSqnGoYE+zay483yW3XEe\nQECK4/5FBTGPtQ9uurE9MG8DR6s9XPnE0pB9Ty3ayv6g1Ap2vh2AT7Z/yW/n+1aaPFKVeFPO8Rod\n269UvDTQp5mcrAyyM32/VrEab+bdMpZ3fjQhYBnB8wd3Cfsa/zPffXw9wJbySpZtOxh2v3MS1nm/\ne9//+LLHFvubkmLJq1bj8frXzrXNWb2bU+56g3W7jkR/AaWUnwb6NGa33NiB1dmQc/OkAWHP+6gs\nfCC/9JEPI75ndYRmGbspaUt5Jcu2HuTw8Vo+2f6l67En/3weEx56L2Db2+t9nb/rdx/hs70VvPbJ\nFxHL0twt//xLimfNoay8MvrBSkUQy+LgTwMXA/uMMae67BfgD8BU4BhwrTFmhbVvsrUvE3jSGPNA\nEsuuonjkyhE8sbCMgV0LARjdvyPXji5m5vj+EcfgN8TRag952ZksLTsQss++8Uyf7UuwNrxXO1bu\nOETZ/VPJcFlKK3hZRLtJKTND+Npji6mo9jB1aDdyGjD8symzb2SLNu+nX1HrFJdGNWex/Ic8C0yO\nsH8KMMD6mgH8CUBEMoFHrf2DgStEJDSfrmo0J3VuzYOXD/OvO5uVmcE9lw5JapC/7V+rA57bs2lf\nWLo95NjyisA2/JU7DgGxr3plj+/PyBD/461R5gmkA11CIDWMMXx+ID3+vqIGemPMQiD8Z3mYBjxn\nfD4C2olIN2AksNkYU2aMqQFeso5VTdS1o4vjPufvpTv498r6JpTKat8s2OWfuzfJuLH7DlbtOMSj\nCzaHP86u0YvQtpVvZNGXQeP709XmfZUBP2fV+P66dDvjH3qPFWGaF5uTZHzm7QHscDzfaW0Lt92V\niMwQkVIRKS0vL09CsVQ0JX3aM3ZAJ8run8qCH09gWM+2Cb3OLS+t9D+e9cqnbCmvDGl2icSu0V/2\n2Ic89ObGkE5Ym8fRdGM34xyr8bDmi8MJlbs5Oe937wf8nFXjW2FVVralwafGZAR6t8HaJsJ2V8aY\n2caYEmNMSVFRfJkYVWL+9d3RPH/9WWRkCH07FfibeKJ56PJhYfet2nHIv7pVrOxAb48WOnjUvZbu\ndQR6+1PAT/+1mosfXhQytDMdaNbQ1EqnZTeTEeh3Ar0cz3sCuyJsV01UuPHzAKvvqc9Vk5cdOZvl\n0er4kpjZ79uxIAeAPYerXI+za/RZGeK/Oeyv9N0U7KUS01FjB5ylZQf4MszNVaWHZAT614FrxOds\n4LAxZjfwMTBARPqKSA4w3TpWNVGeuvABJSez/k+lT8f8iK+zO0ygDufDzfvZV1FF23xfoJ/2aP0Q\nTmMMv/rPOjbuqfAH9wxH041tf2U1s15eHTKh6omFZRTPmuP/NNCU7Th4jJJfvcWOg8eAxFJYxMvr\nNXxj9kdc+WToJLiWLpGZ5U1VLMMrXwQmAJ1EZCdwN5ANYIx5HJiLb2jlZnzDK6+z9nlE5PvAm/iG\nVz5tjFnbCNegkqTWJfWBLdsR6If1bEdWhvhr2A01868r6N42j6I2oaOBdn55nCcXbeXt9Xvp3q4V\n4OuMDR6p8/u3PmPxlgMM6d6Gq0cV+7c/+p6vc/eLQ8fp1SHyDcrNvooqqmq89I5yc0uGf5buYH9l\nDS+v2Mmt553seowxJqkByB69tH63TkILlk5NN1EDvTHmiij7DfC9MPvm4rsRqGYgUo0+uP2+T8d8\ntpQnr5Nq1+EqerRvFbLdHq657cAxth3w1XQNJqSGbgf+o44a/Q1/KeWQlTmzbP/RhAL9yPve8b3/\nAxfFfW687CvKiBDIvQYyk1jRjNRcB768Rrf8fSU3TejvX81MNT/pOdNEJaSLS40aApttbLHW5rPj\niEofbwsdxua2aInXEJIz357Ne8wR6O2ZtEDERGqHjtXw1KKtUWtwldUetu4/yhtr9kQ8LlH2zSr4\nJ+YsVbTAHK9ov8et+48yZ/Vubn7xk6S+b3PQoppuVMtx4ZAuvHDDWbyy4gteXrGTZ687k2E929HB\n6iR1ijUDZUMDU6VLx67XmJBUyrZ5n+7mutHFtLMyeMZSjtteXs2ba/cyvFdbzujTIexxlz6yiDLr\nU8zbPxzPSZ2TO1vVn6pC7O+hgSbWyWWxqovwKQ7qbzLBZXl73V5W7jjEjy8cmNTyNCXp1HSjNXrl\nJyKMOakT9112Kn+78SwmDOwcEuTPHdQZgMoYR7k0tALqNprGGBNSo7dt2lfJ6fe+xT9LdwZsDw70\nnjqv/x/ZHs757oZ91Hi8fFR2gPc27gt57TJHU9Xh44ll4Ny0tyLuiU/OS410w/ryaA3Fs+bwr+U7\nwx4TzC0ltdt7B99ybniulEcWbOYvi7fF/F6x+GBTeVzlV7HRQK9C5GVnMrp/p5Dtm+6bwuxrSgCo\ndGlS6enSxn7FyN4NKss/S3eEbPN6o39S+GDz/oDn9o3hqUVbKZ41h5PumMfv3vKlTbZf6tEFW/jt\n/I1Mn/0R1z7zccTXd5vU5fUaDgSN5999+HhAzfD83y8MO/EpXO3Z7TrcbLOm6z+/ZFvYY/YdqQoo\nT7SmG2OVKlyR7n49ueMrrn5qGT/+56qkvmYy7Dp0nGVbIyUIaNo00KuYZWdm+Dtl3eLNX74dmu/+\nvq+E5MGLy4KNobOkqzzRc9L/36rAKRt1Xl8N/t7/rPNve3rRVmtf/cXE2sHsNtb/0QWbOeNXb/v3\nbd1/lFG/fpfH3y+L6TWDm27cRBom6hx+6ubTnYcZef87/MNx84wa6P01+tjbq4/X1IXkNWrORODc\n377P1/+8JPrBMXp0wWaKZ80JOws82TTQq4TkZYf+6RTkhHb5ZGQIr9w0Oqnv/f2/xd8xWOeFDzcH\nZtQ0/n31wS7GycHc+veVIU0Mb2/wNffsOeIL9LusNBAfbAq9Wbl9IjH+ztgINfqIgd73Pdyonc/2\nVgAwf+1e/vz+Fl8TWLQ2+hhuPsG++qfFnHnf2zy/ZBvFs+ZQHeHGHO7Gdbymjh0HjzHoznk8/9Hn\nEd/vt/M3MujOeVHL9fqqXa59PtEYA8drk7vgzZ/e863WluzXDUcDvUrInJvH8uiVI3jte2P827Iz\nJSBQ9raGM/ZvAil2vV7DgaOBtUx7hM6njlw54QLa4qCmIIB3NwQujp5lXXyd1e5tv5ZbB6pbZ/af\nF5b5z1u76zDPWu3fAaNujGHhZ+Wuibb8aZyjROV3Nuzj1/M28MGm/dHb6MNnLQnLHpP/+7c3AeH7\nc1buOES/n81l8eb9/Om9Lew7Uv8p6dpnljH2NwuoqvVy7/+tcz3f9vC7m6mq9UbsPF236wg3v/gJ\nZ933dtJHLiXiRHf06qgblZD+Ra1DAnh2lq9px2vVEt//yQTfdscQy5snncQf3w2fobKxeLyGDJfa\na/A/XLjatNvM0WNBs3DtAGt/Grdr1m5xpdpTR6uc+lQSznII8PA77j8jrxeueXoZEDq23z88M0yc\nXxR0s7rm6WX84tIhIce9u2Ev97y+jmnDu5Nr5fpPZKhh8M/WGEOd15BlDdddttX3CeuRBZtZvOUA\nHzrKt9TRHp4bZr2BeZ/u5qAje+nRmjpa57qHNLvmfLSmjqcXbeXGcf1ivo7GHGV5omZsa6BXSZOT\nmWEFN8OHsyb5g0NWRv0/6uDubQDIz8kMCZSN6WevfsqI3u0CtrVtlU1tUPCP55/6QGUNr33yBdOG\nd0dEsC/TrjHWL50Y+s/sXInrpWXbmfXKpwFlOHS8PoA5P01c+eRHAa/jqfNy8GgNndvk+ZtZwjXd\nvOqyIpe94LvTHa+uYffhKh523JAbEuvsWHb/3PU88cFWbjl3ABMGFvnLuXiLL+CHa1YJt7DMd19Y\nEfC8oqo2bKB3Tvgrb0ACvGM1HvJdmijjZf9FJGt2eTTadKOSJtsf6ANr8c7HdkK09vmBwzY7uozV\nT7YV2w8FPD98vDakOSeeQP/pF4e59e8reeCNDVTV1vlvaHagt/+HXWv0tfWB3m6y8ZcBCZil/M6G\n+qGeZUGdxXe9vpaR97/D0WqPv2MvM0PYcfAY33p6Gafe/WbEtQGyXCa0uS0HGcvPpaKqNqBZ5Etr\nVrK97S9LfG3tf3hnE5c9ttjf1OUvS5gOkmgriNk1/ooqD/uOVHHW/W8HLEwPgc1Z+TmRk/JFMviu\nNxM+182JakbSQK8abFS/joAvwNj/q9mOWrxdsz+7Xwd/gO8fNNnouetDR+ycCPsrArM2Ov/xYm1H\n/fP7ZQy68w3/aBe78/Epa1SP1xg8dV62ONZ+rfbUceZ9b3P1U0tDRqiIQG2MAWD+Wt8s3WM1df5P\nJyK+TzDvf1ZOZbWHJxaGH/Xj/D35y+bSQRgt0Nd4vAy9Zz6/+L/Q4Zb2DSi4CSYzaMZ1uDTZ0QK9\nHbiPHK9l3po97D1SHTK+33mZzkED2620Go1h876KsBML7T+tE1Wj16Yb1WBPfqvEP9LEDnbBQ/w+\n+OlEOrXOJS87g3suGczIvh1Z+Fn9aBS3ETsnQnCNPjervrYXaaFzN/YlL//8Sz4qO+BPwXD4WC0n\n3RE4KqTa46W8ojrsMMTaGN/bTtNsMP6gkhmU3XPVzkMs2hTamQyx1+jLyo/yr+U7uez0Hq4B2e7U\nfenj0HkPdV6DMSZk8ltwp3G2S6oNcE/BEfA6VhSv8Xjrh5gGFdHZnJWf6/sdL9i4j+ue+ZjHvjmC\nqUO7RXyPeO06dJzzfreQa0cXc49LP4jdyX2i2ui1Rq8arCA3y98x+7URPYHQ2luvDvm0yslERLh2\nTF9/W70tVQt8By9y4pzxWhPnGOf3rDH/j723hSc+2OrfXuayQpGdbM3Nr+asZ10M2SR3OVbx8nph\n3prdgC+oOTtPdx+u4qqn3NMQf7a3MmSbWy3zWE0dP/7nKm55yX1oq31jcavBerxe//rATsFNNeFq\n9LkuQ3kDGet9jL+ZLLjz2Dnyyb5xfLbHN9x0admBBgfcn/5rFQscTWx2xecTl+uGE1+j10CvkurO\niwez6u4Loi5OEizcyIp4PXzF6XEdHxzo33d8ygjObZ9M4QJvPJy19iNVtfxndX2g3xXHUo7xsN/D\n6Y/vbGLoPfPDnuPxGteO95++vNrl6FB2YP5420FW7jjkz9cfzP7kAPVNTcWz5vDNJz8KyI1kB1n7\nE8RflnzO/XPXh33/WFrw/lG6k+uerZ9Nbf/t5Ef5P6iLMrw1WTTQq6TKzKhfuDsesdTol9w+Keox\nl5zWPa73XeCS08YWqdbdFDgrrc4moAwhpDMyme549dOA53YqiXA8dYaqGCYGhRvTb/9t/NfjS/jK\nox8y9jcLwpxvHE039T+cDzcfCEgdYdeisx1/c393aXKyReswdfbl2LOi7Rub3X+w4+CxgE87OupG\ntRg/PL9+cQ1n23g4bsMGf/WVU7n/sqEB22aO788vp4W2i/5z5qiQbcGzZZ3CrV3bVDhrms6UDPPX\n7XU5Orr7566neNacqMe9sHR7XK/r8RqqaqPXXGs9htNcFqh3a9Jx+7TlS3Phexx8ijNY20E/x9E/\nkRkhnXa0Gr1ziO51z37Mkapa/422VU4mx2o8jP3NAma94vgEYzfdRJmZnCwa6FXKOCdcueWtP61X\nu4Dnzjj/4NeGsuHeyVx1dh+uPCswcdqsKYO4ZlQxD3w18AZwZnEH5v9gHKf3bse9MeTgueKJj6Ie\nk0rvOPLt2wnNGmJ2hNE5DVHn9VJZHf3TUZ0xrkNRa+tMSK36SFX969m7fvbqGn+tOVIbvd0e7+z8\ntfsLqmrrqLBe27ic68aZr2Z/ZTWTf7+QB9/YAPjmWtgznOev3ctLy7bzyfYv/Z2x/yjdwUarr6Ax\n6agblTLOfyC3mZdt8nx/nhcM7sK1o4sDavTfODMwuN82eRBDewTWBqeP7E3vjvlc+UR9e/jJXQp5\n9aYxVNXWcedrawKO79GuFV8EtW13Lsxl3MlFrqlzN9w7mXMefNc/8uVEu8eRGuDhFMw2jtULS7ez\nKkynpJPHGxrQwRdIL398ccA25wL0dqA9eLSGJWW+T2jBf07vrK9vorObS5zHZGYIK3cc4ivWesX3\nThvCv1f6EuMFZww9XhM4qzk4Mdkux6erJWUH/GWqrPYETIwDeG7J5zy35PNGX8Esphq9iEwWkY0i\nsllEZrns/4mIrLS+1ohInYh0sPZtE5FPrX2lyb4A1XwF15T+89/n+NMmnNGnvX9hj7P6dWT0SZ0i\n5nD57oT+nDMgNLWyW7pliL3z9/Dx2rDvm5edyaNXjojpdVqyV1Z8EVNW0Dqv17X2XFvn5ZOgyW7O\nzl1n27fdJh78O3POALZr9M4ml71HqgM6ZO/8d/18gOB7z9B7AidNOUdnNSSFzZItBwIyiyZT1L92\nEckEHgWmAIOBK0RksPMYY8xDxpjhxpjhwO3A+8YYZ/Lmidb+kuQVXTV3wbW3U3u0pU/HArY9cBEv\nf3e0P++Mcelgayi3TxBuL1/t8YZN+wu+m1C8HcDNydVn9zlh71XrMTEngHNmfXTWqO0/qcfe28KM\n59zrlXaNPrh9PFy++eCJc8EdqMFpNBJhjOH1VV/w0JsbG/xabmKp1owENhtjyowxNcBLwLQIx18B\nvJiMwqn0Zv+/fPX0Hq777fjqT9aVoh4ltxmyzuDuNqTzxxf4OpqLO8a/IHlT0rF146emsFV76lyb\nbo4cD82B46zRO09x3hTCdUqv3nmIWS+vZs2uw677g0UbdRPr5LZI+t4+lxeX7UjaMONgsbTR9wCc\nnyd2Ame5HSgi+cBk4PuOzQaYLyIG+LMxZnaCZVVpxv4IHa7GPGZAJ55ctJURvdv7jmvkxZrDvfwB\nx+ibm88dEDBaKFin1rnsr6yma1vfaltNICNug5yomZvg+/TUymXc+Z4joYu8zPs0dDy/7zWiD+Oc\nF+fi7s5F5t3YqS6g4ZkuGyvQx/KqbkUP99u/BPgwqNlmjDFmBL6mn++JyDjXNxGZISKlIlJaXh66\nUINKPxMGFdEmL4tvj+nrun/iwM6s/cWFlBT7FuyOdVGQRA3v1d51+/7Kav9audEC33BrpFCXNrmu\n+78TlB73O+NjT5drG9ytTfSDkuREjfMGX6CPtFSik1uqBWichTzchuD+z5sbOWYtpxltYZR45MQw\nzDgRsQT6nUAvx/OewK4wx04nqNnGGLPL+r4PeBVfU1AIY8xsY0yJMaakqKgohmKp5q5zYR6r77kw\nJB2CU4Ej7WyiNfq/fHsk7/5ofMRjnrimhIcuH+a6r7yimhF9fDeBaEPt/jB9OHNvHusvt3PhjsK8\nLG6fegrPXHemf9tVZ/XhD9OHx3IZfm6rezWWhsT5CQMD/49L+rjfSG3VtXVRf77RrPkieuqIZHhk\nwWam/uGDkHkHDV1PJJU1+o+BASLSV0Ry8AXz14MPEpG2wHjg345tBSJSaD8GLgDWBJ+rVCwS/Vg8\n/uQi+kVZ5Wpk3w7kZWfym6+FBvv9ldX+1bJ6dYjc5l6Qm8Xg7m3847Kd//j/tlbjmjiwM09fW0Kv\nDq3o3CaXQV3jq6EXhMm5Hs2sKYPiPqchgXfG2H785MKB/uduawo7VXu8nKCMAEmxzSXz5f4G5LqH\nFAZ6Y4wHX5v7m8B64B/GmLUiMlNEZjoOvQyYb4xxjqPqAiwSkVXAMmCOMeaN5BVftSTJbqNfdse5\n/sd22/D4gaGfJkf378TFw7rxtxvOYvqZvUL2A1w3ppg+jo5Xt0XUnTebSYO68MFPJ5GblUmU5Iwh\nogWDd380ntN7t6NT68Dmo4KcTP7v++e4ntOtbZ7rdk+d4ZpR8Y+8WXTbREaf1Ik2jnQY0W5QHq85\nYYtlN1W5ceaIilVMf2LGmLnGmJONMf2NMfdZ2x43xjzuOOZZY8z0oPPKjDGnWV9D7HOVSkSym+g7\nF9YHNzufSnCq3JO7tObRK0cgIow+qVPYJfXuvmQI7/9kov95uJS7buw0u1kZws+mDuLXQTN6AV64\n4SyGWekBor12v6LWvHrTGEp/fl7A9owMCZsJ8pZzB7huz8vOYPzJ8TWlXnV2b3q299304l0bdV+Y\ntM3p6KJh3UJSMEdLyZwonRmrmo3MDOGsvh24bkxxo71HcH72Ph0LAmZBxqq+Rh890NmTe3KzMpgx\nrj8APdu34uqnlvmPEakP8OE+2cy9eSzFncI3LWWIkBemsy/ch6X/njSA9z8Ln/gt2Is3ns2o/h39\nz5vCQtyJGn9yUUA202Tr2iaPoT3bBqwAFj0lc2I0141qNkSEv39nFJNPTd4iETefO4B+nQr8zwty\nsujUOtc/ysYtB08snLnWZ199Bree515jhvqEWs4MnmMHhNai7bKEy/JYmJcVcT3TDAkfSNw+qfRs\n34pWOZmcGpRaIpLgBGTJCvTx9i9cFmZuRjxE4KYJ/Rv8OuHkZmWEVARS2RmrVNr64fkn8+6PJ/if\nZ2YIpT8/j0uHd7eeJ/YvkmXVvg1wwZCu3Hpe+LH3tkgZPAXx1+jt4Nk9qF09WpOORKrRBz1vl5/N\n2z/0jVTq2T4/IBfL09eWMNYl3QQQ0t9gd+baC7OH6+OI5szi9nHddIMzmiZiWI+2/t9jY8jLzgzJ\n0x/vOg6x0kCvlAt7Wnt2goP3g2f1RmLP5oyWk99uv7XLNn5g54AAHC0QjizuEHONPi8rMyTozLn5\nHJ657kwmDerCY98cwXfG96Nn+1YBxwQ3K9l9q/65EAn/PCXi0MXgiVbhmttunnQSt08ZRL+igpB9\n3594UkAb+S3nnRzQmz4kwjDgRORmZfh/93aOnqLW7vMvGkoDvVIuBnUtBGCC1YQTL4mj67hrmzyy\nM4XbJodvnujZvpU/SNs1+uCP/dkRbhTbHriI4k4FMTcNuN2ghnRvy8SBvp9HYV42t085hZ1fBmb7\nDG66CV4IxN7984tOYcO9k6OWY0Tvdpzeux2nRJkkFuvIoG+e3YfvjO/Pm7eGztv88YUDOX9wF//z\nzAyhygrEP77gZM60blbxcpZ9cLc2/t9BblZGyPq87fLjX7QnFhrolXJxao+2rLr7Ai5NMGGZXbGN\nZdBJq5xMNt03lYuGufc9fHLn+fTqkM89lw7hGyW9+PY5xQAMtG5GtlhGbNg1d2dK5zdvHRdyW0q0\naT14QZDzTvEFzouta7MDflaGxNRMMbxXe169aQx52ZkRP/G49QWM6tcxZJv9/uGauYJvcAO7+H7G\nI3q3T7i/4aHLh/kXvcnLzvDfTNrmZ/sDvb04ebhhrg2lgV6pMBJZEtFmt4UH58hPRPsCX2KxzoV5\nPHj5MCYN6sJr3xvDtaOLA45zC15v3DqWRbdNDNj29g/H87cb69NVDexa6L8xXWAFoUmDYhtSGTws\ns7I6MAHZwK6FbHvgIn+Hrh1og2NmuNnBzpagkX19NepZUwbx9ZKeAce5pWp4ccbZIZ24Bbn1N5c/\nTB/O32ecHbDfDua/sham+crpPVg8axKjT+oUthM8mpysDP/rZmaIf65Ap9a5/tw83zizF2/eOo4L\nh3RN6D2i0eGVSjWCtvnZvHLTaH+NMNmGB62+Be5L7rnNurXz/P/ma8N4xlr9yA6orXIyWTxrUshk\nq3C6WjXQiQOL6FyY57omgFOPdr42/eCsmHbZzzulM18d0ZObXlgBwKk96sv/h+mn84+Pd3DD2L4B\nicQg9hm8zlFJ04b7Rua89YNx/pzy9v2ic2Guv1zdrTInmo44JzPD/3pn9+vI9JG+eQaj+nWkMDeL\niioP+TmZIZ/QkkkDvVKNxM66majfXD6MAZ0jp24AX7t3Iq0KXz+zF1+3RsHYfQpegz+wxcJuLmqf\nn8ODYXIFOX37nL70aN+KKacG1lztG+K04T2YOrQbi26bSFWtl/6OTtO2rbK50UoKF7zYeLjka87m\nFrcbIcAAx8040toHngRn7eZkZdC9XSve/dF4+nQsIDNDuPNi35IeT193Jq+s+IKTG6lCYNNAr1QT\n9fWS2IYizv/BOFbtiC23ejj2KBV7+cZY2e3m1TEGwcwM8bdHOw3oUsj6X072l8OeWRuOnaWye9s8\ndh2uwus1/sf2sFCoX1zkO+P78YMYhrja2TPdRtXWJthGb3e+uuVbGtS1DT+b2vjZSDXQK9XMndS5\nkJM6N6xGeP4pXfj5Radwxcje0Q92GHNSJzoU5DBjbPzpliFw5Ek8M5CP1/jOsfPneLyGhT/19UU4\nx75fO7qYz/ZWMHNc/5g6f+1PAG41+roYmm7ysjNo2yqbvUfqUzlEGzZ7ImigV0qRkSHckECw7lCQ\nw4o7z0/4fT+4bSKHj9XGfZ6d/rnQ+gTi9RrXyU1t87N59Juxr+s7qGshH2zaT1FhaB9FcGfsM9ed\nyTMfbmOhI03Cul9M5lhtHafeXb+urAZ6pVSL1rkwLyC5XKxuPfdkBGFg19as2H4o5gVLovnp5EFc\nOKQrQ7qHjpYa0ac9b6+vz/szcWBnyo9UBwT6jAyhICeT807p4l+ZKtKM5xNFA71Sqtlpm5/NXZcM\n5sPN+wFCZugmKjszwz+LN9jMcf2Zemo3Nu6toKM15NVtXpyI8OS3Snhq0daAnEeppIFeKdVsjTmp\nk5V7p/FXpcvIEIo7FVDsSILnDOMjg24Q15/jvkRmKmigV0o1a5MGdYl+UCP76oge/O7rw1NdjLBS\n30uglFLNlD8ZXBNPu6+BXimlEmQ33TTxOK+BXimlEmWvSNZYSwAmS0ylE5HJIrJRRDaLyCyX/RNE\n5LCIrLS+7or1XKWUaq6mDu3GzPH9+dnUU1JdlIiidsaKSCbwKHA+sBP4WEReN8asCzr0A2PMxQme\nq5RSzU52ZkbcyxymQiw1+pHAZmNMmTGmBngJmBbj6zfkXKWUUkkQS6DvAexwPN9pbQs2SkRWicg8\nERkS57lKKaUaSSzj6N2mdgV3Mq8A+hhjKkVkKvAaMCDGc31vIjIDmAHQu3d8iZWUUkqFF0uNfifg\nzJfaE9jlPMAYc8QYU2k9ngtki0inWM51vMZsY0yJMaakqKjxZ7kppVRLEUug/xgYICJ9RSQHmA68\n7jxARLqKNXNAREZar3sglnOVUko1rqhNN8YYj4h8H3gTyASeNsasFZGZ1v7HgcuB74qIBzgOTDe+\npVpcz22ka1FKKeVCTJLSeyZTSUmJKS0tTXUxlFKq2RCR5caYErd9TXs6l1JKqQZrkjV6ESkHPk/w\n9E7A/iQWpznQa24Z9JrTX0Out48xxnUkS5MM9A0hIqXhPr6kK73mlkGvOf011vVq041SSqU5DfRK\nKZXm0jHQz051AVJAr7ll0GtOf41yvWnXRq+UUipQOtbolVJKOWigV0qpNJc2gT5dV7ISkV4iskBE\n1ovIWhG5xdreQUTeEpFN1vf2jnNut34OG0XkwtSVvmFEJFNEPhGR/1jP0/qaRaSdiPxLRDZYv+9R\nLeCaf2D9Xa8RkRdFJC/drllEnhaRfSKyxrEt7msUkTNE5FNr3x/t/GIxMcY0+y98eXS2AP2AHGAV\nMDjV5UrStXUDRliPC4HPgMHAb4BZ1vZZwIPW48HW9ecCfa2fS2aqryPBa/8h8DfgP9bztL5m4C/A\nDdbjHKBdOl8zvrUptgKtrOf/AK5Nt2sGxgEjgDWObXFfI7AMGIUv/fs8YEqsZUiXGn3armRljNlt\njFlhPa4A1uP7B5mGLzBgff+K9Xga8JIxptoYsxXYjO/n06yISE/gIuBJx+a0vWYRaYMvIDwFYIyp\nMcYcIo2v2ZIFtBKRLCAfXxrztLpmY8xC4GDQ5riuUUS6AW2MMUuML+o/5zgnqnQJ9C1iJSsRKQZO\nB5YCXYwxu8F3MwA6W4ely8/if4GfAl7HtnS+5n5AOfCM1Vz1pIgUkMbXbIz5AvgfYDuwGzhsjJlP\nGl+zQ7zX2MN6HLw9JukS6GNeyaq5EpHWwMvArcaYI5EOddnWrH4WInIxsM8YszzWU1y2Natrxlez\nHQH8yRhzOnAU30f6cJr9NVvt0tPwNVF0BwpE5KpIp7hsa1bXHINw19iga0+XQB/zSlbNkYhk4wvy\nLxhjXrE277U+zmF932dtT4efxRjgUhHZhq8ZbpKI/JX0vuadwE5jzFLr+b/wBf50vubzgK3GmHJj\nTC3wCjCa9L5mW7zXuNN6HLw9JukS6NN2JSurZ/0pYL0x5neOXa8D37Iefwv4t2P7dBHJFZG++Nbu\nXXaiypsMxpjbjTE9jTHF+H6X7xpjriK9r3kPsENEBlqbzgXWkcbXjK/J5mwRybf+zs/F1weVztds\ni+sareadChE52/pZXeM4J7pU90gnsWd7Kr4RKVuAO1JdniRe1zn4PqKtBlZaX1OBjsA7wCbrewfH\nOXdYP4eNxNEz3xS/gAnUj7pJ62sGhgOl1u/6NaB9C7jmXwAbgDXA8/hGm6TVNQMv4uuDqMVXM78+\nkWsESqyf0xbgEazMBrF8aQoEpZRKc+nSdKOUUioMDfRKKZXmNNArpVSa00CvlFJpTgO9UkqlOQ30\nSimV5jTQK6VUmvt/1qN01DNqkMEAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "value_grad_fn = jax.jit(jax.value_and_grad(task.loss))\n", "lr = 0.1\n", "\n", "losses = []\n", "params = task.init(key)\n", "# get from environment variable so this notebook can be automatically tested.\n", "num_steps = int(os.environ.get(\"LOPT_TRAIN_LENGTH\", 1000))\n", "for i in range(num_steps):\n", " batch = next(data_iterator)\n", " loss, grads = value_grad_fn(params, batch)\n", " params = [p - lr * g for p, g in zip(params, grads)]\n", " losses.append(loss)\n", "plt.plot(losses)" ] }, { "cell_type": "markdown", "metadata": { "id": "jddnOrHu8WCb" }, "source": [ "## Optimizers\n", "SGD is all fine and good, but it is often useful to abstract away the specific update rule. This abstraction has two methods: An init, which setups up the initial optimizer state, and an update which uses this state and gradients to produce some new state.\n", "\n", "In the case of SGD, this state is just the parameter values." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "executionInfo": { "elapsed": 54, "status": "ok", "timestamp": 1647716627049, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "WX6fbsYu8Xmy" }, "outputs": [], "source": [ "class SGD:\n", "\n", " def __init__(self, lr):\n", " self.lr = lr\n", "\n", " def init(self, params):\n", " return (params,)\n", "\n", " def update(self, opt_state, grads):\n", " return (tuple([p - self.lr * g for p, g in zip(opt_state[0], grads)]),)" ] }, { "cell_type": "markdown", "metadata": { "id": "ah5U6H1_qzpv" }, "source": [ "Instead of inlining SGD, we can now use our optimizer class." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "height": 282 }, "executionInfo": { "elapsed": 904, "status": "ok", "timestamp": 1647716628067, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "Ul52LhQc8x3w", "outputId": "755da128-a07c-48af-d15e-c43495229ffd" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAAA30ElEQVR4nO3dd3wUdfoH8M+TTgotCZ0QQu/FiBGQYkFQEfXneWI7K6J4eqen\n4p3t7If1PPEQhbODBaxUFWkiJSAdggEChJaEmoS0zT6/P3ZmMrs7szu72c0my/N+vfJi9juzs99Z\nkme+863EzBBCCBG+IkKdASGEEMElgV4IIcKcBHohhAhzEuiFECLMSaAXQogwFxXqDBhJSUnh9PT0\nUGdDCCEajPXr1xcxc6rRvnoZ6NPT05GdnR3qbAghRINBRPvM9knVjRBChDkJ9EIIEeYk0AshRJiT\nQC+EEGFOAr0QQoQ5CfRCCBHmJNALIUSYC6tA/+8ff8fCrUdgt8vUy0IIoaqXA6b8UW1nTF++G6WV\n1YiPicSZymos/uswdG2ZFOqsCSFESIVNiT4ygrDpqVGYcm1ftEiKBQCMen053lm2G7K4ihDibEb1\nMQhmZmZybadA+CL7AB7+cjMAoHFcFObeOxidW0jpXggRnohoPTNnGu3zWqInovZE9DMR7SCibUT0\ngMExNxLRZuVnFRH10+3LI6ItRLSRiOpsAps/ZLZH9uMXY2S3VJwut+Hi15Zj2a7Cuvp4IYSoN6xU\n3dgAPMTMPQBkAZhERD1djtkLYDgz9wXwLIDpLvtHMnN/s7tNsKQkxuJ/tw3CvPuHAgDu+iAbi7Yd\nqcssCCFEyHkN9Mx8mJk3KNvFAHYAaOtyzCpmPqG8XA2gXaAzWhu92jTBikdGIi05Hnd/tB6vLc4J\ndZaEEKLO+NQYS0TpAAYAWOPhsDsALNC9ZgCLiWg9EU3wcO4JRJRNRNmFhYGvYmnfPB5TbxgIAHhz\nSS7u+Xg9KmzVAf8cIYSobywHeiJKBDAHwF+Y+bTJMSPhCPSP6pKHMPNAAGPgqPYZZvReZp7OzJnM\nnJmaajh3fq11a5WEKf/XFwCwYOsRzFyZF5TPEUKI+sRSoCeiaDiC/CfMPNfkmL4A3gMwjpmPqenM\nfEj5twDAVwAG1TbTtXHdue2R89xopCfH418Ld+KZ77aHMjtCCBF0VnrdEIAZAHYw82smx6QBmAvg\nZmbepUtPIKIkdRvAKABbA5Hx2oiNisT5nZIBADN/2Sv97IUQYc3KyNghAG4GsIWINippfweQBgDM\nPA3AkwCSAbztuC/ApvSwaQngKyUtCsCnzLwwkBfgr6fG9kJhcQV+3FGAJTsLcFGPlqHOkhBCBEXY\nDpiyYk9hCS58dRl6tG6MBQ9cEPTPE0KIYKnVgKlwlpGaiCGdk7Hj8GnMWrs/1NkRQoigOKsDPQDc\nf2EXAMBjc7fgoc83hTg3QggReGd9oD8vI1nbnrMhH+VV0rdeCBFezvpADwAvXdNH284tKAlhToQQ\nIvAk0AO4flAaFv/VMY5r68FTIc6NEEIElgR6RefURHRKTcDkuVtw6GRZqLMjhBABI4FeERFBeODi\nrgCA+VsOhzg3QggROBLoda7s1wZdWybiy/X5oc6KEEIEjAR6F5f1aY2dR4ql940QImxIoHeRkZoI\nAJixUubAEUKEBwn0LjJSEgAALy/KwTcbD4U4N0IIUXsS6F2kK4EeAD5avS+EORFCiMCQQO8iMTYK\nT1zRE91bJWH9vhM4XV4V6iwJIUStSKA3cMfQjqistgMA3lu+J8S5EUKI2pFAb+K5cb0BABVKwBdC\niIZKAr2JwZ1TAADvLNsjvW+EEA2aBHoPUhJjAQCFJRUhzokQQvhPAr0Hr13XDwAw8uWl0igrhGiw\nJNB70K5ZIwBAaWU1vpU+9UKIBsproCei9kT0MxHtIKJtRPSAwTFERG8SUS4RbSaigbp9o4koR9k3\nOdAXEEytmsRp26fKpEQvhGiYrJTobQAeYuYeALIATCKini7HjAHQRfmZAOC/AEBEkQCmKvt7Ahhv\n8N56Kz4mStt+eVEOKmwy/40QouHxGuiZ+TAzb1C2iwHsANDW5bBxAD5kh9UAmhJRawCDAOQy8x5m\nrgQwWzm2wbhtSLq2nXOkOHQZEUIIP/lUR09E6QAGAFjjsqstgAO61/lKmlm60bknEFE2EWUXFhb6\nkq2gempsL227pNwWwpwIIYR/LAd6IkoEMAfAX5j5tOtug7ewh3T3RObpzJzJzJmpqalWs1WnpJul\nEKIhivJ+CEBE0XAE+U+Yea7BIfkA2utetwNwCECMSXqDdPR0eaizIIQQPrPS64YAzACwg5lfMzns\nWwC3KL1vsgCcYubDANYB6EJEHYkoBsD1yrEN0orfi0KdBSGE8JmVqpshAG4GcCERbVR+LiOiiUQ0\nUTlmPoA9AHIBvAvgXgBgZhuA+wAsgqMR93Nm3hboiwi2j+4YBMAR6GXhcCFEQ+O16oaZV8K4rl1/\nDAOYZLJvPhw3ggbrgi6pGD8oDbPW7se6vOMY19+wPVkIIeolGRlr0eQx3QEAhcXSICuEaFgk0FvU\nOC4KcdERyC0oQZH0vhFCNCAS6C0iIlTY7Ji97gAyn/tRpi4WQjQYEuh9cP25adr2W0tyQ5gTIYSw\nTgK9Dx65tJu2/fn6Ax6OFEKI+kMCvQ+axkdr2+S5I5IQQtQbEuh9QEQY2S1V2Q5xZoQQwiIJ9D6y\nK22wtmppjBVCNAwS6H1kV3rbVNjsIc6JEEJYI4HeT5WyCIkQooGQQO+jYV0cdfQpSbHSl14I0SBI\noPfRnRd0RHpyPPYUluLmGWtDnR0hhPBKAr2PiAh92jUFAKzMlWmLhRD1nwR6P7RpEqdtPzD7N5RV\nSn29EKL+kkDvh9SkWG37m42HsHj7kRDmRgghPJNA74eLerR0eh0TKV+jEKL+kgjlh44pCejcIlF7\nHSWBXghRj0mE8pPdXtO10i7dLIUQ9ZgEej/NvjtL2y6vksZYIUT95TXQE9FMIiogoq0m+x/WLRq+\nlYiqiai5si+PiLYo+7IDnflQapEUh/sv6gJAAr0Qon6zUqJ/H8Bos53M/DIz92fm/gAeA7CMmY/r\nDhmp7M+sVU7roVsHpwMA9h8/g+fnbcepsqrQZkgIIQxEeTuAmZcTUbrF840HMKtWOWpAGkVHAgCm\n/rwbADAwrRnG9GkdyiwJIYSbgNXRE1E8HCX/ObpkBrCYiNYT0QQv759ARNlElF1YWBiobAVVXHQE\nYqJqvsJymehMCFEPBbIxdiyAX1yqbYYw80AAYwBMIqJhZm9m5unMnMnMmampqQHMVvAQEVrrRsmW\nV8nUxUKI+ieQgf56uFTbMPMh5d8CAF8BGBTAz6sX2jVrpG3/a+HOEOZECCGMBSTQE1ETAMMBfKNL\nSyCiJHUbwCgAhj13GrKxfdto2yfPVOHgybIQ5kYIIdxZ6V45C8CvALoRUT4R3UFEE4loou6wqwEs\nZuZSXVpLACuJaBOAtQDmMfPCQGa+Pji/U7LT6yv/szJEORFCCGNUHxfPyMzM5OzshtPtPq+oFCNe\nWaq9/u2JS9AsISZ0GRJCnHWIaL1ZN3YZGRsA6SkJTq8HPPtDiHIihBDuJNALIUSYk0AvhBBhTgJ9\ngNw9PCPUWRBCCEMS6AOl/rVpCyEEAAn0AaOP80TA8dJKPPH1VlTItAhCiBDzOqmZsOae4Z1w5FQ5\n2jRthGnLdmOg0vNmQFpTXDOwXYhzJ4Q4m0mJPkCaJcTgzfED0CE53im9Hg5TEEKcZSTQB9igjs2d\nXhOFKCNCCKGQQB9gnVITMbJbzeybEuiFEKEmgT4IkuKite0IifRCiBCTQB8E52XUVN88MHsjTp2R\nJQaFEKEjgT4IbhiU5vT61R9yQpQTIYSQQB8U5FJd8+Gv+1Bpk9WnhBChIYG+jryzbHeosyCEOEtJ\noK8jJ6SeXggRIhLog+T+Czs7vW7SKNrkSCGECC4J9EFy/0VdnF7Hx0SGKCdCiLOdBPogiYqMwBNX\n9MS7tzhW9np+/g78uP1oiHMlhDgbyaRmQXTH0I5Or6cv3wObnTG6d6sQ5UgIcTbyWqInoplEVEBE\nW032jyCiU0S0Ufl5UrdvNBHlEFEuEU0OZMYborV5xzHx4/UoLK4IdVaEEGcRK1U37wMY7eWYFczc\nX/l5BgCIKBLAVABjAPQEMJ6IetYms+HCZpc+9UKIuuM10DPzcgDH/Tj3IAC5zLyHmSsBzAYwzo/z\nNHi92zZ2er3zcDEA4LXFOUifPA9V1RL4hRDBE6jG2POJaBMRLSCiXkpaWwAHdMfkK2mGiGgCEWUT\nUXZhYWGAslU/fDNpqNPr295fB7udMX3FHgCQUbNCiKAKRKDfAKADM/cD8B8AXyvpRtM2mi7DwczT\nmTmTmTNTU1PNDmuQIiPcv4pTZVXaoiQ2u6xOIoQInloHemY+zcwlyvZ8ANFElAJHCb697tB2AA7V\n9vPCxdcbD6JCKcnbpOpGCBFEtQ70RNSKlFm8iGiQcs5jANYB6EJEHYkoBsD1AL6t7eeFi39+t13b\nrqqWEr0QIni89qMnolkARgBIIaJ8AE8BiAYAZp4G4FoA9xCRDUAZgOuZmQHYiOg+AIsARAKYyczb\ngnIVDZw0xgohgslroGfm8V72vwXgLZN98wHM9y9rZ4/Xf9yFF67ug7homSZBCBF4MgVCHblnRCcM\n65qKP7tMdgYAczccxPJd4dXTSAhRf8gUCHXk0dHdAQAFxeX4z5Jct/3S8UYIESxSoq9jZtUzxeUy\nX70QIjgk0NexuCjjQF9SYdO2V+85hvNf/AmlujQhhPCXBPo6Fh1pNI4MKCmvCeovzN+Bw6fKseto\ncV1lSwgRxiTQ1zHXhcNV+hK9OiVCTJT89wghak8iST1xutw90EdFyH+PEKL2JJKEwPCujrl8Luze\nQkvbd6wUBcXlAFAzNYJMZyyECAAJ9CHw2nX9cMN5aZh6w0A8MrobOqYkYNXuYxj0/E9gZt0cONLn\nUghRexLoQyA5MRYvXN0HjWIice+IzmjZOFbb9+2mQ6i0VQMADp8qwx/f+RWrcotClVUhRBiQQF8P\nrN1bs67Lmr3HUanMfTPx4w1Ys/c4bnhvTaiyJoQIAxLo64EXr+mDrIzmAICVvxehvMq9bv60DKgS\nQvhJAn098Mdz0/DB7YMAAPuPnzE8Zl/RGTBLnb0QwncS6OuJmEjP/xVj31qJj9fsr6PcCCHCiQT6\nesJsIJXegi2H6yAnQohwI4G+AZGRskIIf0jkaEBiJdALIfwgkaMe+tf/9TFMjzGZ+VIIITyRhUfq\nkYcv7YbkhBgM7ZJquN9bg60QQhjxGjmIaCYRFRDRVpP9NxLRZuVnFRH10+3LI6ItRLSRiLIDmfFw\nNGlkZ1w/KA2N44zvv7HRvgf6YyUVuOzfK5BXVFrb7AkhGigrkeN9AKM97N8LYDgz9wXwLIDpLvtH\nMnN/Zs70L4tnn6S4aOx98TK39A37TuC9FXtQYau23Kf+202HsP3wacxYuTfQ2RRCNBBeq26YeTkR\npXvYv0r3cjWAdgHI11nPqLvlziPFeG7eDjw3bweevKInOrdIxN6iUvxpcLrbsV9kH8DWg6fQJD4G\nANAsPjrYWRZC1FOBrqO/A8AC3WsGsJiIGMA7zOxa2tcQ0QQAEwAgLS0twNkKP99vPoQN+08CgGGg\nf/jLzQCAW87vAABawBdCnH0C1rpHRCPhCPSP6pKHMPNAAGMATCKiYWbvZ+bpzJzJzJmpqcaNkWeb\na88xfzhqpgvcZyrN15aV1aqEEAH56yeivgDeAzCOmY+p6cx8SPm3AMBXAAYF4vPOFi9f29d03087\nC7TtPYXmDa1qDVCVTRYxEeJsVetAT0RpAOYCuJmZd+nSE4goSd0GMAqAYc8dYYyIMONP3tuwr/jP\nShwrqTA7CwBoUx8LIc4+XuvoiWgWgBEAUogoH8BTAKIBgJmnAXgSQDKAt5UGRJvSw6YlgK+UtCgA\nnzLzwiBcQ1jr3baJpePyjpUiOTHWLV0t0VdKiV6Is5aVXjfjvey/E8CdBul7APRzf4fwRWSE8WRn\ncdERTvPWHzxZjnM6uB9ntzu6YVZJiV6Is5a00NVzUSaBvkPzBKfXOw6fNjyuSll3Vkr0Qpy9JNDX\nc2YletdRsv9duttwEJXN7gjwFTY7qu0si5cIcRaSQF/PRbvMb/PHzPa4e1iG4UyWFUqpffG2I1qa\nWpJ/f1UeOv19Pt5eujuIuTVWWmFDwenyOv9cIYSDBPp6Li46EvPuH6q9fmR0Nzx2WQ/EGsxkOWPl\nXqRPnocJH63X0sqrqp2Omb2u7lepuvrtXzDohZ/q/HOFEA4S6BuAXm1qet40inEEeKMBUC8vynFL\nK3MJ9KGYAXPX0ZI6/0whRA0J9A1Et5ZJAIA4pSRfUm4+GlZv+yHnRlrXqiAhRPiT+egbiE/vOg+/\nF5QgQmmcLTIdIOXstMsNweoqVSfPVKJJo2hLa9kKIeo3Kd41EMmJscjKSNZeF1oM9K6slOhPlFai\n/zM/4NI3lmNz/km/PkcIUX9IoG+gii1W3biKiYpAcXmVx2NKKhzn3nW0BFe+9YtfnyOEqD8k0DdQ\nE4d38ut9m/NPoc/Ti9H57/Pxw/ajTvuKSiqwdu9xVNud+9ov21Xodz6FEKEngb6BmjymO7b+81Kf\n36eW1m12xl0fZuOdZbtRqqRd/fYvuO6dX92mS/jTzLW1z7AQImQk0Ddg6vQIrRrHaWmje7XCq3+w\nPsXQiwt24qUFOwEAB46XAagZeCWECA8S6BuwuOhIfH73+fjkrvO0tGk3n4NLerX06Twfrd7n9Np1\nkFWgyPQLQoSGdK9s4AZ1bK5Vx6iSYmv333qmMjiB3s5ApPTWFKLOSYk+DMS59I0nIjSKdp8iwZM+\nTy3Stm8xqJP/bf8Jr+f4/Wixx+PsUqIXIiSkRB8Gogz6xu94djT2HzuDyupqFJVU4vrpqz2eo7jC\nc3fN8e+uxs5nx5ju31tUikteXw4AyHvpcsNjqu0MH+8/QogAkBJ9GEtLjkfnFknIykjG30Z1rdW5\n9IucGBn5ylKv55ACvRChIYE+jHia3uC+C7sE5DMOHD+DCptzHf6kTzZYem+1j5F+T2EJ+jy9CPkn\nzvj0PiGEMwn0YeL7Pw/F8kdGejzmpqy0Wn0GM+OCKT9jzBsrnNLnbTls6f2uA7G8+WzdARSX2/Dd\nJmvnF0IY8xroiWgmERUQ0VaT/UREbxJRLhFtJqKBun2jiShH2Tc5kBkXznq3bYKWuv70Rp67qo/f\n509PjkelMpBqT1GpX+fwtXulOqGaNOIKUTtWSvTvAxjtYf8YAF2UnwkA/gsARBQJYKqyvyeA8UTU\nszaZFaFTVc2GA6m8zZujp5boT52pwsGTZSgoLseJ0krT49VVFKX/vRC147XXDTMvJ6J0D4eMA/Ah\nO/4aVxNRUyJqDSAdQC4z7wEAIpqtHLu91rkWAXFuejOsy/PebRIAqqrtqNA1yDIziAgPfr7J8uep\nNTcjX12K47oAb9ZLJ0Ir0Vv+CCGEgUDU0bcFcED3Ol9JM0s3REQTiCibiLILC2USrWD5x2U9tO0Z\nt54LAMhITfD6PpudnRphOz42Hy8t2Ikt+acMj2dm2F0itFoFc9xDKV5PLdG7Vt3Y7ayN3p21dj8+\nXr0Px0oq8ODnG1EWpMFewpojp8rd/t9F6AUi0BuNdWQP6YaYeTozZzJzZmpqagCyJYzcNSwDO54Z\njfWPX4zGcdHIe+lyvH/rIK/vO15aibV7jzulTVu2G0dMFv2+ecZaZPx9vlNatZ3dJkzzhHQl+rV7\njyPnSDEA4Nl529H9iYWwVdvx2NwtePzrrXh5UQ7mbjiIr347aPn8wru9RaWWb56HTpYh68Wf8O+f\nfg9yroSvAhHo8wG0171uB+CQh3QRYo1iIpGcGKu9jo6yNi+BL9U0K3OL3NLKq6rR5R8LLJ9Drbph\nZlz3zq+49A3HgKxZax0LnFdVu5cbarsg1tSfc3HH++tqd5IwMvKVpbjzQ2vfh/qk5jr9tQi9QAT6\nbwHcovS+yQJwipkPA1gHoAsRdSSiGADXK8eKesZswfBLevo2OZrqxQU7DNMf/nKzT+dRs+VadUPw\nvTdOdt5xjw2/qpcX5eCnnQWWzllWWY1vNwWu7JI+eR5u/V/9mRJabQT/JfeYpePjlGHP5TapPqtv\nrHSvnAXgVwDdiCifiO4goolENFE5ZD6APQByAbwL4F4AYGYbgPsALAKwA8DnzLwtCNcgaikqwvjX\nQF2QPMLHUvI7y/Zo2/r62vX7jBt+9x0z7q5JXhpjF28/om17ivl2O+Paab/i5plrzA/yw/Pzt+P+\nWb9hXd5x7wdbtDQnsO1TuwtLcNSkes0bX6va1Wmzy6WdpN7xGuiZeTwzt2bmaGZux8wzmHkaM09T\n9jMzT2LmTszch5mzde+dz8xdlX3PB/NChP+axEfjrRsG4OFLuzmls9KkYmWdWTOlld6XPBz+8lKc\nLq/CiJd/xqYDJ7X0CJN+9Gr1zF8/c69KMronqXnYebjYYz58aT8AgILTjnV7rTYu1wVmxte/HUSl\n0hX2oleX4bwXfjI8tqC4HKt2u1exqXwd4KaOfC4L0jTXwn8yMlYAAK7o2wZ3D8vAY2O6a2kXdm8B\nALigi+fG8YFpTU339Xl6saXP7/v0YuQdO4OXF+VoaVqvm1r24lCnXY7zMKOardruU/sB4Fh/F4AW\nVOuDH7YfxV8+24inv9uGN700il7z9irc8K75U46vA9XUqh5v8yL5ipmxJf8UbD7eiEUNCfRCExUZ\ngbt1a9Ge06E58l66HH3bNfH4vsaNogOWh71Fpdiw/wRGvb5M62Wj//t+ZVGOSXcu86CkztdvNhdQ\ncXkVOvsY5IGatg2zQP/r7mO4btqvWoCasz4fA55ZbFhS9rX0bOZkmWMA26dr9uO1H3Z5PDb/RJnH\n/b6OU1MvIdAl+o9X78PYt1Zabjup7yZ8mI1LXltWp58pgV54lZnezHTf+EFpfjfaGjl4sgzXvL0K\nu46WYK7SVVJfpfLWz7la3b1V6pq4ZiV6f6tetBK9SUnzwc83Ym3ecRwtdlTxPPHNVpw4U4UzBtVZ\nrhPF+c2P+4XZyGNfJ6EL1lQVa5RuvcFa+ayuLd5+FL8XlABwPK3eMnMtVv5uXoUWCBLohZs7hnbE\n5X1ba68Hd0rB7AlZyH78Ytyc1cHp2Bev6YOkuMCV6I24lhB97UGproV78KRxCZZ8PqODvuqmvKoa\nn63b7xQ01VJ6pHJjUtscbAbdQvV91T/8NQ8Hjvs2Y+eX6/Ox/dBp3y7AJZ+ufA3cgXoqqY38E2ca\n1ICt4nIblu8qxL2frA/q50igF26euKInpt4w0CktKyMZKYmxaNXEfeI01xWuAs1tmmKDuPx5dr62\nPenTDXhn2W7t9c85NY/86qCvVbuLsHyXo4eLp2ofT/RVN6//uAuPztmCxbo+5Hate6KjtKa2OahP\nANsPnUb65HlY8Xuh083syW+24cb3fOsh9LcvNuGyN1d4P9CAWcmdDR5Udhw+bTqAKlhTEll9gtt/\n7AyG/utnvLnE0Tax62gxcpWSc33l7++eryTQC5+o9d392jfFor8MA+C5kdOb24d09HrM6j3O3Rc9\n/dnnFpRg3ubDeHHBTpRW2FBSYXOajG368t1InzwPN7y7Rlsy0agkaqVUGKksgFtlt2s9cO7+aL1W\n1aSe4qEvHL2DIpRIr84ZlL3PcV2Lth1xa8A8eca8OumEwShl1SNzfBurAJiXxF1vAGWV1Rjz7xX4\n86zfDI8P9iyjRvk8eaYSq5QbaWGJoxvpMuUGPur15bg4QHXhdjtjyc6jQZtgz9fqSF9JoBc+6d7K\n0bf+hat7o5uy7WnBE2+aBLAhFwDeW7lX2+73z8Xo/dQirY4eAH7c4d6gZzMIIDY744HZvyF98jzT\nz1KrfJidg9z6fSdQVW13C0xq1Y1aHx+pBP5qO/vUc+eWmWtx3Tu/BqyqxGrVjZrHtXuNB1AFq+pG\nDYFGp7/jg2zc8N4alFVWa92AjarGauuTtftx+/vZ+GZjYAf319XErLJmrPDJlf3a4MLuLZzq5RNi\n/f81ahrve6C3+rehBvBjJRUejzMKstV29vhHPfRfS7ReK5+tO4D9ujr166evxsC0pm5PBTWB3vF5\nUbpAb7N7D/R5RaVonhiDLQcdE8mVV1WjuNyGrBeN+8lbZfbRvpbQrcb5vUWlaJEUa/n3Ri3suubn\nzg/WaYPwbHa7Fuh9HQ9hxdFTjqeF/T62nXijXlOQC/RSohe+ISK3xtcWjWNNjq5x25B0TBiW4ZZu\npUSf1jze6XVxufdBWHqbTGbYVF3xn5VuafrA+9GveW779V0Tjf74N+w/6RaY1Dp6tR4/UhmRbLOz\n21PF6XKb29PEiFeWoq9uXEJZVTU27Lc2zfSeQvO6arM6etcbgHo91XbG20tz3erq9dUat3uYL2jk\nK0u1ajNfuN449U9ndq4Jlma9oHyxfFchth6s+b0J1k1EvaQgx3kJ9KL2UhK8B/oxvVtjZLcWbulN\nLJTog1n3a1o/rUt/4ptt+DmnAMyMx7/eYjm46gPoD9uPaiV6dSCTOuC42s645u1Vhuew2xmz1u7H\nl+vz3faVVVZb7pZ54avmddX6az1dXqX1+3f93tXrKa2sxpSFOXjjR+d++vqvcsnOAizfVag1RLsy\nmw7DiKeqGy1vdtaqbAJRdXPLzLVOBQBvg+P2HSt1qiK0qq5WT5NAL2otIoLQv31T0/15L12OQR2b\nIyujuTZ/jirJ5PE9UZce5WGyndq0DwBAicnTgWsJ+4vsA8g/UYaPV+/Hg59ttHRufYn4l9wip3Nu\nOnBSmw3UU9324u1H8NjcLfjbF+7TPZRXVTstBuOv137YhW2HHKXXvk8vxu0fZOPo6XIMfmmJ03Gu\n+SxRGruvmvoLfj9a7Lb/lplrtd5DP2w/im6PLzAcQ2CVp379NrtdC5rBqLrxNmZi+MtL/ZqQrqbq\nRhpjRQPw9aQhGNI52eMxRITZE7IwrGsqmikl+RiTQP3FxPO17em3ZJqe09NNwIrjJr1bPvx1n9Nr\nux24YMrPAIDWTRpZOre+tPb+qjwU6doKbpm5VmuI8xToJ368wXRfWVW14fKOvpq1dj+umvqLVvWy\nfFch5m12XpD9vRV73PLJcNzANh44iZcX5Zj2SGFmTFm4ExU2O/KKfK/jJt101WaqddVfh0+Va73D\nAsVTiV59ArK6Wptqx+HTdTb2QAK9CJgp1/bDTVlp2Pms+RLDzRJi8OHtg3DjeY6BV0b98gGgQ3JN\nvXxXl6cAvdr+mRwvNW6odZ0nZuG2mpkyzfLsyqwE2ik1AafKatba9fePvayyOmDz7LiuCexapfDc\nvB2GbSN5ykLxdjavWpmxcq/2Xej7jZdXVeOL7ANInzzPdAZTvWo749r/rsK4twzaVKqdVzS79r/G\nVWGevLY4Bwu3HnFKs9sd541RutIafd+lfs7WOebfK7RutVJHLxqMtk0b4bmr+ljqV//gJV2x7h8X\no0VSHB4Z3Q0tkmK1SdQAR+NX15aJXs9T2/rY46XWFzdXWV3FyqwAql/0BahFoK+q9nlaAHWCsL9/\ntcVt30sLdmrbJwyedDYecC6xEoAXlfdUVdux4nfjKZYd7RuObf10E3d8sE5bo8Cs3eP3o8Xa9z1l\nYQ6y953ApvxTboub2Ozs9D3uPOJ5plIjby7JxcSPnUeoXvz6MvR+epGWf9dqoSOnytHvnzUN5M9+\nv92naaHVJ49g97qR7pUiqH574hLD9IgIQmqSI+DdO6Iz7h3RGQC0niZREYS59w7RSr7RkWS4olRt\ne1jM33LY6zFNGkU7lcBry7VUuMLPeU5KKmz4TTetsxXfbz5sOuDp/VV52vbUn3e77X90jvvNQbVs\nV6E2UMlVpa1mTMHNM2rqsfULmkSYRLpLXl+ubetHDy/Z6Rzoq+3u4xYCYU+h40lDrRZyrUv/dY/z\n/92MlXuxu7AE79/mfXlOALVqs/CFlOhFUPRr3xTPjOuFZgkxfr2fiJAYG4W2TRtpr4PBW+n80l4t\n0U9paDZrT/CVa6D392b1wrwdWOLjjI45fpR0zVgNq5XVHPDeJa6/DzY7e2ysPVNpw8hXliLbwiIx\n05a53+TMbiJGNyhfbjhnKtSblzTGigbom0lDcMv56QE7X13MRZ6V0dwtrazKDqV61ueVtswEop83\nABw65fvKUYF8MrEqOoK8jgA1K9GbcT3aVs1O3RvjY5yrD7cePI29RaVO1VNmjI4xGj1txZo9xzyO\nYbCyME8gSKAXDYL6dzauf5ugfcZL1/R1SztWUoFspc93bRfUeOiSrhjaOSVkC5VU2uz4aPU+7wcG\nWGJclNe2BDXQz9t8GG8vzfV6TtcbwxX/WenUQ0n/9LVk51FtUJuvNxSVWtAgOG6WO4+cNj0f6a7l\nj9NXexzDoC6KIyNjhdB5/uo+Tq/NRtZe1N15cNbbNw7Es+N6mZ73u/uGor3LCFwA6Nm6sel7jJ4A\nPGnXvBHaNI0LWaD/ZqO1RmSrrMamM5XVOOZlzn/1aWnSpxswZWEOTpd7fvLwFhgjdQfc/n62Ng7B\n34Cq/p/N/e0gRr+xHKPfWAFmNjwfwTEz6aRPzbvGqs7U0fq6EuhFvWJWPaIubqIfSJX30uXolJpg\neHwj3aP7koeG47I+rXGzh6qkuOgIREaQ0yN/alIsnr2qt1s1gGpUz1am5zMSFRGBmKiIgFXd+Ork\nmcBW2xR5mUNIZWXEqGud+w4vc+t7C5BmNxYiR8+ZDftP4JuNBy33WnpVt1rXYaXKrMJmN13L4NXF\nOU6v/6H0cnKtvz+j9rqxlAv/WQr0RDSaiHKIKJeIJhvsf5iINio/W4momoiaK/vyiGiLsi/b/exC\n1Fg1+SLMv/8Ct/S3bhiA9Y9f7Jaurzu9d0Qn3JSVBsAxdfIdQx1TIOu7M06/+Ry0a+Y+4Ck2yhHM\nP70rS0v78PZBiIuO1Pa5ior07c8zOpKQGBsdssXEA10fvGjbUe8HwVqp1fUGn1tYYtqLB4DhlBBW\nRBBhysKduObtVXhg9ka8MH+HX+cBHHMuGZXoq6rtbssefrJmPwCg09/nO6Xr++BnPvcDHv/avGdT\nbXgN9EQUCWAqgDEAegIYT0Q99ccw88vM3J+Z+wN4DMAyZtY3b49U9psPcRQCjsFIPdu4V5fERkW6\n9T8HnPvRPzK6O/q3dyx7aGfG45f3QO7zY5yqd0b1amW4Bm5stONPQV1MBIA2XYPZ435kBCHBpLRv\nJDoyAh1T3KuHrPA0xYRVdVVN4MpKiX7KohzY7axN3VxwusKvKQW8iSByGvV8RCmd+7MqVXF5lWEP\nmxMmT05GVTn67pVFJZX4ePV+n/NhhZUS/SAAucy8h5krAcwGMM7D8eMBzApE5oTwxnV6X7VkyOyo\nDoiKNP8Vf+uGAdp2nFJqVxvxmsZHawuFmPUYiYogvHWjYyUusyokp+MjI9C+mX+BvrZz+gDA9OV7\n/H7vNQPaolVjayOCXVm5weQWlOB/q/K0wPnZugNBmau9qtruNAJ48fajGPqvJT6vjws4SvRG00uf\nMHlic51WAgBKK+pPY2xbAAd0r/OVNDdEFA9gNIA5umQGsJiI1hPRBLMPIaIJRJRNRNmFheaPbEJc\nM7CtVi2jVt18frdjbhy1F4SVvsz6+lXXEr2+hGe23BtRzRnaN4/H+RnJ6NLCfDQvAYjz8ATQu615\nw2+sD6t49QtA6V/v7mEZeOaq3tp35KpPW/cnJD2rVUbPfr9d2z7iw+hSX6wxWJkr/0SZX4OtNuWf\nxM7D7uMSjEYVmymrqj/dK43uNWbfylgAv7hU2wxh5oFwVP1MIqJhRm9k5unMnMnMmampqRayJc5W\nr13XH09c4ag9VAOyOspWrTf3tYSklpjVEr3+F1wttF3QJQVz7hmsTc2gn/mSGZg1IQs/PDjc9DPK\nqqo9lswHpjVzen3fyM7a9g2D0ixdBwAkxvq/tKOREd1aIDE2yjTvSXGeB9jX1SpKteFPoH/ym214\nx+ApyZeJ5krUEn09GDCVD6C97nU7AGZL71wPl2obZj6k/FsA4Cs4qoKECAi1J40a6Ef1bIWbszpo\nNwIjRoFH7fWhBXqDY/5+WQ+c06EZLuvTGoBj3pYB7ZshKTYKk3RB2cyZSpthw65a3aSeV3XDeWla\n3XxqUix+fNCwjKS58bw0jB/UHsO7BragpObPrFHa0+jnGA9VZ/WJvwOiakvtdePPXPa+sPK/sA5A\nFyLqSEQxcATzb10PIqImAIYD+EaXlkBESeo2gFEAtgYi40IAwB1DOyLvpcu1bpcxURF49qreSDFo\nuLUiWnkiMBqyr64ylKwEtuLyKjSJj8aWf16KQR1r+tT/MbO90/vUBsZz0poblorVSeCiIyPQWjcz\nZlQkae+1MyM10XMdeXpyAl68pi/uGJrh8UbniVHpXC3txplU3Vzcw31BGVW75tamdA41fxpj9UZ0\n8+/mqrZfFIc60DOzDcB9ABYB2AHgc2beRkQTiWii7tCrASxmZv18oy0BrCSiTQDWApjHzAsDl30h\nfHeeEpQ7JMfjusx2TvvUEr0+cKtBX70JXHtOe/zhnHa478Iuhud//IoeTq+zMpoj76XLkZYcb1jP\nrfbTj4mMcOohFBsZiWbxjptKBBGaxEfj5qwOptelNihGRhBuzurg08Lr024aiM1Pj8KPStWT/kap\nBiOzuX7OTW+O2ROyDPd1SvU+A2l9UNsSfcsk/xqqtx92Hi/wl4uNf6dqy9Lslcw8H8B8l7RpLq/f\nB/C+S9oeAP1qlUMhAuxPg9NxUY+WaN88HlOu7Ycp19b8isZGRWLhXy5wWqdWLdyrPXgaxUTi5T+Y\n/1rHx5j/WRlVf2gl+ihCY11wToqLwpRr++KL7AMYmNYUgOMmsreoFCsNlujT1zPHREVg01OjcM3b\nv2DD/pOm+VGN7u1cbdS7bWMszXF0ijijDCoyq7pp27SR6aRznVIT8QOs9bc3MrZfG3y3yXyR9kAg\nAu5xmZ7YV1cNaIvPsg94P9CL2i6kY6ZhVKAJEUBEZDjdgap7q8ZOwVrtdWP1bzDSw4H66o9bB6fj\n8j6tnebv76UbQxARQWieEIO7h3fSAmlsVCQ+vvM8jO1XM+fPpJGdABhXP/haUG0cF40595yPt24Y\nqKVVK63R94zo5Hb8Py7r4ejGanLN/dt77pHjyTPjeuGla/q4pb/nYcUxfzBDm8/IX+d3SnZbcCfR\nZJnMGX/KxM5nR7steg/ULBgfaBLohfBCLdH70jPip4eGY9pNA93S9Y2TT1/ZC1NvHIhGSqAvr7Lj\nsTE93N5jRD9vjzqvi1FfcH0/74yUBOx6bozXc5/ToblTkLq8j+Omcm56c6x+7CKnY+8aluHIg0Gg\nv2dEJ62R3B8pibFIMAiW+tXH6gP1/9R1wZ3M9GZGhyM2KhJx0ZGGVWFSohciRLIyHGvhNvJhFGyn\n1ETDIKWWzBvpgsJfL3HUy3ZMSbA8533T+JqeLurALqMS/dHTNfPRlFVV+zWnvv49ZssoGgWoR0d3\nd8qnr8wafz2tYPa3UV39/jx/mXUvTTXpEKBeV7RBjyRPT4O1IStMCeHFlGv74r4LO/vUuAk4+sV3\nbZmIR0d3d0qfdtNA9NDNinlh95bIe+ly7fWsu7LclqzzxFOJXj+vjrpC05KHhjtNnRuIUZlmAapZ\nLQK9UZvAU2N7mg7cAuDUxuGPfu2botJmx7GSChQUV2Bkt1T8nON5AGeiSaA3+31Rr2tvkfs89b7O\nn2SVlOiF8CIuOtLjAuVmEmKjsPivw9G3XVOn9NG9W6NDsvmUCed3SsYwH/rCn9PBUUXgOuAKAD69\n8zytp05zpVtoRmoiMlIcn//U2J5Y+rcRlj8LgLbql34q6CiTumWrN8cWSbHaaGeVa4k+IyUBtw3p\n6PQ09P2fhzodo58fPt2PKp5LerTAggcuwN8u7QYAaJ7gveop0uROqT4B3jm0I6bq2jzU6zJa30BK\n9EIIQ4M7p2DDE5dogVzvvIxknJeRjIEdmmJQx2QtXe1OOKJbC9ObzsThnTBA6e2jN+eewdicfxIX\n92ippcVGRYAIGNwp2Wkt2MgIQqvGcdqUBt1bJRku3D3nnsGIiiRsPHASuQUlOFVW5Vain6fMaqqv\nuundtgn+b2A7zNngmM1SH+h/eHA4np+3A8yMD3QTmXlSVOJ4AmKtq6qltxlS8xkVGeHUrmDWewmQ\nOnohhAu1ZA3AMMjrXT2gndPx6pw8CR6mS5g8pjsu7eU+536rJnEY1auV1jYAONoJdjwzGu/ekomn\nxvbEp3edp+2bc+9grfHYqI1g9oQstG8ej9ZNGmHOPYPRLN7xFOBaF6+WkF3rtl+9rh/+cI5jPIQ+\nUEZHRuDpK3uhTVPrg7ZuUp5+1OYO/Y3j3VsyMeVa91XI9AX6WwenAwAu7N5CGx9RXlXttKaBWdsD\nELxeN1KiF6KB+umh4X7N0QIAb1zfH5vzT6GFnwN9jKiB+bYhzlUwbZs2Qh+X6itV84QYrbFbVe0y\nQM0K9XuIMCgRqwO+xvZrgyqbHQ+P7oaLlDaKyWO6a2vEntOhGTorN0B1kJx+fMCg9OaIjY7AI19u\ndjq//pinxvbE01c6bmqz1zqmHC6rrHZqOwhFiV4CvRANlKfeJ94kxUVjSOeUAObGGn0YW//4xYYl\nfLVHqLf1XUfqph2o9lDVUm5zBPrurZK0OYn+d+u5iIggbToLwHnai5oSfc15oiLJ63TR+qCfplTX\npCXHOzVKe2pMljp6IURYmHvvYJworTRcSAYA7rygI/753XYkJ5pXR+U8N9qpAVgr0RvcHMqVEr3+\nxjhS15D85vgBuH/Wb05BX62jjyACkWMsRVQkGY4ANgvNgzulYPaELGR2aOYUwD3dLKREL4RosFhX\nWjbqHaR325CObtU/rlyrP9TSuFHvH3XaYLO68bF9W6O4vAqX62YPtdtrRkM/OtpRvRPtcu5vJg3B\nuKm/eOye6lotBbivj6snJXohRIOlBmBfBp35olG0I5TFRUfgzfEDsE63wIi6AHgjk6ouIsKN5zlP\nFqdW3RARJg7vhInD3ad/0NY+sDhi+rv7hnps/AbgU8OxLyTQCyGCrnfbxnjwkq7447ntvR/shyfH\n9kTHlHiM7NYCERGEK3VzAV09sB2+3ngI56Y393AGZ3Y2rwoCgIzUBK30bXXAWR+DtYoBR9VRRVU1\nrujbJmg3Qgn0QoigIyLcf5F/U/DePqQjGjfyHKqaNIo2nTZ6eNdUp5HHVrBBY6xq+zOXIjKCsLeo\n1H2nD27KSsPHq/c73ZSCRQK9EKJee3Ksf4uo1IZWojeI9OrMpgbrgvvkuav64Lmr3GfnDAYZMCWE\nEC4u7ukY9Tuuv3lpW626iQ9SdUsgSYleCCFcdEpN9Frd07VlIh64qEvQ2h0CSQK9EEL4gYjw10vq\nflpkf0jVjRBChDlLgZ6IRhNRDhHlEtFkg/0jiOgUEW1Ufp60+l4hhBDB5bXqhogiAUwFcAmAfADr\niOhbZt7ucugKZr7Cz/cKIYQIEisl+kEAcpl5DzNXApgNYJzF89fmvUIIIQLASqBvC+CA7nW+kubq\nfCLaREQLiEhdudjqe0FEE4gom4iyCws9L90lhBDCOiuB3miAr+sk2BsAdGDmfgD+A+BrH97rSGSe\nzsyZzJyZmmp9GTUhhBCeWQn0+QD0HUXbATikP4CZTzNzibI9H0A0EaVYea8QQojgshLo1wHoQkQd\niSgGwPUAvtUfQEStSJl7k4gGKec9ZuW9QgghgstrrxtmthHRfQAWAYgEMJOZtxHRRGX/NADXAriH\niGwAygBcz44JqA3f6+0z169fX0RE1lbzdZcCoMjP9zZUcs1nB7nm8Feb6+1gtoP0CwKEAyLKZubM\nUOejLsk1nx3kmsNfsK5XRsYKIUSYk0AvhBBhLhwD/fRQZyAE5JrPDnLN4S8o1xt2dfRCCCGchWOJ\nXgghhI4EeiGECHNhE+jDdTpkImpPRD8T0Q4i2kZEDyjpzYnoByL6Xfm3me49jynfQw4RXRq63NcO\nEUUS0W9E9L3yOqyvmYiaEtGXRLRT+f8+/yy45r8qv9dbiWgWEcWF2zUT0UwiKiCirbo0n6+RiM4h\noi3KvjfVQaqWMHOD/4FjMNZuABkAYgBsAtAz1PkK0LW1BjBQ2U4CsAtATwBTAExW0icD+Jey3VO5\n/lgAHZXvJTLU1+HntT8I4FMA3yuvw/qaAXwA4E5lOwZA03C+ZjgmONwLoJHy+nMAt4bbNQMYBmAg\ngK26NJ+vEcBaAOfDMYfYAgBjrOYhXEr0YTsdMjMfZuYNynYxgB1w/IGMgyMwQPn3KmV7HIDZzFzB\nzHsB5MLx/TQoRNQOwOUA3tMlh+01E1FjOALCDABg5kpmPokwvmZFFIBGRBQFIB6OubDC6pqZeTmA\n4y7JPl0jEbUG0JiZf2VH1P9Q9x6vwiXQW54OuSEjonQAAwCsAdCSmQ8DjpsBgBbKYeHyXbwB4BEA\ndl1aOF9zBoBCAP9TqqveI6IEhPE1M/NBAK8A2A/gMIBTzLwYYXzNOr5eY1tl2zXdknAJ9JanQ26o\niCgRwBwAf2Hm054ONUhrUN8FEV0BoICZ11t9i0Fag7pmOEq2AwH8l5kHACiF45HeTIO/ZqVeehwc\nVRRtACQQ0U2e3mKQ1qCu2QKza6zVtYdLoA/r6ZCJKBqOIP8JM89Vko8qj3NQ/i1Q0sPhuxgC4Eoi\nyoOjGu5CIvoY4X3N+QDymXmN8vpLOAJ/OF/zxQD2MnMhM1cBmAtgMML7mlW+XmO+su2abkm4BPqw\nnQ5ZaVmfAWAHM7+m2/UtgD8p238C8I0u/XoiiiWijgC6wNGI02Aw82PM3I6Z0+H4v1zCzDchvK/5\nCIADRNRNSboIwHaE8TXDUWWTRUTxyu/5RXC0QYXzNat8ukaleqeYiLKU7+oW3Xu8C3WLdABbti+D\no0fKbgD/CHV+AnhdQ+F4RNsMYKPycxmAZAA/Afhd+be57j3/UL6HHPjQMl8ffwCMQE2vm7C+ZgD9\nAWQr/9dfA2h2FlzzPwHsBLAVwEdw9DYJq2sGMAuONogqOErmd/hzjQAyle9pN4C3oMxsYOVHpkAQ\nQogwFy5VN0IIIUxIoBdCiDAngV4IIcKcBHohhAhzEuiFECLMSaAXQogwJ4FeCCHC3P8Drf4ykaTI\nUXkAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "losses = []\n", "opt = SGD(0.1)\n", "opt_state = opt.init(task.init(key))\n", "\n", "num_steps = int(os.environ.get(\"LOPT_TRAIN_LENGTH\", 1000))\n", "for i in range(num_steps):\n", " batch = next(data_iterator)\n", " loss, grads = value_grad_fn(opt_state[0], batch)\n", " opt_state = opt.update(opt_state, grads)\n", " losses.append(loss)\n", "plt.plot(losses)" ] }, { "cell_type": "markdown", "metadata": { "id": "uCOBWVkeq2qD" }, "source": [ "Now, let's define some other optimizers. Momentum makes use of an additional accumulator variable. We can define it as follows." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "executionInfo": { "elapsed": 52, "status": "ok", "timestamp": 1647716628264, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "h70Uo7TB89zk" }, "outputs": [], "source": [ "class Momentum:\n", "\n", " def __init__(self, lr, decay=0.9):\n", " self.lr = lr\n", " self.decay = decay\n", "\n", " def init(self, params):\n", " return (params, [jnp.zeros_like(p) for p in params])\n", "\n", " def update(self, state, grads):\n", " params, momentum = state\n", " momentum = [m * self.decay + self.lr * g for m, g in zip(momentum, grads)]\n", " params = [p - m for p, m in zip(params, momentum)]\n", " return (params, momentum)" ] }, { "cell_type": "markdown", "metadata": { "id": "UuLD7NdxrHC_" }, "source": [ "We can use this in our same training loop again. Here, the parameters are stored in the 0th entry of opt_state." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "height": 282 }, "executionInfo": { "elapsed": 1342, "status": "ok", "timestamp": 1647716629720, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "-suAcnqC9QpH", "outputId": "892958ae-bf4f-4083-800e-a32a7d82f2b8" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAAA2zUlEQVR4nO3dd3zU9f3A8df7MkgChJWwN4KAyjIiiIiICjiKo7XYVpxFrf5a\nd9FWrau1tY46qqWuqlWKFUdFEVy4QAlL9gYJK2EFEiDrPr8/7r6X712+t5K7XHK8n49HHtx9x93n\neyTv+3w/4/0RYwxKKaWSlyvRBVBKKRVfGuiVUirJaaBXSqkkp4FeKaWSnAZ6pZRKcqmJLoCTnJwc\n071790QXQymlGo2FCxfuNsbkOu1rkIG+e/fu5OfnJ7oYSinVaIjIlmD7tOlGKaWSnAZ6pZRKchro\nlVIqyWmgV0qpJKeBXimlkpwGeqWUSnIa6JVSKsk1yHH0tVFe6eadxdsoKatkYJcWDO7SCpdLEl0s\npZRKuKQJ9Kku4cGZKzlwpBKALq0zGdCpJZcN78awnm0SXDqllEqcpAn0Lpfw0c2nsa+0gpU7DvD6\nt1uYuWwHM5ftoFPLTK47vRcTT+pCWoq2Vimlji7SEFeYysvLM7FIgbBlTymvztvC/77fzq4DZQA8\n87MhnHpMDi2y0ur8+kop1VCIyEJjTJ7jvnCBXkS6AK8A7QE3MNUY87eAY34O/Nb7tAS43hiz1Ltv\nM3AQqAIqgxXELlaB3uJ2G372/Hzmb9wLQGZaCovuPovM9JSYvYdSSiVSqEAfSTtGJXCrMaYfMAy4\nQUT6BxyzCRhljBkAPABMDdg/2hgzKJIgHw8ulzBt8nCuGtEDgMMVVfS/dxZ/eG8F7y7ZlogiKaVU\nvQnbRm+M2QHs8D4+KCKrgE7AStsx39hOmQ90jnE5Y+L35/bj2lE9mb9xD//8ciMvf7OZl7+Bk7q3\npmPLzEQXTyml4iKqnkkR6Q4MBr4NcdjVwIe25waYLSILRWRyiNeeLCL5IpJfVFQUTbEi5nIJ7bIz\nmDCoE+/8agQje+cAMP5vX/LJql1xeU+llEq0iAO9iDQD3gJuMsYcCHLMaDyB/re2zSOMMUOA8Xia\nfU5zOtcYM9UYk2eMycvNdcydH1OpKS5evfpkpozvS/HhCq7+Vz57S8vj/r5KKVXfIgr0IpKGJ8j/\n2xgzI8gxA4DngQnGmD3WdmPMdu+/hcDbwNC6FjqWrhvVi8cuGQjAkAfmsG3/4QSXSCmlYitsoBcR\nAV4AVhljHgtyTFdgBnCZMWatbXtTEWluPQbOBpbHouCxdMGgTr5mnMdmr6X4cEWCS6SUUrETyfDK\nU4EvgWV4hlcC3AV0BTDGPCcizwMXA9ZSVpXGmDwR6YmnFg+ejt/XjTEPhStUrIdXRsIYQ487P/A9\nn3PzafRu17xey6CUUrUVanhlJKNuvgJCJo0xxlwDXOOwfSMwMMJyJpTnxqXa4q37NdArpZKC5gOw\nWXn/WN/jO/77Pa/O25y4wiilVIxooLfJSk/lD+f3p0WmJz3C3e+uSHCJlFKq7jTQB7hiRA/eun64\n7/lr87eEOFoppRo+DfQOjmnbnL7tPe3zv39nOQeP6CgcpVTjpYE+iIkndfE9Xrq1OIElUUqputFA\nH8QVI3qw9J6zEYGFW/YlujhKKVVrGuhDaJGVRnZGGo9/vJY9JWWJLo5SStWKBvowfn5yVwB+9PTX\nXDp1Pg1xoRallApFA30Yd4zry5n92rJt/2HmbdxD4UGt2SulGhcN9BE4UuH2Pd60uzSBJVFKqehp\noI+AfcnBzRrolVKNjAb6CDx0wfH87px+pKUIm/ZooFdKNS4a6CPQNjuDX57Wk1ZZ6fxj7kaKD+kE\nKqVU46GBPgpjj2sPwBxddlAp1YhooI/CfT86DoDb3lxKSVllgkujlFKR0UAfBZerOmf9mp0HE1gS\npZSKnAb6KP3rKs+Stxc/+w2Hy6sSXBqllApPA32UOrXM8D3euu9QAkuilFKR0UAfpZxmTXyPn/h4\nbYgjlVKqYQgb6EWki4h8JiKrRGSFiPzG4RgRkSdFZL2IfC8iQ2z7xonIGu++KbG+gPrWMiudFy73\nrL9beEDTISilGr5IavSVwK3GmH7AMOAGEekfcMx4oLf3ZzLwLICIpADPePf3By51OLfRGdOvHecN\n6MBuzWiplGoEwgZ6Y8wOY8wi7+ODwCqgU8BhE4BXjMd8oKWIdACGAuuNMRuNMeXANO+xjV6X1lls\n23+YUh1mqZRq4KJqoxeR7sBg4NuAXZ2ArbbnBd5twbY7vfZkEckXkfyioqJoipUQpx6TQ0WV4X9L\ntye6KEopFVLEgV5EmgFvATcZYw4E7nY4xYTYXnOjMVONMXnGmLzc3NxIi5Uwvds2A2DKjGXsKy1P\ncGmUUiq4iAK9iKThCfL/NsbMcDikAOhie94Z2B5ie6OX27x69M2Sgv2JK4hSSoURyagbAV4AVhlj\nHgty2HvAJO/om2FAsTFmB7AA6C0iPUQkHZjoPbbRExGevHQwABsKSxJcGqWUCi6SGv0I4DLgDBFZ\n4v05R0SuE5HrvMd8AGwE1gP/BH4FYIypBG4EPsLTiTvdGLMi1heRKD8a2JGcZums26WBXinVcKWG\nO8AY8xXObe32YwxwQ5B9H+D5IkhKvds257M1hZRXuklP1flnSqmGRyNTHY09rh2FB8vYtv9wooui\nlFKONNDXUY9cz+ibPTp5SinVQGmgr6M2TdMBeHS25r1RSjVMGujrqGPLTADmbdzD0q37dUy9UqrB\n0UBfR62bptMjpykAE575mp8/HzhpWCmlEksDfQzsO1Rdi1+5I3DSsFJKJZYG+hg4q1+7RBdBKaWC\n0kAfAw9eeDxjj6sO9lv2lOJ2O6b0UUqpeqeBPgaapKbw4AUn+J6PeuRzXvx6UwJLpJRS1TTQx4g9\nyRnAsm3FCSqJUkr500AfJ62y0hNdBKWUAjTQx40GeqVUQ6GBPobeuWGE7/HjH6+lRJcZVEo1ABro\nY2hQl5Z8+JuRvucfr9yVwNIopZSHBvoY69ch2/dYQiZ3Vkqp+qGBPo5Ky6oSXQSllNJAH0+l2kav\nlGoANNDHkXbGKqUaAg30cfDdXWMA+Gr97gSXRCmlIgj0IvKiiBSKyPIg+2+3LRq+XESqRKS1d99m\nEVnm3Zcf68I3VG2zMwBYuGUfX64rSnBplFJHu0hq9C8D44LtNMY8YowZZIwZBNwJzDXG7LUdMtq7\nP69OJW1k7hzfF4CrXl6AZ+10pZRKjLCB3hjzBbA33HFelwJv1KlESeLaUb0AqKgy7DxwJMGlUUod\nzWLWRi8iWXhq/m/ZNhtgtogsFJHJYc6fLCL5IpJfVJQczR1XnNIdgEPlOsxSKZU4seyMPR/4OqDZ\nZoQxZggwHrhBRE4LdrIxZqoxJs8Yk5ebmxvDYiXOsJ5tABjz6Fze/367NuEopRIiloF+IgHNNsaY\n7d5/C4G3gaExfL8GLyOt+uO98fXFvL14WwJLo5Q6WsUk0ItIC2AU8K5tW1MRaW49Bs4GHEfuJKuM\ntBS/57tLyhJUEqXU0Sw13AEi8gZwOpAjIgXAvUAagDHmOe9hFwKzjTGltlPbAW+LJ+FLKvC6MWZW\n7Ire8AUG+hSXTltQStW/sIHeGHNpBMe8jGcYpn3bRmBgbQuWDOxNNwCpLs1yppSqf1rFjCNXQPrK\nwxU6+kYpVf800MdRYA2+5IjmvlFK1T8N9HHUM7cZf/3JQH4xrCugSc6UUomhgT7OfnxiZx684ARy\nmqXznwVbE10cpdRRKGxnrIqN3SXlANz3vxW4RLj7vP4JLpFS6mihNfp69tLXm3nhq02UV7oTXRSl\n1FFCA32C7D9cnugiKKWOEhroE2T/oYpEF0EpdZTQQF9PvvWuOmXZV6o1eqVU/dBAX0/aZWew+eFz\nfc/XFpYksDRKqaOJBvp6dsGgjgAsLyjmje9+YI8mOlNKxZkG+nr2xMTB9OuQzaIf9nHnjGXc+Pri\nRBdJKZXkNNAnQE6zdIq8Nflt+w8nuDRKqWSngT4B2jRN9426qajS8fRKqfjSQJ8Ag7q09D3WQK+U\nijcN9AnQM7eZ7/HuknJW7TiQwNIopZKdBvoEyEz3X3lq/N++BGBn8REdX6+UijlNapYAmQFLDFqG\n/ekTRGDTn8513K+UUrWhNfoECKzR2xlTjwVRSh0VwgZ6EXlRRApFZHmQ/aeLSLGILPH+3GPbN05E\n1ojIehGZEsuCN2bBavRKKRUPkdToXwbGhTnmS2PMIO/P/QAikgI8A4wH+gOXiogmYUcDvVKqfoUN\n9MaYL4C9tXjtocB6Y8xGY0w5MA2YUIvXSTqhmm6UUirWYtVGP1xElorIhyJynHdbJ8C+dl6Bd5sj\nEZksIvkikl9UVBSjYjVMTVI9H/tFgz0fR6eWmYksjlIqycVi1M0ioJsxpkREzgHeAXoD4nBs0K5G\nY8xUYCpAXl5eUndJiohfJssZi7fxxnc/+J5vLCqhfYsMstJ1UJRSqu7qXKM3xhwwxpR4H38ApIlI\nDp4afBfboZ2B7XV9v2TTJM3zX3DnjGW+bWc8Opf+93yUqCIppZJMnQO9iLQXEfE+Hup9zT3AAqC3\niPQQkXRgIvBeXd8v2TTPSAu6z5ox+58FP7Bdk58ppWopkuGVbwDzgGNFpEBErhaR60TkOu8hPwaW\ni8hS4ElgovGoBG4EPgJWAdONMSvicxmN19jj2gXdV+U2HDhSwW/fWsZlL3xbj6VSSiWTsI3AxphL\nw+x/Gng6yL4PgA9qV7SjQ592zYPuS091UVnl6a7Yq6kRlFK1pDNjE6x5RhoTT+riuK+yylDl9gR6\nlzj1bSulVHga6BuAhy8e4Li90u32BXrRQK+UqiUN9A1YRZXbl6/epXFeKVVLGugbiPbZGTW27Sg+\nQrkv0GukV0rVjgb6BmL+XWN4/ZqT/bbd+PpiSssqAUjRKr1SqpY00DcgpxyTU2Pbo7PXAp5FxN3u\npJ4wrJSKEw30DdzctdV5f9YWHkxgSZRSjZUG+kak5EhloouglGqENNA3IqXlVYkuglKqEdJA34hY\nHbNKKRUNDfSNyOIf9iW6CEqpRkgDfSPyr2+2YHT1cKVUlDTQN1ATT+rCB78eydWn9vBtK69yU3iw\nLIGlUko1RrqEUQNl5b/p274fL3y1ybf95D9+wpi+bWmWkcrfJg5OVPGUUo2IBvoGZvq1w1mzq3q8\nvMthRuwnqwsBNNArpSKiTTcNzNAerblsWLeozvnvwgKKD1fEqURKqcZOA30jt3xbMbe9uZTfvb0s\n/MFKqaOSBvpGbv8hT01+T4muQKWUcqZt9I3Yq/M2s9sb4Juk6Xe2UspZ2EAvIi8C5wGFxpjjHfb/\nHPit92kJcL0xZql332bgIFAFVBpj8mJUbgXc/W71WusZqSkJLIlSqiGLpBr4MjAuxP5NwChjzADg\nAWBqwP7RxphBGuRjI69bK07rk1tj+4odxbw6bzMAs5bv4NHZa9iyp7SeS6eUaojCBnpjzBfA3hD7\nvzHGWHPz5wOdY1Q25TXimDa+x73bNWNYz9Y1jtm69zB3v7uCd5ds47rXFvHUp+s587G59VlMpVQD\nFeuG3auBD23PDTBbRBaKyORQJ4rIZBHJF5H8oqKiUIcedf59zTBeuWooAH3bZ5OZFryZ5jfTlvge\nV1RpugSlVAwDvYiMxhPof2vbPMIYMwQYD9wgIqcFO98YM9UYk2eMycvNrdk0cbQ7rU8u794wgknD\nu4UM9EopFSgmgV5EBgDPAxOMMXus7caY7d5/C4G3gaGxeL+j1cAuLRER0lN1hI1SKnJ1jhgi0hWY\nAVxmjFlr295URJpbj4GzgeV1fT8V3ULhbrfRjJdKHeUiGV75BnA6kCMiBcC9QBqAMeY54B6gDfB3\nEYHqYZTtgLe921KB140xs+JwDUcd72cakRP+8BFd2zTlw9+MjGOJlFINWdhAb4y5NMz+a4BrHLZv\nBAbWvmgqmCgq9JSWV7Fqx4H4FUYp1eBpY28j5LLV6L+9a0wCS6KUagw00DdC3ds09T1ul50R1bmP\nzVlL9ykzY10kpVQDpoG+EerfMZsXr8jjrz/xtIzddnYfv5WonLz41SZWbj/Ak5+sC/v6h8or6XXX\nB3ywbEdMyquUSixNatZIndG3ne/xjWf0Zl9pud9KVIHuf3+l3/MNRSWUVbjp3zG7xrG7DpRR5Tb8\nedZqzjmhQ+wKrZRKCK3RJ4los1eOeXQu5zz5Jbe/uZTCg0f89mV4X2vLnkPMWFQQ8nX2lZbzzuJt\n0RVWKVWvNNAniSa27JXTrx3ON1POiOi8NxcW1GjOqXJXj7u/ZfrSkOff+MYibvrPErbuPRRFaZVS\n9UkDfZKwJlFdkteZoT1ak9OsScTnrtl50O+5PdADLN26n9KySsdzdxR77gbKKt3RFFcpVY+0jT6J\nrH5gHOkpnu/uaNIkLNi8z+95YKCf8MzXjD42l5euDJXBQmffKtVQaY0+iWSkpeCyzabqmdM0xNHB\nuR1SJizeut/x2CjmbimlEkQDfTKrZRSucmiFCZYux0rHoOl0lGq4NNAnsWji/PrC6nb6wKYbIGhi\nNOs9NM4r1XBpoE9iriiSn5352BdsLCoBnJtuwgVypy8HpVTDoIE+iUUT6AGWeNvhHYN20KYbgp8T\nZ9v2H+ZQufNoIKVUNQ30SezE7q2iOv6W6UspPlRBlUON/mBZZci89pUJCPQjHv6USS98V+/vq1Rj\no4E+id17fn9m/vpUAFo3TY/onIL9h1i/q8Rx36vzt9TYJt5W+ip3YsbR52/ZF/4gpY5yOo4+iTVJ\nTeG4ji344vbRZGem8n1BMZNeDF0DPvfJr4Lue3XeFvq0a86wnm1826ymm09XFzKka6uoFkVRStUP\nrdEfBbq2yaJlVjqn9anbouvrCkuYOHW+475nPtvAWY9/EXQG7eHyKrpPmcmb+VvrVAaLLo+oVOQ0\n0KuoHThS4bh9fWEJN7y+yPd8R/FhnvlsPcYYdpeUAfD4nLWO50ZLB/koFTkN9EeZ9tkZnN2/XfgD\nQ7hzxrKgNeqv1u32PZ78ykIe+WgNP+w9RGqKp0knVp22OpxTqciFDfQi8qKIFIrI8iD7RUSeFJH1\nIvK9iAyx7RsnImu8+6bEsuCqdubfNYapk/LY8Mdzav0aM7/fwQfLdlJ8qKJGm3xmenUWzaKDnlp8\nWorL12mrgV6p+hdJZ+zLwNPAK0H2jwd6e39OBp4FThaRFOAZ4CygAFggIu8ZY1YGeR1Vj1KiWWHc\ngb2Jxu7gkUo+W1PIse2as/OAJ7NlldtgvFWKCqf8CrXgNAQUoPhQBc0zUv1y/ih1tAtbozfGfAHs\nDXHIBOAV4zEfaCkiHYChwHpjzEZjTDkwzXusamCGdm/te/zAhOPq/HpXvrSAUx7+1Pd85F8+4/2l\n2wGorKoZoJds3c9nqwtrbH9v6XZumb7E8T3sNfqLn/2GFduL2VtazsD7Z/NEBMslKnU0iUUbfSfA\nPpSiwLst2HZHIjJZRPJFJL+oqCgGxVKR+s+1w7hyRHdm3TQybu/xpw9XA3C4ooorXvqOwgNHmLV8\nJz9+9hsueOZrrnx5AcWHKlixvdh3zq/fWMyMRc6rV9kD/cIt+3jw/VW+Dt8Pda1bpfzEItA73SOb\nENsdGWOmGmPyjDF5ubl1GwaooiMi3Hv+cfRtn01RSblv+8VDOnNsu+Yxf7/P1xTx9883cN1rC/0m\nPP34uW9CjuO3C2yjN5i4ZNCschvW7joY/sBGavm24hpLSarkE4tAXwB0sT3vDGwPsV01YFec0t33\n+IbRvcjOjM+cupe/2Vxj27pCz4zcdxZvo/uUmb7tTiN8AhOv2Z/Gcs7WEx+v5ezHv2Bdkgb78576\nijMfnZvoYqg4i0Wgfw+Y5B19MwwoNsbsABYAvUWkh4ikAxO9x6oGJHAlKnuqhBSXJGSm65Of+rex\nlzt04DqmUo5DsuRvN3q6p/aUloc5MjpVbsMf3lvB9v2HY/q6tXHgSCXrC0t49vMNiS6KipOw1TUR\neQM4HcgRkQLgXiANwBjzHPABcA6wHjgEXOndVykiNwIfASnAi8aYFXG4BlVLS+85G1eIr3qXCE6D\nV645tQcGeOGrTXEpV+Ds2vJKN01SU3C7DUsK9jOka6t6G15pfcmkpdT8oB6fs5ZKt5vBXVpxpndu\nQvHhCtbuOshJtg5uJ99u2sPL32xmfWEJr11zcuwLHqWf/mMee0rLufyUbmSla2aUZBP2f9QYc2mY\n/Qa4Ici+D/B8EagGqEVWmuP24T3bMG/jHsA51fHvz+vPc3PD1/5uPasPj9ZiJuzBI/6BvqzSzbQv\nNlJ8uIKnP1sPwNu/OsXvmGXbivluk6f2LTFc4NAaDpoeEOgPHqngb7bRPZsfPheAq19eQP6Wfax9\ncHzIdXutHHBOuf8ToVTTPSc1/epWNTx32Yl8tHwnXVpn+cbbXz68G7NW7OSmM/sAcFzH7LCv07uW\nHbmHyqv8nn+1bjcPfbDKb9tfZq2pcc4979a8YZz8Sj6dWmVy7/m1GzZqBfrAO58T/jDb8fhl2zyj\nhirdbtJDtIxazUzRrhkQL9b3jc5DS06aAkHV0CIzjUtO8vSjW230p/dty7d3ncmlQ7sCMLJ3Lh/8\neiQTBnVkZO8cx9dp1iQ29YiDDonSrDsOJ/bYOXvlLl76ejMAb+ZvZcaigqgWK6nwjvuPNAuzFbjn\nrNwVsv3dCqi1ifPrCw+yde+h6E+MgM44Tk4a6FVI3hQ1jiNf+nfM5m8TB/Pq1c5tzGkpsamt/rCn\nNKrjN+0upd/ds3gsoNno9v9+zy3Tl9L/no8wxnA44M7BbsueUrpPmcmm3Z73DjYTN5AVuH8zbQnn\nPvll0OOszzOws7v4cAU7i0MPdzzzsS8Y+ZfPIipPtNwJCPTfbdrLWY/N5UhFFQ9/uJqrXl5Q72VI\ndhroVUhWDTVc5oJ7zuvve9y8SSqThnerc5oFyz+/jK7Tt6zSzeGKKp4MMUP2759voN89s9jnHU2z\nde8hv/QM1mgbS6QLq9iveN8h5yyfUN1UEvgJjfzzpwz70ycRvZfdNf/K9xuSGi0rvEf6hRaJTbtL\nefrTdWFTSt///grWFZawdtdBnpu7gU8dZkk3BMYYCg80zjkHGuhVSOcO6ABAn3bNQh531ak9ePnK\nkwAY1LUl9084PmaBPh4e+cjTxr+7pIwv1xUx8i+f+aVQ/nZTYKCP7HUDa+jWbN1AV3prrYEf0YEj\ntesU/XjVLoCwdwPhxLJGf9kL3/LX2Wt9Q1Nnfr+Dgw4prlO8HSCJWI4yGi9+vZmhf/yE9YXOK7A1\nZBroVUgXDenMuofG061N07DHtszyjMHf763JpoYau9mAXOZdd3bVjgOApyb61qICv2MibbsObHNf\nVlDsfKCXS4TNu0u5/c2lMUn49tOp82psu/k/S5j5fZi0EN7LK69ys3BLqNRWkTtS4bket/HMLr7h\n9UVMeWtZjePSXNZdY8MO9PO9/UIbi+oe6PeVltdrM1nj+EtUCeU0htxJjxzPl8FP8joD1SNV0lKE\nnGaRrVkbL8GaDwyQmeZJrdyvg2ckkVNnbaTDIANH0YTrbBWBW99cypsLC/i+YH9E7xHKlj3+nbRu\nt+Htxdscs406fSaPz1nHxc/OY9EPdV+L17pbqXIb39yIAocOauvOzynhXbSOVFTFraPaGi5bVlm3\nL+S9peUMfmAOj85ZE/7gGNFAr2KmRWYam/50DpOGdweqa/TNmqT6jUN/flJe2Ne62TuMM1YqggSR\n4sMVHK7wdMpa7dNOdyL7D1Xw6epdbAnTMRwY2K3A/8xn67ll+hLueXe53yLrIkKlNYQzDkMt9x6q\nOaP33SWeFBM/OARE666mrk1AUP1ZlAcExneXbGPW8h0cqahifWGJL9DHokZ/7asL49ZRbf0OB15P\ntPZ4m/M+WrGrzmWKlI6jVzFlb6O2/oAz01JI89aGfjSwI2P6tQ37Om1ifAdw+3+XOm7/aPlO3+Oy\nCjezlu/kutcW1jjOXiPe6LBoS1llFZ+sKqyRWM36DKw+gUBCddt0YPPK9wX7ads8g/YtMqrLu6K6\nvAu37KPoYBljj/NfMWzGogLOHdCBVJeLXd7Ow5a2yXH/86aMXrn9gG9bYPqIYEHXGBNxWgzri8sv\nMBrDb6YtATyd9vahs5WRjmENYe5aT+bbKreJeR9RE+/vsFNKjmi4g3TEx5MGehVHnt/ojPQU3x/9\nr0b3iihQDOnaKqYleXeJcz69dYUliEB2Rhovf7PZMdlaIKeRKf9buoPb3qz5ZRLuSl0ivqD6fEBK\niR89/TXpKS7WPjQe8KSGuPbV6i+hS/4xzzEg3zJ9KbdMX8pFgzsx0TvvoaktrYEvADsELOu/xqmp\n6uqXF/DJ6kL+fc3JjDim5tyJL9YWUVbp5ixvOgjrfcoq3b7/c/tnFzg/ojY1+pKySpqmp/DD3kN+\nd0QVVW72lFYw9KFPmDZ5GAu37GPc8e3plRt6UEEoVtNNXWr03afM9E02rM+5ctp0o+LG6ozLTEvx\n3fZa7bCL7j7Ld1xmWgqXDevmd26fds3Y8Mdz+GleF+Jp7toicpo1iSpLp1OHolOQB1hfVMKMgI5d\nu1krdoZs/y+vcnPCHz5i1vIdviYmS7jAOGPxNl8tedv+wxRbneQp1QHYElgEp/byT7zDHn/+/Lc1\n9h2pqGLSi9/xy1fyAdhQVOLr67B/oVRUBi9ztKNu9pWWc/y9H3HeU18x6pHP/XIvHTxSySJvCuyn\nPl3HIx+tYeLU+YCnM/WGfy+KOmDHqulmhfdOqj5nRWugV3HTtnkTAC4c3MnXdGP90duzZKalCNef\n3sv3fGDnFqSmuEhxCX/+8QDf9rd/dUpcOnX3lZaTkZoS/kCvwBE5odzz7gpume78JWAJF98OHqnk\nwZmratVuvnRr9aifnz3vCXTWcMY7/vu9b58VZK3YE248/caiEl8zyfyNe+h79yzfPmMMYx6d65tH\nUF7p9nX8hhpZ9PHK6Nqs95R62rqtwGnvzL7o2a99dxFfr/eMljninSA3ZcYyZi7bwULbWgiRCPwd\njlawAQEHjlTwyar4ttdroFdx0zY7g5X3j+XqU3vQxFsbqnCoDV0xogcdW2byp4tOAKBve+c8Oj1y\nmpL/+7Mc99VFpduQkRZ5oI81p3HZgUHhSIWb856KbFEWuz/PWu17bAXE1AjarsPdLZzx6Fwuf9Ez\nLHXeBv90FD3u9M9jWF7p9t21hAqSby4M/gU64emv+FNAvqPAJkD7HcHWvYdrNptZzVLe45YW7Ofp\nT0MvO7noh30UH/Z8YVk1+tqOugl23k3TlnD1v/LZFseU1RroVVxlpaciIvzp4hM4u387BnVtWeOY\nW87yjLCxRp+kBkmd0DzDOdtmLGQmMNBbOto6XQODQrCJV9E6/6mveHux8/KMUP1lUFZRMz1EsM8o\nXINLeaXb1xQU6VyBwC+6pQXF/OOLjZRVVpcrJSDQB44iCtYXZN2tPPzhav46O3h21YoqNxf9/RsG\n3jebkrLK6qaWCIba/rDnEL9/Z5nfF2Zg+m3r9axx+U6feaxooFf1olduM6ZOyqNJiCaScm8wCBy3\n//7/ncr/bjw1ZqMoTutTc6nK5hmJHZcwsncO7WyB/kic/uit7JrBWDFs36EKuk+Zya22ZienkVDd\np8zkrRA1cfDU4q2AF2yYa6C1u5wnJb06zz401X/f/oCUE4G/LtbTwIlKe0rK6D5lJm9894N/uW1f\ntpt3l/rezzp79c4DNYK35f+mLea1+T/4fd6BWVlFPH0Zm71zH+K5yI8GepUwn9w6iq9+O9r33Mon\nE9i0cHynFpzQuUXQ17HnfT+zX7ugx1msJiK7RAf6JqkuvxpzYFCob1au/bcWFfCfBZ4A2Kapc/9I\nuCaHooNlvmaVkiCBMZDVLBTI/gUYrm8jWA0/sNPXOm7adz+wde8hbp2+lPJKd41OV5etQn/Jc/MY\n98SXDL5/Dmt21lxm0voysf8mB3ami8AY2zKO8eya1UCvEqZXbjM6t8ryPT/nhA40a5LqGxIYqSb2\nyViXV0/GmjS8G78e07vm8QELglw2rFtcm4UikZ7qIiu94QR6u996Rxm18Ka4iPbOasnW/b6abagR\nKwNtX+bBmnjsMTpcorn7/rfScfsK2/wB+2sa4PfvLOetRQVMW/ADgx+Y43dcdc5+w3ebPWkiyqvc\njH3iC991PfXJOvrdPcs3L8FeSQ/8Pw0cdRPPhAga6FWD0blVFsvvG8sxbSMb6/zW9cOB4G36vzu3\nH71yPWkZrJp+dkYqrbOqa6bL/nA2D1xwfNjhlZlpKfxyZI+IylUb6Skuvw7hbzbsjtt71ZZVS+3Y\nMiPMkf4OHqkIOmHMzur0tBw4UsGfZ632C/rGL9BHVQyq3Ibnv9xYY7vVUby3tNxXu7cmlVme/XyD\nL3GcU0C2hpI+OmcthyuqfOsXfLdpr68NPvDLK/C3NtIMqbWhE6ZUo5Oe6qJddhMGdWnFwC4tuenM\n3lz5Us0c5qkul69NODsj1bfcH8C/rhpKeaXbV5NvkRm6Rp+WIpSUxa+WnZ7qwt4w4rRaViKt23XQ\nFxCjzUkT2GQRzOaAPD1/mbWa1+b/4Jc51T7nINoJViVlnmGqgaa85RlmWrCvugkq8LVnLqueteyU\njKy0vIqW1TenvvkLD85cxYMzV7H54XNrBPrApqVI+y9qI6JALyLjgL/hWeT7eWPMwwH7bwd+bnvN\nfkCuMWaviGwGDgJVQKUxJnyiE6VCWHnfWESEFJfw7g0jgh7nkupaVGAH76iADtm2zUPXUtNTXQzq\n0oI3nJuO/fTvkM0LV+Txh/dW1Mhn0q1NVo3EY9brN+Rsn2c9/gVNvU1LO6Icz19WEX1NVaR6wp09\nAC7YvJeSskqaNUmN2Xq7G4pq5i8KFXPtuYosIx7+1O95YGfyL1/JZ07APIHA9Qrimb0z7G+WiKQA\nzwDjgf7ApSLS336MMeYRY8wgY8wg4E5grjHGnut0tHe/BnlVZ9ZkqnBEhPbZngDep33o9WvbZjcJ\nuX9I11ZckteFL+8YzY9P7BymfEKHFpk1tr9y1VD+7wxPn8GAzi04u391x3GT1BQy0xM/xDOU0lr2\nGxypjP683SXljqMYv9mwh+u8aSDiGRhDpRCuTf9JYJB3Es98/JFUIYYC640xG40x5cA0YEKI4y8F\n3ohF4ZSqq9F92/KfycO48pTuIY+z1+hP7NaqRoftExMHISJ0aZ3F3ef25yXvIiuh7CutrrFdPrwb\np/XJ9S2vKOC31m56qoufn+zfCf3ABceHfY/GINSSjaEEm4H87aY9TM/fyj3vxa95K1Z3C9HYd6ic\nvaU1s43GQiSBvhOw1fa8wLutBhHJAsYBb9k2G2C2iCwUkcnB3kREJotIvojkFxUVRVAspar1adeM\nX59xjOO+k3u2wRXmDsDeGfvW9aew5sHxfvuzbEnBWmSlMfrY8Bk4K2ydawM6twRsTUgifp16zTNS\n6dwqi79cXJ3yoXkUi6v3tnVgp7gk6FDIRKhr/vaNAU0rFVWGO/77PUu37q/T64aSiEVQHpq5irMf\nnxv+wFqIJNA7/YUE+xTOB74OaLYZYYwZgqfp5wYROc3pRGPMVGNMnjEmLze35oQWpUKZffMobjn7\nWMC/phyp7DgMr7SCxT3n9eeiIZ66kX2OgL3S2DLTE5jtKQLso3A6tAjdh/CjgR35pzfP/8jeOdw2\n9ti6FT6G7J2ctfHc3A0xKknk6pq4rDbKKqvitvxmJIG+ALCnEOwMOOd8hYkENNsYY7Z7/y0E3sbT\nFKRU3Lxw+UksvffsqM4JbKqJxJybT+PN64bX2G4FcKsTcUDnFr7JOlaNPvDPeWiP1oD/5KOMNJft\ncej2+1+NPsY3zNRtYOJJXVhx39jIL4bo7iDipW+YvpT6snF36AVm4uFIhTtuHfKRvOoCoLeI9BCR\ndDzB/L3Ag0SkBTAKeNe2ramINLceA2cDy2NRcKWCSU91hR0uGSjU9PPeQcb1927XnJO6t66x3Zos\nY+XusXeyWm+T0yydn+R15qLBnVh091m+uQP2Cl1mWopvxm6oml7nVpmkuMSX+8Xt9iwO0jQgcF+S\nF7oT+ZGfDAQiS3oWL1kx6JC+NMoJdw3FkfIqXx9OrIUN9MaYSuBG4CNgFTDdGLNCRK4Tketsh14I\nzDbG2L8K2wFfichS4DtgpjFmFko1UPaVmCz/ubZmrd3u3RtG8NltpzOwS0u/7VbTjT2/T1731lw3\nqheP/HggWempPPbTQX4pm28cXT2TNyMthTF9PX0Bkfz5WzMtg3UkBgb+QNZdTV2CbV3vCux9IbUV\n4RLHDU5ZpTuhTTcYYz4wxvQxxvQyxjzk3facMeY52zEvG2MmBpy30Rgz0PtznHWuUg3RN1PO4PPb\nTvc9t2q2rcN0bA7s0pIeOU15cIL/KJkKh9w9zZqkMmV8X1oFec3M9BRfHv/M9BQevngAL115El1b\nZzkeb2e9jT3QXzWiejZvsxBBeGCXlgzp1opJw7vxR4dcQJH62cl1q03HokYf6wEzVrNaMIFLOdZW\neZW7xnyPWGmk331KxV7Hlpm0tKVH+Oy203n9lydH/TpWoJk8sicQfox+IOuP3UqLMPrYtr4mn0nD\nuwU9zxpZZJ9Jf8/51VNegtWWrz2tJ+/eMIIWmWncP+F4Rvau3WCI8ce3DznHIJLafnot+kriLTAd\nst1vx/WNaXBOaI1eqaNRl9ZZnNIr8hE8gfHgsuHd2fzwuVE3Rzj1L1jNLgM6t6yxvKK1IIbVdBNs\ndahmTTy15Z/mdfHL2xPY1NMiMy3imvntttE9T/9sCL3bNfc1NwWy+gDCCdYnkiiBC6fb9cxt6tin\nMSBEttVQAhOuxYoGeqVirK5NB/+47ERuOrM33dpUN9dYgf5weSUPXVjdRPSr03vxwhWeyVtOTTd2\n/Tp4Vu46vlM2d53TjxtG9wpaXuvLo5VDn4XTcVBdG506yXkCvNPIplev9h+E5xJhzi2jfM+nXnZi\nyPevjWiC8KtXDw35/5kiUqMj/63rT+GJnw6qZeniQwO9Ug1Ml9ZZ3HRmH78AYjV7HCyr9Lu9v2Nc\nX3rkeDJ0+ppuggSmvO6t+eTWUfxiWDdEhB45npqzU3+BFZSv8TY/Bcr19iO0bprOAxOO46UrqmcK\nh2p+aNM0nYknVd+RBDYTBa5iNbhrq6CvFQ1rbsUtZ/XhH7Yvj5Md2t+/vKN6jYRTjwl9R5fikhp3\ncid2axXVMMn6GCWU+IGzSiWJPu2ac1qfXG4/O/aTlUYdm8s/vtjIoC4tgw4F7dzSk1/nnOPbB32d\nXrnVzSIXDe6E2224cEjNie5WW3lllaFJqqvG7NaXrjiJxVv3c+HgTmFnHVvcxrDwbs+av9MWbK2x\n/+cnd+Vm77KSlmiHyTr5/LbTeWXeFr5ct5v0VJdfm/svR/bk203V8zuf/tlgutg6viVgBnOgFJf4\n5ZW3PoqUKIZJhrtrigUN9ErFSHqqi1euis98wFN65bDivrEhh0i2zc5g+X1jfVkmw3G5hEtO6uK4\nz6rRl1c5z9Zsm92Ey4YF7xh2Ei6rwEMX+o/26ZnTNCadsy4RvzULKmwFsc9x+ODXI+nf0dO8dUbf\ntny6ujDsa6e4xG/oa7AJcZZRfXKZu9Y/xUtqPYwH1aYbpRqJcOPgwTOEMhZrj1oBtqzCzdjj/O8Q\n5tx8Wti0zk6iSRS2+oFxfHjTyJDH5DisYetEpLovobzSTbZt2Uh7oLdPVnruFyey2Hv3YVXp7z7P\nL2kv4PkSsdforUBv/V+NC/jsEjWqSAO9UqoGa7TR6ce25c8XD/Bb27d3u/BpCtp5h5R2aV2drjlU\n6t9AGWkpIReSBzhvQMeIXivFJUw6xZM9dEy/tjTPSPP1edg7iO2lS091+fouyrwznHvlNq3RlJSa\n4t9Gb905tMhMY+2D4/m/Mf6J9oIN1bxyRPeIrqW2NNArleQGBczYjcTxnVqw8Y/ncGrvHNJTXX5r\n+0biiztGs+r+cX7B2B7nl983luUR5uJpkZlG9zY13//2CBO3uURo2zyDV64aynEdPSNu0rwB3p7M\nLtgNx2HvMoHNM9JqdLy6AkbdjLB13qanump8WRkM5w/s6BvxBJ5mnnsc7hZiSQO9Uklu+rXDWXl/\ndAnOAMdO1kiTv1kLqdgnE9mbbpo1SQ05U9du4e/P5GPbkEtLZlqKr2b+zyBDOsE/f1DNcroY6s1X\nFKyp3FpoJDsjtcaC3p7OWM/jy4Z149GAuQKBn1eVG566dDC3j+3r23Zmv3YxaW4LRTtjlUpynvVo\n616nm3rZiRwbZXZJ+4gSaxhotAI7K1//5cnMWr4Tl0t8Xx4n96weJtk8IOV0qCDqcglP/2ww//t+\nh9+IJDtr4ZTmGWm+oN6xRQbbi49QcqTSV8vv3a5ZjSyjgYHeqZ/ihFpOroqGBnqlGqELB3dyzJwZ\nT2cfF3zYZjC/GNaNKrdh7HHt/YYt2k2bPCyqVahO6ZXj60N49JKBPDp7LU1ts49/M6Y3LbPS+NvH\n6zhc4TxqyL6lbXYGV5/ao8YxlkO+QF/d0X31yJ488P5KOrXKrE4m59AHEdh0E65Dumdu7b4Mw9FA\nr1Qj9HgDm3kZTFqKK+ikK8uwnm0ieq2RvXP4ZsMev23jju/AuOM7+G3LTE/hulG9ePrT9YBz081r\n15zM9PytEa3E9fdfDOGFLzeRlZ7i+4I4b0AHJg3vRlqKyxfonUJ4k7TAppvqo2b++lS/3Erz7xxD\ns4z4hGQN9EqpRuHVq6NLMCe+f2tG+n4dsrn3/OMiep3Rx7b1LR05pGsrZq3YSZNUV41kZk6DitJr\nHFN9kNUxbGkfZhWxutBAr5RKTnHo33zspwO5dmdPv5q4r0bv0CwT2KHtrv8VCgEN9EqpJPHSlSdR\nWRXfRb2z0lNr5N+xYnmw5vfubbL46UldWbZtP7fGIT1GJDTQK6WSgtW8YunUMpPVOw8Sp2VYfaxR\nN8E6Wj+/fbTj9vqkgV4plZReuWoo8zbuqTHcMtasIZXxWjQkFjTQK6WSUtvsDCYMqpmZM9auP70X\nZZVufhFlkrf6FNFNjYiME5E1IrJeRKY47D9dRIpFZIn3555Iz1VKqcYsKz2Vu87pV2OyVEMStkYv\nIinAM8BZQAGwQETeM8asDDj0S2PMebU8VymlVJxEUqMfCqw3xmw0xpQD04AJEb5+Xc5VSikVA5EE\n+k6AfTmYAu+2QMNFZKmIfCgi1kyESM9FRCaLSL6I5BcVFTkdopRSqhYiCfROXcmB44gWAd2MMQOB\np4B3ojjXs9GYqcaYPGNMXm5urtMhSimlaiGSQF8A2Ncb6wxstx9gjDlgjCnxPv4ASBORnEjOVUop\nFV+RBPoFQG8R6SEi6cBE4D37ASLSXrxp3URkqPd190RyrlJKqfgKO+rGGFMpIjcCHwEpwIvGmBUi\ncp13/3PAj4HrRaQSOAxMNJ7ED47nxulalFJKORCnRDyJlpeXZ/Lz8xNdDKWUajREZKExxnGprQYZ\n6EWkCNhSy9NzgN0xLE5joNd8dNBrTn51ud5uxhjHkSwNMtDXhYjkB/tWS1Z6zUcHvebkF6/r1cXB\nlVIqyWmgV0qpJJeMgX5qoguQAHrNRwe95uQXl+tNujZ6pZRS/pKxRq+UUspGA71SSiW5pAn0ybrA\niYh0EZHPRGSViKwQkd94t7cWkTkiss77byvbOXd6P4c1IjI2caWvGxFJEZHFIvK+93lSX7OItBSR\n/4rIau//9/Cj4Jpv9v5eLxeRN0QkI9muWUReFJFCEVlu2xb1NYrIiSKyzLvvSSvtTESMMY3+B096\nhQ1ATyAdWAr0T3S5YnRtHYAh3sfNgbVAf+AvwBTv9inAn72P+3uvvwnQw/u5pCT6Omp57bcArwPv\ne58n9TUD/wKu8T5OB1om8zXjSVm+Ccj0Pp8OXJFs1wycBgwBltu2RX2NwHfAcDxZgT8ExkdahmSp\n0SftAifGmB3GmEXexweBVXj+QCbgCQx4/73A+3gCMM0YU2aM2QSsx/P5NCoi0hk4F3jetjlpr1lE\nsvEEhBcAjDHlxpj9JPE1e6UCmSKSCmThyW6bVNdsjPkC2BuwOaprFJEOQLYxZp7xRP1XbOeElSyB\nPuIFThozEekODAa+BdoZY3aA58sAaOs9LFk+iyeAOwC3bVsyX3NPoAh4ydtc9byINCWJr9kYsw34\nK/ADsAMoNsbMJomv2Sbaa+zkfRy4PSLJEugjXuCksRKRZsBbwE3GmAOhDnXY1qg+CxE5Dyg0xiyM\n9BSHbY3qmvHUbIcAzxpjBgOleG7pg2n01+xtl56Ap4miI9BURH4R6hSHbY3qmiMQ7BrrdO3JEuiT\neoETEUnDE+T/bYyZ4d28y3s7h/ffQu/2ZPgsRgA/EpHNeJrhzhCR10juay4ACowx33qf/xdP4E/m\naz4T2GSMKTLGVAAzgFNI7mu2RHuNBd7HgdsjkiyBPmkXOPH2rL8ArDLGPGbb9R5wuffx5cC7tu0T\nRaSJiPQAeuPpxGk0jDF3GmM6G2O64/m//NQY8wuS+5p3AltF5FjvpjHASpL4mvE02QwTkSzv7/kY\nPH1QyXzNlqiu0du8c1BEhnk/q0m2c8JLdI90DHu2z8EzImUD8LtElyeG13Uqnlu074El3p9zgDbA\nJ8A677+tbef8zvs5rCGKnvmG+AOcTvWom6S+ZmAQkO/9v34HaHUUXPN9wGpgOfAqntEmSXXNwBt4\n+iAq8NTMr67NNQJ53s9pA/A03swGkfxoCgSllEpyydJ0o5RSKggN9EopleQ00CulVJLTQK+UUklO\nA71SSiU5DfRKKZXkNNArpVSS+3/vwPgVTpj8UgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "opt = Momentum(0.01)\n", "params = task.init(key)\n", "opt_state = opt.init(params)\n", "del params\n", "\n", "losses = []\n", "num_steps = int(os.environ.get(\"LOPT_TRAIN_LENGTH\", 1000))\n", "for i in range(num_steps):\n", " batch = next(data_iterator)\n", " loss, grads = value_grad_fn(opt_state[0], batch)\n", " opt_state = opt.update(opt_state, grads)\n", " losses.append(loss)\n", "plt.plot(losses)" ] }, { "cell_type": "markdown", "metadata": { "id": "RrwIMhObH29t" }, "source": [ "And finally, we can implement Adam." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "executionInfo": { "elapsed": 53, "status": "ok", "timestamp": 1647716629923, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "L7gd-MqEH2da" }, "outputs": [], "source": [ "class Adam:\n", "\n", " def __init__(self, lr, beta1=0.9, beta2=0.999, epsilon=1e-8):\n", " self.lr = lr\n", " self.beta1 = beta1\n", " self.beta2 = beta2\n", " self.epsilon = epsilon\n", "\n", " def init(self, params):\n", " return (tuple(params), jnp.asarray(0),\n", " tuple([jnp.zeros_like(p) for p in params]),\n", " tuple([jnp.zeros_like(p) for p in params]))\n", "\n", " @functools.partial(jax.jit, static_argnums=(0,))\n", " def update(self, state, grads):\n", " params, iteration, momentum, rms = state\n", " iteration += 1\n", " momentum = tuple([\n", " m * self.beta1 + (1 - self.beta1) * g for m, g in zip(momentum, grads)\n", " ])\n", " rms = tuple([\n", " v * self.beta2 + (1 - self.beta2) * (g**2) for v, g in zip(rms, grads)\n", " ])\n", " mhat = [m / (1 - self.beta1**iteration) for m in momentum]\n", " vhat = [v / (1 - self.beta2**iteration) for v in rms]\n", " params = tuple([\n", " p - self.lr * m / (jnp.sqrt(v) + self.epsilon)\n", " for p, m, v in zip(params, mhat, vhat)\n", " ])\n", " return (params, iteration, momentum, rms)" ] }, { "cell_type": "markdown", "metadata": { "id": "VfBmLGVdrQKo" }, "source": [ "## Learned optimizers\n", "\n", "A learned optimizer is simply an optimizer which is itself some function of meta-parameters. The actual function can be anything ranging from more fixed form, to more exotic with the meta-parameters encoding neural network weights.\n", "\n", "### Per parameter learned optimizers\n", "The family of learned optimizer we will explore in this notebook is \"per parameter\". What this means, is that the update function operates on each parameter independently.\n", "\n", "In our case, the learned optimizer will operate on the parameter value, the gradient value, and momentum. These values get fed into a neural network. This neural network produces 2 outputs: $a$, $b$. These outputs are combined to produce a change in the inner parameters:\n", "\n", "$$\\Delta w = 0.001 \\cdot a \\cdot \\text{exp}(0.001 \\cdot b)$$\n", "\n", "We use this formulation, as opposed to simply outputting a direct value, as empirically it is easier to meta-train.\n", "\n", "Choosing input parameterizations, and output parameterizations varies across learned optimizer architecture and paper." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "executionInfo": { "elapsed": 53, "status": "ok", "timestamp": 1647716630098, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "ymF3QnR0-UdM" }, "outputs": [], "source": [ "class LOpt:\n", "\n", " def __init__(self, decay=0.9):\n", " self.decay = decay\n", " self.hidden_size = 64\n", "\n", " def init_meta_params(self, key):\n", " \"\"\"Initialize the learned optimizer weights -- in this case the weights of\n", "\n", " the per parameter mlp.\n", " \"\"\"\n", " key1, key2 = jax.random.split(key)\n", " input_feats = 3 # parameter value, momentum value, and gradient value\n", "\n", " # the optimizer is a 2 hidden layer MLP.\n", " w0 = jax.random.normal(key1, [input_feats, self.hidden_size])\n", " b0 = jnp.zeros([self.hidden_size])\n", "\n", " w1 = jax.random.normal(key2, [self.hidden_size, 2])\n", " b1 = jnp.zeros([2])\n", " return (w0, b0, w1, b1)\n", "\n", " def initial_inner_opt_state(self, meta_params, params):\n", " # The inner opt state contains the parameter values, and the momentum values.\n", " momentum = [jnp.zeros_like(p) for p in params]\n", " return tuple(params), tuple(momentum)\n", "\n", " @functools.partial(jax.jit, static_argnums=(0,))\n", " def update_inner_opt_state(self, meta_params, inner_opt_state, inner_grads):\n", " \"Perform 1 step of learning using the learned optimizer.\" \"\"\n", " params, momentum = inner_opt_state\n", "\n", " # compute momentum\n", " momentum = [\n", " m * self.decay + (g * (1 - self.decay))\n", " for m, g in zip(momentum, inner_grads)\n", " ]\n", "\n", " def predict_step(features):\n", " \"\"\"Predict the update for a single ndarray.\"\"\"\n", " w0, b0, w1, b1 = meta_params\n", " outs = jax.nn.relu(features @ w0 + b0) @ w1 + b1\n", " # slice out the last 2 elements\n", " scale = outs[..., 0]\n", " mag = outs[..., 1]\n", " # Compute a step as follows.\n", " return scale * 0.01 * jnp.exp(mag * 0.01)\n", "\n", " out_params = []\n", " for p, m, g in zip(params, momentum, inner_grads):\n", " features = jnp.asarray([p, m, g])\n", " # transpose to have features dim last. The MLP will operate on this,\n", " # and treat the leading dimensions as a batch dimension.\n", " features = jnp.transpose(features, list(range(1, 1 + len(p.shape))) + [0])\n", "\n", " step = predict_step(features)\n", " out_params.append(p - step)\n", "\n", " return tuple(out_params), tuple(momentum)" ] }, { "cell_type": "markdown", "metadata": { "id": "WH9_EHD6rfrL" }, "source": [ "We can now randomly init the meta-parameters a few times and apply it to our target task and see what we get.\n", "\n", "Unsurprisingly, our randomly initialized learned optimizer doesn't do all that well at training our target problem. Many of them even diverge / nan." ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "colab": { "height": 300 }, "executionInfo": { "elapsed": 2701, "status": "ok", "timestamp": 1647716865746, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "iJ_uf9uwDASS", "outputId": "08ef3727-489c-4b2c-eee4-7532dbf6392d" }, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'inner loss')" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAAAwPElEQVR4nO3deZhcZZn38e/de6ezdJbOQnaQRcIamhAElU2GVWYAARlkUYzg\nyrgiOAqK4+g4iMhIjCiIIryDIMO+g4AEQhKSkJAACUESkpCNJJ2l17rfP86p9KnqqurqTp/udNfv\nc111ne2p089JuuuuZzd3R0RECldRT2dARER6lgKBiEiBUyAQESlwCgQiIgVOgUBEpMApEIiIFLjY\nA4GZFZvZq2b2YIZrZmY3mtlSM1tgZpPjzo+IiKTqjhLB14DFWa6dDOwdvqYBN3dDfkREJCLWQGBm\nY4BTgVuyJDkDuN0DLwHVZjYqzjyJiEiqkpjvfwPwbWBAluujgRWR45XhudXRRGY2jaDEQFVV1WH7\n7bdfl2dU2mpuSrBx1TY2FzljR/anojTL94btG2DTu1A+ABrqoHoc9BvavZktNFtWwdb3YeBo6D88\nc5qGLbBhGZT1h8atue9XUQ31m4L9AaOgLvInWD0u+P8FGHkgFMX9sSFxmDNnznp3r8l0Lbb/UTM7\nDVjr7nPM7JhsyTKcazPnhbvPAGYA1NbW+uzZs7sqm5LDxlXbuPOHL3N/v0ZmXP1Rxg+typxw7h/h\n/i8HHxJrFsA5v4T9P9m9mS009ZvhyWvgxB9DWb/MaZY+CX86C/Y8FgbuAfPuyHHDZqB/sHvs1+CZ\n61ovnX4NPPDVYL+6Ba6Yt8vZl+5nZv/Idi3O0H4U8EkzOwWoAAaa2Z/c/YJImpXA2MjxGGBVjHmS\nDvBITC4vKc6e0MKSwo4Pgm3l4BhzJQBUDILTftFOouT3LIeyLEE849vSvp9ZpCSYLBlInxJbG4G7\nf9fdx7j7BOA84Om0IABwP3Bh2HtoKrDZ3Ven30t6luOUl+T4VUl+cGzfEGz7DYk/U9Ix0cklT70+\nd1pL+7/2RNfnR3Yr3V7ZZ2aXAbj7dOBh4BRgKbAduKS78yP5Kc/WPgCtHxxN24P644GjuydTktuA\nkcF21MHwwTvB/qdug/KBHbtPoqkrcyW7oW4JBO7+LPBsuD89ct6BL3VHHqQTIl8iy4rzCAQANftB\nZXVsWZIOGDEJPv80jDwYNi6DkkrY52R4b07H7tOSIxC8+CuwYjjyi7uWV+lRav6XdhUVGSX5BoLS\nyvgzJPkbfViwrdkXzvxNsF9Skfs9T12bepwrEDz+vWCrQNCraYoJaVdpUabOXVm09yEjPa+kvGPp\n178RTz5kt6FAIO3KWRqA1F4m69+KNzOy6zoaCF79U/tplj2d2iAtvYoCgbSrtN1AELm+dU28mZFd\nV1y6a++f9Vv48R7wSmTCgD/+C8y/c9fuKz1GgUCySn7BK2mvaii9u6Hs3orCQDCgk7O5PPxNaNoG\nD30j9XxyHIn0OvoLlnYVFXegjUB2f8kSQVdX5WhakV5LgUByCD4oitNHmrZJpgFHvUpVDXzkK/CZ\ne1PPf+oPbdN2ZJR4ebYpxWR3p+6j0q6ijgSC4g42REr3M4MTr2t7vqx/23P9huVf5aMvBL2WSgTS\nrqL2fksSkQ+Azz8da14kRpn+o6syTlaZWaKl6/Ii3UqBQLJ6e902oJ0SgTskmluPRx4Qc66kS409\nonXfMkwsOLoDiwa6AkFvpUAgGc1bsYkv3zEXCEYWZ3VTLdx3WTflSrrcRQ/C8P2D/aJiOOwSODQy\nN+TxP0hNP/6o7PdKqGqot1IgkIxWfrB9537OEsGGpd2QG4lNSVnrFNVWDKffAFPD6SKmTAuuJ5VW\n5R45rhJBr6XGYmlXe+PJdjr797HmQ2KWHA8yYhJc/BCMOTz1emlF7g97tRH0WioRSLva7TWUNO4j\n8WZEus+Eo9tORVFSmbtnUDRINGyFRX9VdVEvoUAg7Wp3HEFSUY5VzKT3K63I/cEe7TSw8B64+2L4\n+w1x50q6gAKBtMuyNRY31KUnjD8zEp/2An57JYJo1VBzQ7BdlDZobdO78JuPQZ3mpNqd6C9X2pV1\nhomnfph6rEDQu2WbciJZ5bd2Ebz7Yo73J6BxO6x7M5iLCGDj8tQ0L02H1fNh/l27nl/pMvrLlayS\nn/9F2UaUNW5Le4N+nXqndkoC598VdDNtb+RwogXuOh/+5/DW0mLjVqjf3Jom2Y6w4P+pVLAbie0v\n18wqzGyWmc03s0Vmdm2GNMeY2WYzmxe+vh9XfqRjol8ON2xryJwovTFRbQR9U8UgmPjR9tM1boW3\nnwn2t29oPV+/pXU/WX209nW487yO52Xdm9C0o+Pvk5zi/ArXABzn7gcDhwAnmdnUDOmed/dDwtcP\nM1yXHrZ8/bbMF9L7lGcamSqF45kft+5vW9+6/8h3YOlT4UHkG8aqVzvWq6hxW1DauHfaLmVT2oot\nEHhga3hYGr60hFEvlPU/Lb1EoKqhXq4L/zyjgeCNh+BPZ8KahW3XP968Iv97NtUH23de2PX8SYpY\n/3LNrNjM5gFrgSfc/eUMyY4Mq48eMbNJceZHulhR2kpXqhrqnY76arAdtk/X3XPr+23PzTim7aCz\naJfTdoWBKt/uzJK3WAOBu7e4+yHAGGCKmaXPSDYXGB9WH/0KuC/TfcxsmpnNNrPZ69atizPLEnLa\nbUKEsn6pxyoR9E4fPh2u2Qz9hnTdPT9YDv1Hpp5LNAWvqI60E2hN5Nh0y1+uu28CngVOSju/JVl9\n5O4PA6VmNizD+2e4e62719bUdGBaXIlX+rc5BQKJGjSm7bn0qqH1b+Z/v2SvJf2edbk4ew3VmFl1\nuF8JnAAsSUsz0iwo55nZlDA/G5DeoSUSCI7/vorsheLU/84vXaZAkF4i6IidU1jo96yrxRlaRwHP\nmNkC4BWCNoIHzewyM0vOW3w2sNDM5gM3Aue5q/y3u/nmiftmvhD9o/7QCd2TGel5h1+aX7r+I9qe\na2mnTSCRgPVvZbmmSe3iEtvso+6+ADg0w/npkf2bgJviyoN0jT2qKzNfSCnm61tanzd4YlD3n6/o\nFNZJjVvbnot68Zfw5DVw2d/bLnKULBGoaqjL6V9UMnL3nR/tWaehjrYRqFqo77v87/Ctt4P9LzwH\n/YZmTvfxK4Nteq8ygK1rs99/w7IgCABsea/t9SfDMan6XetyCgTSruJskSBaItC3tL6vrAqqwg//\nUQfD6TfmTl9cClO/lHquaTuMPzr13PuLgu2vIstiRgcnblsP1wxqO4FdVHT0snSY/nolo5ZEa1NN\nSbbZR6NtBOUDYs6R7HbSBxQmtTQG2+JSOOk/4EOfaL22eQUUp9VI3/wRWDEr7R6RaU3WvJZ6Lf1L\nx/Ln4D/HwrKnU8+vXdI6C6rkpEAgGeUVCFqaoKIaLnoAqsd1T8Zk91GcoeoHWr8gJKuGitPaCjJV\nGaUveRqd4rxiUFpig+0b4YErgnmHVoTjVJc/15pk23r49RHw4L/legIJaalKySjad6s463oEW6Cq\nBiZ+rHsyJbuX9A/4pMMvhTcfg4PODdOlfcwUZfjYSZ+nKhoIMvUWeubHMOdWaK5vnYQuWlWZbJR+\n5/ns+ZedFAgkoxZ3LNlcnK1x7v1FMCJ9sLgUjGyBYPAE+PIrrcfpJYD0wABtpyeJ9i5KH3tgRa3B\nYf6dkXTNqWlAo5HzpKohyShaNZQxDNRvgY1vw8iDui1PspvJ9M0+k/RxB5mqhtJXu3vyGnj1jmC/\nzQh2yDhBXqbOC+2toSCAAoFk0e64vmTD3KiD48+M7J6iJYKvzIWPfBXO+l3bdOOPhEsebT1e9hR8\nfXHwSpp7e9v3/d8Xg236tBSQ+Zt+IsO4FgWCvKhqSDKKlggyFgnefQmKy2GvY7stT7KbSQaCIXvB\n0L3gxB+1nxZg0DgYuEewf8Tl8PLNsHll9vemtxG0NMPcP7RNlxIwwt9fVQ3lRSUCyWjJmrrcCVoa\ngi6j2XqOSN83cFSwTU5jnUv09+TMGa37J14XbLdlGWjm3raNoG5V5rTRQJAsCahEkBeVCCSju15Z\nwehs3xPWvQmzfw8DR3dvpmT3UlYVTF+dj2jf//7DW/eLS4J1ELLNQvrKLanpc4m2JSgQdIhKBNLG\nOb+ZmXJs6XVDd5wdbNMb+ESyiX4gVw5OvTb8w9nft2pe5jaC9n6GAkGHKBBIipaEM2v5xnYShX+Y\n+iOTvIV19SMPattVNNeo9KbtcM/n8vwR4e/jizfBjYemnpOcFAgkRWNzHn84yakFNHxf8jVkz6Db\n6LFXtb1WPjD7+3LNL5Qu2TD8+NVtz6VLtASjk5Pq3of/HAer5+f/8/oQBQJJkTEQpPcaSgaCXVlk\nRApL+QD4/nrY9+S217INTOuoTN/+M5176Bsw/aPws4mt1ZvLnoL6zTDz112Tl15GjcWSorEljxJB\nV/3higCseyP1eOAY2JKjO2k23hIsbJNyLsPv8yu3tO43bA2rppLfdgqzu6lKBJKivqm1z/bIgeE3\n/zYlgoruy5D0feljUTpb0mxphEe/k3ouPRDUvZ96/OdPBdvkNCrRqqSGdhbR6UMUCCTFyb9snaSr\nKNtkc9kWJBHpjMM/H3RFHhiucdxc37n7NDfCrBlpJyMf7H+/Ef57n9TLa16D+XfBX7+Qmv4fL8JP\nRsPSpzqXl15GgUB2mrV8I1sbWvtiF2dbfjLbPPQinVFUBF9/HS66PzjubCeElgzv80Trt/wn/j3z\n+3YGAVrTLn0y2K6c3bm89DKxBQIzqzCzWWY238wWmdm1GdKYmd1oZkvNbIGZTc50L+ker69KHRyU\nLBG0CQfJhUeO1lzv0oWS3Ug7UyIYfVhQIkjnCbi2Gq4bmeeNwkCwY1OwTTQHq6PNOKbjeepF4iwR\nNADHufvBwCHASWY2NS3NycDe4WsacHOM+ZEcHlu0hodfW5NyLuvKsM0NMLoWTrgm7mxJISnrn/3a\nR7+R+739hkLzjrbnk6ONM13LpLkh+OCfHU6e99zPgu2qV4NeRZnWRsjEPZhBdePy/NL3sNgCgQeS\nrS2l4Su9Sf4M4PYw7UtAtZmNiitPkt0X/jiHWe+kDiTb2W6Wvh5BS6Oqh6TrlVa27vePfIM/9Xo4\nLku1DsDX5gcr5KWvctYZ23MMpvzVYfmXDDYsgxd+AXf9667nqRvE2kZgZsVmNg9YCzzh7i+nJRkN\nrIgcrwzPpd9nmpnNNrPZ69atiy2/ksqzdaVr3KaeQ9L1zGDPY+HM38KlT7aeP/xz2RdHgmAhnD2P\n6Zo85OqxtG0drFkAbzyaPc1O4d9OxnYLb6162k3EGgjcvcXdDwHGAFPMLH05q0z/u20+fdx9hrvX\nunttTU1NDDmVjLJ1qd66FvqP6NasSIG48D446ByoHgtXrYZvLcvvfRXVXfPz6/OYRO/Oc/P4IM/Q\nHTVp5v/AT8fDpnc7mrvYdEuvIXffBDwLnJR2aSUwNnI8Bsgyx6x0t501Q9GTiQRsfR8GKBBIzMr6\nQdWwPNNWdc3PzFU1FNW4Deb8AV64oe215sZgjiQg47epJQ8F200r2l7rIXH2Gqoxs+pwvxI4AViS\nlux+4MKw99BUYLO7r44rT9IxGVcp27ExKD73z7cXhkiMxn0k2OZqaO6IHXkGgu0b4IGvwpM/aD2X\n7LX0x3+G33w02E8f0LZtA7z7YrCfq7qrm8VZIhgFPGNmC4BXCNoIHjSzy8zssjDNw8DbwFLgt8AX\nY8yPdNBZk4MBPhYdWFYX9ixSiUC625FfhqoaOPjTwfF+p8G/3h3sl/Xb9fuXDch/ttLkBz0E8xNd\nMwiuqwm2//h767UP3gnW9k567r8iN8kQCFbNg998LHVU83M/D+6bPn1GF4qz19ACdz/U3Q9y9wPc\n/Yfh+enuPj3cd3f/krvv5e4HunthjN7oJaZOGAJAUXH4C/vIlTD9qGB/gDp3STf7px/Dt5YGC9kA\nDP0QlIclgWjV0CWPwKCxbd/fnlzTYefy2HdzX7/lhNb96AI9luHj94l/D2ZAXRHpV/N0uIqb59l1\ntRM0sliySoTrFluyCPtyZJjHAFUNSQ9JfoBGv72XRgJBzX65G49P+Xnm8xU5psPeFds3tO4XR+b5\nzFQ1ZOFaDfdcGmlMTvZAim+2XwUCySrREvwCFhVbMB9L1KBxPZAjETIHgpLIjLhl/XPXv4+pzXy+\nq3oe5VIUWbs5UxtcctGeHRuDEnhUjNO+KxBIGzUDgsFiycZi27g0dT4WCOaHEekJOwNBlv7NJWW5\nA0FVli7oQ/fatXzl0rA1GGVcFCkR/OMFeOYnqekssnrbhqXw6yNbj/Md1dwJWo9AUpzw4eHcdP5k\nWhLOqteCIm1Ry7bURIMndH/GRJIylQjaJmp7asQB8P7CtoGgqCSYiiLOQPCLSVC/CQ79TOu5p34Y\nbPc5ERbeC3ufCG891np9fdo6DS9cDydeF0v29LVOaEm0frNKOFSUFlNVXoKH54vq07rUXfxwd2ZP\nJNXOtQNyBIJksDj4/KCL6fij4IJ74fy7g+lR/iUyXXVylHw+VUMHntOpLFO/KdiuSJ9cAfjtcTDz\nJrj9k7nv8eKvOvez86ASgdAUWZUsESlu72wsjgaCQeNgUJtZQES6UR6riSUDQe1nYezhrecHnBhs\nDz4X/jot2C8ph8atqXMdZTNlGix5MDJgrIPWv9m598VMJQLh2gde37kfrXZNNhZbc+SXviXDVL8i\n3ak67Bo6eGL2NHlVH4WKwwkU85k/a/Rk+O5KmPIFOOeP7afvJRQIhNfe27RzP1oi2Fk11BKZwlcL\n1ktP2+9UuPB+OOKy1POffxo+G9ax51N9lHTij4LBZCMPbD9tUXHwOuVnsP8n4fBLU69XVMNJP4V+\neU6NkUmm8QUxUyAQ+pe31hBGSwQZA0GMfZlF8rbnx9v2XBt9GIwLlzzZ+WGax2L0+50GV63MPK/R\nN5fCUVdkf++ISanHRcUw9bJdG+n8uSfbT9PFFAiEuvpmykqCX4VMbQQpgaBhS7fmTaRTOlQ1FPbt\nL80wcV3/GvhEuLhick3lqKIszazR8wed234eovoN7lj6LqBAINTVNzO0KhiQU1Ha2o95ZxtBy/bW\naacz/bGI7G5q9gu2FYPaT5scxJUclDYkQzfSry+BL87M8N7S1OPkF6loIKhM+2A/63e589MdA9vS\nqNdQAWtJONP/tozVm3fwqdqx1PQv54Kp43deT9StB6CoZTuU9oML7oHqCT2UW5EOOOkn8OHTc9f7\nD9kzdUI4gM8+Hp5flvqBPDDL3Fr9swxOSwaCSWfC8T+Al6e3XhuXvmJvmlzBa8cmqKzOfr2TFAgK\n2CMLV/NfjwWDVkYNrOArx++dct1f/i3wKaxhS9C17kMnZLiLyG6opBz2OjZ3ms890XZN4XFHBNts\nH/Dp9joehk+CtYuC42QjdXKE8NH/FrQXfOav8NA3YPKF5FgNPFBUnP3am48FXV+7mKqGCtj6utZl\n9A4d17ZeMtEcNAwXLX1US1NK31M1LHWMQWeYwaR/bj2edGawTX6YJ5qD7V7HwVdfDQJDtp53ow8L\nglPUFQvho99oPY6p154CQQGb8Vxrsbi6X2mb6+7Br4eRgFWvdlu+RHqVZIP0lGlw8k+D/Z2BIMP8\nQNXjYf8z2p6vqoGxU9LSjg2m205KrgfSxVQ1VMAsMjFXtJE4KUExRktQ2s2n0U2kECUDQeWQ1gAw\n4aPw3hyoGto2vRmcc3vr8VtPwh1nZR+sGV1rIaaSuQJBgXps0Rre29TaLbRfWVog2PEBTjFFhN9o\nLri3G3Mn0oskewpFZzw9/vtwyPlBw3N7kmsURMfoXPp0a6PwfqfDyT8L2hfymQajE1Q1VKD+/PK7\nKcdtAsGfzyXhRZiFgSCfUZcihejAs4Ntsn0AgpJBzb75vb847LYaLRGMOax1NtSiIjjiC7EFAYh3\n8fqxZvaMmS02s0Vm9rUMaY4xs81mNi98fT+u/EiqA0anrsbUpmpoxcu0UEpxskQQXfhDRFrV7AvX\nbIaafTr3/j0mB1VJp/xX+2ljEmfVUDPwDXefa2YDgDlm9oS7v56W7nl3Py3GfEgGzQmnrLiIxnDm\n0fKStt8JmryCEqvv7qyJFJbSCrj4wR7NQpyL169297nhfh2wGND8xbuB9Vsb2LC1karyYsqKw55B\n0frNsM6z2csptYZMtxCRPqRbGovNbAJwKJBhVQaONLP5wCrgm+6+KMP7pwHTAMaN01q5nfHLJ9/i\n6L2Hcdj4wdReF0xqNWZwJfd/+WiWrtuamnhbMKK4adghlG4jWDlJRPqs2BuLzaw/cA9whbunz1g2\nFxjv7gcDvwLuy3QPd5/h7rXuXltTk+eIP0nxiyff5KybX0w519SSYOyQfhy77/DWk4kEvPkoAM1U\nUjJ0NPzr3d2ZVRHpZrEGAjMrJQgCd7h7m/6H7r7F3beG+w8DpWa2CxN5SyYemVH06/9v3s79qXtm\n6OO85AG4/8u4w3sri1OrjESkT4qz15ABvwMWu/v1WdKMDNNhZlPC/GyIK0+FqqmlNRDc++p7O/e/\neeK+wSRWjdta+zA31AGwvnkCAKve2tRNuRSRnhJnG8FRwGeA18xsXnjuKmAcgLtPB84GLjezZmAH\ncJ5Hv75Kl4iuSRzVr6wYfto62yj9R0L9ZgBaPFi+b9z+Q2LPn4j0rHYDQdj//1agDriFoNH3Snd/\nPNf73P0F2plmz91vAm7KO7fSKUvW1GU8X1We9t+/tXUekxWNBwFw2CkT4sqWiOwm8qka+mzYyHsi\nUANcAvxnrLmSLpXeSFxOMIKxvDh7nJ619XwASstzTIkrIn1CPlVDyU+LU4Bb3X2+qQWxVzrIlvHL\n0puYWPR+cGLDLABavJiGRH/+d8N/84nqX1BC61D30vSpJ0Skz8knEMwxs8eBicB3w1HCeSwEKj1p\ne30Di//3Gg6tWMU9ZYuYn9iLC4ufAgx3eL9pX6p/dRzlBreuvY0G7w/AfRuvS7lPaYUCgUhfl08g\n+BxwCPC2u283syEE1UO9yub3V7PuxeepOmQ8VlFGS6IREi2UWjGeaMY8QWNzPS2JZooqh+DejDmU\nllRgJeUkPEEi0QI4JRsHM3BAC6VlztKF9Qzdo5J+A0up39pIv6oiEs3NNOxIUFJWRHHlAOq3bKWx\nPgFFRnllCc0NzZSVGzu2NlFSatRva2ZAVQObNrRQWj2cbQ2VWMt2hg1uYOPbKykbNIRBowaxo66J\n0paNlNUt5a26wyltXkdFWRNDi5bz0mvjGD/RSVgJzctmMvcfk5hUuZG/Nk5hTdNFDCxewy0tl1Be\ntBXD2ZbI0HU0gxKVCET6vHwCwZHAPHffZmYXAJOBX8abra734M9vZFPdJyh6+APKGjdR3rCR4pZ6\nyho2Yd5CS0kl5Q0b2TJwb6q2Pk3dgIk0lkBLSRXNJeVQPJxmX0WxjcTYFGNO30k7LgG2hK+kcUBY\nvUMxECxc8fbOdWaOBmD2ttYl7ba0jARge6JjvYDKVCIQ6fPyCQQ3Aweb2cHAtwnGBtwOfDzOjHW1\nPU76GNv/71mK6/vhpQPZVjaABCNIFO1H0GZevHOd0c3VmaeP7ddYRWN5a/v6gLJt1DUGi0YUFTmJ\nhLHP3jsoKoZhwxIkWhIUJerZXl/BkCEtQIJ160qpHpzAzGhsKiLhRiJhVPaDhh0tbF7XQN22UoZW\n11PRv5SBg4xNm4spbtxISVECmnfwzuohVJbWc0DZ/9E05mPMXXUEW9bVMaJ8OUNH9aNu+VsMK1nO\nguaDWMAYHi4dwid2QJmXwuQhjJu1hZqSZRzQ7xHmln+L4RMGMnR0FQceM4aWpgTrVtTxwI3zmXL6\nRA0oEykA1l63fTOb6+6Twymi33P33yXPdU8WU9XW1vrs2bO79J6JhLN5RyMfbNhBg8G819ezo8Ko\nb0zw9gfbuHf2SipamhizdQ3VzQlOf+sJatcGi743FpezdfhYFh/3zzw5/MPsaHbu/PwR9Cvr2TV/\nrn/8DW58emnrCYerT/0wn//Ynmzf0kjZ8ocpadoItZ/N+P5Vb33AsLEDKKvQ2kUifYGZzXH32ozX\n8ggEfwMeBT4LfBRYR1BV1CMrlcQRCNrTknBaEs5ba+sYMbCCJavr+PP9LzHmlWf5l3kP7Ux36Qnf\n4b3+NRQZXHnyfhz1oWFM2qP7l3h8dOEaLvvTnJ3HPzh9f6594HWuP+dgzpw8ptvzIyI9b1cDwUjg\nfOAVd3/ezMYBx7j77TnfGJOeCATZuDsbH3iQtd/+9s5zz518ET8pT42Rv72wljffr2PqnkM4bHz8\nI3UnXPlQyvHb/3EKj7++hn+aNFJVPSIFapcCQXiDEcDh4eEsd1/bhfnrkN0pEEStn/4b1t1wAwAV\nTzzP3YvW8+tnl7VJ9++n7c9nj5qQ9wdy09q11L/+Ov1qaynu3z9n2rnvfsCm7Y189rbg3+dnZx9E\ndWUpJ04a2bGHEZE+J1cgaHdksZmdA8wCPgWcA7xsZmd3bRZ7v2GXfYGSESMASFx+MZe8+TivfryY\nfk07UtL96MHXWby6Luv8P+l2zH2VlZddTtOqVTnTvb1uK2f++sWdQeD7p+3PObVjFQREpF35tARe\nDRyeLAWYWQ3wJPCXODPWG4279VbePuUUGpcuY/3Sm4FgDu7Ztz1KSZHx3XtfA+CUG58H4IdnTOLC\nIyfkvKc3BbOCWmlpxuv1TS289t5mPjV9Zsr5EQMrduFJRKSQ5BMIitKqgjbQDQva9EZlYzKvxHnB\nYXuQaGjgtINO5OZnl+2sMpq5bAOHjK3mxWUbuOSoCZSXtO2z3xoIUheP/+1zbzOwsoTv3PNaxp85\nclD5rjyKiBSQfALBo2b2GHBneHwu8HB8Weq9rKyMDy9ZjDc20vjuu7x92ukALDkwmMnzw0sWc8UJ\n+7BHdSXfu28hjyxcwyMLgxk/y0uKqB0/hAPHpPYySi8RbN7RxIMLVvHjhxfnzEtNf5UIRCQ/7QYC\nd/+WmZ1FsL6AATPc/a+x56wXs7IyigcPbnN+0z33UDpmLBdMncL37luYcu3aB14HoKy4iJsvmMxz\nb67jgqnjGZoMBGVBIPjPR5Zw56x3s/7sGz99KH+Zs5I9qhUIRCQ/efUa2p3srr2G0nlzM0sOyDzU\nYsRVV/HWAUey7fe3sLD2BK5f0pD1Pv+y9G9MW/gAZ536I2ZedwYHX5t9GYhRgyqY+d3jdznvItL3\n5Oo1lLVEYGZ1QKYoYYC7+8Auyl+fZCUljLv9D2x74e9YSTHrf33zzmvv/8d/MBAYCIxdMJMpt95D\nU0Mjn7n91Tb3KUm0ANBcVMIzS7L32r3h3EM4dr/hWa+LiGSTtdHX3Qe4+8AMrwEKAvmpmjKF4V//\nN4qHDcuapnntWibe/TuGnXk850wayj4fvMsj932Tm565nsPeX0JpojlIV1TMFZGF52vHp1Y9HTRm\nEIMqM/csEhHJJc7F68ea2TNmttjMFoVLXqanMTO70cyWmtkCM+uR+YviNvicc6j5+tezXt94220A\nXFW1ml/+7UYA9tq8iutm3sKAxu20YCQs9b/qL5d/BIAPDe/P9AsOY8+a3IPNRESyiXNGsWbgG+4+\nN1zMZo6ZPeHur0fSnAzsHb6OIJjp9IgY89QjrKSEYdM+T+moUdQ98QSNK1bQsLhtr5/VV13V5tx5\nk4aw+Z3UbqW3XhwM8n75quMZWFFKpdYMEJFdEFsgcPfVwOpwv87MFgOjgWggOAO43YMW65fMrNrM\nRoXv7XMGnX4ag04/jYZly3j71NPyek/9A/fTXNLaA6hmQPnOtgANGhORrpCzasjMis3syV39IWY2\nATgUeDnt0mhgReR4ZXgu/f3TzGy2mc1et27drmanx5XvtRcT7v5f+h1xBIPPP7/d9M1FwTf+V64+\ngee+dWzc2RORApOzRODuLWa23cwGufvmzvwAM+tPMNPCFe6+Jf1yph+bIR8zgBkQdB/tTD52N5UH\nHsj4P9wGwLaZM2lcvjxr2v79K/nthbXUDNBoYRHpevlUDdUDr5nZE8C25El3/2p7bzSzUoIgcIe7\n35shyUpgbOR4DJB7drU+aK9HHuaNKUeQ2BLESSst3TmiGKC8rIRP7D+ip7InIn1cPr2GHgL+HXgO\nmBN55WTBPMu/Axa7+/VZkt0PXBj2HpoKbO6r7QPtKaqsBGDo5y9l75kvplwr33efnsiSiBSIfKaY\n+IOZVQLj3P2NDtz7KOAzBKWJeeG5qwhWXsfdpxPMWXQKsBTYDlzSgfv3KckSQMnwEW3WHRh83nk9\nkSURKRDtBgIzOx34OVAGTDSzQ4Afuvsnc73P3V8gcxtANI0DX8o7t31Yy8aNAJTvtScAE++9h+Vn\nnsXwK7/DgGPVQCwi8cmnjeAaYArwLIC7zzOziTHmqSBVn3sudU89ReWhhwJQsf/+7Dt/HkXlaiAW\nkXjl00bQnKHHUJ/oubM7GXXtNezzwvM72woABQER6Rb5lAgWmtn5QLGZ7Q18FXixnfeIiEgvkU+J\n4CvAJKCBYHGaLcAVMeZJRES6UT69hrYTrFt8dfzZERGR7pZPr6F9gG8CE6Lp3f24+LIlIiLdJZ82\ngruB6cAtQEu82RERke6WTyBodveb208mIiK9UT6NxQ+Y2RfNbJSZDUm+Ys+ZiIh0i3xKBBeF229F\nzjmwZ9dnR0REuls+vYY0ilhEpA/La4UyM/sIbXsN3R5TnkREpBvl0330j8BewDxaew05oEAgItIH\n5FMiqAX2D2cKFRGRPiafXkMLgZFxZ0RERHpGPiWCYcDrZjaLYL4hANpbj0BERHqHfNcjEBGRPiqf\n7qN/646MiIhIz8jaRmBmL4TbOjPbEnnVmdmW9m5sZr83s7VmtjDL9WPMbLOZzQtf3+/8Y4iISGdl\nLRG4+9HhdkAn730bcBO5u5k+7+6ndfL+IiLSBfLpNdQp7v4csDGu+4uISNeILRDk6Ugzm29mj5jZ\npGyJzGyamc02s9nr1q3rzvyJiPR5PRkI5gLj3f1g4FfAfdkSuvsMd69199qampruyp+ISEHosUDg\n7lvcfWu4/zBQambDeio/IiKFqscCgZmNNDML96eEednQU/kRESlUec0+2hlmdidwDDDMzFYCPwBK\nAdx9OnA2cLmZNQM7gPM0n5GISPeLLRC4+6fbuX4TQfdSERHpQT3da0hERHqYAoGISIFTIBARKXAK\nBCIiBU6BQESkwCkQiIgUOAUCEZECp0AgIlLgFAhERAqcAoGISIFTIBARKXAKBCIiBU6BQESkwCkQ\niIgUOAUCEZECp0AgIlLgFAhERAqcAoGISIGLLRCY2e/NbK2ZLcxy3czsRjNbamYLzGxyXHkREZHs\n4iwR3AaclOP6ycDe4WsacHOMeRERkSxiCwTu/hywMUeSM4DbPfASUG1mo+LKj4iIZNaTbQSjgRWR\n45XhuTbMbJqZzTaz2evWreuWzImIFIqeDASW4ZxnSujuM9y91t1ra2pqYs6WiEhh6clAsBIYGzke\nA6zqobyIiBSsngwE9wMXhr2HpgKb3X11D+ZHRKQglcR1YzO7EzgGGGZmK4EfAKUA7j4deBg4BVgK\nbAcuiSsvIiKSXWyBwN0/3c51B74U188XEZH8aGSxiEiBUyAQESlwCgQiIgVOgUBEpMApEIiIFDgF\nAhGRAqdAICJS4BQIREQKnAKBiEiBUyAQESlwCgQiIgVOgUBEpMApEIiIFDgFAhGRAqdAICJS4BQI\nREQKnAKBiEiBUyAQESlwsQYCMzvJzN4ws6VmdmWG68eY2WYzmxe+vh9nfkREpK04F68vBv4H+ASw\nEnjFzO5399fTkj7v7qfFlQ8REcktzhLBFGCpu7/t7o3AXcAZMf48ERHphDgDwWhgReR4ZXgu3ZFm\nNt/MHjGzSTHmR0REMoitagiwDOc87XguMN7dt5rZKcB9wN5tbmQ2DZgGMG7cuC7OpohIYYuzRLAS\nGBs5HgOsiiZw9y3uvjXcfxgoNbNh6Tdy9xnuXuvutTU1NTFmWUSk8MQZCF4B9jaziWZWBpwH3B9N\nYGYjzczC/SlhfjbEmCcREUkTW9WQuzeb2ZeBx4Bi4PfuvsjMLguvTwfOBi43s2ZgB3Ceu6dXH4mI\nSIyst33u1tbW+uzZs3s6GyIivYqZzXH32kzXNLJYRKTAKRCIiBQ4BQIRkQKnQCAiUuAUCERECpwC\ngYhIgVMgEBEpcAoEIiIFToFARKTAKRCIiBQ4BQIRkQKnQCAiUuAUCERECpwCgYhIgVMgEBEpcAoE\nIiIFToFARKTAKRCIiBQ4BQIRkQIXayAws5PM7A0zW2pmV2a4bmZ2Y3h9gZlNjjM/IiLSVmyBwMyK\ngf8BTgb2Bz5tZvunJTsZ2Dt8TQNujis/IiKSWZwlginAUnd/290bgbuAM9LSnAHc7oGXgGozGxVj\nnkREJE1JjPceDayIHK8EjsgjzWhgdTSRmU0jKDEAbDWzNzqZp2HA+k6+t7fSMxcGPXNh2JVnHp/t\nQpyBwDKc806kwd1nADN2OUNms929dlfv05vomQuDnrkwxPXMcVYNrQTGRo7HAKs6kUZERGIUZyB4\nBdjbzCaaWRlwHnB/Wpr7gQvD3kNTgc3uvjr9RiIiEp/YqobcvdnMvgw8BhQDv3f3RWZ2WXh9OvAw\ncAqwFNgOXBJXfkK7XL3UC+mZC4OeuTDE8szm3qZKXkRECohGFouIFDgFAhGRAlcwgaC96S56KzMb\na2bPmNliM1tkZl8Lzw8xsyfM7K1wOzjynu+G/w5vmNk/9VzuO8/Mis3sVTN7MDzu689bbWZ/MbMl\n4f/1kQXwzP8W/k4vNLM7zayirz2zmf3ezNaa2cLIuQ4/o5kdZmavhdduNLNMXfOzc/c+/yJorF4G\n7AmUAfOB/Xs6X130bKOAyeH+AOBNgik9fgZcGZ6/EvhpuL9/+PzlwMTw36W4p5+jE8/9deDPwIPh\ncV9/3j8Al4b7ZUB1X35mgoGly4HK8Ph/gYv72jMDHwMmAwsj5zr8jMAs4EiCsVmPACd3JB+FUiLI\nZ7qLXsndV7v73HC/DlhM8Ed0BsGHB+H2n8P9M4C73L3B3ZcT9Nia0q2Z3kVmNgY4FbglcrovP+9A\ngg+M3wG4e6O7b6IPP3OoBKg0sxKgH8EYoz71zO7+HLAx7XSHnjGclmegu8/0ICrcHnlPXgolEGSb\nyqJPMbMJwKHAy8AID8dkhNvhYbK+8G9xA/BtIBE515efd09gHXBrWB12i5lV0Yef2d3fA34OvEsw\n5cxmd3+cPvzMER19xtHhfvr5vBVKIMhrKovezMz6A/cAV7j7llxJM5zrNf8WZnYasNbd5+T7lgzn\nes3zhkoIqg9udvdDgW0EVQbZ9PpnDuvFzyCoAtkDqDKzC3K9JcO5XvXMecj2jLv87IUSCPr0VBZm\nVkoQBO5w93vD0+8nZ3INt2vD87393+Io4JNm9g5BFd9xZvYn+u7zQvAMK9395fD4LwSBoS8/8wnA\ncndf5+5NwL3AR+jbz5zU0WdcGe6nn89boQSCfKa76JXC3gG/Axa7+/WRS/cDF4X7FwH/Fzl/npmV\nm9lEgrUgZnVXfneVu3/X3ce4+wSC/8en3f0C+ujzArj7GmCFme0bnjoeeJ0+/MwEVUJTzaxf+Dt+\nPEH7V19+5qQOPWNYfVRnZlPDf6sLI+/JT0+3mndj6/wpBD1qlgFX93R+uvC5jiYoBi4A5oWvU4Ch\nwFPAW+F2SOQ9V4f/Dm/Qwd4Fu9MLOIbWXkN9+nmBQ4DZ4f/zfcDgAnjma4ElwELgjwS9ZfrUMwN3\nErSBNBF8s/9cZ54RqA3/nZYBNxHOGpHvS1NMiIgUuEKpGhIRkSwUCERECpwCgYhIgVMgEBEpcAoE\nIiIFToFACoaZvdjTeQAwsyvMrF9P50MkSd1HRbqImZW4e3Me6d4Bat19ffy5EmmfSgRSMMxsa7g9\nxsyejczvf0dy/nYze8fMrjWzueH87vuF56vCueNfCSd+OyM8f7GZ3W1mDwCPp/28KjN7yMzmh3Pq\nn2tmXyWYO+cZM3smTHeimc0Mf+bd4bxRybz81Mxmha8Pdds/lhQUBQIpVIcCVxDM8b4nwRxGSevd\nfTJwM/DN8NzVBNNZHA4cC/xXOAMoBPPAX+Tux6X9jJOAVe5+sLsfADzq7jcSzANzrLsfa2bDgO8B\nJ4Q/czbBWgtJW9x9CsFo0Ru64LlF2lAgkEI1y91XunuCYFqOCZFryYn75kTOnwhcaWbzgGeBCmBc\neO0Jd0+fUx7gNeCE8Fv9R919c4Y0UwmC0d/De18EjI9cvzOyPTLPZxPpkJKezoBID2mI7LeQ+rfQ\nkOG8AWe5+xvRm5jZEQTTQrfh7m+a2WEEcz/9xMwed/cfpiUzgkDy6Sz59Cz7Il1GJQKR/DwGfCXS\nlnBoe28wsz2A7e7+J4JFViaHl+oIlhUFeAk4Kln/H862uU/kNudGtjN3+SlEMlCJQCQ/PyKoo18Q\nBoN3gNPaec+BBG0JCYLZJS8Pz88AHjGz1WE7wcXAnWZWHl7/HsFMuQDlZvYywZe2bKUGkV2i7qMi\nuyl1M5XuoqohEZECpxKBiEiBU4lARKTAKRCIiBQ4BQIRkQKnQCAiUuAUCERECtz/B9r3QV+rnddD\nAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lopt = LOpt()\n", "for i in range(5):\n", " losses = []\n", " key = jax.random.PRNGKey(i)\n", " meta_params = lopt.init_meta_params(key)\n", "\n", " key = jax.random.PRNGKey(0)\n", " params = task.init(key)\n", " opt_state = lopt.initial_inner_opt_state(meta_params, params)\n", "\n", " num_steps = int(os.environ.get(\"LOPT_TRAIN_LENGTH\", 1000))\n", " for i in range(num_steps):\n", " batch = next(data_iterator)\n", " loss, grads = value_grad_fn(opt_state[0], batch)\n", " opt_state = lopt.update_inner_opt_state(meta_params, opt_state, grads)\n", " losses.append(loss)\n", " plt.plot(losses)\n", "plt.ylim(0, 4)\n", "plt.xlabel(\"inner step\")\n", "plt.ylabel(\"inner loss\")" ] }, { "cell_type": "markdown", "metadata": { "id": "FH1b4PfrDFcH" }, "source": [ "## Meta-loss: Measuring the performance of the learned optimizer.\n", "\n", "Now we must define our measurement of performance for our learned optimizers. For this, we will define a meta_loss function. This function takes in as inputs the weights of the meta-parameters, initializes the weights of the inner-problem, and performs some number of steps of inner-training using a learned optimizer and the passed in meta-parameters. Each step we return the training loss, and use this average loss as the meta-loss. Depending on what we use, e.g. different unroll lengths, or different objectives (such as returning just loss at the end of training, or validation loss) we will get different behaving optimizers." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "executionInfo": { "elapsed": 699, "status": "ok", "timestamp": 1647716634001, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "D_7V3TQHD5Ju", "outputId": "73b63800-2b6b-4f4c-f507-1b1ef8105b4b" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray(2.2990375, dtype=float32)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lopt = LOpt()\n", "\n", "\n", "def get_batch_seq(seq_len):\n", " batches = [next(data_iterator) for _ in range(seq_len)]\n", " # stack the data to add a leading dim.\n", " return {\n", " \"image\": jnp.asarray([b[\"image\"] for b in batches]),\n", " \"label\": jnp.asarray([b[\"label\"] for b in batches])\n", " }\n", "\n", "\n", "@jax.jit\n", "def meta_loss(meta_params, key, sequence_of_batches):\n", "\n", " def step(opt_state, batch):\n", " loss, grads = value_grad_fn(opt_state[0], batch)\n", " opt_state = lopt.update_inner_opt_state(meta_params, opt_state, grads)\n", " return opt_state, loss\n", "\n", " params = task.init(key)\n", " opt_state = lopt.initial_inner_opt_state(meta_params, params)\n", " # Iterate N times where N is the number of batches in sequence_of_batches\n", " opt_state, losses = jax.lax.scan(step, opt_state, sequence_of_batches)\n", "\n", " return jnp.mean(losses)\n", "\n", "\n", "key = jax.random.PRNGKey(0)\n", "meta_loss(meta_params, key, get_batch_seq(10))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "executionInfo": { "elapsed": 2, "status": "ok", "timestamp": 1647716634110, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "NGHImLJ9FcjO" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "EFeC6wiTtYPu" }, "source": [ "## Meta-training with Gradients\n", "Meta-training means training the weights of the learned optimizer to perform well in some setting. There are a lot of ways to do this optimization problem. We will run through a few different examples here.\n", "\n", "\n", "One of the most conceptually simple way to meta-train is to do so with gradients. In particular, the gradients of the meta-loss with respect to the meta-parameters.\n", "\n", "Te will use our meta-loss and `jax.value_and_grad` to compute gradients. For this simple example, we will use the average training loss over 10 applications of the learned optimizer as our meta-loss." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "executionInfo": { "elapsed": 2555, "status": "ok", "timestamp": 1647716636790, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "hQtTvMmnFAFb" }, "outputs": [], "source": [ "key = jax.random.PRNGKey(0)\n", "meta_value_grad_fn = jax.jit(jax.value_and_grad(meta_loss))\n", "loss, meta_grad = meta_value_grad_fn(meta_params, key, get_batch_seq(10))" ] }, { "cell_type": "markdown", "metadata": { "id": "rgY-bMy1I-p_" }, "source": [ "We can use this meta-gradient, with Adam to update the weights of our learned optimizer." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "executionInfo": { "elapsed": 2078, "status": "ok", "timestamp": 1647716638983, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "Bh0v3tZcFSeJ", "outputId": "8a214403-0cb8-446a-953d-e13e9f337217" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.2990098\n", "2.2865815\n", "2.263659\n", "2.2415566\n", "2.2126753\n", "2.1844268\n", "2.1526406\n", "2.1231513\n", "2.0907454\n", "2.0569248\n", "2.0191243\n", "2.0076103\n", "1.9771802\n", "1.9669901\n", "1.9574435\n" ] } ], "source": [ "meta_opt = Adam(0.001)\n", "key = jax.random.PRNGKey(0)\n", "meta_params = lopt.init_meta_params(key)\n", "meta_opt_state = meta_opt.init(meta_params)\n", "meta_losses = []\n", "\n", "num_steps = int(os.environ.get(\"LOPT_TRAIN_LENGTH\", 300))\n", "for i in range(num_steps):\n", " data = get_batch_seq(10)\n", " key1, key = jax.random.split(key)\n", " loss, meta_grad = meta_value_grad_fn(meta_opt_state[0], key1, data)\n", " meta_losses.append(loss)\n", " if i % 20 == 0:\n", " print(onp.mean(meta_losses[-20:]))\n", " meta_opt_state = meta_opt.update(meta_opt_state, meta_grad)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "height": 296 }, "executionInfo": { "elapsed": 147, "status": "ok", "timestamp": 1647716639240, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "r8-fAIqGI6X2", "outputId": "94ea45b3-7199-4ef8-97a4-13cf5d0e6cd3" }, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'meta-loss')" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAABLuElEQVR4nO2dd5hcV3n/P2f6zPYqrVZaq8uSmyTLxsbYuGGMCZgWMCSGGPiZ\nJE6AQDDESRxagBSKIQRwDNg4hG66Ke7doGIVq1mW1bXS7mr7zk4/vz/uPXfulN2dlXa2aN/P8+jx\nlDt3ztVY53vfrrTWCIIgCLMXz1QvQBAEQZhaRAgEQRBmOSIEgiAIsxwRAkEQhFmOCIEgCMIsxzfV\nCxgvjY2NeuHChVO9DEEQhBnFxo0bu7TWTcXem3FCsHDhQjZs2DDVyxAEQZhRKKUOjPSeuIYEQRBm\nOSIEgiAIsxwRAkEQhFmOCIEgCMIsR4RAEARhliNCIAiCMMspmxAopRYopR5RSu1USm1XSn2gyDHX\nK6W2KqU2K6U2KKVeUa71CIIgCMUpp0WQAj6stV4JXATcopRalXfMQ8B5WuvVwLuBu8q4HvpjSX76\n3GFS6Uw5v0YQBGFGUbaCMq11O9BuPx5QSu0EWoEdrmMGXR+pAMo2HEFrzYd/uIUHdhwH4HXnzmPj\ngR4uWFiPx6PK9bWCIAjTnkmJESilFgJrgD8Uee+NSqldwK+xrIJin7/Zdh1t6OzsPKk1/HjjYR7Y\ncZygz8PdTx/giw++wNvufJbvrz90UucTBEE4XVDlnlCmlKoEHgP+VWt93yjHXQbcrrW+erTzrVu3\nTp9Mi4loIsUP1h9CAR//pWWUeD2K5qogn37D2VyxolksA0EQTluUUhu11uuKvVfWXkNKKT/wE+C7\no4kAgNb6caXUEqVUo9a6a6LXEgn4uOmSRaQzmkjAx/4TQ5x/Rh1/9d1NvOeeDdx67Qr++vKlE/21\ngiAI056yCYFSSgHfBHZqrb8wwjFLgb1aa62UWgsEgBPlWhNYVsBbL1jgPP/jbVdx64+3cseDe2is\nDHLt2XOpDvnLuQRBEIRpRTljBJcANwJX2umhm5VS1yml/lIp9Zf2MW8GnldKbQa+CrxNl9tXlUdt\nJMCn3nA2i5squfXHWzn/Uw/wyy1HJ3MJgiAIU0rZYwQTzcnGCMZCa83GAz187je72Hyol3DAy5ff\nvoYrVjRP+HcJgiBMNqPFCKSy2EYpxbqF9dz1rnXccOECIgEvX3zgBWaaUAqCIIwXEYI8aiMBPv2G\nc/jg1cvZeriPS//9EZ56sYuO/hiZjIiCIAinHzNuQtlk8aa1rexs7+f+bce47afbONIzzKvPnstX\nblgjaaaCIJxWiEUwAkGfl09efzYfuGopB05ECfg8/HprO5/81Q7e/LWn+e3z7QAMxlNc+flH+daT\n+6Z4xYIgCCeHWARj8KfrFrD5UB9vv3AB3/vjIe5+ej8AiVSGa89u4d9+s4uXOof46iMv8u5XLJra\nxQqCIJwEIgRjEPJ7+fxbzwNgZUs1sVSaVDrD77Yf5+9/tIUfbzwMwHAyzZcefIHVC2q5XDKNBEGY\nQUj66EnQPZTgos8+RCKV4S3nz+fixQ18+EdbAFjVUs39H7h0StcnCIKQz5S1mDhdqa8IcP/7LyXg\n9dDWEOHgiajz3o72fnYfG2DF3Kqin+2NJsho6xyCIAjTAQkWnyRLmytpa4gAsKA+TFt9hNedNw+v\nR/G+ezfwm23tBZ9JZzQ33Pks77s316J5/IVO3nvPBr704AuTsnZBEAQ3YhFMAEopfvfBy/B7FZcu\na+SLD7zAVx5+kdec0+Ics35/N7/e2s6uYwN4FDy86zgepbh8RTP/9fCL/HF/N4+90MH7r1wm6amC\nIEwqIgQTRDjgBeCt6xbQG03wmft3caR3mN3H+okEfHzg+89xvD9OdchHfyzFe+/ZQCTg44lbr2Dr\nkV7Cfi/DyTRH+4aZXxehN5rgaG+MVfOqp/jKBEE43REhKANXrZzDZ+7fxYM7jvNfj7zIYCzFcDLN\nbdedyfWrW7n03x8hkcowGE/x8V9uJ5bM8K6Lz+CeZw5wx4N7iAS89MdSPLjjOFv+5RqxEARBKCsi\nBGVgSVMlixsruPfZA3QOxAEIeD3ccGEb1SE/r1jayLG+GJUhHz/fbHU6fesFC7jnmQP8yE5H9XsV\nybTmSO8wC+ojU3YtgiCc/ogQlImrV83hzsdfAqClJsTaM+qcOQdfefsa0lrTPZjg8v98FICz5tVQ\nXxGgeygBQDJtpfXu7RwUIRAEoaxI1lCZuOpMq6gsEvDy8Icv5463rXbeqwj6qA75WdhYwYMfuoyf\n3XIJAEubKgG48sxm5laHALjriX3c8t1NxJLpUb9v2+E+rvzPR+mLJstwNYIgnM6IRVAmzj+jjtqI\nn5Vzq51AcjGWNmfrDV6+tIFEOsNd71xHMpPhos88xJMvWlM7GysDfOL6swE42jvMTzYe5m0XLOBQ\nT5Tzz6jnj/u7ealriH0nhlgdqS3rtQmCcHohQlAmfF4P3/jz86mNlF449sGrl/PBq5cDEPR4Wdpc\nyfr9PTRVBfnOswd4/1XLaKgM8t+Pvsj/PnuQzz9g1R089bErae8dBqAnmpj4ixEE4bRGXENl5GWL\nG0asMC6Fpc2Wq+iTrz8LrXGsgw37e3KOe7FjkPa+GAA9QyIEgiCMDxGCaczNly3hjhtWc81Zc6mN\n+HliTxcnBuPsOjbAR169gvX/eDUA+zoHOdpnLAKJEQiCMD7ENTSNWdRYwaLGCgAuWdrI4y908srl\nTQBcvKSBxsoAVUEf+7qGOGq7hnrFNSQIwjgRi2CG8KqVc+gYiPOfv99NfUWAc1prUEqxqKmCF44P\n0mHXK7hjBH3RJB/64WbJJBIEYVRECGYI15w1h8qgjwMnorxxTSt+r/XTLWqs4Nl9JzDdxHuGspv+\n+v3d3LfpCI/v6ZyKJQuCMEMQIZghRAI+Xms3sXvrugXO64saK3CPlHBbBObxnuMDgDVV7fO/3y0W\ngiAIOYgQzCA+/OrlfPUda3MykS6zYwYAS5oqcoLFRgheOD4IwOZDvXzl4Rd5YOfxSVqxIAgzARGC\nGURzVYjXntuS89ratjqe/OgVfOPG81nbVpeTPmpE4YUOyyIwAeUjPcMF59Za851n9tPeV/ieIAin\nNyIEpwHz6yK8+qy51FcEcl1DtigcOBElnkpzxBaCwz3RgnM8f6Sf23++nX/+2fZTWsutP97Cx36y\n9ZTOIQjC5CLpo6cRtZEA8VSG4USacMDriEI6o3mpc8i52z/cM0xvNME1X3ycZXMq+eo71vL0XqtY\nzXeKLa93HxsAJW2zBWEmIUJwGlEXsbqb/nzzES5cVE/PUJK6iJ+eaJI9HYMc7bWqjw/3Rtl1bICO\ngTgdA3F+vvmok1nkPUUhiCUzZNzRa0EQpj0iBKcRy+dW4fMoPnbfNue1q85s5tEXOtlzfMCJEbT3\nxtjfNeQcc7A7yvp9VtuKjoHYKa0hlkqTSosQCMJMQmIEpxFr2+rY8i/XcPdNFzivzakJcUZDhBeO\nD3Ckd5iw30sqo1m/vwevR9FYGWTDgR4S6QyAM0jnZIkl0wwlUqd0DkEQJhcRgtOMiqCPS5c1EfJb\nP21dxM/y5io2HexlIJbi/DPqAHhmbxfzakO01oXZcbQPgPMW1E6AEGSIxkefnSAIwvRChOA0xOtR\nzKsJA1AXCbB8TqWzwb9iWSMAR/titNVHmFMVdKahrVlQy1AizVDcuqPfdriPvuHc4rOvPLSHR3Z1\njPjdsWSaRDpDIpWZ8OsSBKE8iBCcpjRXBwFLCJbOyRagvf2CNla2VAPQVl/hHOfzKFbNs17/yabD\ndPTHePPXn+Ybj+11Pqu15quPvshPNh0u+p2ZjCZuC8BwQqwCQZgpiBCcpsyxR12mteaiRfWsaavl\nZ7dcQk3Ez5VnWtXIXg/MqbKOm1sToqXGenz7z7fz3u9sIJHKsOvYgHPOE0MJYsmM0+Aun7jLCpA4\ngSDMHEQITlNuuKANgNULammuDvHTv76E1QtqAXjL+VavomvPanEEo7U2TFNV0Pn81sNW3GCPXZUc\nT6U51G0VonWNIATuucpREQJBmDGULX1UKbUA+A4wF8gAd2qt78g75s+Aj9pPB4G/0lpvKdeaZhMX\nL2lg32evQxUp7lrUWOG89+huy9/fWhem2bYO3BzuGWY4kea6Lz9B0GfdN7gtgrueeIl5tWGuO6eF\nWCorBEMSMBaEGUM56whSwIe11puUUlXARqXUA1rrHa5j9gGv1Fr3KKVeA9wJvKyMa5pVFBOB/PfM\n5t9aG6Yu4ucfXnMmT+89wWMvWAVmWsPGAz3sc9UdDMZTRBMpook0n/71TgD2f+61xJLiGhKEmUjZ\nXENa63at9Sb78QCwE2jNO+ZprbUZwPssML9c6xGK09YQoaEiwNq2OpRSvO+VS7h8hRVDMO0mfr/j\nWMHnOgfi/Oy5Izmv5biGilgEdz6+l7+8d+NELl8QhAlgUiqLlVILgTXAH0Y57D3Ab0b4/M3AzQBt\nbW0TvbxZTWXQx8Z/flXOa2Y85tq2OjYd7OH32wvbVncOxLlvkyUEPo8indEMu4SgmEWw+VAvGw50\nT+TyBUGYAMoeLFZKVQI/AT6ote4f4ZgrsITgo8Xe11rfqbVep7Ve19TUVOwQYQIxQrCosYLzFtRy\nrD/bdmKenVm09XAfO9r7WdgQIZXRHOuP5QWLCy2CwXiawbi4jARhulFWIVBK+bFE4Lta6/tGOOZc\n4C7geq31iXKuRyiN1towc6qDrG6r5ZKlVgHaosYK5teFnUE4P9po1RL8+UVnAPDb54+x/UhW54eK\nbPjReIpYMkMqLcVmgjCdKJsQKCsa+U1gp9b6CyMc0wbcB9yotX6hXGsRxofP6+Hpj13FDRcs4JIl\nDYAlBL9+/6V8/PVnAbCzvZ+51SGuPLMZgE/9agf/ev9O5xzFLQJLHIak2EwQphXljBFcAtwIbFNK\nbbZfuw1oA9Bafx24HWgA/tvOYklprdeVcU1CiZh21Gva6qiL+FnVUk1N2Gpz3VYf4WB3lD9dN5/W\nunDRzxeLEZjXhuIp51yCIEw9ZRMCrfWTwKjN7bXW7wXeW641CKdOwOfhd393GdWh7MZ9z7svJJHK\nsHxO5YgpqtF4Gm3PJTDHmNqCYm4jQRCmDqksFsakuSpEyO91ni9qrGDF3Cpng//ry5fkDCWrDPoY\nSqT4yI+38q5vr3deNwJQLGB836bDvP3OZx3xEARh8pDBNMIpc+u1Z1IT9vPZ3+wCoL4iQOdAnD/u\n6yaRztA1GKc27Hd6ERWrOv7QD62C8oF4Ksf6EASh/IhFIEwIdRWBnMdP7OkinsqgNfzsuSMcsPsU\nQdYiONQd5f5t7TmdSk8MJpzHx/pifP2xvWIlCEKZEYtAmBDqI1khaKkOsQWIBLx4PYpP/3onX3e1\nsx6Kp0hnNO+7dyM72vtzAsddg3G2Hu7lA9/fzLyaEEf7YrzuvHm01hYPSguCcOqIEAgTQn1lVgj+\n863n8cotTbTUhDjUM8w//+x5ulx3+k+92MXPNh9hR3s/V6+cw4M7s5XLJwbjfOeZA4A1PAckuCwI\n5UaEQJgQ3BZBZdDH2y/MtgLpiyb4z99ny0Tue+4IQZ+Hmy9bzD+85kyGEmn2dQ7xuv96kq7BBMvn\nVLLxQI9z/InBBFsSvZxnt9EuhXRG85WH9/DOixdS73JbCYJQiMQIhAmhbpTN1j3nwLB6QS23XbcS\npRSVQR8r5lpT1E4MJgqmm/33oy9y/Vef4uebj/DVR14knRk7ZrD72ABfenAP929rH+eVCMLsQywC\nYUKoDo38v1KxOQf5Pv+Az0NN2E/XYJzevDnJZiDOB76/GYCLFjdw/hl1o66nc9CamdDeNzzm2gVh\ntiMWgTAhjDb7oJhFMK9I8LexMsCJoTh9w0kuXdbIQx9+JQDdQ4mc4w6cGCr4bD5milp7b2yMIwVB\nECEQyk6pQtBQGaRrMEFfNElN2O/MUO6PWcHi61fPA8gZkjMSXbZFcFQsAkEYExECYcK49doVfMJu\nSuemviJAvsEwr7bQXdRYGXBcQzVhP2G/1+l5tHxOJXfcsIa2+khRIejoj/GWrz3NkV5r4++0LYKj\nYhEIwpiIEAgTxl9fvpR3vXxhwet+r4f6SACPSwyK1QU0VATpGrBcQ7URP0opquzYQ23YCkYvaqwo\nKgTffno/Gw708MP1h4CsRXCsL0ZmlODyNx7by/ajfSVfoyCcjogQCJNCU1WQikA2oNxSRAiaqoL0\nx6xiM7PxO0IQsYrOFjVWsL9rqKDa+HCPZQk02PUMpm4hkc5wIi/GYMhkNJ/9zS5e++UnT+XSBGHG\nI0IgTApNVUEiwWzjuspgYZbR4qYK57GpNq4KWv91C8FQIs0r/u0ReqPZDd5kFg3Y8YTOgTgBr/W/\n99He4nGCaFLmIggCSPqoMEmsbasj6PPynXevcDbtfJbPqXIe19gbv7EI6uyCtUuXWRPTjvQO81LX\nEGvbAmQymhc7BgHose/+uwbjrJxXzZZDvbT3DRctRosWmZkgCLMRsQiESeHvXrWcu961jhVzq7h6\n1ZyixyxsKGIR2J1IjTAsbqrk//7fywCcGckHuqNOI7ueaJJUOkN3NMF582uAkQPGUVcXVGlsJ8xm\nRAiEaUPAl/3f0biCqvMsAsCZjRBPWm2t212un95ogu6hBFrDsuZKgj7PiK4h9xS13miy6DGCMBsQ\nIRCmJVmLwGQNZTuUhm0hMBaBmYFcHfLRE03QbccO6iuCtNaGae8bwSJwtbIwwWZBmI2IEAjTiqvO\nbAZwZQ2ZYHGhRRBLWRu58fW31kXoHU7SP2w9rwn7aakNjVhUlisExeMWpTAQSxZUPwvCTEKEQJhW\nfPXP1vKbD1xKOGBt9vnpowAhv/W/bcx2DZn4wPy6ML3RJAOxpPPZlprwyFlDrvbWp2IRfOKXO3jv\nPevHPlAQpikiBMK0IuT3srKl2nneWhcm4PUwpzpbiRzyWSIxGEvxv88ecFJGW2vD9EYT9A1nhWBe\nTYiOgTjJdKagsGzIZREc6z/5CuRjfTH2nzh5i0IQphoRAmFa85qzW3j0I5fnzBQwrqFHdnfwTz97\nnsdf6ASsthUZna0bqAr5mVcbRmurUd3i2+7nfx5/yTnPsO1S8nmUY0UA3Pn4Xp7Y01nyGgfjKXqi\nCVLpTMF7O472890/HBjHFQvC5CNCIExrvB5V0KAuaGcXmTYSR3qHCfu91FdYze0O2nUKVSGfU8G8\nYb816OZf79/ppIoai2BOdcixKgC+9uhefrrpSMlrHIqn0LqwSyrADzcc4hO/2CHpqcK0RoRAmHF4\nPIqAz+NsvO19MSqCXursOMLB7igBr4eQ38s8u4Pp1iPZfkKbDvYCVoxAKWiuDuYIwVA8zcA4xmOa\nUZpmBoKbeCpDIp0hniq0FgRhuiBCIMxIQi4hSKQyRAI+J7Po4Iko1WEryGwsgq2He53Pvu/eDezr\nGmIokSbi91Id8tNvu4YS9sY96BKG4USaeCq3HcVQPOUEqc1/3XOZDXE7xbV/WOoUhOmLCIEwIwn5\nvbhjv5GAl2Z77sHRvpiTdloZ9FEd8rGrfQCAH77vYroGE/xqy1GiiTSRoI+qkM+xCMzd/UA8u3G/\n5571fOKXO3K+/9afbOWD338OrbXjYjLDcNwYS6A/JkIgTF/GLQRKqTql1LnlWIwglIoJGBsqgz6a\nq7MDcKpcozPn1YZJZTRej2LdGXWE/B4G4imiiRQVAS9VIb8TLHbu8l0WwaGeaE71MlgB6SO9MeKp\njDNDuauoa8gSib5h6WskTF9KEgKl1KNKqWqlVD2wBfi2UuoL5V2aIIyMqSUwRII+gj6vk12ULwQA\nDRUBPB5FZdDPQCzFUDxNOOCjOuxzpqDlu3vAcg0l8jKCovE00UQq57jiQiAWgTD9KdUiqNFa9wNv\nAr6ttT4fuLp8yxKE0cm3CCrsAjRTb2DaVwPOyEszMrM65GMglmQ4aVkE1SE/iVSGWDKddQ25LIJo\nIu30NXJeS1pCMpQjBIUxgpjECIQZQKlC4FNKtQBvBX5VxvUIQkmYojJDxB56M9d2D5lgMWQtgsZK\n673KkI/BuLWRmxgBWJu/ucOPpzIkUlYRWtS2CG776TY+95tdwElYBCIEwjSm1HkEnwR+BzyptV6v\nlFoM7CnfsgRhdIJ5rqEKe+jNXPvu3wSLodAiMMHhaCLF3OqQSwiSDLlaUw/FU873JFIZnjvY6zS/\niybSDCfTTiwhEvA6c5LdGEuiPyYxAmH6UpIQaK1/BPzI9fwl4M3lWpQgjEW+a8hYBI5rqEiMwLEI\ngj66BqK2RWC5hgA7bpDdsAdiKVIZu+V1KoNSEEt5yGQ0w7bLx9QOzKkO5TSxM5hgsVgEwnSm1GDx\nv9vBYr9S6iGlVJdS6s/LvThBGInCrCHbIqgutAhabSHIWgR+O0aQpiLgc47tjyVzCskG4kmG7c09\nkcoQT2aIJTOOCAB09FtC0FARKKg1gGxjPAkWC9OZUmME19jB4j8BDgPLgY+UbVWCMAYhX17WkLEI\nagotgvl1Yf7tzefwxjWtgGURDMStu/9IwJsTI3BbBIOxFNGkiRlYRWXxZDpnoE2H7Q6qrwg4m76b\nrEUgriFh+lJqjMDcXl0HfE9r3a2UKtOSBGFsCrKGbItgaVMlPo/ijPqI855Sirdd0OY8r7aDxVpb\ngpEbI3AJQTyFz2sJTjyVAQ0xb9qxEgA6BqyupQ2VxS0CSR8VZgKlWgS/VErtAtYBDymlmoCT79sr\nCKeIqSMwk8yMRbCgPsLGf3oVL1vcMOJnK0M+TA+4lpow1fY5+odzs4AG46lc15D9xx1Q7nRZBPFU\npqC5nDtraOOBHt7/vefYdriPUuiNJrjtp9tyhEcQykFJQqC1/hhwMbBOa50EhoDrR/uMUmqBUuoR\npdROpdR2pdQHihxzplLqGaVUXCn19ydzAcLsxFgEJiPIWASQHXQ/Eu74wbzaMJUBH0plLQIjMh//\nxXa+v/4gkG0eF0umGU66XEP9cUJ+D5GAJS7uwrNkOlt13Dec5IEdx/nFlqPccOczzlS10fi33+7m\n//5wkF9uPTrmsYJwKpQaLPYDNwI/UEr9GHgPcGKMj6WAD2utVwIXAbcopVblHdMNvB/4z3GtWpj1\nGCE4a14NPo+izeUKGovKoDujKITHo2ioCNDeF2Mwnqalxgou90ST/Gpre85nY6lMTnZQx0CMyqDP\naY3t7jLqftwfS+XMWN51bGDMdXbabqfqUKkeXEE4OUp1DX0NOB/4b/vPWvu1EdFat2utN9mPB4Cd\nQGveMR1a6/WAOFCFcWE23tULanj+E69maXNVyZ91B5JN3cF582vZdLCHwXiSxsrASB8lndE5gd+e\naJKKoC87R9mVUWQ6j1YGffQPJ3OsgO1H+8dc5wm7u6pxewlCuSj1/7ALtNbnuZ4/rJTaUuqXKKUW\nAmuAP4xjbe7P3wzcDNDW1jbG0cJswGy8kYCvIHA8FkYIGiuDBO0K5bVn1PHQrg4S6QxLmipH/Xz3\nUG7hWI5FkCy0COorAhzsjtIbTbK4qYITgwl2HB07TmDabKcyMstAKC+lWgRppdQS88SuLC4pgqWU\nqgR+AnzQTkEdN1rrO7XW67TW65qamk7mFMJphtn83bGBUjExgtba7Bzk88+oA+BQ9zAVwdHvj7qH\ncg3YebVhgv5s4ZnBWAemEV7HQJyKgI+z5lWXZBF0D5p5CzLdTCgvpQrBR4BH7C6kjwEPAx8e60N2\nbOEnwHe11ved/DIFIRcT0D0Zt4mJEbhHYJ43vxafx0qJrgr62P6JV/OhVy0v+vl8i2BBXcSpa8hx\nDdmi0GALQedAnEjAy1nzqtl1bICO/tET70xxW7LILOTpzq+3tnPxZx+akWufjZSaNfQQsAwrsPt+\nYIXW+pHRPqOsQoNvAju11tKyWphQIgFjEYxfCIxryASFAcIBL//+lnO56ZKF3HjxGVQEfU5GUj7d\n0VyLoK2+uEXgCEFlrhC85fwFeJXiff+7ccRZxu7XZ6JraF/XIO19sZwqbGH6Muq/IqXUm0Z4a4lS\nijHu8i/ByjTappTabL92G9AGoLX+ulJqLrABqAYySqkPAqtO1oUkzB5evqSRf3rtSlYvqB33ZyuD\nPq5ZNYcrz2zOef1Na+fzprXzneeNVcH8jwKWRRD0eZyNvq0haxHEiwSL6yus8yTSGSJBHyvmVvG3\nVy3l33+7m+6hBA2Vhd/T6xKb5Ax0DSXT1ppT6Zm39tnIWLdTrxvlPQ2MKARa6yeBUcuPtdbHgPmj\nHSMIxQj5vbz30sUn9VmlFHe+c92YxzVWjCQEVqZQPGX58BfURZxxlaaoTClFLM81BBCxLYc5VZa1\nMRRP01AkNn18IOs2Ss5Ai8C4hFLiGpoRjCoEWuubJmshgjDdaKwqnkbaM5QgEvDSPWQ9n18XYf8J\n68nN924gmdbs/9xrHYugziUExpVl/uuejezmWJ9LCFIzbzM1QpA/2U2YnpzMzGIZTCPMCky2T8Cb\n+8+kO5pwYhRgxRdMFpNxiZiWFJCNEZhjIRunGBxhToFJHXWfcyYhrqGZxbiFgLyiMEE4XQn6vFSH\nfNTmtaxIpDJEAj4uXtzAgvqwfWzuP6WD3UMFWUOQHalpMpf6hpM5d/+GHneMYAa6howlMBMD3bOR\nkylZfG7CVyEI05S2hggBr8dpN22oCvm49z0vc57nF7W91DlUUEcA2XTXStsi+MrDL7LtSB8PffiV\nOYVsvVGXRTATg8W2CM5Ea2Y2Mm6LQGv97nIsRBCmI9961wV88vqzC15vysv0ybcIfrjhMN98ch8A\nVUG/414yLqUq2yLYdsSqMH5wx3Eg61vviSaorwjgUTOzjiAbLBYhmAmU2nRumVLqx0qpHUqpl8yf\nci9OEKaa5uqQM+LSTVPV6ELw4M7j7OuyAshBv8epgI4Ecy0Cw0O7OnhyTxfnfPx3dA7E6RlKUhvx\n4/d6Ttk1tK9rKGfOwmRgLIGZ6NaajZRqEXwbq8lcCrgC+A5wb7kWJQjTiYCv8J9JvhD4vB6nMjmf\noM/jZAmZ9NGw34v78A37u3lw53FiyQz7uoboiSaoiwQsIUhpUukM33lmf9HhN6Ohteb1X3mSu5/e\nP67PnSoJsQhmFKUKQdiuLlZa6wNa648DV5ZvWYIwfTB3+35vdufOFwL3cfkopaiwYwMR2zJQSuW0\nw85o+JU9d+B4f4yeaNIWAkUqk+HBnce5/efb+Y/f7h7X2gfjKQbiqZwspMlA6ghmFqUKQUwp5QH2\nKKX+Rin1RqB5rA8JwumAsQjcAeH8GIH7/U+94Wy23H5NznuOa8jVG8k0vztvfg0AXXaTueP9MXqj\nCeqMayidweux1vDEnq5xrb1v2Mo+Gq8lcapIHcHMotSsoQ8CEaw+Q5/Ccg+9s0xrEoRphc+j8Cgr\nnXQAy9feXD2yRdBUGaAm4ufJj17B/q4o4Cokc9UfGItgSVMlHQNx2u00UssiSFBXYbmGEint9OzZ\nfXzsgTZuHCFITu6GbDKdxDU0MyjVIliotR7UWh/WWt+ktX4zds8gQTjdUUoR8HmcjqcATZWFDelM\n4zkzA3l+XYRXLGsEspt+2CUExkqojQRY1VLtvL7/RJRYMmMHiy3XkDvYe6g7WvLajRDEJrk6WeoI\nZhalCsE/lPiaIJyWBH3enBhAdbjQmDbvV4cKZyZnLYLs5yrt4+or/KyaV+2cY7c9xrIuEsBnu4bc\nQrDpYE/J6+53LIKpcQ1JHcHMYKzuo68BrgNalVJfdr1VDUxuPpogTCGWRZC9m7e6rOdiLIKacKEQ\nFLMITC1BXUWAc1preOrFLmojAR7e1WG9bscIEinNUNzayL0exWMvdHLHQ3v4ytvXcNa8mlHXnY0R\nZO/Mo4kUAa8Hn/dkGguURlIsghnFWP8nHMVqEx0DNrr+/AJ4dXmXJgjTh4DXM2JWkCE0ikXQUhOi\nNuLPOYcRh/pIgHPn13LfX1/C4sYK5/26SICAcQ0lUoT8Hs6cW8XPnjvCS51DPLyzY8x1FwsWr7r9\nd/zVdzeN+dlTwakjmKYWwff+eJAfrj801cuYNozVfXQLsEUp9X/2sW1a6/HlrwnCaUDQ7yHo8/KJ\n15/FCOUCjkWQXywG8K6XL+QNa1pzLAlzXG0k24LCPQOhrSHiuIYG4ykqgz5WtWTHXG49Mvbc43yL\nYNB2MT1gVzIDtiXiH9O6GA+J1PSuI/jhhkP4PR7eesGCqV7KtKBU2/BaYDPwWwCl1Gql1C/KtShB\nmG4EfV5Cfg/vevlCbrx44QjHeKgK+vAWUYqQ38uc6twAs2MRuHoRLbItgtuuO5OWmjB+ryKZ1kTj\nKSqCPieWALD1cO+Y687PGtpvVzu7i+T+7K4/8NovPznmufYcH+APL50Y8ziY/q6hRCpDXFJbHUpN\nH/04cCHwKIDWerNSamF5liQI04+/u3pZ0Tt9NxUBLzWRQrfQSJjsIrcQvGrlHB75+8sdQfB7PQzG\nUwzG00QCPs5ute7az5pnWQbH+2M0VQb51lP7uH51a0GhW/+wZQEY15CZm9DqmtdcKq/64uPWOT73\n2jGPTWWmt2sonsqQzkzPtU0FpQpBSmvdVyxAJgizgWvOmjvmMX9z5dKCLqWj8YbV86gN+3M2b49H\nOSIAOAVlQ/EUlUEv686o4+t/vpbaSIAb7nyWrYf78HkVn/71Tl7sGORzbz435zvyXUPGIig2jzmd\n0UWtGci6lEol23209LvuWDLNlx7cw99eufSkZlGPh0QqQ8Yr+5mh1L/t55VS7wC8SikzxP7p8i1L\nEGYeS5urWNpcVfLxDZVB3nz+6JNa/V5FKq0ZSqSorwiglOLas1ucjfmF4wNOu+tim2eBEJywahDM\nhu8OIrf3DTO/LlJ0Hc/uLc0lZEicRIuJzYd6+fpje7lgYR1XrZwzru8bL4lUhnRGhMBQaozgb4Gz\ngDjwf0Af8IFyLUoQBAuf10PCtghyahCCPlprw+w+NsAfXuoGig8Iz68jMBaBCeYOuCakHTwxcqHa\nE3s6gcJpbSNxMnUE0YS1FjP/uZzEU2lpf+GiVCFYZf/xASHgemB9uRYlCIJFwHENpZ1KZMOyOZVs\nOdzrFJiZu383+RbBAbsq2WyCbiE4MErF8k67yE0z9saezmiM+308wWJTKzEZLbMTqYwjhkLprqHv\nAn8PPA/I354gTBKOa8jOGnKzfE4Vj+7udJ73x3KFIJPROUKQzmhODFoxjKxFkP3MwVGE4LD9XjKt\niSXTxJLpnLRXN+64wHjSR40ATIYQxFMZ0lqCxYZShaBTa/3Lsq5EEIQCfF4PiZRVUOZ2DQEsa7ZG\nW7bWhmmpCRVYBId6oqQymrb6CAe7o3QMxJw79WQRi2Ak11AilaG9P0bQ5yGeyvCvv97Jvc8e4E1r\nWvnC21YXHu8SgvG4hoxLaLyB6fGSyWhSGU1Ga7TWRavEZxulCsG/KKXuAh7CihMAoLW+ryyrEgQB\nsFxDA/EUGV0YDF4+xwpMv2HNPHYfG2Bv5xBLb7ufO25Yw9YjvQR9litp9YJaDnZHOdprdTf1KGtz\nv/Pxvew5PghAbcTPsf5YwfcPJ9Ls7RxEa1jaXMn2o/281GV95rfbj/GFImtOulwu43ENRSfJIjBC\nldFWmqtfsodKjhHcBKzGKix7nf3nT8q0JkEQbPxe5bhxKvNiBOe01vCP163kPa9YTHXYz76uIVIZ\nzWfu38k3HnuJLz+0B4+Cc+15B+19wwDMrQ6RSGX4zP27+NHGwwAsbarkeBEheON/P8WffMUqNltq\nWyCmNiGaSDsZS27cVsCG/T28++71JaWRZi2C8gaL3S25JyJOsPVwL7f836YxM6RiyTSf+80uJyg+\nnSjVIjhPa31OWVciCEIB7sZwkTzXkMej+H+XLQZyG92573AXNVY477XbFsGcmlBBK+slTZVsPdxX\n4CrZdSw7/2BpkyUEvcPZaWcnhhIFxWnuTX9Hez872vvpGIiPWcRmLIFyb5TxdFZo4qkMFYWjJcbF\nTd9ez4mhBP943UrmjXKNP9pwiK8/thel4KPXnnlqXzrBlGoRPKuUWlXWlQiCUIDfJQSjFVnlCkH2\nMytbqp0eSEdti6ClJlRwJ7y0uZJEOkNPtDDzyNDWYNUY9LqO6R4sHIFZLC2zv0hGUz5DiclxDU20\nRXDCHgM6UjGewfwuneMoOpwsShWCVwCblVK7lVJblVLblFJby7kwQRDA79pc8tNH3bg7nro3pIsW\nNzgdTx2LoDqU05Y67PfSWmfdyea7h6pc4mMskoFYymmLcWKocFMr5gZyB6UNHQMxbv/5844FEI2P\nP1g8FE+Nu1WEW6gmMoU0NcY6jFgXS/Odakp1DV1b1lUIglAUv6s5XGOROckGt0Vg7r6/9RfruGJF\nM4+9YKWYtvcNUxXyURHw5QhBVcjHHHv05rH+GCtbqrn1x1sYiFmD7y9f0cRNlyzKKVhrrQ3TPZSg\ne6jQIjBjKt0Uswje/73nePalbi5b1sTVq+a4LILSYgTpjOaV//EIf/eq5fzZy84o6TOQu/lP5Czn\nzBhCYAS6bxSra6ooSQi01gfKvRBBEApxu3mK9QcyuIWgy3bX1EaslhQme+hoX4wGew6ym6qQj+Yq\n69wdtkXwww2HnfevWjmHVy5vyuk82lobZtuRvqJCUMw1NBDP3fz6hpM8a1dEH+qx4hXjrSOIJdN0\nDSY42jtc0vGGeI4QnJpF4F7rWJaJsRimo0VQvhFFgiCcMu7Ab7HJZ4Zq13tmIzZ1B0F/1jddXxHI\naUEN1sjMZtsiON5f6OqptwvHgq4Jbc3VQfxeRddggo6BGBv2dzvvFXMNmUwjwzN7u5zH++y2F9Fx\n1hGYjKXxuncSEygE7iK8sVxD5u/FHWyfLogQCMI0xty9B32eUQufiolExB6L6Z6KVl8RLBCC6pCP\noM9LfUWgaAppXYW/4DzhgJe6SIDuoTh3PbGPm+5ezx0P7mHhx35ddHMdyKt67h6yns+pDvJSpyUE\n4w0Wx+zvGa8QuN1BpxojcAtBZoxKZWMx9E5D15AIgSBMY4wQjBYfAGu+cT5ZIcjeyTdUWOMv3ZhA\n85zqUFEhMIHhHCHwW8LRPZTgxGCCwXiKOx/fC+CkpoZdFkR/XrDY3BWvWVCXtQjs2EA0mR7T3w4u\ni2CczeMmMkbgdkuN6Rqy6ytO1QopByIEgjCNMa6hhsrifX0MzdUh7rhhNVed2ey8ZtJN3Rv43JpQ\njkXwzovP4K8uXwJAY2WAziLpoHVFXENhv5fGyiAnhhL0x5JonR2z+WLHoP392eMHYkliyTS/334M\nsAKmQZ+HVfOqOdI7zHAizVAihUeB1jBcpFAtHyME491Y3UJwqhZBzJWKOpYQJDOlHzvZiBAIwjTG\nZJqMZREAXL+6lZZaK+irVFYATIwAYNW86hwhuPbsuc7Us6bKIF0D8YJNqjZS3DVkLAKTEWQEY0/H\ngHOMoX84xWfv38nN925k08EeeqNJaiN+ZwjP3s5BYsmMM6SnFPeQ2YTHKwTu40+2FXU8leZQdzTH\noijVIgDoiU6vOEHZhEAptUAp9YhSaqdSartSqmB+gbL4slLqRbs+YW251iMIMxGTf+8eZzkaJkBc\nEfA5MQW3a2hVSzUBb/Z5yHWX31gVpGswnlPZWxn0OZ93C0HIuIYGE47bx1gAL9j9iyL+bFJifyzp\n9DI63hejdzhBTdjPXDsT6oDd8M5kL5USMI5PRLDYdUf/7Esn+NXWoyWd4wfrD/HqLz3uBLih9GAx\nwIkiltdUUk6LIAV8WGu9ErgIuKVIdfJrgGX2n5uBr5VxPYIw4zBto8dyDRnMXXjEdTfu3sDn14Vz\nMpHcfvzGygDxVCZn3KYJFFvnyXUNNVQEGIin6LLXaETLVM7mWASxbBvtwXiKvuEkteGAE9swvvas\nRTC2a2j4JIXAPbTebRHc9cQ+/uN3u0s6R0d/nGginVMTMFaw2C0U5u9sulA2IdBat2utN9mPB4Cd\nQGveYdcD39EWzwK1SqmWcq1JEGYaZlbyW9aOPtLSYCyCkYRAKZXjGsqxCGz3kzsTpsnlkvJ7FSZx\nKez3Um+Lk9n484vG3GsYGE5SaQvBUDxFbzRJTcRPTdg6xxFbCJptISjFIjCuoXELgSv+4H48GE+W\nnONvRMg9A2Js11B2neVutT1eJiVGoJRaCKwB/pD3VitwyPX8MIVigVLqZqXUBqXUhs7Ozvy3BeG0\nZfmcKvZ/7rUsm1PaLOSsRZB1yxgX0fI5VtO4XCEorFw2cwluvXYFn33TuTnnMaISDlgWgZv8TdTd\n6qK4ReB30l7zhaCUxnPZYHFx6+Hdd693MpncJEawCIbiafqHkyVlLBkhcLfOGDNY7IoRDE/COM7x\nUGqLiZNGKVUJ/AT4oNa6P//tIh8p+NvUWt8J3Amwbt266RVuF4RphPHT5/cleuDvLnP88Tm+fndq\nqX2HbyyCc1prWDE3V4CCPi+xZIaQ31tQj9Bn3/Xn3+0GfR76Y0knFXbAtghqI34CPg+VQZ+TctpU\nba0xWsJGGUuNnj665VAv1aHCLc4dF3A/HrLnPgwlUlSFRi7eg+xGPjAei8CVNVTK9U0mZbUIlFJ+\nLBH47ghDbA4DC1zP5wOlRWsEQSgg7DeuobxpZnOqnM3N3WLC7RpqynMN5Z8DsiJi6gjcZDS8bFF9\nQbvpuXa3U5MJ1NEfZ9g16rIm7GdvpxVgXtRgZRGVcsc8lmtoOFl8QH0inXHcXO73B+z15Vs2P9xw\niBvufCb33InRLYLdxwZ41Rceo9eVHeTOGiolPXYyKWfWkAK+CezUWhcbZATwC+CddvbQRUCf1rq9\nXGsShNMdYwm4/fP5uO/kc6uOAyiVdQ0VO4dJRS3mGjLvP/ihV/LLv3kFJnY6x77LN4FvU0Bm3EJ1\nFX7HbbKw0Wp1PZ46gmJCoLVmOJnOueM3JFIZgj4vAXsMqGFoBCG49cdbefal7pyNPjpGjGD70T72\ndAw62VBgBYsDtggPT7PhNOV0DV0C3AhsU0pttl+7DWgD0Fp/HbgfuA54EYhiTUITBOEkiRSJEeQT\ncLWt8Lj8+D6vh7pIwLEI8mckW5+xzh/2e6kO+fF6VM4GGPJ5CQe8nDO/xrnbbrRdTubuef8JSwhM\nfUJtOGsZNFSYGMGpCUE8lUHr4m6jeCpNwOfBo7I1BZmMdr5zpIDxb58/xsYDPdz+ulXEEkYIUs4s\n57Qra8hMW3PXQ6TSGcIBL+m4nnauobIJgdb6SYrHANzHaOCWcq1BEGYbRgBGm11gLAK3W8jQWBnI\n1gEUOYfbNeTxKOoigZxUSHfxmsmbNy4g43oxvXYcAbAFYW51yAlej3XHfKR32BGW/IKyLz+0h8VN\nlotpZIvAg0cp57NDru/Lz34K+CzL4VtP7WPjgR5uvXZFTupqXcRvCUHGHQy2zueOlyTt+cgRv3fa\nuYbKHiwWBGHyMHfx4dFcQ14jBIWe4ZaacFYIirmG8qqVGyryhMAVfDZCYGoF8ofTOK4h+/05NSGU\nUoT93lHvmKOJFFd//jHnbj/fIvjGY3u5YFE9kFszYIinMrZFoJ2MI/eGnd8ptTrkp2swzo6jVq6L\nVUOQPSbs99JDMtd1ZCyCRK5F4PN48ATUtMsakhYTgnAaYQSgmFvHMJpFYFo+QG5GkSHo8+a0r8gv\ndHNbBGaDNkHqQddMgkjAy5Jm67uMZTDXboUdCYx+x9w5YAWbzcbr3uzTGc1QIu1U7saLnMdYBEFf\nNkbgduHku4ZM5pFZU8dALKfHUMj+O8+1CIxryFV5nNb4vIpwYHShmwpECAThNKIm7KexMsgSe9B8\nMRwhKLLRn2HPJQZy4geGoN9D2O91ahPyM4dyLQJrYzSFZG6L4BVLGx03lokVzK2xso3CAa+zkZ4Y\njLPwY7/miT3Z+qETecNwEqkM2vbPmztwE5gumjVkWwQB27cPMOjasPOFoCqvxffxPIsgUkQIirXU\ntlxD1t/fdHMNiRAIwmlEwOdhwz9dzWvPHblA33ENFXH9LGyoKHjNTdDnyWlLsaA+kpM95M5CMmJk\ngsWDLiF445ps3aiJIcy1s4sirjvmjQd6ALj7qf3O8cX69Fz1+ce45+n9znd0DRmLYATXkDfXInCv\nLV8Ignn1Eh0DsZyN3PRUcgeLoyMEi30eZVk8YhEIgjCVeD1WDn3IV/jP320RFCPs9+YEkW+5Yik/\nu+USp/WE2930xbedx/+992XMr7POmcpozm6t5t73XMi1Z891jqtzLIKg8x0mPdNsuO6YR/dQYZ+e\nl7qG+JdfbHd8/WaDL9aZ1EkfdQvBKK6h/BjEsb5c11B4FNeQ29JIpjU+r8d2DY0/ffS992zghxsO\njX3gSSDBYkGYZSilCHg9RWMEZtMeib+8fAkdrnGWlUGf3aHUQyyZybl7rgr5efnSRvYcH8g5/tJl\nTTnnXLewnhsuWMAFC60AbzjgddIzzZ21O3Cd7xpykx+QThRpP3F8IMaqlmoG4ymnc6q5c68J+/nF\nlqNktOa/3mE1Q84XE3dtAGQb9+W6hopYBBmrkC3s9zkT2kolndE8uPM4Z82rHtfnSkUsAkGYhQR8\nnqJZQ/ltI/I5c241ly1vKnjdiEoxcXHHDdxVzYaasJ/PvflcJ6gcCfiIJnM3aHddxGgtnPvy5gHn\nb+LdQwkOnIhy3oJaKoM+5/zGp28skF9tzda1usWkNuLnQHeuEBSLEZj00dysIe1yDY3PIjCtLEab\nW30qiBAIwixkJIsA4M1r5/MXL184rvM5aaVFhMSdSVTs/XzcWTUmp999ju5RLILnj+S2M0uks4Fk\ngM2HrJjD6jwhMJbE4sbCGIk74Hzu/FoO2AVx7vUCxdNH8ywCn8eTEwMpFeOuEiEQBGHCWD6niuUj\ndDT9/FvP4+OvP2tc5ys2vCb7Xva1YhZBPhF/NpjabffqcfvpR+vlv+lgT85zrXO7fm4+2ItHwbnz\na6gI+pwg8VA8hc+juOfdF3LVmc05a44nM4T8Hl65vImFDZGCTdy4htzzCEZLHw35xx8sFiEQBGHC\n+d7NF3HLFUsn7HzBUWoTxnIN5RN21RH02FXIbhdP91DCCU7ns+lAT8Fr7jv65w71smJuNZGAj6qQ\nj8FECq01Q3GrTfac6hBntdYQT2WcdtSJdIa3rlvAPe++kJaacMH5jWsoVSR9NL+y2Of1jFknUQxH\nCCIiBIIgTFOMABSzCNxxh7FiEJDrGuqx3UAx18Z5YjBBW33xoHZ/rND37i4qO9wzzBK7/URF0IfW\nlhtnMJ526h1M7MSIT8JONwWKfm/Yjl8Ucw1F8yqL/XaMIJXR4xqoIxaBIAjTnmzriUKLwOtR+Ozi\ntNJcQz4Sdu+e7rx6AK2t15Y1F7q1zAafj9si6BqMOwN4ciemJaiyK4hNoV128E3GiVEUS68N2+8Z\nC0JrXTR9NFtZbFcqj8M9ZISgeow5CSeLCIEgCKeM2SiLZSLB6MHkfIyrJZpI0WPHCKKJFP/66x3s\nbB8gkc4U3fRHinkYEYkl0wzEUk6BmxGC/liKTQd7OGteDZAN/sZSVhuLdEYT8FqvtbmEwLinTEZT\nyuVKMo+H4ik6B+Lc/vPn6Y8lrToCWyzHcg8d7ok6VoNYBIIgTHtCTrC4eCaSsRT83lEbElvnsjfi\n4USaHjvffl/XEP/zxD6+98eDACxtruTqlXP4p9eudD63YAR3kbEITP1BvkWwfn83PdEklyxtsL7f\nFrNYMuNsxMal5b4jN5ty0G+1tDbBYnOnXx3yMZxM847/eZbvPHOA9r6Y4xqC0cdxDifSvOoLj/MD\nu4CsbzhpZ3qVZ8sWIRAE4ZQp1SIoJUYQsUWjczCe3cTt2oHddnFaY1WQu961LqemocUexZmPsQhM\n/6EGWwjMDOXfbT8GwCVLG61rcLmGTHfSYpZMnd0aI+D14PUoxwow8YEme/7yno5B5zOmsth9XDG6\nBq3GeoftmoX+4STVYb/T42miESEQBOGUCY5lEdgbaUkxAnujPNIz7LxmZhmYKuVGe4BNwHW+ebWF\nGT0AibS14Zq0U+MaMjGBR3d3sripwpmkFnK5bvItAjfGIgjYsw0yjhBYazVC4MbvzVoEsVFcQ8Yl\nZmIkfcNJasLlawQhQiAIwikzVgzACESpWUOQbeVQ4WovYdJJ6+3N3H2+/FnJhhu/+Uc+8P3n6BrM\ndQ0ZiwBgcWO2W6sRAssiyBRc16XLLMvBdE0N+rz4XJPazJ2+Eb0F9dl1+TzZGMFoFoERACMIlhCU\nJz4AIgSCIEwAo6WPQtZ1FCjJIrA26Je6LJfKwiLVvqbjqVsI8i0Cs5ZoIs3PNx9l436rxqAhL1hs\nfTbrVnLSR5MZRwjc33Pnjet48EOXOTMfAvbIz3zX0OUrmqkK+vj8n652Puv1KKrtDb3DDiJ3DhQW\nyBkhyLUIRAgEQZjGRAJWN0/fCBv9uGIEtgVgJqXlC0FFwOsIj/t8da6Rl5C9Yzf8YMMhIgGvIzRu\nIXAXirktgkQRiyAc8LK0uSrnmnweRUZbGUabD/UCsLatlm2feDVr22qdz/q9imbbZXT30/v4zjMH\n+MIDuwv+DrIWgWUBlVsIpPuoIAinzI0Xn+F0Dy2GcQ2VEiNYYHdA3Xq4l4qAl6bKXF97g+u5e4NW\nSvGLv7mElpowF37mQWrDAY73595tu2cpuAPbuRZBNkaQDRYXqZj2Z91hJlj8u+3H+NxvdgFZ15PP\nnn0QT2XweT3UhP0EfB52tQ+MeO6CGEFULAJBEKY5LTVhrjizecT3zZ17Ka6hmog1ZS2Z1sytCeU0\nnIPcqWjmfGe3Wu2Zz51fS1NVkIjfm7NxmiZ67hbW7gwct1sp7FgEhemjbtxxDxMs3tdlNaT70KuW\ns9Q1Jc5YOX6PQilFU2XQcSUV2+BNm+q+4STxVJqBeIqaSKDguIlChEAQhLLjZA2V4BoCWGrPM55X\nGy4YqemeiKaU4me3XMJ333NRzjHzasM5xV9Xr5wDwAUL64p+nzv1NFtHkHbSV4sLQVbcTLD4WF+M\nmrCf91+1LGfUpxEX4zprrs5aNcYCSbo6pfa4BGt/VxSts9lO5UCEQBCEshMch0UA2TGXc6tDBY3s\nGvI2xNULaguasf3oLy/m1levcJ7XVwTY9vFruPc9Lyv6fSZ1FLIb84mhuJO5NFoPpaAdLE5nNMf6\nY06MotixPrugrtmVWhpLpukbTrL6E7/n0Res2cym6yrAng7LhdRQUZiOOlFIjEAQhLKTdaOUVhBl\nhKClJlSwCdeXsCHWRgI5efoNlQFn8E0x3LEL831ffWSv89poFkHQ58XrUaS15nh/jDlFCtucGInH\ntgiqssfEUmmO9g4zlEiz+9gAV6xopnsoQXXIR38sxR47aC4WgSAIM5ps+mjxgrN8ljbbFkFN2LlD\nNxthqRui2/rIzyAy1BV5XSlVID7FLJmzWms4u7WaiqDXCRa398VoGa9FkEgXBId7hhIssf8OXrQr\nkxsqxSIQBGEGk60sLs0iWNNWy+UrmrhkaQMb7RkDi5sq+dhrFnDlKEFpN24f/UgVz09+9Ep0kdfD\nAW/ODIRiXVWvWNHMFSustXiVIpHK0DUYL2oRZIWgMEYQS2botdNETwwmyGQ0PdEEV5zZzHMHex3X\nUDktAhECQRDKzngqi8EafH/3TRcC2fGTlUEfbzl//oSuy11d7MYKUGcHzI8V2/B6FMf6YmhN0RiB\nI4QeYxFkjxlOui2COAOxFBkNi+0Oq/u6hvB5VNlaUIO4hgRBmATG02soH5PFM9KmXQ7yBWssAfN6\nFEd7rd5Ic2sKXTj5FsHCxgqnjfVwMu1YBN1DCWe6WUNFgLb6CMm0pqEykGPhTDQiBIIglB0nRlCi\nReDGxAgqJ1EI8mcFjDVHwetRTo3CnFEsAjOgZ1FjBY9/5ApWL6gllkzTa1sEJ4YSzsD7SMDHyhZr\nxkI5M4ZAhEAQhEnAuFZKTR91YyyCymBpgWY3Z7dW866Lzxj35wZiyZznpbiGDMU27WIxigX1EcJ+\nL7Fk2mklYVkElghVBL2sbLEK5fJTZicaiREIglB2aiMBlDo5947ZRE/ms7/620vH/RmwArhuxnLL\neF1VysVmMhhLKH9OcTjgpXMg5VgE0UTamZtQEfA5QtBYxowhECEQBGESuO6cFtoaIkV79I9F1iKY\nvtuV2yLIL4CDrGsons4VgpDfYweLsxbIIXsYTUXQ5zTDK2fGEIhrSBCESSDg87C2rXh7h7Ew2TLu\nHkOTxatWzSnpOG9OqurIFkE8L/YQ8nutkZzRhPO5Q/ZAnoqgjwX1Yf7k3BZeuby0lNmTZfpKrCAI\nAtBcHeL7N1/EGlc758nizhvPL+k4IwRBn6foOEnj3krkWQRhv5d4Ks1wMs3ipkp2tvdnLYKAF6UU\n//WOtadyCSUhFoEgCNOeixY3jFgUVg7uuGE1f/ayNpRSJc0JNkJQzC0E8OcXtbG4qYI3rcmtgwj5\nvUQTVtbQErtu4KAtBJFJdIWV7ZuUUt8C/gTo0FqfXeT9OuBbwBIgBrxba/18udYjCIJQKtevbuX6\n1a0lH2+CxcUCxQDz6yI8/OHLC14P20IA2f5Kh23XUGQEUSkH5bQI7gauHeX924DNWutzgXcCd5Rx\nLYIgCGVjLItgJNzC0VYfwe9VDMZTRALeshaQ5VM2IdBaPw50j3LIKuAh+9hdwEKlVGmRGUEQhGmE\nIwTjdF+5haOlNkSdPXxmMquoYWpjBFuANwEopS4EzgCKNhJRSt2slNqglNrQ2dk5iUsUBEEYG49n\ndNfQSLiFoLU27GRGVQQmzy0EUysEnwPqlFKbgb8FngNSxQ7UWt+ptV6ntV7X1NQ0iUsUBEEYG9M6\noliX0tFwz1CeWxNyKogjgcm1CKYsfVRr3Q/cBKCssPw++48gCMKMIhssPnnXUNDndYbuTHbx3JRZ\nBEqpWqWUqRB5L/C4LQ6CIAgzimyMYHxbajiQe7yZxxw5ib5Kp0I500e/B1wONCqlDgP/AvgBtNZf\nB1YC31FKpYEdwHvKtRZBEIRyYoQgPE7fvgkum0lpToxgki2Csn2b1vrtY7z/DLCsXN8vCIIwWXhO\nMmvIxBTqbAGYjcFiQRCE0wLfSWYNJe2WE/V22qjjGprkYLEIgSAIwiniOclgsRlHefNli4GsRTDZ\nwWJpOicIgnCKGCEYb/poc1WI/Z97rfPcSR+d5GCxWASCIAinSDpjuXjG6xrKZ051iEjAS2tteCKW\nVTJiEQiCIJwiyYwGxh8szqcq5Oepj15JTdg/EcsqGRECQRCEUySZMhbBqbt06qZgAI+4hgRBEE6R\nlG0RFJtONhOYmasWBEGYRpjJYwERAkEQhNmJcQ35vZM3Q2AiESEQBEE4RYxryO+dmVvqzFy1IAjC\nNMJUCPtECARBEGYnRgjENSQIgjBLSaYt11BALAJBEITZibEEZmrWkBSUCYIgnCJffNtqvveHg5zT\nWjPVSzkpRAgEQRBOkZaaMB+6ZsVUL+OkmZl2jCAIgjBhiBAIgiDMckQIBEEQZjkiBIIgCLMcEQJB\nEIRZjgiBIAjCLEeEQBAEYZYjQiAIgjDLUVrrqV7DuFBKdQIHTvLjjUDXBC5nKpFrmZ7ItUxP5Frg\nDK11U7E3ZpwQnApKqQ1a63VTvY6JQK5leiLXMj2RaxkdcQ0JgiDMckQIBEEQZjmzTQjunOoFTCBy\nLdMTuZbpiVzLKMyqGIEgCIJQyGyzCARBEIQ8RAgEQRBmObNGCJRS1yqldiulXlRKfWyq1zNelFL7\nlVLblFKblVIb7NfqlVIPKKX22P+tm+p1FkMp9S2lVIdS6nnXayOuXSn1D/bvtFsp9eqpWXVxRriW\njyuljti/zWal1HWu96bltSilFiilHlFK7VRKbVdKfcB+fcb9LqNcy0z8XUJKqT8qpbbY1/IJ+/Xy\n/i5a69P+D+AF9gKLgQCwBVg11esa5zXsBxrzXvt34GP2448B/zbV6xxh7ZcBa4Hnx1o7sMr+fYLA\nIvt38071NYxxLR8H/r7IsdP2WoAWYK39uAp4wV7vjPtdRrmWmfi7KKDSfuwH/gBcVO7fZbZYBBcC\nL2qtX9JaJ4DvA9dP8ZomguuBe+zH9wBvmLqljIzW+nGgO+/lkdZ+PfB9rXVca70PeBHr95sWjHAt\nIzFtr0Vr3a613mQ/HgB2Aq3MwN9llGsZiel8LVprPWg/9dt/NGX+XWaLELQCh1zPDzP6/yjTEQ38\nXim1USl1s/3aHK11O1j/GIDmKVvd+Blp7TP1t/obpdRW23VkzPYZcS1KqYXAGqy7zxn9u+RdC8zA\n30Up5VVKbQY6gAe01mX/XWaLEKgir820vNlLtNZrgdcAtyilLpvqBZWJmfhbfQ1YAqwG2oHP269P\n+2tRSlUCPwE+qLXuH+3QIq9N92uZkb+L1jqttV4NzAcuVEqdPcrhE3Its0UIDgMLXM/nA0enaC0n\nhdb6qP3fDuCnWObfcaVUC4D9346pW+G4GWntM+630loft//xZoD/IWuaT+trUUr5sTbO72qt77Nf\nnpG/S7Frmam/i0Fr3Qs8ClxLmX+X2SIE64FlSqlFSqkAcAPwiyleU8kopSqUUlXmMXAN8DzWNbzL\nPuxdwM+nZoUnxUhr/wVwg1IqqJRaBCwD/jgF6ysZ8w/U5o1Yvw1M42tRSingm8BOrfUXXG/NuN9l\npGuZob9Lk1Kq1n4cBq4GdlHu32Wqo+STGI2/DiubYC/wj1O9nnGufTFWZsAWYLtZP9AAPATssf9b\nP9VrHWH938MyzZNYdzDvGW3twD/av9Nu4DVTvf4SruVeYBuw1f6H2TLdrwV4BZYLYSuw2f5z3Uz8\nXUa5lpn4u5wLPGev+Xngdvv1sv4u0mJCEARhljNbXEOCIAjCCIgQCIIgzHJECARBEGY5IgSCIAiz\nHBECQRCEWY4IgTCrUUqtdnelHMfnPqmUutp+/EGlVGQC1/QGpdSqYt8lCOVA0keFWY1S6i+AdVrr\nvzmFc+y3z9E1js94tdbpEd67G/iV1vrHJ7smQRgPYhEIMx6l1EKl1C6l1F1KqeeVUt9VSl2tlHrK\n7t9+oV2d/S2l1Hql1HNKqevtKvNPAm+z+9W/zT72afuYp5VSK0b4zruVUm9RSr0fmAc8opR6xH7v\nGqXUM0qpTUqpH9k9cMxMiduVUk8Cf6qU+n/2erYopX6ilIoopV4OvB74D3tNS8x32ee4yl7bNvt6\ngq5zf8L+zm1KqTPL/hcvnDaIEAinC0uBO7AqM88E3oFVcfr3wG1Y1ZcPa60vAK4A/gOrxe/twA+0\n1qu11j/AKue/TGu9xn7vM6N9qdb6y1i9Xa7QWl+hlGoE/gm4WltNAjcAH3J9JKa1foXW+vvAfVrr\nC7TW52G1Tn6P1vpprCrYj9hr2ms+qJQKAXcDb9NanwP4gL9ynbvL/s6v2dctCCXhm+oFCMIEsU9r\nvQ1AKbUdeEhrrZVS24CFWM24Xq+UMhtkCGgrcp4a4B6l1DKstgX+ca7jIqxhIU9ZLXAIAM+43v+B\n6/HZSqlPA7VAJfC7Mc69Aus6X7Cf3wPcAnzJfm4ax20E3jTOdQuzGBEC4XQh7nqccT3PYP1/ngbe\nrLXe7f6QUupleef5FPCI1vqNdm/7R+3jvo3V5/6o1nq04LLC6iH/9hHeH3I9vht4g9Z6ix2ruHyU\n85pzj4a55jTyb1sYB+IaEmYLvwP+1u5UiVJqjf36ANZ4Q0MNcMR+/BfmRa31TbarppgIuM/xLHCJ\nUmqp/T0RpdTyEdZUBbTbLZT/bITzudkFLDTnBm4EHhvh3IJQMiIEwmzhU1hunq3KGjz/Kfv1R4BV\nJliMNRv2s0qpp7BmXZfCncBvlFKPaK07sQTke0qprVjCMFLg9p+xJmk9gLXJG74PfMQOCi8xL2qt\nY8BNwI9sl1cG+HqJaxSEEZH0UUEQhFmOWASCIAizHBECQRCEWY4IgSAIwixHhEAQBGGWI0IgCIIw\nyxEhEARBmOWIEAiCIMxy/j8eNSgqSE6cEgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(meta_losses)\n", "plt.xlabel(\"meta-iteration\")\n", "plt.ylabel(\"meta-loss\")" ] }, { "cell_type": "markdown", "metadata": { "id": "D5hgIM4STRQ3" }, "source": [ "Our meta-loss is decreasing which means our learned optimizer is learning to perform well on the meta-loss which means it is able to optimize our inner problem. Let's see what it learned to do by applying it to some target problem." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "height": 296 }, "executionInfo": { "elapsed": 436, "status": "ok", "timestamp": 1647716639849, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "hpkbGP89JH45", "outputId": "17dfb41c-c701-4e7d-ae2b-0e05841df9c1" }, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'loss')" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAABlvElEQVR4nO3ddXgVx/rA8e8cjbsT4kJwCQ7BilOBuht1vfX23tv2Vm7td+tK\nqXuLtBQoVtzdJSQhhBD35OScHJvfH0mBlCQECQEyn+fJQ9idnZ3ZlvOend15R0gpURRFUdouTWs3\nQFEURWldKhAoiqK0cSoQKIqitHEqECiKorRxKhAoiqK0cbrWbsDJCggIkFFRUSd9XH7hIcxGHVIA\nCGQTP0f308j+06dBopUOtDhq/5QONNKBVjrROB1onU600glSINAghECgRaPRIIQWjVaHRqNFo9Ui\nNBqE0CA0mtr9f/2IM9NWRVHOf5s2bSqSUgY2tO+8CwRRUVFs3LjxpI+zOqxUWMqx20xYrSZMlioq\nTaVUWSqptlRitlRhsZqwW0wIkwWN2YrO6kBnc6BzgN4JeinQSYFGaHFq9KDV4dTqcWj1OLVa7Fod\nNi1YtRKbENRowKaRdX8KrAKsAkxaLSVaV4p1HpRqvSjVeFMmfHCI2v8czrofu3TiTTm+FONLCX6U\n4HvMjx8leDtKMTpqkDaJ0w5Om8RhkzitDqRNIB16kAaENCJwQaN1QatxQ6/3wM0jktDIwYTGJuLh\n568Ch6JcwIQQBxvbd94FglNl0BoIcA8EGgyIp8zplFTbHFRZ7FSZbZiqrJgrarBUWbGabNhMNpxm\nO06zHVljR1vjRFvjwNVcg5fVhicaXIUBKcyUGQQFRsh3dVLg6iDXaCfPaKfA6Emh0Z9UoytVepf6\nDdCBXmPFW1uOj7McH1mKjyytCxTF+Isi/DUF+GsPoxe2eofagfSCL9m8wRNzXiDeHr0IielAcEwc\nwTHxePj6ndFrpSjKuUmcbxPKkpOT5ancEZyrHE5JcWUNRYVVlBWaqSwxYymrwVZhQVtmwrWyGs8a\nKx5OiZvGAHo3il31FBo1FLgICo2CQhcNeUZJvsFJoYuGEhc9Nu3xj39crRbcrdV42KrxtJkIoIAo\nYzqx7rsI02ajszspKwqkJCeQihxvdNKdgKBgQiMiiIhPpH1iEm7ePmf/IimKctqEEJuklMkN7lOB\n4PxRVWOnoNxMUbGZsqJqTEUm7LklaIvLMVZW42Gx4uFw4ip01Li4U+LpSZG7C4VGDYUuggKjOBJA\nCoxQZtTWq9/HXkaYyCZMk02IzMXLZMJYrMGR74nV4gkOOzrpxM1oxNvHh4DgYMIiowgICsbb2xsv\nLy90ujZzk6ko55WmAoH6V3se8TDq8AjyJCbIs8lyNoeToqoatJU1kF+Be2Yu4YcKEAfLMFRU42a2\n4mF34uIRRF5IBFluGtL1FezyrGK3nyf7XFNwaN3AE/AEbaSNQGcR/uZyPCtrcC214mWuwivtAG67\n99Z7fO5iNOLj64OPj++R4ODt7X3kd09PTzQa9bKaopxLVCC4AOm1GkK9XQn1doVwH+gV0WC5jIJK\nCn5eRsC2gwzQeuPuEwmA3ZRNptjOwrBclkbrMbpHoDEEU+YeQKp7DLZQw5E6DHY7AdWVeFeW4Fle\ngk91JV6Fxfjl5aFD4PzbOYUQeHl51QsQcXFxREdHt9DVUBTlRNTQkILTKdmQWcLyBdsI2pxGD40b\nvl7hCI0WZ3URzuqtFBlWsjE6jcIwFzQe7TC6+VGoDSbXGc5hWzx5hFJidEMe8+aRh8WET2khPpVl\n+FSVE4yTUJ2WAJ1ACA0Wq42KykocDgcxMTGMGDGCdu3ateKVUJQLl3pGoDSbxeZgwe58Fq1KI2p7\nOoOFgRDPMIRGh9NcCqY9GLUbMHuvZG+MljJ/V9y9dLjpbViljoPlyWSV9SG7JpF8D08qfLUUu+io\n0B69+RROJ15VZfiVFxNir6GjVuKZexiLxUJSUhLDhw8nMPDMvt2lKG2dCgTKKSmosDBrWw5/bDhE\nx4zDjJKCSLdgNFo9zppKnOV7Mbqk4+G1DodfBunBLhT7u6J1qz3eZPKl+nAPqrN7UVIRQ16wlcoI\nLbZAN6qMBg7bIFujo0arx7+0kIElObTLPwx2G926dWPIkCH4+vq27kVQlAuECgTKadudU8HMLdnM\n2ZxDz5IKLnPYiTf6odUakTYzjqJd6Fzz8QjOQue7hzLXHAr8DZT6GEADNruR0oJELJl9Med1xWE3\nUOaZi71dNdVJ7VmBD5kGN9yqq+iVm0F8QTYudjvJycmkpKTg4eHR2pdAUc5rKhAoZ4zd4WRFWhEz\nNh9myc48elgdXOe0kaT1RKcxIu1W7AU70egLcY8y4xpWRJV2M0XGEor8DNj1GnBqqK6MoDi3G+b9\n/XGa/SnxzUGb0o7VBlc2alzQ2ax0PpxOp/xD+Dls9OvXjwEDBuDq6tral0BRzksqECgtosJi448d\nuUzffJhNB0rogYZbdJLOdgM6jEiHDUfhXqTlAC4J7ngkGLB57aDIspVCdzMW19p5DNbKEA7tHYUt\ncwAW12rcUgLZEODC/JraVBuxeQfpkpdFRI2JwSkp9OnTB4PB0HTjFEWpRwUCpcUdKqlm5pbDzNic\nzcHianpq9dzmYaBTlQOdw4B0OnAUp+Io2oUxyhX3folo2hVSUvUnOSIVi1GDw+pCVuYAbLsvxuY0\n4tXPmz3x3vxSZaJaoyO0OI9uuZkkVpYwfMRwkpN7o9VqT9w4RVFaJxAIIdoDXwMh1H6xmyKlfOdv\nZa4Hnqz7axVwj5RyW1P1qkBwbpNSsjmrjBmbs/l9Ww4VFjv93F24xdudjiUWNBYtUjpxlmRgy92C\nzqcGtx7RWAI2Uei1iTJvgdMpOJzXgZqdl2Mri8CnizuHewfwXVUZBUKHT1U5XXMy6Faax6ihw+jV\np4+apKYoJ9BagSAUCJVSbhZCeAKbgMuklLuPKTMA2COlLBVCjAWel1L2bapeFQjOHxabgyV7C5i+\nOZul+wqxOyUjg7y4zseT+HwTlNdON7Pnbcey/QdkdTF086R6QBnlSXakTlBYHkjVnvHUHOqHZzt3\nqoaG8JOjjH1ODa41FjrlHiA57yBjBvSjb8oQlUFVURpxTgwNCSF+A96XUi5sZL8vsFNK2eSMIhUI\nzk9FVTX8vi2HGZsPs+NwOVqNYGKUP1e7exC6twycTnReedgyl2LeugVbTQnVA5yYhjhw+IPVZKTo\nwGBM6ePQG33RDA9jrreZFRYHOoedhLws+mTvZ1yXJAaNGY9GDRkpSj2tHgiEEFHAcqCzlLKikTKP\nAR2klJMb2HcncCdAREREr4MHG02rrZwHUvMrmbH5ML9uOUxehYUkdyMvuXkRXFiDPswdn4lxCFmG\neds2qlf8QXH5Ckq7mrEmSbAJTAc6UpA5EVt5e1y7e7Cmmzdzqs3YEUQX5ZKcsYuLI4IZfPFluHp6\ntXZ3FeWc0KqBQAjhASwDXpZSzmikzDDgQ2CQlLK4qfrUHcGFw+GUrEwrYuqKDFbsL2KM3sgjwhU3\nmxOPAWF4jYpEY6xbrCd7JyWznyPNsZmqWA3CAPacQAoOXExFTm+M2iJ2JHvxe7tgqnV6gsuL6bV/\nGxPcBIMvnkhgRFTrdlZRWlmrBQIhhB6YDcyXUr7ZSJmuwExgrJQy9UR1qkBwYdqRXc7Hy9NZtj2X\nu4ULl0oDeOoJmBiPa0f/owWrCrCtf4/Nh74lP1CDuwvYql2oTBtK8YGRiIpq9gcXMXdAN0o8vfGq\nrqLH/m2Mqchh0NgJxPbqg0ajho2Utqe1HhYL4CugREr5cCNlIoDFwE1SytXNqVcFggvbgSITU5Zn\nsGdjDo84jcSgxRbrRcRVHdB6G48WtJqQm79h9d53SfOvIdwdHE6BOTeZot0jsZe147BPFn90j+BQ\nUAgu1hq6pO9k8PYVDIhNoPukq3CPiVUPl5U2o7UCwSBgBbADjmQjfgaIAJBSfiyEmApcDvw16G9v\nrKF/UYGgbSiosPD5igysq3O53qFHagTmfiF0Gh+L5tjV1xx25O5fWbzpNbZ5lZLoJXHRgNUUTdGO\nEVTmJGOKhz/CnOz2D0LrdJJ0cC/91i2iR85hktrHEJzcG9fu3XHt0QOhXkNVLlCt/rD4TFKBoG0p\nN9uYuSSdgFV5dHdoOWAAx6gIUgZEoNEc821eShwHljF71YusNR6iq7eTIL3EafeiLG0oxamDqYkJ\n5M/2VjZ5+eDQaIk5nEGfjYvpnrqb6MJyonv1IfzN/6Fxc2u9DitKC1GBQDnvma12VvyWSrvNRbhL\nmO8q8R8dxSXJ7THo6n+Lt+Zu45fl/+ZPZyrdvRx0cnUipRZzfh8Kdg6h0hjB6kQH6/38seiNhJbk\nk7xxCf23b6Cvhx+JH32Czs+vlXqqKC1DBQLlgmGtrGH3D3sIyKgkByefuzlIHhbFtX0icDfWX3Cv\nuiSdb5Y+zSzTTnp52hno4USnkdhNMRTsHEJRWUe2dJCsDQqmws2D8IJsLvnjO1KKShnw4RQMEQ2v\n7KYo5yMVCJQLjiW9jJyf92Eot7IQG1+62Ll0YBS3DIjCz71+Qrqy8kN8vuxpppVuoYe7gzEeDtwM\nTqTdm+L9g8jL6samWDeWRMehkZKLVs3hok1rGPPsS/j26dNKPVSUM0sFAuWCJO1OKpYcomLJISxI\n3nGa+VPn4Oo+EUweHE24b/2x/vzKw3y87GlmFm0myWjnKjcHXh5OQEtVbk+2ZQ1nRqd48rx86ZCV\nyoQFPzNh7MV0uOHm1umgopxBKhAoFzRbQTWlM/djPVDBYQ8tz1RXkImTS7uFcdeQWBJDPOuVP1ie\nyQcrn+WPoi1ECju3uVjx9dfjkBoObbiOOaHJrIoIxtNiYszqPxiDlZHPvoTeYGykBYpy7lOBQLng\nSaekelM+ZXMPIGscbG3nwj/ziii3ObgoKYi7h8SSHFX/AfCe4j28u/ZlVhZtI0bauc/PgtZDQ1X2\nQFZnXsOvffypMOjpnb6L4VtXcM0jjxMSG99KPVSU06MCgdJmOKqslM/OoHprIcLfhWVRrryxJ4fS\nahu9o3y5Z2gswxKD6k0k25S/iTfXvcLO0r3c6mahi58TaQ5g75oHmNUllu3BboSUFzFy03JGJsZw\n0XU3qtnJynlHBQKlzbGkllL6axqOEguGHoEsDNLx4bosDpeZ6RDiyd1DYpnQNRRd3eQ0p3QyLXUa\nb296i3aaCm73NaPXCEp3XsEiywT+SHZFSicp+7fSKz+Tq26dTLvomFbupaI0nwoESpvktDqoXJxF\n5fLDaFy1eIyL5k9h46NlGaTmV9HOx5U7U2K4Krk9robab/hF5iLe2PAGyw7OYbK/g2gXK+S2Y932\np/ittx8H/V1IyDvIoNSt9IyK4NLrbkCv17dyTxXlxFQgUNo0W56J0hn7sWZVYozzwfvSWFYUVfLR\n0nQ2HizFz93ArQOiuHVQNB51cxFW56zm5bUvEUMGl3jb0NXoOLxmMrOCBrKiowueZhPDUzcTUVnK\nuAkX061nT5W3SDmnqUCgtHnSKTGty6V8XibSIfEa0R7PweFszC7j46Xp/Lm3gOgAd96/rgedwrwB\nqHHU8On2T1mwbwo3+pnx1Uo0qV1ZkP0gv/V1o9xdR6+MnfQ4fIBQf18mXnkVISEhrdxTRWmYCgSK\nUsdRUUPZrHTMO4vRBbvhOykeY6QX6zKKeejHrZRUW/n3+CRu6Bd55Bt+RnkGr619nnjbWnq4OdBY\n25P+50381LETO6KMtCvOY+jeTXjabXTv1o2LRo7Ew8OjlXuqKPWpQKAof2PeXUzZb+k4Kmpw7xOC\n95hoypxOHv15K0v2FTKuSwivTOqKt2vt+L+Ukt/TZ7F45wuM8SwD4Yp+/9VMKxrC3GR3BDZStq0i\nxlSB3mBg6NCh9O3bF51O13RDFOUsUYFAURrgrHFQsSCTqtU5aDz0+F4ah7GjP1NXZvD6vH2E+rjw\n/rU96dbe58gxZZYyPt7wLOHVcwnRSRxiFLt+v4gf+kdwKFBP5wM76J+6Ha2LO76+vowaNYoOHTqo\n5wdKq1OBQFGaYM2upHT6fmy5Jly7BuBzSSzbSkw88P0WCiotPDmmA7cPiq73Yb4xdzWrtz1IJ0Mp\nJU5fXBZdy3eh/Vje2Q2f6lIuWvkHIQY9Vq2B6OhoRo8erZ4fKK1KBQJFOQHpcFK5PJuKRVlojFp8\nLonFGu/DE9O3s2B3PhclBfHGFd3wPSahnc1hY9rmJ/Eq/w2JQLOjPxsPXcYPw6IocxcM3LyYXhm7\ncIZEYJcaevbsyfDhw3F3d2/FniptlQoEitJMtnwTJdP2YztUiUtHf3wujeHbnbn8d+5e/D0MvHdt\nj+NSVaQXrmXz9nvwExXkZ3tj+HUkn06YyLZoFyKLshm5cBrBXjpKPcPRGwwMGTKUPn36qOcHylml\nAoGinATplFStPEz5goMInQafCTFkBBu5/8ctZJeaeWRkAvcMia23QprDUcOfW+9HW76YkkpB2NQI\npvd6hOn9w0DjZOSKX+l8cAdukWHkaILx8/Vh1OgxJCYmqucHylmhAoGinAJbkZnSaalYMyswJvhi\nGBfFPxenMnt7LoPjA3jr6u4EeNTPSHoobw67dj+OdNTg8YuBA1VX88Fl4zkYqKdrVipDFv1MjE81\nRf4JFONDTFQEo8eOJzg4uJV6qbQVKhAoyimSTolpbS7l8w4AAu9xUfwubDz/+268XPW8c013BsQG\n1DvGUpPH+q13YTPthM0ajLMi+Oz655kf54FXTQ1jF3xLUtlBIsOsbDH0pAYjvbp3YdjIMer5gdJi\nVCBQlNNkL7FQOmM/NWllGGO8KRkSyr2zd3GgyMSDI+J5YHg82mOGiqR0kJ7xLpkHP0AWg/8UPfMv\nepxvOnaj1F3DoN0b6L16Dj1jjNRoa9goumHQCgYPHEB4TCIeHh64u7vj4uKiho6UM0IFAkU5A6SU\nVG/Ip2xOBjglLiMieCW/mGlbDtMvxo93rulBsJdLvWNKS9eyfceD2CzF+EzTstLQlT8HPcGGYAPh\npaWM+eMLEjyM9E2UbMizk0ZUveO1Wi3u7u64u7sfCQ5//fn3393c3NBoNGfxiijnExUIFOUMspfV\nUDZzP5Z9pRgivVif5Mnjf6biZtDy5tXdGZIQWK+81VrC7t2PU1yyFJetgl0bDayf8B8W+cfiFJJR\n6+bRae8GBl88nsjcb6mocVI19CVMTj0mk4mqqqp6f5pMJpxO53HtEkLg5ubWaKA4NpC4ubmpt5ba\nGBUIFOUMk1JSvaWAst8zkDYHtv4h3Jeaw+78Su4ZGssjIxPQazX1ymdlfUZa2mtoS5yU/KFn6sBk\nSv0f5IC3gY5ZBxmx6BviIsLo55hP+7goNLfOBu3xKa6llJjN5iNBoaFAcew2u93eYB9cXFwavcsI\nCgqiffv2LXb9lLOvVQKBEKI98DUQAjiBKVLKd/5WRgDvAOOAauAWKeXmpupVgUA5lzgqrZT+moZl\nVzG6du5846vhg52H6RXpy3vX9iDMx7Ve+YqK7WxbNxkrxYhlRp4L1uIZ/C+2+ifgYbZy8bJfCMva\ni4euhqQO7eh48/MERESdVhtramqaDBTHbqupqTly3LXXXktiYuJpnVs5d7RWIAgFQqWUm4UQnsAm\n4DIp5e5jyowDHqA2EPQF3pFS9m2qXhUIlHONlBLzjiLKfkvHabGT09GHO/YdRugE/3dFNy7qWP/V\nULu9kh0r76LEuQ59up4pNk9SPTpSFXA/JS46JuzPZeTu7zmcV4pEEBQVS8eU4XQYmIK7j2+L9sVm\ns2Eymfjhhx+orKzknnvuwdPTs0XPqZwd58TQkBDiN+B9KeXCY7Z9AiyVUv5Q9/d9wFApZW5j9ahA\noJyrHFVWyn7PwLytEBnoyn8x80dhBbcPiubJMR0w6OoPFR3c9DYZRR8gamC/qSvv2rIweD9LjncE\nvTJqeOrQLFzL5rJX15f8Q9kIjYaobj3pmDKc2OS+6A3GJlpzegoKCpgyZQpRUVFcd9116iH0BaDV\nA4EQIgpYDnSWUlYcs3028KqUcmXd3/8EnpRSNvpJrwKBcq4z7yqm9Nc0nCYrW0Nd+MfhApLCvXn/\n2p5E+LvVK1uavpIdG+/AFmjFYErmP5YiDuvGUeozgthcK3ds3Me40OloL3+F3WvXsXvlUqqKizC4\nupHQbxCdUobTrkNHRAt8UG/YsIE5c+YwZswY+vXrd8brV86uVg0EQggPYBnwspRyxt/2zQFe+Vsg\neEJKuelv5e4E7gSIiIjodfDgwRZts6KcLme1jbI5B6jelI/F28DT5gr2CievXdGVcV1C65WtKc5l\n+4+TqEgqwMUSymcygJWW9lT53kZgpYNrlpWTErCLPo/ei1YvOLRrB3tWLCF17SpsNRa8AoPpOHgo\nSYOH4xfW7oz1QUrJDz/8QHp6Onfeeaea/Xyea7VAIITQA7OB+VLKNxvYr4aGlAuaZV8JpTPSsFfU\nsNgD/ltZwZX9IvjX+I646LVHyjktFva+fQN5XbYgNHo2WzsyxazHHPwYbjbB5UtMdLCZGTa5D+Ed\napPe2SwW9m9Yw+7lizm4YytISWh8Ih1TRpA4YDCuHqc/tm8ymfjwww9xd3fnjjvuQK8//i0m5fzQ\nWg+LBfAVUCKlfLiRMuOB+zn6sPhdKWWfpupVgUA53zgtdsr/OIBpXR4VLhqetlRiDXXng+t6EBN4\ndElL6XCQ++2bpGmmYGvnpOZQe57UGDC1+xcO6cY1a3OIOORB0sBQBkyKw8X96IdyZUkRe1cuY/fy\nxRQdOohGqyOmZ286DhlOTI9ktLpT/wBPS0vj22+/pU+fPowbN+60roXSelorEAwCVgA7qH19FOAZ\nIAJASvlxXbB4HxhD7eujtzb1fABUIFDOX5a0Mkpn7MdRYmG2zs4UUcOzk7pwWY/6wzmWQxnsmH8L\nFTGH0WUZ+ajUl42dn8KkbcdNqduI2B6Jm4eBlGsSiO0ZVO9YKSWFBw+we/mf7Fm5jOryMlw8PEkc\nkEKnlOGExCWcUsqKefPmsXbtWq677joSEhJO6zooraPVHxafSSoQKOczp9VBxfxMqlblUKKFFxwm\nYpJD+c8lnXE1HB0qklJy4M+XyHR8haiWpK3y4n9D7qfMoxdXlmyn/74OlGZVEdM9kJRrE3D3Pv4N\nIqfDQeb2zexevoT0DWux26z4hraj4+BhJA0ehndQ88f8bTYbU6dOpaqqinvuuQcPD48TH6ScU1Qg\nUJRzTE1mOSXTUnEUWfgdK38E6Pi/G3uSEFx/XL+8YAvbN9yO1ViObp6e5zpcx97YSxjizOd+eyK7\nZmei1WkYeHkcSQNDG/22X1NtInXtKnavWEz27p0AhHfsTMfBw0noNxCj24mznv71Sml0dDTXXXed\nSoZ3nlGBQFHOQdLmoHxRFpXLsylG8pbGwvjLkriqd/3UDnZ7FbvW30+RZQXGHYIZeSP58aLbSdA7\n+SA6kdSf08nZX0a7RB+GXt8BnyC3Rs5Yq7wgnz0rlrB7xWJKc3PQ6Q3E9u5Hx5RhRHXtiUarbfTY\ndevW8ccffzB27Fj69m1y7qdyjlGBQFHOYdZDlRT+vA9ZaGYeVhgZweQR8fXKSCnJPvQ1+/e/jCh1\nkrGoGy9d9gjubu580zUew44yVk9Pw+GQ9Lk4mu4j2qPRNj23QEpJXloqu5YvZt/q5ViqKnHz9iFp\n0BA6Dx3ZYGoLKSXff/89GRkZ3HXXXQQFBR1fsXJOUoFAUc5x0u6kfHEWFYsPsR47hqsTuKTH8XMC\nKiq2s33bPdRY8ihfEMHz/Z6mwjeAD7rEMFTvyrIf9nFgWxGBEZ4Mu7EDge2b9wqp3WbjwJYN7F6+\nhIzNGwDJDa++Q2ADwaCqqoqPPvpIvVJ6nmkqEKh544pyDhA6DT6jovC6NJZ+6Cj+eR9r0oqOK+fl\n1ZW+/eYQEDQc77FZPJ/7NNEH05i84wBf5B5mzF2dGX1HZ6rKavjllY2smZmO3eo44fl1ej3xfQZw\n6WP/5I73P8Pg4sqybz6joS+KHh4eXHrppRQUFPDnn3+ekf4rrUsFAkU5h/j0D8MwNJyRUs+6L7aT\nml95XBm93oeuXT8hNvYJvJPKedD3OQbtWcNLeRU8/McyIrr6cd1zfUnsF8Lm+Qf58aX15OwvbXYb\nPPz86Xf5tRzcvoUDWxu++05ISKBPnz6sXbuW/fv3n3J/lXODCgSKco4JHB0FPQO5wqFn+kcbya+w\nHFdGCA1RkXfRs8e3+Lo7uS3hTcbsn8nPbr5c+f0sKg6kMuKmJC55qDvSKZn5vy0s/W4vNeaG1yb4\nu+6jx+EbGsayrz/D0ch6BiNHjiQoKIhff/0Vk8l0Ol1WWpkKBIpyjhFC0O6KRKxx3txo0fHxBxuo\ntNgaLOvr14/+feYQYBbcGPctE4s/YmO7CC7ZlsGW/71NWHsj1/y7L90uas/ulTn88PxaDmwrPGEb\ntDo9KdffRklONtv/nNdgGb1ez+WXX47FYuG3335rcBhJOT+oQKAo5yChEUTf0hlzqBs3l0ve+mQj\nVvvxy1MCGL3jSE7+mYhDVq7wW8Rkx8vkBwZwXWx3Zt95L+blixl4eRyXP5mMi4eeuR/tYP6nO6mu\nsDbZhtjkvrTv2IXVv3yPxVTVYJng4GBGjhxJamoq6iWO85cKBIpyjhI6DbF3daPGx8i1uTbe/mZL\no9+6RVgP4ru+Srcd5QzU7uYZzZM43e08eMt9fP/5t2Tfcy9++kqufLo3fS+JIWNbId8/v5Y9q3Mb\nr1MIhtw0GUtVJWtn/NRoO/v27UtcXBzz58+noKDgjPRdObtUIFCUc5jGRUfcfT1wuOoYt6+KT3/d\n3XjhHtcTEHM9KRsKiNHaeUX7OH6aLP5zx8N85uZL2oSLKfvyc3qNbMc1/+qDX6g7i7/ew+/vbqWi\nyNxglcHRsXQaMoItf/xOWV7DSYGFEFx66aUYDAamT5/e6BrJyrlLBQJFOcdpPQ3E3tsdg05Lz3VF\nTFua0XjhMa/h4teVYesOEeaVzH81/6azcw0fX3wV797/ODlvvs2BSZdjzNnHxEd7knJNAnkZFfzw\nwjq2LsrC6Tz+7mDQ1Tei1elY/v0XjZ7W09OTSy+9lPz8fPVK6XlIBQJFOQ/oA91of2dXfIWGoHlZ\nLN2W00hBF7j6GzRCS78Nu4mNfpqHxNuMdU5nZkwSz7//OeU2Owevu568554lqZs71z7Xl3aJvqya\nlsb01zdRfLj+8wAPP396X3o5+9etPpKnqCGJiYn07t2bNWvWkJ6efia7r7QwFQgU5TzhGuGF/00d\niUCL+cdUtmeWNFzQJwImTYX8XSRsWUX/5N+4WMzmTvk+a4SWh557A/Pd91I2YyYZ48bjWDaPcfd0\nYeTtHakoMvPzyxtYP/tAvWcHyRMm4uEfwNJvpiKdDT+0Bhg1ahQBAQHMnDlTvVJ6HlGBQFHOIz5J\n/hgnxtJJakj9dDtZhY182MZfBEOfgu0/4p+6mgmDV9JZn8dT8gUOmcu4sVcK5T/+giEyktynn+bQ\nzbcQ6V/Ndc/3JbZXEBtmHyBt49EHv3qjC4OvvZn8jDR2r1jSaPv0ej1XXHEFZrOZWbNmqVdKzxMq\nECjKeSa0bxiOYe3p59Cy7P2NFFfVNFww5QmIGwnznsKQv5erBs8myrcj/3Y+jbDlcVWxhR3vfEjI\nf/6DZd8+Mi69jMrPPmL4tTEERniy8pf99SagJQ0cQnBMPCt//Bqb5fhJbn8JCQnhoosuYt++fWza\ntKnRcsq5QwUCRTkPRY+OpqpnAMNqNMx8dz0WWwP5hDQamDQFPELg55vAVMwlPV6hQ+z9PGx/hgjn\nPu7YfZDv+6UQM3cO3uPGUvzRx2Reeil9OlqorrSyftbRB9NCo2HozZOpKilmw+8zmmxf3759iY2N\nZd68eRQWnngCm9K6VCBQlPNU4pUdKIn1YkwFfPPBBhwNvPGDmx9c/TWYCmH67eB0MCD6Kob2+oor\nrP+jn1zJixl5PFlQRuArrxLx5RcIrRbz03eREAM7lmZTmHU031F4h04k9B3Iht+nU1lyfFK8v2g0\nGi677DL0er16pfQ8oAKBopynhBB0ua0rBcEujM6z8u2XWxsekw/rAePegIwlsPQVAOL8O3PjkHkk\n22ZxsZzO9/lVXLN5G47k3kTPnIExqQOh057H6Kph2Q/7kMcEmcHX34p0OFj14zdNtu+vV0rz8vJY\nvHjxGe27cmapQKAo5zGhFfS4vyeFXnoGp1YyY+behgv2vAm63wDL34DU+QD4u/rz+EW/E6XN43r7\nB6ytsDN67XqypYbwd9/DIGzEZ88h/0AFu1cdfV3VJziEHmMvYdeyP8nPSGuyfR06dCA5OZnVq1eT\nkdHE/AelValAoCjnOaHX0u2hZMqNGjqvL2ThkgMNFBIw/v8gpCvMuANKassYtUaeSZlCQkASV9e8\nRL7Vwaj1m9jmbqTd228RsHMu/qKINTPTMVcezU3Ub9LVuHp61b5OeoI3g0aNGoW/vz8zZ86kurr6\njPZdOTNUIFCUC4DOXU/Cgz2p0QqC52exYWsD6SD0rnDV17W//3wT2GrTSgghuKvHI4xIuJ1hlc+j\nd1Zy+dYMspJiCH78ceLWfYS12sbqmUcniRnd3Blw1Q1k795J2oY1TbbNYDBwxRVXYDKZ1Cul5ygV\nCBTlAuHu70b4XV0xCIH8KZX9GQ1MOPOLholTIG87zH2s3q7xsRdzf/83SCr7H3pZw93bNuF5042E\nDutF+6yF7F2dS05a2ZHyXUeMxj88guXffoHD3nCa7L+EhoYyYsQI9u7dy+bNm89Ed5UzSAUCRbmA\n+EV443FDEn5SkPfZTvIamnCWOAZSHoct38Kmr+rt6hHUgxeHvkcn0+fsc4Ty2s4/CH3pRRKN6bjU\nlLL0q504HLUzizVaLUNuvJ2y/Fy2zJt9wrb179+f6Oho5s2bR1FR428cKWefCgSKcoEJ7xSI/eJo\nIhyw/b3NVDQ04Wzo0xAzDOY+Djlb6u1K9Eskyt2TDo51fFISzObyfUS99xaJh2dTWmhl27yjD32j\nu/ciqltP1k7/keqK8ibbpdFomDhxIjqdTr1Seo5psUAghPhcCFEghGgwS5UQwlsI8bsQYpsQYpcQ\n4taWaouitDUdBranaHAYHa2w8u2NWP8+4Uyjhcs/A/fA2ucF1fWHke7ocgeVeV/hioWH9mZhD/Sg\n27O341+8g/W/Z1BRfDRt9ZAbb8dqNrNm2g8nbJeXlxeXXHIJubm5LF269Ex0VTkDWvKO4EtgTBP7\n7wN2Sym7AUOB/wkhDC3YHkVpU3qNjyOziy9dq5zMf3cDzr8ni3P3r314XJELM+6EY/Z3CexCj6CO\nRFX9RLqM4IUtP+MxcCD9h/kgHU6WvLHwSNmA9pF0vWg02xbOpfjwoRO2KykpiZ49e7Jy5UoOHGjg\nDSflrGuxQCClXA40kh6xtgjgKYQQgEddWXWvqChn0MDrOpEa4UaPQhvzP9t2fIHwXjD2VUhbWDvH\n4BiTu0wmr2QZvXUZfFPdg8XpPxFxz00kGtPJLvNg74/Lj5QdcOX16I0uLP/282a1a8yYMfj7+zNj\nxgz1Suk5oDWfEbwPJAE5wA7gISllg/lthRB3CiE2CiE2qrwlitJ8QgiG3tWDvX56uqRXseSnXccX\nSr4dul5TO+t4/6Ijm/uG9KWzf2cshd/ipbHyzCEjxeXbGfzfm3G3lbB6fj7VGQcBcPP2oe/Eq8jY\nvIGD27eesF0Gg4HLL78ck8nE7Nmz1Sulraw1A8FoYCsQBnQH3hdCeDVUUEo5RUqZLKVMDgwMPHst\nVJQLgEarIeWhZFLdBDFbilm/4G8zfIWACW9BUEeYMRlKD9ZtFrV3BZX7ucGvnINE88L2PxBukiE3\nd8Zs9GfFv7/HWfeNvufYS/AKDGbpN1NxOhtIgvc3YWFhDB8+nN27d7N169Yz3W3lJLRmILgVmCFr\npQEHgA6t2B5FuWAZjDr6PNybQ3qB/+Jsdq8//LcCbnD1N+B01E02q00zPSxiGDHeMWzO/IxLfJ1M\nd1zErB3/R9TgRKKjNKS792T/M68gpURnMJBy/a0UZWWyc8nCBlpxvAEDBhAVFcXcuXMpLi4+091W\nmqk1A0EWMAJACBEMJAIqGYmitBAPLyPx93WnWAOaGekc2ve3D17/WJj4MeRuhXlPAqARGm7vcjup\npalc5luFr9bJK6U9OXDoB4bc0x+dTsOmnBCKP6tdzzih30DCEjuy6qdvsZpPPPb/1yulWq2W6dOn\n43Cc+E5COfNa8vXRH4A1QKIQIlsIcbsQ4m4hxN11RV4EBgghdgB/Ak9KKdUsE0VpQcEhnvjf2oka\noPSrXRQfrqxfoMN4GPgwbPoS9tROEhsbPZYw9zB+2PUpbybFc0hE8nr6PpyadPpe0YFSvyR2fr0Y\n0+rVtc8kbrqd6vIy1v36S7Pa5O3tzSWXXEJOTo56pbSVtORbQ9dKKUOllHopZbiU8jMp5cdSyo/r\n9udIKUdJKbtIKTtLKb9tqbYoinJUbLw/zivjMDgh86OtVJf9bbWx4f+CkC4w+x9QXYJeo+fmTjez\ntXArfo50rgpyZxaXMmP76yQN8iGgnRtpiVdx8NGnsWZnExqXSNKgoWya8ysVhQUNN+JvOnbsSI8e\nPVixYgWZmZlnvtNKk9TMYkVpg3r0CiP/onB87JLt72zCZjnmzW2tHi79EMwl8McTAEyKn4Sfix9T\nd07lxYRogvSCd2quYGfq86Rc34EarQdpISPIfuBBnGYzg669GYFgxQ9fNdKC440ZMwY/Pz9mzJiB\n2Ww+8QHKGaMCgaK0UUMuimF3cgChZgcb396I037M+Hxo19p8RDt+gT2zcdG5cGPHG1l1eBU5Fft5\nKymWHBHOhwUeOI3z6DQ4jOyQQRQdqiT32efw9A8g+eKJ7F21jJzURtZI+Buj0cjll19OVVWVeqX0\nLFOBQFHasIuv6MjaWA/al9nY8NGWeiuRMfjRekNEVydejYfeg6k7pjLM34sbQv2YKy5lduoPdBll\nx8VDT0bKw5T/PpuSr76i96VX4O7jy9KvP232h3q7du0YNmwYu3btYtu2BibAKS1CBQJFaeOuvL0H\ny4L0tDtsZvM3x6QG+9sQkafBk2s6XMPCgwvJLM/k+bh2hBn1fMy97Eh7kn4TwyiudqV0xG0UvPF/\n2LZuY+A1N5K7fx/7Vi9vvAF/M3DgQCIjI5k7dy4lJU0lJ1DOFBUIFKWN02gEk+5PZqWnIHhPGbt+\nTT26829DRDck3YBBa+CLXV/godPyVococgnmK3M/pM9HhMZ5s8fYB2I6cPgfj5AQl0RgVAzLv/8S\nm7WBLKgNtkfDpEmT0Gg0zJgx4/gcScoZpwKBoii4GnSMfqg36w0S77X5ZCw9eHTnMUNE/lIwMW4i\ns9JnkWfKY7CfJ7e0C2CemMDSvL10HrcXm8VB9tgnkDYbOQ89TMrVN1JZVMjmOb81uz3e3t6MHTuW\n7OxstZDNWaACgaIoAPh6GOl9f092aJ1o5h1kycqDtWP7fxsiuqXzLUgp+WpX7RtB/44JJcLFyFTt\no6Tnv0mXkU5St1ehe+I1LLt3Y/hpGrG9+rLu118wlZU2uz1du3YlMjKSRYsWYTI1sMCOcsaoQKAo\nyhHtgzyIvK0zNQLKZmdwxUer2XSwpN4QUbtDWxgfM57p+6dTainFXafl7aQIcp2+/CSuRxvyBp7+\nsGGfO37330/5b7Po5uGPw2Zl1c/Nny4khGD8+PFYrVYWLVp04gOUU6YCgaIo9XSM9SdkfAy90RFS\nUMPlH63h7m82kdHhriNDRLfFXY7Zbub7vd8D0N/HgzvCA5gnh7PBbCRx7CyKD5vITZyAx/DhWD74\nkE7dktm5eCGFB5u/BkFQUBD9+vVjy5YtZGVltVSX2zwVCBRFOY5X/3bogt34p4sHjw2PZ8X+Qka+\nu4YPvB5BmkuIXfUhw9sP57s932Gy1Q7bPB0TRrSrgS90T1FoXUjMwJ2sn52J1zMvYoiIIOz3+Rhc\nXVj6zWcnNUdgyJAheHl5MWfOHJWLqIWoQKAoynGEVuBzSSzOshpu0rqw9PFhXNcngrd2uvCB/TLY\n8Qs3uSVQaa3kl321OYXctBre6RBBnsOVGYZHcWk/Ba1bLmvm5hD+wfvozRYSq2xk7djKgS0bm90W\no9HImDFjyM/PZ8OGDS3U47ZNBQJFURrkEuuDa5cAKpcewtchefGyziz4Rwp74u9klzOSqHmvE+3a\nia92f43VYQWgj48Hd7UPZK6tJ7tET2JGfEbGtmzyzd6Evf4aYTv34qk3sOybz3CcxOL1SUlJxMbG\nsnjxYiorK098gHJSVCBQFKVR3uOiASifWzuuHxPowQc39kVc9iHeVDIkLZ8icyGvrvjmyHDPk9Gh\nxLsZmap9mCpNHuH9p7P8x1RcU4YSdM+9JKQepCQnm+2L/mh2O4QQjBs3DofDwYIFC858R9u4ZgUC\nIcRDQggvUeszIcRmIcSolm6coiitS+frgueQcMzbi7Cklx3Z3rHnILRDHucR62Z8rP78tP9rrvpk\nFVuySnGtGyIqsGn4zeMV3MMW43RZzub5WQTcfx8xPXrjX2Vh9Y9fY6mqanZb/P39GTRoEDt27FCL\n3p9hzb0juE1KWQGMAgKpXV3s1RZrlaIo5wzPIeFofY2U/56OdBx9yCtSHkOEdOG56lw0hhLSqlcz\n8cPV3PfdZvxscF9EELNN4ex3u5Kwvt+wY8V6yosstHvjdbpIHZbqalZ/89lJtWXQoEH4+PgwZ84c\n7CcxtKQ0rbmBQNT9OQ74Qkq57ZhtiqJcwIRei8/4GGx51ZjW5R7dUTfRbHh5ETHChajYtTw0Io4l\n+wq46M1lmHcWE+dq5CP7dVh03oT2/ZgVP+1E4+lJ53ffp32Fma1LFlKSldnstuj1esaNG0dRURFr\n1649851to5obCDYJIRZQGwjmCyE8AZUARFHaCJdO/hjjfChfcBCHyXZ0R2hXNCmPc3vBYdLK9tMr\nKZ+ljw/l6t7t+WHdIYpW5FBgdTLT838YfbKoMUwhfXMhxrg4Uu5+AI3TycLn/3lSr5MmJCTQoUMH\nli1bRllZ2ZnvbBvU3EBwO/AU0FtKWQ3oqR0eUhSlDRBC4HNxDNJqp2JBZv2dgx9lrEc0YQ7Jp1s/\nItDDyMsTuzD/4RRSgr3RZFQwp8KF1dqH8I1fwqZl32O12Am5bCKdo+LINpWz54N3T6o9Y8aMQUrJ\n/Pnzz1wn27DmBoL+wD4pZZkQ4gbgX0B5yzVLUZRzjT7YHY/+YZjW52E9fMxDXq0e/aUfcXN5OVuL\nd7IpfxMAcUEeTLkpmV9GdMLN4uBD8wAOVHXCt+NU1s1dDUDKi6/iKrSsWjAH06bmzy3w8fFhyJAh\n7Nmzh/3795/RfrZFzQ0EHwHVQohuwBPAQeDrFmuVoijnJK+LItG46SmblV5/OCe0K5O63YWfw8HU\nta/UO2ZAtD+/DUxCY9TxtuMxHALKHf9h1YaDGFzdGHLbXVS4Glj79BPY8pu3xjFA//798ff3Z+7c\nudhsthMfoDSquYHALmv/q18KvCOlfAfwbLlmKYpyLtK46vAeE4X1YAXmbYX19rkMeZIbnW6sKk9l\nz+H6D3K7eLnxSFQIRT5erNC/hKt/JqmbX+L+7zbh3mMwQe3as8fTyMEHH8BptTarLTqdjvHjx1Na\nWsqqVavOWB/bouYGgkohxNPAjcAcIYSW2ucEiqK0MW69gtGHe1A29wDOmmNy/2j1XD3qfTycTqYu\nffK44x6MDKarhyu/GeKwiGuJiF+MMWcxI99eQWanCdTotOzKO0T+Sy83uy0xMTF07tyZFStWqNXM\nTkNzA8HVQA218wnygHbAGy3WKkVRzllCU5eHqMJK5ZL6GUE9I/pxjU8XFtqKydz8Rb19eo3gnaQI\nKuwOZgbcit0Uw+guX3NTNx1fpMEBzzjSQ/zJmzGd0p9+bnZ7Ro0ahVarZe7cuWrB+1PUrEBQ9+H/\nHeAthJgAWKSU6hmBorRRxggv3HoGUbniMLYic719N4x6BwOCL9a/DtX1v6UnebjyeHQIs4sqyY1+\nC4RkgNc7zHuoP45e47EjWBcbTc4LL1K1eUuz2uLl5cWwYcNIS0tj7969Z6yPbUlzU0xcBawHrgSu\nAtYJIa44wTGfCyEKhBA7mygzVAixVQixSwix7GQarihK6/IeE43QaSifnVFvu797MBMjRjLLRUve\nnIeOO+7e9kF093Tj9TKJxfo4GPZSnf0eH90zkoiUMZhcJAe8/Ng5+R7WrG/eB3ufPn0IDg7mjz/+\nwNrMZwzKUc0dGvontXMIbpZS3gT0Af59gmO+BMY0tlMI4QN8CFwipexEbZBRFOU8ofUy4DUiAsve\nEsx763/zv6XPo0ih4avc5bBndr19Oo3g3aQITA4n09sPovLQcIorvqag4E8m3nozrl7eZCd3xGg1\nk/3Io6zeX/+hdINt0WoZP348FRUVLFumvlOerOYGAo2U8tj3uopPdKyUcjnQ1NOb64AZUsqsuvLN\nf29MUZRzgseAMHSBrpTPzkDajyYbaOfRrnY5Sy8vSuf847ghogR3F56MDmV+aSWHE5/EUhrOzh2P\nITUVDLzqekwFh7Ddcj2dizL46r9TSSs4cerpiIgIunfvzpo1aygsPHHwUI5qbiCYJ4SYL4S4RQhx\nCzAHmHua504AfIUQS4UQm4QQNzVWUAhxpxBioxBio/oPrCjnDqHT4DMhBnuRmapVh+vtu63LZMwC\nvtfb4I8njjv2rvaBJHu58Z7NTGnp0zgcNWzf9iCdho3APzyC7VmpyPh4rt8+mzumrqaoquaE7Rk5\nciQGg4E5c+aoB8cnobkPix8HpgBdgW7AFCnl8e+HnRwd0AsYD4wG/i2ESGjk/FOklMlSyuTAwMDT\nPK2iKGeSS6IfLkl+VPx5CEfF0Q/rWJ/Y2uUs/fwx7Zx23BCRVtS+RWRxOpndLZb8zTdRadpM5sH3\nGHrj7ZTn51E2ehiBphL6bVnInV9vxGJreqlKd3d3RowYQWZmJjt27GiR/l6Imr0wjZRyupTyESnl\nP6SUM8/AubOBeVJKk5SyCFhObZBRFOU84zMhBulwUv5HZr3tk7tMptJp5Zd28TD7+CGiWDcXnokJ\nZWmViUPdJ1KWMZiDBz/Cs72J6O692LB2BRUD+nJD2hIy0g7x6C/bcDqb/qbfq1cvwsLCWLBgARaL\n5Ux39YLUZCAQQlQKISoa+KkUQlSc5rl/AwYLIXRCCDegL7DnNOtUFKUV6Pxd8UwJp3pLATUHj340\ndAnsQt/Qvnzt4YLVXNLgENHk8ED6ebvzhVsNuQW3YqsKZ9fuRxl+x434t49klbmEXFcd/zNtYM72\nXP63cF+TbdFoNIwfP56qqiqWLFlyxvt6ITrRA19PKaVXAz+eUkqvpo4VQvwArAEShRDZQojbhRB3\nCyHurqt7DzAP2E7tq6lTpZSNvmqqKMq5zXNoe7Rehto8RMd8a5/cZTKFNWX81uMy2PHLcUNEGiF4\nOykCm4RFQ4LJWnEnNms16VnPcuW/XyA8qTPb2gdi3b2KuyMFHyxJ5+eNh5psS7t27UhOTmb9+vXk\n5uY2WVZpwTWLpZTXSilDpZR6KWW4lPIzKeXHUsqPjynzhpSyo5Sys5Ty7ZZqi6IoLU9j1OI9Phrb\n4SpMG/OObO8b0pfO/p353JqDPaRzg0NEUa5G/h0bylqrhQO9upC36TrKytaRlvkfBtzajYQhwWR2\nd6fbjk8ZHu/KMzO2szqtqMn2jBgxAldXV+bMmYPTqZZPaYo4356sJycny40bm5+uVlGUs0dKSeEn\n27EXVhPyaDIat9qUZH8e/JOHlz7Ma13uY9zsf0KniXD51HrHOqXkyq3pbK2o5q4/yuiW/BM6v3kN\nnsfu1FFl8yDQOxhP1wD0el/0Bj8Mej/0er+6331JTy9g/vxVjBlzBb169W7x/p/LhBCbpJTJDe5T\ngUBRlDPJmlNFwXtb8Ogfhs8lsQA4pZOJv01Ep9ExzTMZsexVuPo7SJpQ79gscw3DNuwj3q5hws8F\njL4zmvAkLVZbKTVVuax/9wVKPCX+MaHswA0vYzU92oF0lmGzlWC3NzzfQErQ670xGPzR633rAoYv\neoP/kd8Nhrpten8MBj+0WtcWv1ZnU1OBQHe2G6MoyoXNEOaBe99Qqtbm4N4nBH2IOxqh4fYut/PP\nlf9kRbf7SNnXpXaIKHIAuPkdOTbC1cjzcWE8vi+buF5euH6Xzdi7uxAWHwFe3Ric4mT5qy+yN9RK\naKfOvGbpTYc8P76/ox8uei1OpxWbrQyrrQSbtQSbrYTi4izWrVtIaJgXfn5e2KwlmM1ZlFdsxWYr\nRUp7g/3QaFxqg4TBl+DgS4iMmHy2LuFZp+4IFEU54xwmG3n/txFDmDsBk7sghMDmtDFhxgSC3IL4\nuvtjiKnDGxwiklJyzbYMNpRV8fAaC/rsalKuSaDT4HZIKTl47XWklhaww9cNt4gE3mQQo7tH8u41\nPdBoRIPtmTdvHmvXrmXy5MmEh4fXO5fdXonNVozNVorVWlL7p60Em7V2m6k6g4qKrSQmvEB4+PUt\net1aUlN3BC32sFhRlLZL667He1QkNenlmHcWA6DX6Lm5081sLdzKJo0NUh5v8C0iIQRvdmiPViOY\nM8IHty6+LP1uH8u+34fTKQl+6knaH8xhUFwnLNlp3Fu9gEVbMnhzYWqj7Rk6dCienp7HPTgWQqDX\ne+HmFo23d08CAy8iLOxKoiLvIj7+GTp2fINePX8iwH84+1Kfp7BwUctcsFamAoGiKC3CvW8o+lB3\nyudk4LTWzgieFD8JPxc/pu6cCoMfhZAuDb5F1M7FwDtJEey31PBSRw0lE0LZsfwws97eCrEd8Ro3\nDu/Z8xl/272IsnxuLZ3Nlwu3NvpaqYuLC6NHjyY3N5eTHVHQaHR07vwOXp6d2bnrIcrLt57K5Tin\nqUCgKEqLEBqBz8WxOMpqqFyWDYCLzoUbO97IqsOr2FOWBpd+CI1MNBsX6MOS3ol08XTjA/caFl0b\nyv68Sn5+ZQOa6+4FpxOPhYu54p8v4Gav5vrC33j9x+WsTm/4tdJOnToRExPDn3/+SVVV1Un1Rat1\no2u3TzEagti2/Q6qqw+e/AU5h6lAoChKizHGeOPaLZDKZdnYS2rTPVydeDUeeg+m7pgKoV0bHSIC\niHQ1Mq17LC/Ft2MbVj6d4MvmEB2/fX0I06QHKf9tFv52uOq5V/A2CCbl/sYzU/4greD4D3ohBOPG\njcNms7Fw4cKT74shgO7dPwckW7fdhtV64SyNqQKBoigtyntsNEJA+dzaBWw8DZ5c0+EaFh5cSGZ5\nZpNDRFA783hyeCCLe3cgydOVn7oYmT7Um0VlMRxIupL8114jKCqGa194HV8vd8ZkzeSx96ZT3EC2\n0oCAAAYOHMi2bdvIzMw86b64uUXTtesn1NTksm37nTgc5hMfdB5QgUBRlBal8zHiOaw95p3FWNJK\nAbgh6QYMWgNf7PoCtPomh4j+Eu1mZGaPOJ6PDWO/v4ZPJ/jye/Io1lp6UTLvT3xD23HDS2/gHRBA\n//0zePKtHxvMVjp48GC8vb2ZM2cODkfT2Uwb4uPdi04d36KiYiu7dv0DKU++jnONCgSKorQ4z8Hh\naP1cKPs9A+lw4u/qz8S4icxKn0WeKe+EQ0R/0QrB3RFBLExOJMHHjZkDPPjkkv788Hs5pdnlePoH\ncPPLb+AREk7izmk8/9bXx2UrNRgMjB07lsLCQtatW3dK/QkKGk1C/L8oLFpI6v6Xzvu1D1QgUBSl\nxQl93QI2+dVUralNAndL51uQUvLVrq9qC51giOhY8e4uzOoZz79iQkkLN/DmxVE8//U2snYX4+bl\nzR2vvIGuXRyBm6bz1jufHXd8YmIiCQkJLF26lIqKU0uk3L79LUS0v53s7K/JOnT8Oc4nKhAoinJW\nuCT5YUzwpWLRQRxV1qPLWe6fTqmltNlDRH/RaQT3RwazsE8Hwk0l/NTfi1s3pLF8USZ6F1fuf+VV\natp1hLW/MvWdj+p9axdCMHbsWJxOJ/Pnzz/lPsXFPUVQ0DjS0l4hP7/xO5lznQoEiqKcFUKI2gVs\nrE4q5te+fnlb59sw2818v/f72kLNHCI6VgcPV+Z0CueW36ezN9zArdYSXv9pF0Kj49FXX6Y4rBvl\nq+fwwzvvIo+ZTObr68vgwYPZtWsXaWlpp9gnDR2T/g9v72R27X6c0tL1p1RPczgdDmw1LbPQjgoE\niqKcNfogNzwGhmHamIc1u/LocpZ7vsNkM9UWOokhor94JiXxoJvk49f/TZBO8FawnYkzt1FUbuOJ\nl54lM6QXuWsW8vObb+CwH80tNGDAAPz8/Jg7dy52e8M5h05EqzXSresnuLq2Z/uOuzCZTi2oNKY0\n9zArvv+SKffdyua5s85o3X9RgUBRlLPKa0QEGnf9kQVsJneZTKW1kl/2/VJboN4QUfOXRg988EHi\nC3P5fvZX3ObiySZ/wbCNe1lwoITHn32CbcH9yd6wgumvv4TNWvtqqV6vZ9y4cZSUlLB69epT7pNe\n70P3bp+j0RjYuu02amoKTrkuAFuNhd3LF/PTf57i84fvYsOsGQRHxxIS1+Cy7qdNBQJFUc4qjYsO\n77HRWLMqqd5ScHQ5y91fY3VYawuFdq29M9jxMxxY3qx69UFB+E++HcuC+fxLX8pPUe1xs8ODhfk8\nu/MQd/7jblYEDiFr20amvfwcNdXVAMTFxdGxY0eWL19OaWnpKffL1TWcbl2nYrOVsm3bZOz2k5u9\nLKUkL30/i6Z+wMd33cQfH7xJVXExg665iTs+/JyJTz5HZJfup9y+pqjso4qinHXSKSn8aBv2Mgsh\njyazvnQjdyy4g2f7P8uVCVfWFrKZ4YO+oHeFu1fW3imcgNNsJn3MWHSBgUT9/BMVJhsPz9nNvEDw\nkxpu8vFi0Ze/MbpoMcFR0Vz+zAu4eXlTXl7O+++/T3R0NNddd91p9a2oaAnbd9yFn+9AunadgkbT\ndLvNVZXsWbGUnYvnU5iViU5vIL7fQLoMG0l4UmeE5sx8X1fZRxVFOacIjcDnkliclTYqFh86upzl\njs+xO+vG6vWuMOZVKNwL66c0q16NqyuB/3gYy86dVMyZg7enkc+u7MZ/y13BZOetynL0l13Eb2Hj\nKMg6yI/PPUlFUSHe3t4MHTqU1NRU9u3bd1p9CwgYRmLiixSXLGffvmcbnGMgnU4Obt/K7Hde55O7\nb2LJl5+g0ekYcfu93PXJ14y7/1Had+p6xoLAiag7AkVRWk3JtFSqtxQQ/HBPllevqV3OcvBrjIsZ\nV1tASvjuSshaCw9sBM+QE9YpnU4yr7gSe2kpsX/MRePiAsC2VTk8vzOLtfFG3B3gsS6Vm1Kn4eHh\nwRX/ehHv4FA+/vhjbDYb9957LwaD4bT6lp7xJpmZHxAT/TDR0Q8AUFFUyK6li9i5dBEVhfm4uHvQ\nYdBQugwfRVBUzGmd70TUUpWKopyTHJXW2gVsIr3wuyWJSbMm1S5nefE0hKhbZKY4HT7sB50mwaRP\nmlWvad16sm6+mcCHHybg7ruObM87UM7HP+7il85Gij20eGYWcdeab3Bz2rj8mRcwa3R8+eWXDB48\nmBEjRpxW36SU7N7zOHl5M/EzTCZjhYXMbZtBSiI6d6Pz8FHE9+6P7jQDTnOpoSFFUc5JWk8DXhdF\nUpNainVfGbd3uZ3U0lRWHF5xtJB/LAx4ALb/CAfXNKte97598BgxguIpU7AXFh7ZHhLtzeP39OKZ\nPZK++yxURvrz9qX3cTA0ip9feBptdSVdu3Zl9erVFBU1nM66uYqzsyjYGIcp14ti81RM1q30m3Q1\nk9+bypX/fpmkgUPOWhA4ERUIFEVpVR4DQtEFuVI2O4Mx4aMJcw/j0+2f1h9bH/woeIXD3MfA0bz3\n/YMeexSn1Urhe+/X2+7uY+Tqf/TgARcvblpSiYvU8tWwK1ky6GJ+eu1FEkIC0Ol0zJ0796RzCFnN\n1Wz/cz7f//NRvnrsPrbOmweFl2I0RBAx9ABdx/bEO+jEw1tnmwoEiqK0KqHV1C5gU2zBsrrg6HKW\n+ZuOFjK4w+iXIX8nbPy8WfUao6PxvfZayqZNw5JafxlLnV7LiJuTuH5oNPfNK6dzuoXVsV358sr7\n+OLnn+gY2Z6MjAx27959wvNIKTm8dzfzPnqbj++6iYVT3sNqMTPkxtu56+OvuOTh5+nd7wd0Ok+2\nbrsdiyXnpK7P2aCeESiKck4o+mY3Naml+DzchXGLLibJP4mPL/r4aAEp4etLIXcr3L8JPAJPWKe9\ntJT00WNw7dqViKmfNljm0O4SZn+ynT0+grkDPDHpIXnbKlLsFjQaDffffz9Go/G440xlpexevpid\nSxZSkpON3sWVxP6D6TJ8FKHxiUefcdSpqtrHps1XYzSG0Kvnz+j1Xid1fU5XqzwjEEJ8LoQoEELs\nPEG53kIIhxDiipZqi6Io5z6f8TFIKbEszDm6nGXxnqMFhIBxb4DVBH/+p1l16nx9CbjnHkwrV1K1\nYkWDZdp39OPaf/Yh0azh7t/L6FrqZEP3wXwR14M0tCxduvRIWafDQcbmDfz2fy8z5d5bWP7dF7h4\neDLq7ge5+5OvGX33g4QldDguCAB4eCTSpctHVFdnsn3H3Tidxy+c01pa7I5ACJECVAFfSyk7N1JG\nCywELMDnUsppJ6pX3REoyoWrfEEmlYsP4X57HOM2TCLOJ45J8ZOI9Ykl1icWd707LPgXrH4PJv8J\n4Q1+wa3HabWSMeFiNEYD0TNnInS6BsvVmG28+991uBVaOdTTg9+inJTpDHTP2s//9e5I6Y4t7Fq6\niKrSEly9vOmYMpwuw0bhH97+pPqYl/cbu3Y/QnDwxXTq+CZCnKW5Aq31+qgQIgqY3UQgeBiwAb3r\nyqlAoChtmNPqIP9/m9C46Vg8ch//t+n/sDqtR/aHuIcQ6xlJbMZqYvVexFz8IbG+8XgaPJust2L+\nAg4/9BAhzz+P7zVXN1rOanPw7H9X0T7Xjgh3ZXHnSlZ6BxFenMekxdOJb9+O7sNHEdOrN1rdiWc6\nNyYz8yPSM/6PyIi7iIs7ccrtM6GpQNBwaDwLhBDtgInAcGoDQVNl7wTuBIiIiGj5ximK0io0Bi3e\n46Mp+X4vF1cM5arrr+Jw1WHSy9JJL0+v/bMsnU3uRizSBPNuBiDILYhY79gjdw6xPrHEeMfgbfQG\nwHPUSFx79aLwvffwmjAerYdHg+c36LU8/UR/Hnh9FT0PVzO63IXIvjl85x/Ghh4paIoPE1BdQzur\nDbfTCASRkXdjqcnhYNYnuLi0Izz8+lOu60xotTsCIcQvwP+klGuFEF+i7ggURaH2LZyiT3dgyzMR\n8lgyGrfjP3AdDjs5X48jozyd9JR/kF6dS3pZOhnlGZjtRxeUD3QNJMYnhljvWDoXuhD32BTcb7uJ\niCeebrINh0qque3tVYws0eKBYNEVwWzEyj8KMyjftR2DwUDv3r3p378/Ho0ElRNxOu3s2HEPRcVL\n6drlIwIDLzqleprrnBwaEkIcAP56ohIAVAN3Sil/bapOFQgU5cJnyzOR/85mXLsG4j02Gp3P8W/t\nkLcTPkmBXjfDhLcAcEonuaa6oFCWQVpZGhnlGaSXpVNtr+aBWQ767pU8/6A/flGJxHjH1LuL8HPx\nO1L95qxSbvt4LVdYXTHYJFMn+BLmaeSbSF/WrVrJrl270Gq19OzZk4EDB+Lt7X3S/XQ4qtm8+Xqq\nTKn07Pk93l7dTvmancg5GQj+Vu5L1B2BoijHKJ93gMql2QAYIjxx7RyAa+cAdH4uRwv98SSs+wTu\nXAph3RutS0pJfnU+GfvW4X3Lv8jqGcZ3VwWSXpZOle1oumhfoy8xPjHE+cQR4x1DYYkP78yp4CrP\ncGqc8MMAD67Vu/PmwDhKSkpYuXIl27ZtA6B79+4MGjQIPz+/xprRoBprERs3XoHDYSK51zTc3CJP\n6vjmapVAIIT4ARhK7bf9fOA5QA8gpfz4b2W/RAUCRVH+xlZYjXlnMeadRdgO135g69t54NqlNijo\n3Wvg/WTwjYbb5kMzsnUWvPU2xZ98QtTPP+HSpQsF1QX1nj9klNfeSVRaK48c4zBHElN+K7qIKNYG\na3gyW8tdl3XAzctAWVkZq1atYvPmzTidTjp37szgwYMJCgpqdj+rqw+wcdOV6HTeJPf6BYPh5IJJ\nc6ikc4qinPfsxWbMu4ox7yjCeqj2Q1of6o6r/2FcU59CP/EZ6HHih66OKhPpo0djiIwk8rtvG3zn\nX0pJkbmI9PJ09hTv4eNtU6i2WnEWT0J0HY/V5uTBFSbGXJlIXK/aD/zKykrWrFnDhg0bsNlsJCUl\nMXjwYMLCwprVv7LyTWzZciMeHh3p2eMbtFrXk7g6J6YCgaIoFxR7maX2TmFHEdasCpCg0x7GdWAP\n3HqGowt2a/AD/i+lP/1M3nPP0e6dd/AaPeqE58sz5fH40qfZWrQRs+0iTDE30bPIybjF5cQlB5Fy\nTQKuHrUJ5Kqrq1m7di3r1q2jpqaGuLg4UlJSmvXGY0HBfHbsvI/AgIvo0uUDaqdanRkqECiKcsFy\nVNRgXrUN84oN1Di7AAJdgGvtM4UuAejD3I8LCtJu58DEiTgtNcTMmY2mGVlAndLJ5zu+5N0t71Lp\ncSVmv/E8ZDLi80ceRjcdQ6/vQEz3o2kvLBYLGzZsYM2aNVRXVxMVFUVKSgrR0dFNBqlDh74kdf+L\nhIffREL8s02WPRkqECiKcuGb/Q8cG3/FPGAa5iwDNRll4AStnwuunf1x7RyAob3nkQ/WqhUrOXTH\nHQQ9+ST+t97S7NPsLdnLPQseY7/PLTh07XnXPRDHn0UUZ1eR0DeYwVcl4OJ+9JVXq9XKpk2bWL16\nNZWVlbRr146UlBQSEhIa/ZDfv/+/ZB36jLi4p4mMmHw6V+UIFQgURbnwVZfAe70gKAlumYOj2o5l\nd+2DZktaGTgkWm9jbVDoEoAhwotDd96Jeft2YufPQ+fr2+xTWewWHlz+HrNkCnpLFg+7BDDIFMmm\neQdx9dQz7IYORHUJqHeM3W5n69atrFy5krKyMoKDg0lJSSEpKQnN3x5yS+lk566HKCiYS+dO7xAc\nPOG0L48KBIqitA0bv4DZD8OkqdD1yiObndU2zHtKaoPC/lKwSzSeBgzhWkqmvozn6N6E/uuZkz7d\nU9tW82WJG+4l39O9xo+3+tzFqu9TKckx0WFAKIOujMfoWj+Bg8PhYMeOHaxcuZKioiL8/f0ZPHgw\nXbp0QavVHlOuhi1bb6KiYjs9un+Fr2+fU74soAKBoihthdMBU0dARW7tGsfG43MQOS12LHtLMO8o\nwpJairQ5cdZU4NY9FI8B0RhjvBHa5iWCk1Jy7dZ9LCutwifv37hXe/DJyDewbnGwZf5B3H2MDLux\nAxEd/Y9vh9PJnj17WL58Ofn5+fj4+DBo0CC6d++Ori4xns1WxsZNV2G1FpLc6xfc3eNO+dKoQKAo\nStuRvbE2GAx4AEa91GRRp9VB9YaDFE35HW1IF4TQo3HT4dKxdvjIJdYHoWs6KBRabQxbvw9pq0Bk\nPoyQgtuTnuDakJH8+dVuSvOq6TQ4jAGXx2FwOT69m5SS1NRUli9fzuHDh/H09GTAgAH06tULg8GA\n2ZzNxk2Xo9EYSe41DaOx+fMTjqUCgaIobctv98O2H+Ce1RCYeMLiRR9/QuG7HxDy8ifIGl/Me0qQ\nNQ6EixbXJH88BrfDENZ4TqEFReXctOMAl3lrWb/lKWy6LBLchvPZmJfYtSCfrYuy8PRzYfhNSYQn\nNvwsQkrJgQMHWL58OZmZmbi5udG/f3969+6N1bqfzVuuIyTkMjokvnhKl0QFAkVR2hZTEbzXE0K7\nw02/1S5q0wSnxUL62HHofH2JmvYLOMGyv7R2rsKuIoROQ8gjvRpMgPeXx/cd4tucYr7pGMG7C/9H\nhv13XAjknRGvE1MTy59f7aG80EyXoeH0nxiL3tj4HIGsrCxWrFjB/v37MRqN9O3bl86dPfH374pW\n20DepWZQgUBRlLZn3RT443G48ivodNkJi5f//js5jz9B6Kuv4HPZ0fLWw1UUvL8F994h+E6Kb/R4\nk93BRRv3YXVKFvdO5IMV8/k+4zWErpwrYm/lid73sPG3g2xfko1XgAsjbu5IWLxPk23KyclhxYoV\n7NmzB71ez0UXXUTfvn2beQHqU4FAUZS2x2GHKUPBXAr3rweDe5PFpdNJ5tXXYC8oIHbeH2hcj6Z4\nKJuTQdWKwwTe1RVjdONZRjeXm7h4y34mBvnyfsdIVqQf4oEF/8bhtolwtw5MGf0mmjwPFn+9h4pi\nC92Gt6ffpTHoDE3PIC4oKGDlypUkJCTQuXOTOTwb1SprFiuKorQqra52jeOKbFjxvxMWFxoNwU89\niT0/n+Ivvqi3z2tkJFofI6Uz9iPtzkbr6OntzsORwUzLL+W3glIGx7Zn4Q2f0M46mUOVmVwycxJr\n5BKu/mdvOqe0Y9ufh/jp5Q3kZZQ32bagoCAmTZp0ykHgRFQgUBTlwhXZH7peU7vGcXH6CYu79eqF\n56hRFE/9DFtBwZHtGoMWn4lx2AvNVC491GQdD0eG0MPTjSf3ZZNbYyXQ08js2x7gkoD/YTGF8p+1\nz/HwysfoNjGYSx7qjt3mYMYbm1g9Iw27zXHaXT4VKhAoinJhG/kf0Bpr1y5oxlB40GOPIm02Ct99\nt95210Q/XLsFUrHkELaC6kaP12sEH3SMpMYpeWhPFk4p0Wk1/PeSFF4b+AHO4nGsOLyMi2dOJNt7\nH9f+uy9JA0LZsiCLn/+7kYKDFafd5ZOlAoGiKBc2zxAY+hSkLYR9f5ywuCEiAr/rr6d8+gwse/fW\n2+czIQah11I6cz/S2XhQiXEz8kJ8GMtLq5iaXXhk+yXd2vPr9f/Cp+xRSioFdy68k7d3vMmA62KY\n8EA3rGY7017bxLpZGTiaGII601QgUBTlwtf3LgjsAPOeBJv5hMUD7rkbrZcX+a+9xrEv1Gg9DfiM\ni8Z6oILqTflN1nFDqD8j/b14OSOXPVVHzxkf7Mncu6+hv8uLWEv68c2eb7j692uxhpVw7bN9SOwT\nzMa5mfzyykYKsyqbOMOZowKBoigXPq2+9sFxWRaseufExb29CbjvPqrXrKVq2bJ6+9ySgzFEeVE2\n9wCOSmujdQgheLNDezy0Wu7fc5Aa59Fv+J4uej69cQAP93gKy6FbOFCax9Wzr+GXzJ8YdnMHxt3T\nBXOllWmvbmT97AM4HC17d6ACgaIobUN0CnSaBCvfgtLMExb3veZqDJGRFLz+BtJmO7JdaAS+k+KR\nVgdlszOarCPQoOetDu3ZVWXh9QN59fYJIbhnaCxfXH0zmpzHsFbG8tqG17hn0T24x0uufbYvsb2C\n2DD7ANNe3Ujx4apGznL6VCBQFKXtGPUSCC3MO3GmUWEwEPTE41gzMsh74QXKf/8d0/r1WLOy0Hpr\n8RzaHvO2Qiz7Spo+ZYA3N4T682FWAatLj/8wHxQfwOz7xhLjfABL7mWsy9nI5bMuZ1Xxckbd3okx\nd3bGVFbDz//dwI6l2afc9Sb7qiaUKYrSpqx8CxY9D9dPg/iRTRaVUnL44X9QOX/+cfu0/oG49n0U\ndAa0bhsxhAajDwlGFxyCPjQEXXAwGhcXoP6s4yV9OuClO34CmcXm4D+/7+KnrZsIiJmGWWQxKX4S\nT/Z+EmHRseyHfST0DiGmR+BxxzaHmlmsKIryF3sNfDQApBPuXQu6E+fucZpM2PILsOflYsvLx56f\nhy0vH1uRA43nKGxZS7Fs/v6447Q+PuhCQtCHhLA7vgO39Uhhgs3Em946dMHB6ENC6s1gBvhpQxb/\n/m0bniGLsXn+SbhnOK8OfpWugV1Pq9sqECiKohwrbRF8ezmMeBYGP3paVZVO349pUx6Bd3REiEps\neXnY8vKw5+Vjy8/DnpuHLT8fe14enw0YzlcTruDZT99h2Oa1AGi8vdGHhKALCUZfdzeRa/DinW3l\nZBhLoOMiSrRl3NXtLu7ocgc6zfGprJtDBQJFUZS/+/F6SF8M960Hn/anXI2z2kbem5vQ+hgJurc7\nQtN4ptOaajOXbk3jQI2d36ty8MvNOXqXkVcbMBzFxccdV+2io8jTTsno3lz7r69PqZ1NBYJTCy2K\noijnu9H/hQ/6wIJ/wVVfnXI1Gjc9PhfHUPLDPqpW5+A5qF2jZY1urnzYPYERG/bxr/aJ/DhhHJq/\npch21tRgLyjAnpdHTW4eS1buZP+OdCI02SSG9j/ldjbZhxapFRBCfC6EKBBC7Gxk//VCiO11P6uF\nEN1aqi2KoijH8Y2sHRba/StkLD2tqly7BmJM8KViQSb2MkuTZWPcjPwnrnbW8WfZRcft1xiNGNq3\nx613b3wvuZhJrz9Nj/97if8m/4NZgSmn1c7GtOTro18CY5rYfwAYIqXsCrwITGnBtiiKohxvwIPg\nGwVznwB745PDTkQIge9lcSCh7Ld0TjTkfmNY7azjlzJy2Gs68UznUZ1C+P2BQTw1tsMpt7EpLRYI\npJTLgUZfsJVSrpZSltb9dS0Q3lJtURRFaZDeBca8CkX7YP0np1WVzs8Fr5GRWPaUYN55/Dj/serN\nOt6dhdV54pnDUQHuuOibXrfgVJ0rE8puBxrNBiWEuFMIsVEIsbGwsLCxYoqiKCcvcSzEj4alr0JF\n7mlV5TGwHfpQd8pmpeO02JssG2jQ82aH9uysMh836/hsa/VAIIQYRm0geLKxMlLKKVLKZCllcmDg\nqU2mUBRFadSYV8BhhYXPnlY1QivwvTweZ5WV8nmZJyw/OsCb60P9+CCrgDVljaSQqMyHLd/CTzfC\n1h9Oq32NadVAIIToCkwFLpVSNn0vpSiK0lL8Y2HgQ7DjZ8hcdVpVGcI98RgQhmltLjXNWFvghbh2\nRLoaeGDPQSrsDnA6IXsTLPlv7VKb/0uA3+6D7I1gbZl8Q60WCIQQEcAM4EYpZWprtUNRFAWAQY+A\nd3uY+3jtesenwWtUVLOWtgRw12n5IMaPXIuVZ/6cBv8XD1OHw/I3QGuA4f+Gu1fCI7uhzx2n1a7G\ntNg8AiHED8BQIEAIkQ08B+gBpJQfA88C/sCHovY9Wntjkx0URVFanMGtdm7BzzfCxs9q1zA4RRqj\nFp9LYyn+ajeVy7PxGh5Rv4CUULgXUufD/oX0ylrDQxE38WbULYzqMJlLouIgbgS4+Z1mp5pHzSxW\nFEX5i5TwzUQ4vBke2AgeQadVXfF3ezDvKSb44V7ovYEDK2D/fEhdAOVZtYWCu0DCKGxxo7i40JdM\ns5UlfRIJNRpOvz/HUCkmFEVRmqswtTYpXder4bIPTqsqx6EM8qZkYTDkEMA/EA4L6N0gZijEj6r9\n8T46Ezm92sJFG1Lp4+3OD91ijpt1fDpUiglFUZTmCkyA/vfWrmTW62Zo36f5xzpscGhd3ZDPArSF\ne/GWYygz3U914r9wH9wRIgfWzl9oQKybC8/HhfFkajafHy5icvjZeUtSBQJFUZS/S3kctv8Mcx+D\nO5aApomJXFWFkLaw9sM/fQnUlINGD5EDoOdNuMeNonpaNeWHknEJ6YVW3/SQz01h/iwsruCl9BwG\n+3qS6N5w0DiT1NCQoihKQ3ZMg+m3w/g3offtR7c7nZC7FfYvqP05vBmQ4BFcu9BN/OjaoR8XryOH\n2PJN5L+7BbeugfhdnXjCUxfU2Bi6YS9hRgNze8Vj0Jz+C55qaEhRFOVkdb4cNn4Bi1+E2OGQuw32\nL6z98DcVAALa9YJhz9SO9Yd0hUY+sPXB7ngOCady8SHcegbhEu/b5KmDjHreTIzglp0HeONAHv+M\nDWuBDh6lAoGiKEpDhIBxb8DHg+Dd7rXbXLwhdgQkjIa4i8A9oNnVeQ2LwLy9iNKZaQQ/3BONoem8\nQWMCa2cdv59VwHB/L/r7eJxGZ5qmhoYURVGasvkbKE6r/dbfvi9oT/37syW9jKJPd+A5NBzvMdEn\nLG+yOxixcR92KVncu+G1jpurqaGhVs81pCiKck7reSOM/A9EDTytIADgEuuDW69gKpdnY801nbC8\nu07L+0mR5Fhs/HN/9mmduykqECiKopxF3uOi0bjqKJuxH+k88YhMsrc7D0UG80teKb8XlLVIm1Qg\nUBRFOYu07np8JsRiPVSJaW3z0l4/EhVCP293Sm2nlwOpMephsaIoylnm2j0Q4+Z8yudn4tLJH523\nscnyeo1gRo+4MzrT+FjqjkBRFOUsO7K0pVNS9lt6s45pqSAAKhAoiqK0Cp2/K14XRWDZXYx55/GL\n2J9NKhAoiqK0Eo9B7dCHNG9py5akAoGiKEorEVoNvpfH46i0Uj4/s9XaoQKBoihKKzK098Sjf93S\nllknXtqyJahAoCiK0sq8Rkei9TLUzi1wNL20ZUtQgUBRFKWVaYw6fC6Jw5ZXTeWKw2f//Gf9jIqi\nKMpxXDv549rJn4pFWdiLzWf13CoQKIqinCN8LolFaAWlM9M4mwlBVSBQFEU5R2i9jXiPiaImrYzq\nrYVn7bwqECiKopxD3PuGYojwpHx2Og6T7aycUwUCRVGUc4jQCHwnxeM0Oyife+CsnFMFAkVRlHOM\nPsQdz5RwqjflY0kra/HzqUCgKIpyDvIa0R6tvwtlM/cjbY4WPVeLBQIhxOdCiAIhxM5G9gshxLtC\niDQhxHYhRM+WaouiKMr5Rui1+E6Mw15soWLxoRY9V0veEXwJjGli/1ggvu7nTuCjFmyLoijKeccl\nzhe3nkFULsvGlnfipS1PVYsFAinlcqCkiSKXAl/LWmsBHyFEaEu1R1EU5XzkPT4GjYuW0mYubXkq\nWnOFsnbAsfc72XXbjlu7TQhxJ7V3DQBVQoh9p3jOAKB1E3+fW9T1qE9dj6PUtajv3Lge953W0ZGN\n7WjNQNDQcjsNhjsp5RRgymmfUIiNUsrk063nQqGuR33qehylrkV9F/r1aM23hrKB9sf8PRzIaaW2\nKIqitFmtGQhmATfVvT3UDyiXUh43LKQoiqK0rBYbGhJC/AAMBQKEENnAc4AeQEr5MTAXGAekAdXA\nrS3VlmOc9vDSBUZdj/rU9ThKXYv6LujrIc5mhjtFURTl3KNmFiuKorRxKhAoiqK0cW0mEAghxggh\n9tWltHiqtdvTmoQQ7YUQS4QQe4QQu4QQD7V2m1qbEEIrhNgihJjd2m1pbUIIHyHENCHE3rr/R/q3\ndptaixDiH3X/RnYKIX4QQri0dptaQpsIBEIILfABtWktOgLXCiE6tm6rWpUdeFRKmQT0A+5r49cD\n4CFgT2s34hzxDjBPStkB6EYbvS5CiHbAg0CylLIzoAWuad1WtYw2EQiAPkCalDJDSmkFfqQ2xUWb\nJKXMlVJurvu9ktp/6O1at1WtRwgRDowHprZ2W1qbEMILSAE+A5BSWqWUZa3aqNalA1yFEDrAjQt0\nrlNbCQSNpbNo84QQUUAPYF0rN6U1vQ08AThbuR3nghigEPiibqhsqhDCvbUb1RqklIeB/wOyqE19\nUy6lXNC6rWoZbSUQNDudRVsihPAApgMPSykrWrs9rUEIMQEokFJuau22nCN0QE/gIyllD8AEtMln\nakIIX2pHDqKBMMBdCHFD67aqZbSVQKDSWfyNEEJPbRD4Tko5o7Xb04oGApcIITKpHTIcLoT4tnWb\n1KqygWwp5V93iNOoDQxt0UXAASlloZTSBswABrRym1pEWwkEG4B4IUS0EMJA7QOfWa3cplYjhBDU\njgHvkVK+2drtaU1SyqellOFSyihq/79YLKW8IL/1NYeUMg84JIRIrNs0Atjdik1qTVlAPyGEW92/\nmRFcoA/OWzP76FkjpbQLIe4H5lP75P9zKeWuVm5WaxoI3AjsEEJsrdv2jJRybus1STmHPAB8V/el\nKYOzk/7lnCOlXCeEmAZspvZNuy1coKkmVIoJRVGUNq6tDA0piqIojVCBQFEUpY1TgUBRFKWNU4FA\nURSljVOBQFEUpY1TgUA57wkhVp/l84XVvVaIEKK7EGLcGazbRwhxb0PnUpSWol4fVZQTEELopJT2\nRvbdQm12yvvPUH1RwOy6bJeKclaoOwLlvCeEqKr7c6gQYukxufS/q5sRihAiUwjxHyHEZiHEDiFE\nh7rt7kKIz4UQG+qSrF1at/0WIcQvQojfgQV/O19UXX56A/ACcLUQYqsQ4urm1ieE8BBC/HlMe/7K\nhvsqEFtX3xt/nauuDhchxBd15bcIIYYdU/cMIcQ8IcR+IcTrLX3NlQtLm5hZrLQpPYBO1OaSWkXt\nLOqVdfuKpJQ964ZeHgMmA/+kNq3EbUIIH2C9EGJRXfn+QFcpZUlDJ5JSWoUQz3LMHYEQ4r/Nqa8u\nrfFEKWWFECIAWCuEmEVtgrfOUsrudfVFHXPK++rO26UukC0QQiTU7ete1/caYJ8Q4j0p5bEZdxWl\nUeqOQLnQrJdSZkspncBWIOqYfX8l19t0zPZRwFN1qTaWAi5ARN2+hY0FgSY0tz4B/FcIsR1YRG1a\n9OAT1D0I+AZASrkXOAj8FQj+lFKWSykt1OYGijzJdittmLojUC40Ncf87qD+/+M1DWwXwOVSyn3H\nViKE6EttCua/fv+kbtezwPYmzn/C+upcDwQCvaSUtrrspydaBrGhdOp/aarfitIkdUegtHXzgQeO\neZbQ4+8FpJTrpJTd637+nrW2EvA8mfrqeFO7DoKtbqz/r2/wf6/vWMupDSDUDQlFAPsaKasozaYC\ngdLWvQjoge11D2VfPMnjlwAd/3pYfBL1fQckCyE2UvvhvhdASlkMrKp7GP3G3475ENAKIXYAPwG3\nSClrUJTTpF4fVRRFaePUHYGiKEobpwKBoihKG6cCgaIoShunAoGiKEobpwKBoihKG6cCgaIoShun\nAoGiKEob9//eyOAJR/KoiwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "meta_params = meta_opt_state[0]\n", "\n", "for j in range(10):\n", " losses = []\n", " key = jax.random.PRNGKey(j)\n", " params = task.init(key)\n", " opt_state = lopt.initial_inner_opt_state(meta_params, params)\n", "\n", " for i in range(10):\n", " batch = next(data_iterator)\n", " loss, grads = value_grad_fn(opt_state[0], batch)\n", " opt_state = lopt.update_inner_opt_state(meta_params, opt_state, grads)\n", " losses.append(loss)\n", " plt.plot(losses)\n", " plt.ylim(1.0, 2.3)\n", "plt.xlabel(\"inner-iteration\")\n", "plt.ylabel(\"loss\")" ] }, { "cell_type": "markdown", "metadata": { "id": "9hMBooBKTxy5" }, "source": [ "We can see our optimizer works, and is able to optimize for these first 10 steps." ] }, { "cell_type": "markdown", "metadata": { "id": "oOTPnSM5m8pV" }, "source": [ "## Vectorization: Speeding up Meta-training\n", "\n", "The above example, we are training a single problem instance for 10 iterations, and using this single training to compute meta-gradients. Oftentimes we seek to compute meta-gradients from more than one problem or to average over multiple random initializations / batches of data. To do this, we will leverage `jax.vmap`.\n", "\n", "We will define a vectorized meta-loss, which computes the original `meta_loss` function in parallel, then averages the losses. We can then call `jax.value_and_grad` on this function to compute meta-gradients which are the average of these samples.\n", "\n", "One big advantage to vectorizing in this way is to make better use of hardware accelerators. When training learned optimizers, we often apply them to small problems for speedy meta-training. These small problems can be a poor fit for the underlying hardware which often consists of big matrix multiplication units. What vectorization does compute multiple of these small problems *at the same time*, which, depending on the details, can be considerably faster." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "executionInfo": { "elapsed": 3598, "status": "ok", "timestamp": 1647716643584, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "8qud5EuWniIF" }, "outputs": [], "source": [ "def get_vec_batch_seq(vec_size, seq_len):\n", " batches = [get_batch_seq(seq_len) for _ in range(vec_size)]\n", " # stack them\n", " return {\n", " \"image\": jnp.asarray([b[\"image\"] for b in batches]),\n", " \"label\": jnp.asarray([b[\"label\"] for b in batches])\n", " }\n", "\n", "\n", "def vectorized_meta_loss(meta_params, key, sequence_of_batches):\n", " vec_loss = jax.vmap(\n", " meta_loss, in_axes=(None, 0, 0))(meta_params, key, sequence_of_batches)\n", " return jnp.mean(vec_loss)\n", "\n", "\n", "vec_meta_loss_grad = jax.jit(jax.value_and_grad(vectorized_meta_loss))\n", "vec_sec_batch = get_vec_batch_seq(4, 10)\n", "keys = jax.random.split(key, 4)\n", "loses, meta_grad = vec_meta_loss_grad(meta_params, keys, vec_sec_batch)" ] }, { "cell_type": "markdown", "metadata": { "id": "C2cz28gppua9" }, "source": [ "And now we can meta-train with this vectorized loss similarly to before." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "executionInfo": { "elapsed": 8431, "status": "ok", "timestamp": 1647716652129, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "TfZvfCJ_uGTp", "outputId": "a8156fa1-4d03-40cc-a961-d570f425ec54" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.3015163\n", "2.2858846\n", "2.262492\n", "2.2396119\n", "2.2089443\n", "2.1768482\n", "2.1378422\n", "2.0993867\n", "2.0581489\n", "2.0262241\n" ] } ], "source": [ "meta_opt = Adam(0.001)\n", "key = jax.random.PRNGKey(0)\n", "meta_params = lopt.init_meta_params(key)\n", "meta_opt_state = meta_opt.init(meta_params)\n", "meta_losses = []\n", "num_steps = int(os.environ.get(\"LOPT_TRAIN_LENGTH\", 200))\n", "for i in range(num_steps):\n", " data = get_vec_batch_seq(8, 10)\n", " key1, key = jax.random.split(key)\n", " keys = jax.random.split(key1, 8)\n", " loss, meta_grad = vec_meta_loss_grad(meta_opt_state[0], keys, data)\n", " meta_losses.append(loss)\n", " if i % 20 == 0:\n", " print(onp.mean(meta_losses[-20:]))\n", " meta_opt_state = meta_opt.update(meta_opt_state, meta_grad)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "height": 282 }, "executionInfo": { "elapsed": 189, "status": "ok", "timestamp": 1647716652440, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "vLBEKi6WuTcx", "outputId": "215a253b-831a-4f2e-d477-e27e530ba220" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAAAzHUlEQVR4nO3deXycVbnA8d8zk33f96RJ2qZNurehC13YWjbBCiKWVRStekFE\nwavievW6a6+gaK2AgIIIsioghVIo0DUt3dMlSZMmzb7v65z7x0xCkmZtk0wyeb6fTz+ZOe95Z555\nZ/rMmfOe9xwxxqCUUsp1WZwdgFJKqdGliV4ppVycJnqllHJxmuiVUsrFaaJXSikX5+bsAPoSFhZm\nEhMTnR2GUkpNGHv37i03xoT3tW1cJvrExEQyMjKcHYZSSk0YIpLX3zbtulFKKReniV4ppVycJnql\nlHJxmuiVUsrFaaJXSikXp4leKaVc3KCJXkTiRWSriGSKyBER+WofddaKyEER2S8iGSKyotu2K0Xk\nuIhkici3RvoFKKWUGthQWvTtwH3GmFRgKXCXiKT1qrMFmGeMmQ98DngEQESswMPAVUAacFMf+46Y\nh7ac5MPTVaP18EopNSENmuiNMUXGmH2O23VAJhDbq069+Whie1+g8/ZiIMsYk2OMaQWeAdaOVPDd\n1TS28fSu01z/x+18+4WDFFQ1jsbTKKXUhDOsPnoRSQQWALv62HadiBwDXsXeqgf7F0J+t2oF9PqS\n6Lb/eke3T0ZZWdlwwgIg0MedN7++is9emMRzGQVc9Kt32PDmCXRhFaXUZDfkRC8ifsDzwL3GmNre\n240xLxpjZgKfAH7cuVsfD9Vn5jXGbDLGpBtj0sPD+5yuYVD+Xu58/9o0tv33JXx8XgwPbTnJt54/\nxGuHisiv1Ba+UmpyGtJcNyLijj3JP2WMeWGgusaYbSIyVUTCsLfg47ttjgMKzzXYoYoJ8mbDjfMI\n8/Pgz++d4h8Z9h8Vy6eFsuHG+UQGeI12CEopNW7IYF0bIiLAE0ClMebefupMA7KNMUZEFgL/wp7U\nrcAJ4DLgDLAHuNkYc2Sg50xPTzcjNalZaW0zpXUtvHO8lD++k018iA8/WjubvIoGLpoRToS/Jn2l\n1MQnInuNMel9bRtKi345cBtwSET2O8oeABIAjDEbgU8Ct4tIG9AEfNpxcrZdRO4G3sCe9B8bLMmP\ntIgALyICvJgdG8j8+GA++/hubvzTDgA8rBZuWZrAA1en4m7VSwqUUq5p0Ba9M4xki763QwU15FY0\nkBjqy9O78/j77nyWJYey8bZFBHq7j8pzKqXUaBuoRT/pEn1vL+wr4JvPH2R6hD9P3rmYMD/PMXle\npZQaSQMl+knfX3H9wjge+cwF5JTXc8dfdtPeYXN2SEopNaImfaIHuCglnF/dMI/DZ2p5Zk/+4Dso\npdQEoone4Zq50SxJCuE3m49zqrzB2eEopdSI0UTvICL88OOzaGjt4JJfv8NNm3aSV2FP+Dab4eX9\nZyitbXZylEopNXyT/mRsb4XVTbxyoJCHt2ZhsxluXTqFnPIG3jxawsrpYfz1ziVOiUsppQaiJ2OH\nISbImy9dNJU37l3FyunhPPr+KbZklrAqJZz3Tpbz7onhz8OjlFLOpC36QZTXt9Dc1kG4vydrNmyj\ntd3GBUkhLE4K4ePzYnTsvVJqXNAW/XkI8/MkLtgHTzcrv7phLjFBXuzLq+J7Lx1m1S+36nTISqlx\nT1v058AYw77T1dz6yC5WTg/j1qVT+NeBQv5n7Sx8PIY0T5xSSo2o853rRvUiIiyaEsw9l03nF/85\nxpuZJRgDS5JDuWFRHADP7y2guqmN6xbEEuLr4eSIlVKTmXbdnIc7VySxNDmE6xbEkhjqw3MZ+Rhj\neGjLSe577gA//vdRLvz5FvbnVzs7VKXUJKaJ/jx4uFl4Zv0yNtw4nxsWxbHrVCXf+OdBNrx5gusX\nxvLqPSsI8HLnB68cwWb7qIusw2Z4fm8BFfUtToxeKTVZaKIfIdcvjEME/rm3gNuWTuHXN8xjVkwg\n37xyJgfyq/nWCwf53kuHyato4G8787jvuQPc8sguahrbnB26UsrF6cnYEfTrN47j7WHlvy6ein29\nFvtVtTds3M6+09W4WYTIAC9qmtqIC/Ymp6yBmCAvrpkbw50rkgjWvnyl1DnSaYqdrKm1g+a2Dgqq\nmrjpzztp67Cx+WuryK1o5OG3s9h7uor58UE8/YUleLpZnR2uUmoC0kQ/jpwoqaOmqY0LEkO6yl49\nWMRdT+/jmrnR/M/HZxHabU78msY29uVXcXFKODVNbbyfVc7Vs6OxWPpad10pNVnp8MpxJCXS/6yy\nj82N5lR5Cr/efIK3MktYd0ECd1yYSKC3Ozf9eSfHiuv4ztWpvJVZwq5TlWy6zcLls6KcEL1SaiLS\nFv04klVax5/ezeHFD8/QbjO4WwURYU5sIHvzqgDw83RjVkwA//jiMidHq5QaT7RFP0FMi/DnV5+a\nx9fWpLD5SDHZZQ1cPSeauXGBrP9rBoumhODnaeWnrx3jw9NVzIoJxMNNB04ppQamLfoJpqapjWU/\n20Jjawcebha+97FUbl06pWuUj1JqcjqvFr2IxANPAlGADdhkjHmwV51bgG867tYDXzbGHHBsywXq\ngA6gvb9A1NAEervz59vTOVhQw46cCr738hHePlbK7csSWZUSjlVP0iqlehm0RS8i0UC0MWafiPgD\ne4FPGGOOdqtzIZBpjKkSkauAHxpjlji25QLpxpjyoQalLfqhsdkMG7dl8+h7p6hoaCU60IsvXTSV\n25dpC1+pyea8pik2xhQZY/Y5btcBmUBsrzrbjTFVjrs7gbjzC1kNhcUi/NfF09jx7cvYeOtCpoT6\n8INXjvA//zpKh238dckppZxjWGfyRCQRWADsGqDancDr3e4bYLOI7BWR9QM89noRyRCRjLIyXcVp\nODzcLFw5O5qnP7+Uz69I4vHtuax9+H125VTQ3NbBT149yr3PfNhjvh2l1OQx5JOxIuIHvAv8xBjz\nQj91LgH+AKwwxlQ4ymKMMYUiEgG8CXzFGLNtoOfSrpvz8++DhfzoX0cprWvBx8NKY2sHAL+8YS43\npsc7OTql1Gg47xWmRMQdeB54aoAkPxd4BFjbmeQBjDGFjr+lwIvA4uGFr4brmrkxbL3/Yn55w1zW\npEXy2B3pLEgI4pf/OUZtc9+TqNW3tOt6uEq5qEETvdjP6j2K/WTrhn7qJAAvALcZY050K/d1nMBF\nRHyBy4HDIxG4Gpivpxs3psfz4LoFXDozkh99fDYVDa1sfCe7z/q/fzuLzzy2m9zyhjGOVCk12oZy\nwdRy4DbgkIjsd5Q9ACQAGGM2At8HQoE/OEZ7dA6jjARedJS5AU8bY/4zki9ADc2cuECumRvD49tz\nuXlJAn96N4cjhTX4ebmz6bZF/OtAIQDbsytIDPN1crRKqZGkF0xNIlmldaz5v234ebjR0NpO+pQQ\ndudWcsWsSN44UgLAtfNi+N1NC5wcqVJquM67j165hmkR/lw3P5amtg5+d9NCnv3SMi6ZEc4bR0rw\ndreyJi2SHdnljMcvf6XUudNEP8n87JNz2Hr/xXxsbjQA3746FYvAmrRI1qRGUl7fysnSeidHqZQa\nSTqp2STj6WYlPsSn635KpD9Pf2EpyWG+tLTbAPjzthyuWxjLsuRQAPbkVvHeyTJmxQRy5WydHlmp\niUYTvWKpI6EDzIoJ4Lm9BTy3t4DvfiyV0roWNm3LAewXZr12zwqmRZw9p75SavzSrhvVw8t3LWfX\nA5dx5awo/vfVTDZty+HmJQlsvf9ifD2sfP3ZA1Q1tDo7TKXUMOioG9WnptYOvvS3vcQEefOTT8zG\nYhFeP1TEl5/ah7tVuGRGBDcsimNNWqROoKbUOKBrxqoRk1lUy/N7C3hpfyHl9S08dNMCPj4vxtlh\nKTXp6fBKNWJSowP47jVp7Pz2pSSE+PD3XaedHZJSahCa6NU5cbNa+NSiOHbkVJBf2ejscJRSA9BE\nr87ZJxfFIQJ/25lHq2NoplJq/NFEr85ZTJA3F6eE86dtOcz+wRtsfDe766rafx8s5Bf/OebkCJVS\noOPo1Xl66KYFvH2slH8fLOLnrx9jb14VKZF+/OGdbIyBTy6MY1qEHwDtHTasFtFROkqNMW3Rq/Pi\n7+XO2vmx/OnWRXx9TQq7cip4eGs2F0wJAWDz0WIAGlvbuehX7/Dw1ixnhqvUpKSJXo0Ii0W457Lp\n7PveGt6+7yKe/sIS5sYFstkxK+Zfd+RxprqJZzMKdNI0pcaYJno1otysFpLD/XCzWrg8LZL9+dWc\nKm/gT9ty8Pd043RlI0cKa50dplKTiiZ6NWqumGWfAG3NhnepbGjlwZvmY7UIf999mruf3sd3Xjyk\nQzOVGgN6MlaNmmkRfnzjihlUN7ayICGYS2dGcuHUUJ7adRo3i2AR4YV9Z9j8tVU9ZtRUSo0sTfRq\n1IgId10yrUfZ7csSySlr4Fc3zCUq0IvVG97l77tP85kLE9m0LYf7Lk/Bx0M/lkqNJP0fpcbUmrRI\n1qRFdt2/dGYkz2YUcLiwlm0nypgR6c+NF8Q7MUKlXI/20SununlJPOX1LWw7UYabRXhp/5kh7dfe\nYeN3W05yukL7+JUazKCJXkTiRWSriGSKyBER+WofdW4RkYOOf9tFZF63bVeKyHERyRKRb430C1AT\n20UpEcSHeDMvPogvXTSVHTkVFNc0D7rfkzvy+M2bJ3h6t06qptRghtKibwfuM8akAkuBu0QkrVed\nU8BFxpi5wI+BTQAiYgUeBq4C0oCb+thXTWJWi/DCl5fztzsXc/3CWIyxT58wkMLqJn6z+TgAH56u\nGoswlZrQBk30xpgiY8w+x+06IBOI7VVnuzGm83/cTiDOcXsxkGWMyTHGtALPAGtHKnjlGsL9PfH3\ncic53I/58UH85YNcmlo7+q3/u7ez6DCG1amRHCyoob1DJ1RTaiDD6qMXkURgAbBrgGp3Aq87bscC\n+d22FdDrS6LbY68XkQwRySgrKxtOWMqFfPuqmZypbuJ/Xz3Kl/+2l/ufO4Axhoe3ZvGdFw9hsxne\nPFrC6tRIrp0XTVNbB8eK65wdtlLj2pBH3YiIH/A8cK8xps9LG0XkEuyJfkVnUR/V+rz+3RizCUeX\nT3p6ul4jP0ktSQ7lkwvjeGrXaSwCNgN+nm48sSMXYyA9MZjy+hYuS41gYUIwYO++mR0b6OTIlRq/\nhtSiFxF37En+KWPMC/3UmQs8Aqw1xlQ4iguA7mPl4oCBO2DVpPfdj6WyflUyb9y7ioUJQTy+PZfo\nAC/cLML3Xz6CReDilAjigr0J8/Pkw9PVzg5ZqXFtKKNuBHgUyDTGbOinTgLwAnCbMeZEt017gOki\nkiQiHsA64JXzD1u5smBfDx64OpXpkf788oZ5zI0L5LfrFrAmLZK65nYWTQkm2NcDEWFhQhB79YSs\nUgMaStfNcuA24JCI7HeUPQAkABhjNgLfB0KBPzjmGm83xqQbY9pF5G7gDcAKPGaMOTKyL0G5smkR\nfrxyt70nsKW9g9cPF3NZ6kcXXK2YHsbmoyUcLawlLSbAWWEqNa7JeJwyNj093WRkZDg7DDXOGGN4\n/XAxF88I75omoaqhlSU/3cKtS6fw/Wt15K6avERkrzEmva9temWsmjBEhKvnRPeYCyfY14PVaRG8\ntP8MH56u4uldpymvb3FilEqNP5ro1YT3qUXxVDa0ct0ftvPAi4e48Gdvs2HzcWy28fdrVSln0EnN\n1IS3cnoYV8+JIjHUlytnR/Ho+6d46O0sTpbW8/DNC7FYdI1aNblpolcTnpvVwh9uWdR1/7efns+U\nUF8e2nKSQ2dqmBcf5LzglBoHtOtGuRwR4TPLpiAC757Qq6yV0kSvXFKonyezYwLZdqKMrNJ6/vuf\nB2hu63v+nFPlDWSV1o9xhEqNHU30ymVdlBLOh/nVfP3Z/TybUcB/DhcD9mGaRTVNlNW1sDevkmt/\n9z7rNu2gsbXdyRErNTq0j165rFUp4fx+axYHC2pwtwrP7ysgPsSHL/41g/L6VgAsAhH+XhTXNvPE\n9jy+fPFUJ0et1MjTRK9c1oKEIAK83EgK82Xl9HD+8E4Wp8ob8HSz8uO1s2hpt1FS28ydK5L59gsH\n2fhuNrcsTSDAy93ZoSs1ojTRK5flbrXwzPplhPl70NDSwe+3ZlFQ1cTf7lzCiulhPep+dXUKn3j4\nA147WMS6xQlOilip0aGJXrm0rvlv/OHqOVFEBnidleQB5sUFEhXgxXsnyzXRK5ejiV5NGt3H2vcm\nIqx0TJDWYTNY9SIr5UJ01I1SDqtSwqlpauNgQbWzQ1FqRGmiV8phxbQwRGDbiXJnh6LUiNJEr5RD\nsK8Hc2MD+dfBQmoa25wdjlIjRvvolerm7kun819P7eWTG7dz8+IE6prb+SCrnCXJIXxhVbIOvVQT\nkrbolepmTVokT35uCTVNbfzo30f57ZYT1DS18bu3s7j6wffOunp2Z04FN/xxe7/TKyg1HmiLXqle\nlk0NZfcDl1HR0IpgnzfnneOl3PGXPWddPbv1WCkZeVXsz69maXKo84JWagDaoleqDyJCmJ8noX6e\nAFw8I4JLZ0aw8d1sapo+6r/PLmsAYJ8uUK7GMU30Sg3RfZenUNPUxt93n+4qyymzz3q5L6/aSVEp\nNThN9EoN0ayYQGZG+fNBln34ZWu7jbzKRsDeojdGly5U49OgiV5E4kVkq4hkisgREflqH3VmisgO\nEWkRkft7bcsVkUMisl9EMkYyeKXG2pKkEPbmVdHWYeN0ZQMdNsPixBAqG1rJq2h0dnhK9WkoLfp2\n4D5jTCqwFLhLRNJ61akE7gF+3c9jXGKMmW+MST/3UJVyvsVJoTS2dnDoTE1X//yn0uOAj/rp8ysb\nKaltdlqMSvU2aKI3xhQZY/Y5btcBmUBsrzqlxpg9gF5lolza4qQQAHblVJLt6J+/fFYU/p5u7D5V\niTGGdZt2svo37/L6oSJnhqpUl2H10YtIIrAA2DWM3QywWUT2isj6AR57vYhkiEhGWZmu86nGp3B/\nT6aG+7LrVAXZpQ1EBngS6O3OypQwthwr5fCZWs5UN+HhZuHLT+1jR3aFs0NWauiJXkT8gOeBe40x\ntcN4juXGmIXAVdi7fVb1VckYs8kYk26MSQ8PDx/Gwys1tpYkh7Irp5Id2eVMDfcD4IpZUZTVtfCb\nN48jAq98ZQXRgV784j/HaGnvYGdOhZ6sVU4zpEQvIu7Yk/xTxpgXhvMExphCx99S4EVg8XCDVGo8\n+fyKJKaE+lBY08z0CHuiv3hGBG4W4Z3jZcyPDyI2yJt7V09nf341l/76XdZt2sl2bd0rJxnKqBsB\nHgUyjTEbhvPgIuIrIv6dt4HLgcPnEqhS40VyuB+v3rOSv3z2Au66dBoAgd7uLJtqvzJ2dWokAJ9c\nGEdKpB8NjmkTMouG80NYqZEzlCkQlgO3AYdEZL+j7AEgAcAYs1FEooAMIACwici9QBoQBrxo/67A\nDXjaGPOfkXwBSjmD1SJcMiOiR9nVc6J5P6ucNWn2RO9mtfDPL1+IRYRVv9xKVmm9M0JVavBEb4x5\nHxhwuR1jTDEQ18emWmDeuYWm1MTy6fR4Fk0JJiXSv6usc7bLaeF+muiV0+iVsUqNEItFeiT57qZG\n+JFVVt91QvbxD07xw1eOjGV4ahLTRK/UGJgW4Ud1YxsVDa0YY9i0LYenduWdNb1xY2s7p8obnBSl\nclWa6JUaA52jc7JK6zlaVEthTTNtHYYD+dU96n3vpSN87KGz571X6nxooldqDExzJPqTpfW8ebQE\ncZz1ysj7aHrj4ppmXt5/hsbWDt47aZ84rcOmY+/V+dOFR5QaA9GBXvh6WMkurScjr5KFCcHUNLWR\nkVvZVefx7bnYjMHXw8pbR0vIr2xk47s5/PNLy0gM83Vi9Gqi0xa9UmNARJga4cc/9uRz+Ewtq1Mj\nuSAxmIy8Kmw2Q0t7B0/vyuPK2VGsTovkzcwSNrx5gvL6Fu56ep8uVajOiyZ6pcbIVy6dzpWzo7hp\ncTw3LIojfUoIdc3tnCit43hxHbXN7VwzN4bVqZFUN7bR2m7jux9L5UhhLd996bBOoaDOmXbdKDVG\n1qRFdl1MBZCeGAxARm4VbhZ7p/2smACCfT3w8bByY3o8n1+ZTG1zOw9tOUlskDdfW5PilNjVxKaJ\nXiknSQjxIdzfk4zcSgK83fHzdCM+2AeLRXj7vosJ8/MA4Gurp1NQ2ciDW07y6QviiQnydnLkaqLR\nrhulnEREuCAxmD25VRwprCU12h+Lo2UfFeiFm9XSVe/OlUkA7MzRidHU8GmiV8qJ0qeEcKa6iYMF\n1cyKCey33syoAAK87IubKDVcmuiVcqLOfvq2DkNadEC/9awWYXFSCLsGSPQ1jW387LVMapt1oTfV\nkyZ6pZwoLToAHw+r/XZM/4keYElSKKfKGyjtZz3aX28+zp+25fDyh2dGPE41sWmiV8qJ3KwWFiQE\n4WaRrqtn+7Mk2b5e7U5Hqz6zqJb8ykYATpTU8dSuPAA2Hy0ZxYjVRKSJXikn+/yKZO65bDpe7tYB\n66VFB+Dn6cbuU/YTsuv/msFX/v4hAD9//Rh+nm7cmB7HjuwKapraaO+w9djfZjPYdEqFSUkTvVJO\ndsnMCO65bPqg9dysFubGBXIgv4aK+hbyK5vYn1/Ny/vP8PaxUr6wMplPX5BAu81w37MHSPvBG+zt\nNpfOb986wbW/f380X4oapzTRKzWBzIsP4lhxbY8E/o1/HsTL3cKtS6ewID6IMD9P3sosobXdxr5u\n9T7IriCzqPaslr5yfZrolZpA5sUF0tZheDajAIDLZkbQ2m7jhkVxBPt6YLEI37xyBvdcNp0gH3dy\nHHPb22yG48V12AyU1rU48yUoJ9ArY5WaQObGBQHw9rESEkN9uHd1CjnlDXxhZXJXnU+lxwPw3sky\nch2JvqCqifoW+xz3RTVNenXtJKMteqUmkOhAL8L8PLEZmB0byJy4QLbefzFTQs+exjgp1JfcCnui\nzyyu7Sovqul7eKZyXYMmehGJF5GtIpIpIkdE5Kt91JkpIjtEpEVE7u+17UoROS4iWSLyrZEMXqnJ\nRkSYH2+/gnZObP9X0gIkhflSVNNMU2sHmUXdEn21JvrJZigt+nbgPmNMKrAUuEtE0nrVqQTuAX7d\nvVBErMDDwFVAGnBTH/sqpYahs/tmTtzAib5zsZLcigaOFdWRFOaLj4dVW/ST0KB99MaYIqDIcbtO\nRDKBWOBotzqlQKmIfKzX7ouBLGNMDoCIPAOs7b6vUmp41s6PIbe8gYUJwQPWS+pM9OUNZBbXMism\nABF7H72aXIbVRy8iicACYNcQd4kF8rvdL3CUKaXO0ZRQXzZ8ev6gF1h1tuj351eTV9FIalQAMYHe\n2qKfhIac6EXED3geuNcYUztY/c7d+ijr89I8EVkvIhkiklFWVjbUsJRS/fDzdCPC35NH3j+F1SJc\nMjOC6EAvbdFPQkNK9CLijj3JP2WMeWEYj18AxHe7HwcU9lXRGLPJGJNujEkPDw8fxlMopfqTGOZL\nh81w3+UpzI4NJDrQi9K6Fk6VN/DGkeIedXPLG7qGYCrXMpRRNwI8CmQaYzYM8/H3ANNFJElEPIB1\nwCvDD1MpdS6uXxDLzUsS+NKqqQBEB3ljDHzhyQy++Ne9HD5TA0BZXQtXPriNDZtPODNcNUqGcsHU\ncuA24JCI7HeUPQAkABhjNopIFJABBAA2EbkXSDPG1IrI3cAbgBV4zBhzZGRfglKqP+sWJ7Cu2/2o\nQC8AskrrAfjVG8d54nOLeXz7KZrbbOzN04VNXNFQRt28T9997d3rFGPvlulr22vAa+cUnVJqRMUE\n2q+IDfPz4PZliWx48wSPvJfDkzvysAgcLaqlua1j0BO9amLRK2OVmkTigr3x93TjK5dOZ/2qZObE\nBvK/r2ZS19zOly+eSluH4WjRUMdaqIlC57pRahLx9XRjz3dXd7XYX7l7OXvzqiira2FBQjAPb83m\nQH71oGP01cSiiV6pSaZ7t4yIkJ4Y0nU/MsCT/fnVTohKjSbtulFKdZkfH9SV6I0xlNcPbUrjDpuh\nTee5H7c00SuluixMCCavopEPT1fx4JaTLPnplq51aTsZY9ieVY4xH137+ON/H+XmP+8c63DVEGmi\nV0p1WXdBAvEh3nz+iQwe3HKSDpthZ05Fjzp7cqu4+ZFdbM/+qDwjr5K9eVU06AVX45ImeqVUl0Af\nd/54yyLqWtqZHuFHkI87GblV7M+vZtnPtlBS29zVwj9ZUgfYW/g5ZQ3YDBwp1BE745GejFVK9TA7\nNpDX7llBqK8n3/jnAfbkVmIzhqKaZo4W1VJca58ULbfCnvCLa5tpbO0A4GBBNYuTQvp9bOUcmuiV\nUmeZFuEPQHpiCG9llnYl98LqJkq7Er199ars0oau/Q4W1IxxpGootOtGKdWvCxxDLztb7IXVTV1J\nP8/Ros8pt0+nMD8+iENnNNGPR5rolVL9mh0bgKebhXB/T6ICvCiqbqa41j7kMr+ykfYOG9ml9fh5\nurE6NYJT5Q3UNLU5OWrVmyZ6pVS/PN2srF+VzNdWpxAf4s2Z6iZKaprxcLPQbjOcqW4iu6yBqeG+\nXUscHtZW/bijiV4pNaD7Lp/BzUsSiA70pqCqibL6FhbEBwFwqryB7LJ6ksP9mBcfhJtF2HZSFw4a\nbzTRK6WGJCbI3qLvsBmWJIcC9tkui2qamRruS6C3OxdOC+P1Q8U9LqZSzqeJXik1JDFBXl23Z8cE\n4Oth5cnteQDMiAoA4GNzojhd2ajj6ccZTfRKqSGJdsxlD/YFTKaE+lJc28za+TFcOjMCgDVpUVgt\nwmuHigD7RVXP7sl3SrzqI5rolVJD0r1FHxXgxfULY7llSQK/+dQ8rBb72kQhvh5cODWU1w/b16N9\n6O0svvnCQZocwzOVc2iiV0oNSefqVFaLEOrnyedXJvOT6+bgZu2ZRlanRnKqvIHc8gZ2ZFdgDGSX\n1TsjZOWgiV4pNSRBPu54u1sJ9/PsasH35aKUcAAe++BU1zTHJ0vrePVgEVc9+N5ZE5/lVzbynRcP\nUdnQOnrBT3Ka6JVSQyIiRAd5ERnoNWC9xDBfEkN9eHrX6a6ykyX1vHaoiMyiWv7Rq8/+nRNlPLXr\nNLc8sosqTfajQhO9UmrIbl86hZsuiB+03kUp4bTbDDGBXkyP8ONEST17cisBePT9Uz0WKSmtbcYi\n9u6dWx/dRXXj2cneZjP8Y89pmtu0r/9cDJroRSReRLaKSKaIHBGRr/ZRR0TkIRHJEpGDIrKw27Zc\nETkkIvtFJGOkX4BSauzcsTyJdYsTBq130Qx7983SqaGkRPqzM6eC0roWLpkRzpnqJv59sLCrbnFN\nMxH+Xmy6bREnS+q57dHdZyX03bmVfPP5Q7xxpHhkX9AkMZQWfTtwnzEmFVgK3CUiab3qXAVMd/xb\nD/yx1/ZLjDHzjTHp5xuwUmr8W5Ycxvz4IK5bEMu0CD/qHf3y918xgzA/T94/+dGiJSV1LUQGeHLx\njAh+ccMcDp2pYfepyh6Pd9QxLr+wunnsXoQLGTTRG2OKjDH7HLfrgEwgtle1tcCTxm4nECQi0SMe\nrVJqQvD2sPLSXctZOT2clEj7lMf+nm7MjApgWoRv14yXYO+6iQiw9/tfnGIfj3+suOcFV5lF9vtF\nNU1jEb7LGVYfvYgkAguAXb02xQLdz7AU8NGXgQE2i8heEVk/wGOvF5EMEckoK9O5MpRyFdMj/QBY\nMCUYq0VIDvcjp6yha5qEktpmohyJPtjXg6gAL44V1fV4jKNF2qI/H0NO9CLiBzwP3GuM6X19c19j\nrTonu1hujFmIvXvnLhFZ1dfjG2M2GWPSjTHp4eHhQw1LKTXOJYb6EurrwSWOfvvkMF9qmtqobGil\npb2DqsY2IgM8u+rPjPbvSuwAbR02TpbYfwEUVmuL/lwMaYUpEXHHnuSfMsa80EeVAqD7qfg4oBDA\nGNP5t1REXgQWA9vOJ2il1MTh4Wbh/W9eiqebvV05Ndzewj9V3kCkoyXf2XUDkBodwAdZ5dQ0tvFs\nRj7zE4Jo7bAR6O2uXTfnaNBELyICPApkGmM29FPtFeBuEXkGWALUGGOKRMQXsBhj6hy3Lwd+NEKx\nK6UmCG8Pa9ft5HBfAHLKPlqCMLJbop8Z5U9bh+H7rxzm5f2FXd06l8wI56X9hTS1dvR4PDW4obTo\nlwO3AYdEZL+j7AEgAcAYsxF4DbgayAIagc866kUCL9q/K3ADnjbG/GekgldKTTyxQd64W4Xs8np8\nPe0pqHvXTWq0fSbMl/cX4u1upbi2GU83C8unhfHS/kKKappIdvwqUEMzaKI3xrxP333w3esY4K4+\nynOAeeccnVLK5bhZLUwJ9SWnrIFIf3trvfMvQFKYLx5WC60dNh5cN5//e+sk/p5uxIf4AFBU06yJ\nfpiG1EevlFIjKTnM17EylS8ebhaCfNy7trlbLaRG+9PQ2sHq1EiWTQ3FZqNrLVo9ITt8muiVUmMu\nOdyPrcdLKahsIjLAE0f3bpff3bQQq1WwWAR/L/uXgJeH/WTugYJqntiRS2F1M3NiA3n8sxectb/q\nSee6UUqNuQunhtLWYXjtcFGPbptOCaE+xAZ59yjzdLMS5ufJ07tOc7y4joUJQbx7oowdORVn7a96\n0kSvlBpzq1LC+dJFUzGm54ibwcQEeWEzcMuSKfz+5oWE+Hrw+Ae5oxeoi9BEr5Ryim9cMYPPr0ji\nEwt6z6jSv/gQH/w83fjKpdPwcrdy0+J43sosIb+y8ay6eRUNXXPfF9U0Udds7+M/UlhDXkXDWfVd\nmSZ6pZRTWC3Cd69JY01a5JD3eeDqVJ794jJC/ezDMW9dOgUR4a8783rU67AZrv3d+/z89WN02Axr\nf/8BP3j5CDab4Y6/7OHmP++i1pH4JwNN9EqpCSM2yJu0mICu+9GB3lw5K4pndp+msfWjlavyKhqo\nbW5n89FiMnIrKa1rYfPREvbkVlJW18KZ6iZ+/K+jw3ru2ua2rvl5AKobW3nv5MSYl0sTvVJqQrtj\neSK1ze289OFHc9wfK7ZPilZS28Jv3zoJQH1LOz99LROrRbht6RSe21vA8eK6Ph+zt/qWdpb9dAtP\nbM/tKrv/uYN85rHdXVMwj2ea6JVSE1r6lGDSogPYtC2b0xX2vvpjRbVYBCwCO3IquCAxGB8PKwcK\narggMZg7licCcOhMzZCeI7u0nobWDh794BQdNsO2E2W8lVmCzcCZqvE/rl8TvVJqQhMRvn31TMrr\nW7nit9t4/2Q5mcV1JIf7sWhKMADXzI3pWrR8dWokU0J88HCzcKJkaC36U+X2k7f5lU08v6+AH/7r\nCF7u9vR5pvrsE8HjjSZ6pdSEt3J6OJu/tooQXw9+v/Ukx4prmRnlz1Wzo3GzCJelRrB2fiwebhau\nmBWFm9XC9Ai/ri6eweSUN2ARCPf35L//eZDC6iZ+dv0cAAq0Ra+UUmMjJsibmxbHszOnkvzKJlKj\nA7h92RTe+vpFxAX7cOXsKD783pquOXNmRPpzvNdKVrXNbdzxl91klfb8Asgpqycu2Ie7Lp5KYqgP\nz35xGWvnxeJhtWjXjVJKjaVPLoqjczaEGZH+uFktJIb5dm3vnC0TYEaUPyW1LVQ3tnaV7cur4p3j\nZfyf4wRup1PlDSSF+XLH8iTe+cYlzI0LwmIRYoK8KJgAc+9ooldKuYzoQG9WTrf3xc+M9h+wbkqU\nfXv3kTfZjjnyXz9UxLMZ+az65VY+yCrnVHlD1zz63cUGe2uLXimlxtrX16Tw2eWJZ82V09tMR6Lv\nfkI2u6weXw8rbhYL//3Pg5yubOTnrx+jsbWD5LCzE31ckA9nJkCLXmevVEq5lPnxQcyPDxq0XlSA\nF/5ebhztthB5dmk9M6MDWJYcyu7cSpLDfHlmTz4ASWFnz4EfG+xNWV0LzW0deLlb+cmrR5kbF8S1\n82JG7PWMBG3RK6UmJRFhSVIIz+w5zQ9ePkxbh43ssgamhvty/xUzePaLy/j8yqSu+kl9dd04fjUU\n1TRjsxme2JHHKwcKz6rnbNqiV0pNWv/36fn88j/HeWJHHolhvpTXt3QtXg4wLcKfeXGBHCuuI7qP\nWTZjg+2JvqCqER8PK63ttnHZZ6+JXik1afl7ufOjtbN4+1gpf3wnG6BHogf4/rVpnCypx2I5e3GT\nOEeiP1PVhLe7fcHygqrxdwGVdt0opSY1EeHaeTGU1rUAMDWiZ6JfNCWEdYsT+tw3KsALq0UoqGoi\n35Hga5vbx93MmJrolVKT3rXzogFwtwrxwQOP1unOvtC5DydK6siv/KjLZrx13wya6EUkXkS2ikim\niBwRka/2UUdE5CERyRKRgyKysNu2K0XkuGPbt0b6BSil1PlKiw5gargvSWG+uFmH1/6dHRPIkcLa\nHoufjLdEP5Q++nbgPmPMPhHxB/aKyJvGmO6TOV8FTHf8WwL8EVgiIlbgYWANUADsEZFXeu2rlFJO\nJSI8uG4BLe22Ye87KyaAVw4UcuhMDYmhPuRWNFJQ1UhOWT0iQlIf4+87Nbd10NJuI9Db/XzCH9Sg\nX13GmCJjzD7H7TogE+i99tda4EljtxMIEpFoYDGQZYzJMca0As846iql1LgyOzawa7bL4e4H9jnw\n58UH4eVuIb+qiVse2cVlv3mHb79wkA6b6XPfH/37KDdu3HFecQ/FsH6jiEgisADY1WtTLJDf7X6B\no6y/8r4ee72IZIhIRlnZxFi1RSmlZnVb8SohxIfYIG/eOFJMUU0zc2ID+fvufHafquxz36ySeo6X\n1I36GrZDTvQi4gc8D9xrjKntvbmPXcwA5WcXGrPJGJNujEkPDw8falhKKeVUQT4eXRdOxQV7Exvs\nQ0FVE1aL8FPHVMY55fV97ltS1wzAeyfLRzXGISV6EXHHnuSfMsa80EeVAiC+2/04oHCAcqWUchmd\nrfr4YJ+usfUXJAaTGhWAl7uFU2Vnt9iNMZTUdib60e3FGMqoGwEeBTKNMRv6qfYKcLtj9M1SoMYY\nUwTsAaaLSJKIeADrHHWVUspldPbTx4d8lOjXpEVhsQiJob7klJ+d6Gub2mlus+FuFbZnVdDeMfwT\nwUM1lFE3y4HbgEMist9R9gCQAGCM2Qi8BlwNZAGNwGcd29pF5G7gDcAKPGaMOTKSL0AppZzt5iUJ\nhPp5EBfszby4IHw8rFw5OwqA5HBfMovOXsmqs9tmdWokrx8u5kBBNYumhIxKfIMmemPM+/Td1969\njgHu6mfba9i/CJRSyiWF+Xlyy5IpACyfFsahH16B1TFlQnKYH28cKaG13YaH20edKJ3dNtctiOWN\nI8VsO1E+aoler4xVSqkRZu02L05SmC8dNtM1RUKn4hp7op8R5c/cuCDezxq9E7Ka6JVSahR1Tm/c\n+4Rs59w6kQFerJoexv78amqaRmeOHE30Sik1ijpXpuo9xLKktplAb3e83K2sTAmnw2bYkV0xKjFo\noldKqVEU5ONBiK8Hp3qNvCmuaSYywBOwr4rl5+k2asMsNdErpdQomxUTwKsHizh8pqarrKSuhUjH\nYibuVgvLpoaO2oVTmuiVUmqU/fS6Ofh7uXPro7u6TsKW1jZ3JXqAmxbHc8eFif3Oi3M+NNErpdQo\niw/xYdPti6hubGPr8VI6bIbSupaurhuAS2dG8rkVST1G7IwUTfRKKTUG0qIDCPZx58PTVVQ0tNBh\nMz1a9KNJE71SSo0BEWFBQjAfnq7mZIl9BE7cMFazOh+a6JVSaowsiA/iZGk9f999Gh8PK8uSw8bk\neTXRK6XUGFmQYF/Y5N8Hi7gsNRJvD+uYPK8meqWUGiPz4gMRx7nWa+ZGj9nzaqJXSqkx4u/lTkqE\nP/6eblyUMnYLLA1lmmKllFIj5OuXp9DQ0o6X+9h024AmeqWUGlNXzIoa8+fUrhullHJxmuiVUsrF\naaJXSikXp4leKaVcnCZ6pZRycZrolVLKxWmiV0opF6eJXimlXJwYM/KrmZwvESkD8s5x9zBgdNbj\nOj8a1/CN19g0ruHRuIbvXGKbYozpc16FcZnoz4eIZBhj0p0dR28a1/CN19g0ruHRuIZvpGPTrhul\nlHJxmuiVUsrFuWKi3+TsAPqhcQ3feI1N4xoejWv4RjQ2l+ujV0op1ZMrtuiVUkp1o4leKaVcnMsk\nehG5UkSOi0iWiHzLiXHEi8hWEckUkSMi8lVH+Q9F5IyI7Hf8u9pJ8eWKyCFHDBmOshAReVNETjr+\nBo9xTDO6HZf9IlIrIvc645iJyGMiUioih7uV9Xt8ROTbjs/ccRG5wgmx/UpEjonIQRF5UUSCHOWJ\nItLU7dhtHOO4+n3vxuqY9RPXP7rFlCsi+x3lY3m8+ssRo/c5M8ZM+H+AFcgGkgEP4ACQ5qRYooGF\njtv+wAkgDfghcP84OFa5QFivsl8C33Lc/hbwCye/l8XAFGccM2AVsBA4PNjxcbyvBwBPIMnxGbSO\ncWyXA26O27/oFlti93pOOGZ9vndjecz6iqvX9t8A33fC8eovR4za58xVWvSLgSxjTI4xphV4Bljr\njECMMUXGmH2O23VAJhDrjFiGYS3whOP2E8AnnBcKlwHZxphzvTL6vBhjtgGVvYr7Oz5rgWeMMS3G\nmFNAFvbP4pjFZozZbIxpd9zdCcSN1vMPJ64BjNkxGyguERHgRuDvo/HcAxkgR4za58xVEn0skN/t\nfgHjILmKSCKwANjlKLrb8RP7sbHuHunGAJtFZK+IrHeURRpjisD+IQQinBQbwDp6/ucbD8esv+Mz\n3j53nwNe73Y/SUQ+FJF3RWSlE+Lp670bL8dsJVBijDnZrWzMj1evHDFqnzNXSfTSR5lTx42KiB/w\nPHCvMaYW+CMwFZgPFGH/2egMy40xC4GrgLtEZJWT4jiLiHgAHweecxSNl2PWn3HzuROR7wDtwFOO\noiIgwRizAPg68LSIBIxhSP29d+PlmN1EzwbFmB+vPnJEv1X7KBvWMXOVRF8AxHe7HwcUOikWRMQd\n+xv4lDHmBQBjTIkxpsMYYwP+zCj+xB+IMabQ8bcUeNERR4mIRDtijwZKnREb9i+ffcaYEkeM4+KY\n0f/xGRefOxH5DHANcItxdOo6fuZXOG7vxd6vmzJWMQ3w3jn9mImIG3A98I/OsrE+Xn3lCEbxc+Yq\niX4PMF1EkhytwnXAK84IxNH39yiQaYzZ0K08ulu164DDvfcdg9h8RcS/8zb2E3mHsR+rzziqfQZ4\neaxjc+jRyhoPx8yhv+PzCrBORDxFJAmYDuwey8BE5Ergm8DHjTGN3crDRcTquJ3siC1nDOPq771z\n+jEDVgPHjDEFnQVjebz6yxGM5udsLM4yj9GZ7Kuxn73OBr7jxDhWYP9ZdRDY7/h3NfBX4JCj/BUg\n2gmxJWM/e38AONJ5nIBQYAtw0vE3xAmx+QAVQGC3sjE/Zti/aIqANuwtqTsHOj7AdxyfuePAVU6I\nLQt7/23nZ22jo+4nHe/xAWAfcO0Yx9XvezdWx6yvuBzljwNf6lV3LI9Xfzli1D5nOgWCUkq5OFfp\nulFKKdUPTfRKKeXiNNErpZSL00SvlFIuThO9Ukq5OE30Sinl4jTRK6WUi/t/HvKahvKzbLIAAAAA\nSUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(meta_losses)" ] }, { "cell_type": "markdown", "metadata": { "id": "W9Ao4iSNnl16" }, "source": [ "## Evolutionary Strategies (ES): Meta-training without meta-gradients\n", "Computing gradients through long optimization procedures can sometimes lead to chaotic dynamics, and result in exploding gradients. See [https://arxiv.org/abs/1810.10180](https://arxiv.org/abs/1810.10180) and [https://arxiv.org/abs/2111.05803](https://arxiv.org/abs/2111.05803) for more info.\n", "\n", "An alternative is to leverage black box optimization techniques. A method we found that works well is evolutionary strategies with antithetic samples. This estimator can be thought of as a randomized finite difference. We sample a random direction in the meta-parameters, compute the meta-loss when shifting in this direction, and in the negative of this direction, and move in the direction which lowers the loss. The estimator can be written as:\n", "\n", "$$\\nabla_\\theta = \\mathbb{E}_{\\epsilon \\sim \\mathcal{N}(0, I\\sigma)} \\dfrac{\\epsilon}{2 \\sigma ^2} (L(\\theta + \\epsilon) - L(\\theta - \\epsilon))$$\n", "\n", "where $L$ is the meta-loss.\n", "\n", "As before, we will construct a vectorized version of these estimators to average over a number of different random directions." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "executionInfo": { "elapsed": 2115, "status": "ok", "timestamp": 1647716654716, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "KHd6Sf0mu_CZ" }, "outputs": [], "source": [ "def antithetic_es_estimate(meta_params, key, seq_of_batches):\n", " \"\"\"Compute a ES estimated gradient along a single direction.\"\"\"\n", " std = 0.001\n", " keys = jax.random.split(key, len(meta_params))\n", " noise = [\n", " jax.random.normal(keys[i], p.shape) * std\n", " for i, p in enumerate(meta_params)\n", " ]\n", " meta_params_pos = [p + n for p, n in zip(meta_params, noise)]\n", " meta_params_neg = [p - n for p, n in zip(meta_params, noise)]\n", "\n", " pos_loss = meta_loss(meta_params_pos, key, seq_of_batches)\n", " neg_loss = meta_loss(meta_params_neg, key, seq_of_batches)\n", "\n", " factor = (pos_loss - neg_loss) / (2 * std**2)\n", " es_grads = [factor * n for n in noise]\n", " return (pos_loss + neg_loss) / 2.0, es_grads\n", "\n", "\n", "@jax.jit\n", "def vec_antithetic_es_estimate(meta_params, keys, vec_seq_batches):\n", " \"\"\"Compute a ES estimated gradient along multiple directions.\"\"\"\n", " losses, grads = jax.vmap(\n", " antithetic_es_estimate, in_axes=(None, 0, 0))(meta_params, keys,\n", " vec_seq_batches)\n", " return jnp.mean(losses), [jnp.mean(g, axis=0) for g in grads]\n", "\n", "\n", "keys = jax.random.split(key, 8)\n", "vec_sec_batch = get_vec_batch_seq(8, 10)\n", "loss, es_grads = vec_antithetic_es_estimate(meta_params, keys, vec_sec_batch)" ] }, { "cell_type": "markdown", "metadata": { "id": "nXqXfO1QquBY" }, "source": [ "We can use a similar meta-training procedure as before now with this new gradient estimator." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "executionInfo": { "elapsed": 18328, "status": "ok", "timestamp": 1647716673161, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "QAte_Dm3wg-9", "outputId": "cefff89b-82db-406f-f961-5041d1dfe8d1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.3011694\n", "2.2831185\n", "2.2696424\n", "2.256878\n", "2.246539\n", "2.2384717\n", "2.229327\n", "2.2206445\n", "2.213926\n", "2.2045903\n" ] } ], "source": [ "meta_opt = Adam(0.003)\n", "key = jax.random.PRNGKey(0)\n", "meta_params = lopt.init_meta_params(key)\n", "meta_opt_state = meta_opt.init(meta_params)\n", "meta_losses = []\n", "n_particles = 32\n", "\n", "num_steps = int(os.environ.get(\"LOPT_TRAIN_LENGTH\", 200))\n", "for i in range(num_steps):\n", " data = get_vec_batch_seq(n_particles, 10)\n", " key1, key = jax.random.split(key)\n", " keys = jax.random.split(key1, n_particles)\n", " loss, meta_grad = vec_antithetic_es_estimate(meta_opt_state[0], keys, data)\n", " meta_losses.append(loss)\n", " if i % 20 == 0:\n", " print(onp.mean(meta_losses[-20:]))\n", "\n", " meta_opt_state = meta_opt.update(meta_opt_state, meta_grad)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "height": 282 }, "executionInfo": { "elapsed": 137, "status": "ok", "timestamp": 1647716673416, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "X_fqy1t-xzJ_", "outputId": "dcbdd3d6-77d1-44fc-d806-9b15e2eedca1" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAAAyZUlEQVR4nO3dd3hc1bX38e+aUe/dVm+We7dwN8bG9GJKQhwMCQHChUCCCbmB\nG1Ivyb2QQsKbBLgOJCRA6L1jwMY2rnIvclGxrGr1bpXR7PePGcmSkSzJljQq6/M8ejw6Z89ozZnx\nT1t79tlHjDEopZQaviyuLkAppVT/0qBXSqlhToNeKaWGOQ16pZQa5jTolVJqmHNzdQGdCQsLMwkJ\nCa4uQymlhowdO3aUGmPCO9s3KIM+ISGBtLQ0V5ehlFJDhojkdLVPh26UUmqY06BXSqlhToNeKaWG\nOQ16pZQa5jTolVJqmNOgV0qpYU6DXimlhrlug15EYkVkrYiki8gBEbm3kzbLRWSviOwWkTQRWdhu\n36UiclhEMkTkwb5+Aq0amltYvT6TLzNK++tHKKXUkNSTHr0NuN8YMwGYC9wtIhNPa/MZMM0YMx24\nFXgaQESswF+By4CJwDc7uW+fcLdaWL0+mxe2dnnOgFJKjUjdBr0xptAYs9N5uwZIB6JPa1NrTl3B\nxBdovT0byDDGZBljmoCXgOV9VXx7Votw+ZTRfH6omLpGW3/8CKWUGpJ6NUYvIgnADGBrJ/uuFZFD\nwPs4evXg+IWQ265ZHqf9kmh3/zucwz5pJSUlvSmrzRVTImlotvP5oeKzur9SSg1HPQ56EfEDXgdW\nGWOqT99vjHnTGDMeuAZ4uPVunTxUp9cuNMasNsakGmNSw8M7XZenW6kJIUT4e/L+3sKzur9SSg1H\nPQp6EXHHEfIvGGPeOFNbY8x6IFlEwnD04GPb7Y4BCs6y1m45hm8iWXu4mKKqhv76MUopNaT0ZNaN\nAM8A6caYx7poM8bZDhGZCXgAZcB2IEVEEkXEA1gBvNNXxXfmlvkJWES4/9Xd2O164XOllOpJj34B\ncDOw1Dl9creIXC4id4rInc421wP7RWQ3jlk23zAONuAe4GMcH+K+Yow50PdP45SEMF9+cdVEvswo\n43mdgaOUUsipyTKDR2pqqjmX9eiNMVzz1y8BePuehd20VkqpoU9EdhhjUjvbNyzPjBURLpwwir35\nVZTVNrq6HKWUcqlhGfQAF4wLxxhYf/TspmoqpdRwMWyDfnJUIGF+Hqw7rEGvlBrZhm3QWyzC+Snh\nrD9SQovOvlFKjWDDNugBloyPoKK+ma3ZZa4uRSmlXGZYB/2yCaPw93TjtR15ri5FKaVcZlgHvbeH\nlSunRfLhviJqdaEzpdQINayDHuBrs2I52dzCB/t0/Rul1Mg07IN+ZlwQCaE+utCZUmrEGvZBLyKc\nPzac7cfKaW6xu7ocpZQacMM+6AHmJYVS39TC3rxKV5eilFIDbkQE/dykUAA2Z+o0S6XUyDMigj7Y\n14MJkQFs0qBXSo1AIyLoAeYnh7Ijp4KG5hZXl6KUUgNqxAT9gjGhNNrsbMsud3UpSik1oEZM0M9P\nDsPHw8onB4tcXYpSSg2oERP0Xu5WFo8N55MDJ/QSg0qpEWXEBD3AJZNGU1zTyG6dZqmUGkFGVNAv\nGR+Bm0X4+IAO3yilRo4RFfSB3u4sTAnj3d0Fuka9UmrEGFFBD3BDaiwFVQ1s0EsMKqVGiBEX9Msm\njCLE14OXt+e6uhSllBoQIy7oPdwsXDcjmjUHT1Ba2+jqcpRSqt+NuKAHWDE7Dpvd8NK2464uRSml\n+t2IDPoxEX4sSgnjuS05NNl06WKl1PA2IoMe4NaFiZyobuTD/XpBEqXU8DZig35xSjhJ4b78bUMW\nxuhUS6XU8NVt0ItIrIisFZF0ETkgIvd20maliOx1fm0SkWnt9t3nvN9+EXlRRLz6+kmcDYtFuGtx\nMvvzq/k0vRiA0tpGlj32BVuzdDljpdTw0ZMevQ243xgzAZgL3C0iE09rkw0sNsZMBR4GVgOISDTw\nAyDVGDMZsAIr+qr4c3XtjGgSQn14bM0R7HbDW7vyySiu5V+bc1xdmlJK9Zlug94YU2iM2em8XQOk\nA9GntdlkjKlwfrsFiGm32w3wFhE3wAco6IvC+4Kb1cK9y1JIL6zmzV35vL4zH4BP009Q3dDcoW1p\nbaMuhqaUGpJ6NUYvIgnADGDrGZrdBnwIYIzJB34PHAcKgSpjzCddPPYdIpImImklJQN31urV06KZ\nGRfEz97eT3phNdfOiKbRZuej/afWw/lofxFz/+cznlqfOWB1KaVUX+lx0IuIH/A6sMoYU91FmyU4\ngv4B5/fBwHIgEYgCfEXkps7ua4xZbYxJNcakhoeH9+5ZnAOrRfjd16fRYje4W4WfXTmR+FAfHv/0\nKHc9v4M7/pXGPf/eic1u+OTAiQGrSyml+opbTxqJiDuOkH/BGPNGF22mAk8DlxljWj/NXAZkG2NK\nnG3eAOYDz59r4X0pOdyP3399GiU1jYT4enD3kjGsXp9FZkktgnDRxFFEBXnzjy+zqaxvIsjHw9Ul\nK6VUj3Ub9CIiwDNAujHmsS7axAFvADcbY46023UcmCsiPsBJ4EIg7Zyr7gdXTYtqu31Daiw3pMZ2\n2L8jp4JnNmazMaOUK6dGnX53pZQatHrSo18A3AzsE5Hdzm0/AeIAjDFPAT8HQoEnHL8XsDmHYbaK\nyGvAThyzd3bhnJEz1EyLCSTAy423dxfwpTPsF4wJc3VZSinVrW6D3hizEZBu2twO3N7Fvl8Avzir\n6gYRN6uFhSlhfLDP8SHtgYJq3rlnoYurUkqp7vVojF453DI/EWMgMtCbv3+ZzYGCKirqmokM8iI5\n3I+0Y+WE+3sSH+rr6lKVUqqNBn0vzE4MYXZiCJX1TTy/NYf7X9nDoaIafDysXDElkld35DE7IYRX\n7pzn6lKVUqrNiF3r5lwE+Xhw+eTRHCqqYeGYMMZE+PHqjjyiAr1IyynXde6VUoOK9ujP0g8vGkds\niA/fu2AMIrAvvwpvdytX/nkjn6Wf4Bvnxbm6RKWUArRHf9biQn24/+JxeHtY8XK3cl5CCJOiAogO\n8tYTq5RSg4oGfR8SES6ZNJoNGaXc/cJOntt8DGMMf12bwc/e2k9+5UlXl6iUGoF06KaPfW1WDB8f\nKGLX8Qre31fIazvz2ZNbiQi8vD2XFbNjuWfJGCICBsVqzUqpEUB79H1sYlQAXz64lI0PLGXlnDj2\n5FZyy/wENj6wlOtnxfDvrce59olNnGxqcXWpSqkRQgbj1ZVSU1NNWtqgXCmhV4wxZJXWkRTmi/OM\nYTZnlvHNv23hB0vH8MOLx7m4QqXUcCEiO4wxqZ3t0x59PxIRksP92kIeYF5yKMunR/HU+iyOl9UD\n6KUMlVL9SoPeBf7rsgm4WYT/fu8gGcW1zPvfz3ltR56ry1JKDVMa9C4wOtCL7y9N4dP0E6x8egtF\n1Q088mE6dY02V5emlBqGNOhd5NaFCSSG+XKiupEfXjSW0tomnlyXSZPN7urSlFLDjE6vdBFPNyvP\nfDuVnLJ6loyP4EBBFX9Zm8GTX2Ry5+IkfnTxuA5j+0opdbZ01s0gUd9k46P9Raw7XMI7ewqYFhNI\nZkkddy5O4p6lKa4uTyk1yOmsmyHAx8ON62bG8PiK6axalkJ1g40If0/+74ssqhuaXV2eUmoI06Af\nZESEVcvGsvZHF/D4ihnUNNp4YctxV5ellBrCNOgHsSkxgSxKCeOZjVkcKKhydTlKqSFKg36Qe+DS\n8QAs/8uXvLI918XVKKWGIg36QW5ydCBr7lvMeQkh/PLdAxToCphKqV7SoB8Cgn09+O3XpmI3hh+/\ntpenN2SxP1+HcpRSPaNBP0TEhviwatlYNmaU8uv30/naU5vYlFnq6rKUUkOAzqMfQowxZJbU4WYR\n7ngujZyyem5dmMh/nJ9EkI+Hq8tTSrmQzqMfJkSEMRF+JIT58uJ353LxpNE89UUmq17e7erSlFKD\nmC6BMESF+nny52/OICrIi2c2ZFPd0EyAl7ury1JKDULaox/iLhw/CpvdsOGIjtcrpTqnQT/EzYwL\nItDbnc8PFVNV36wXIFdKfUW3QzciEgv8CxgN2IHVxpjHT2uzEnjA+W0tcJcxZo9zXxDwNDAZMMCt\nxpjNffUERjo3q4XFY8NZc7CIzw+doKK+mblJIfy/FTP0AuRKKaBnPXobcL8xZgIwF7hbRCae1iYb\nWGyMmQo8DKxut+9x4CNjzHhgGpB+7mWr9i6cEEF1g40wP09+eNFYduRU8MS6TIqqGrjyzxt4ZmO2\nXq5QqRGs2x69MaYQKHTerhGRdCAaONiuzaZ2d9kCxACISABwPnCLs10T0NRHtSunK6ZEYjeGZRNG\n4e/lzrHSOl5Jy6Ww6iT786vZn3+QoydqeOT6qa4uVSnlAr0aoxeRBGAGsPUMzW4DPnTeTgJKgH+I\nyC4ReVpEfLt47DtEJE1E0kpKSnpT1ojnZrVw7YwY/J2zbm5dmEh9UwsfHzjBf5yfxHcXJfLS9lw2\nHNXjqtRI1OOgFxE/4HVglTGmuos2S3AEfet4vRswE3jSGDMDqAMe7Oy+xpjVxphUY0xqeHh4L56C\nOt3k6EDmJYUS5ufB3UvH8KNLxhEb4s1v3k+nxa5DOEqNND0KehFxxxHyLxhj3uiizVQcH7ouN8aU\nOTfnAXnGmNa/AF7DEfyqnz2xciZv37OQAC93PN2sPHDpeA4V1fDSdsfa9nWNNh23V2qE6DboxXHh\n0meAdGPMY120iQPeAG42xhxp3W6MKQJyRWScc9OFtBvbV/0n2NeD6CDvtu+vmBLJ3KQQHv3wEO/s\nKWDWr9fw1BdZAFTWN7E1q4w9uZUuqlYp1Z+6XetGRBYCG4B9OKZXAvwEiAMwxjwlIk8D1wM5zv22\n1jUXRGQ6jp6+B5AFfMcYU3Gmn6lr3fSPzJJaLvvTBppaHC9jVKAX7/1gERc99gVldU2IwKYHlxIZ\n6N3NIymlBpszrXWji5qNMP/cdIy3dudz9bQofvXuQVLjg9lxvIIHLh3PIx8e4vEV01k+PdrVZSql\nekkXNVNtvj0/gTe/t4Ab58QR7ONOWk4FV06N4ruLkvD3dGNrdnmH9vvzq1j6h3UUVTW4qGKl1LnS\noB+hPN2sXDczBqtFWLUsBatFSE0IZttpQf/y9lyySupYd7jYRZUqpc6VBv0I9qOLx/HBDxaRHO4H\nwOzEUDKKaymtbQSgxW74cH8RAFuyyrp8HKXU4KZBP4J5e1gZN9q/7fvZiSEApB1z9Oq3HyuntLaR\nIB93tmaX63RMpYYoXY9etZkSHYiXu4VVL+8mfs1RPNwseLlbuGtxMv/74SGOl9cTH9rpic1KqUFM\ne/SqjYebhcdXzOCbs+MI9/fkYGE1l0wazdLxEQC8v6+QD/YV0mSzd/NISqnBRHv0qoNLJo3mkkmj\nAag62YyXuwUPq4VQXw9++9FhAC4cH8FfV87Ey93qylKVUj2kQa+6FOh96tKED1w2nsySWgK93fnt\nR4f57r/SWH1zKt4eGvZKDXYa9KpHbkiNbbsd5ufJA6/v5dZnt/PkTTMJ8vFwYWVKqe7oGL3qtRtS\nY/njDdPZdqycJb9fx1u78gH4/ceHueUf21xcnVLqdNqjV2flmhnRjBvtz8/e2s+ql3ez63gF/9zs\nWOqorLaRUD9PF1eolGqlPXp11iZEBvD87XOYkxjCPzfnEOzjGNPfebwSgEZbC58fOkFVfbMLq1RK\nadCrc+LlbmX1t1L5zoIEXrh9Lu5WYefxCr7MKGXho2u59dk0Vm/IdHWZSo1oOnSjzlmgtzu/uGoS\nAJOiAtlxrIK1h4rxcrcQHeTNvvxOL0imlBog2qNXfWpWfDDbjpVzqKiGe5aMYV5yKAcLqnT5BKVc\nSINe9alZ8cEABPu4s3x6NJOiAiitbaK4ppHaRhu2Fj2rVqmBpkGv+tSs+GCsFmHlnHi83K1MjAwA\nYEdOBRf8bh1/WHOkm0dQSvU1DXrVp0YFePHuPQv5wYUpAEyMcgT9Y2uOUFrbyKtpudqrV2qA6Yex\nqs+1hjuAv5c78aE+ZBTX4ulmobS2iTUHT/DqjjwsAivnxLPEuWgaQEZxDSCMifBzQeVKDU/ao1f9\nrnX45r6LxhLs486ql3fz+aFidudW8p1nt/OLt/dja7HTYjfc/Mw2rv7LxrY18ZVS506DXvW7OYkh\n+Hu58Y3UWK6eFkWjzc69F6aw+b8u5LuLEvnn5hweenM/XxwpprCqAatFuOUf23lrV77O1lGqD8hg\n/I+Umppq0tLSXF2G6iN2u6GuyYa/lzuV9U18ll7MtTOisVgEgEc+PMRTX2SSFOZLdYONN783n3te\n3MWe3EpmxgVx09x4rpoWhbvV0S95bUce+/Iq+dXyya58WkoNKiKywxiT2tk+7dGrfmexCP5ejuUR\ngnw8uH5WTFvIA6xalkJ8qA9ZpXV8bVYMsSE+vHHXfH5z7WQq6pv54St7uPzxDWzLLqe6oZn/fvcA\nz289fsYLoLTYDcXVDf3+3JQaCjTolct5uVt55LqpxIX4sHJOHEDbFM3P71/M6ptn0WBrYeXTW7jv\npd1UN9hosRtyyuq6fMzXd+Sx8LdrKag8OVBPQ6lBS4NeDQrzkkNZ/+MlxIb4dNguIlw8aTTvfX8R\n40b789mhYpLCHNetPVpc2+Xj7cuvoslm5/29hf1at1JDgQa9GhICvd15/rY5fGdBAk/eNAsRyDhD\n0GeVOva9u7dgoEpUatDSoFdDRpCPB7+4ahLjRvsTHeTN0eJaqk42k1741UXTskrq8LBa2JtXRXZp\n10M8So0E3Qa9iMSKyFoRSReRAyJybydtVorIXufXJhGZdtp+q4jsEpH3+rJ4NXKlRPiRUVzLr949\nwPVPbqK53dm2dY02Cqsa+HpqDABv7sxzVZlKDQo96dHbgPuNMROAucDdIjLxtDbZwGJjzFTgYWD1\nafvvBdLPtVilWqWM8iezpJb39xZS39RCdmkdzS12Smoa23rwC8eEcfHEUazekMUx7dWrEazboDfG\nFBpjdjpv1+AI7OjT2mwyxlQ4v90CxLTuE5EY4Arg6b4qWqkx4X402ew0OqdYphdW88zGbBb/bm3b\nWbVJ4X789/LJuFst/Pj1vdjtg++cEaUGQq/G6EUkAZgBbD1Ds9uAD9t9/yfgx8AZV7ISkTtEJE1E\n0kpKSnpTlhqBxoxyrIUzdpQf7lYhvbCG9UdKqG9q4Yl1mYhAfKgPowO9+MnlE9iWXc4XR0uob7Lx\nyvZcWtqFfnFNg/b41bDW46AXET/gdWCVMabTSwaJyBIcQf+A8/srgWJjzI7uHt8Ys9oYk2qMSQ0P\nD+9pWWqEGjvKH38vN25flERyuB/78ivZedzxR2VxTSMxwd54uVsBuH5mDGF+njy3OYc/fHKEH7++\nl08OFAGQW17PVX/eyAW/X8d1T3ypJ1mpYalHQS8i7jhC/gVjzBtdtJmKY3hmuTGmzLl5AXC1iBwD\nXgKWisjz51y1GvH8PN3Y8dOLuCE1lgmRAWzKLKOh2c6ilDAAksNPrX7p4WbhxtmxrD1czD83HQPg\njV35lNU2svLprTTa7KxalsLO45W8vD2X8romLv3TejZnlnX2o5Uacnoy60aAZ4B0Y8xjXbSJA94A\nbjbGtF1ZwhjzX8aYGGNMArAC+NwYc1OfVK5GPA83x9t3QqQ/rUs2Pbx8Mn6ebm0rZra6cU48FhG8\n3K1cOyOadYeL+fFreymsOsmz35nNqmVjmRYbxKeHinlvbwGHimp4bM3hgX5KSvWLnqxHvwC4Gdgn\nIrud234CxAEYY54Cfg6EAk84fi9g62pxHaX62vjRjlBPDvclIcyXj+87nxAfjw5tRgd68dDlExgV\n4EVCmA9v7srns0PF3HthCtNjgwBYNj6CP6w5Qm1DM1aLsP1YBTtyypkVHzLQT0mpPtVt0BtjNgLS\nTZvbgdu7abMOWNeL2pTqkfGR/gDMTgwFIDrIu9N2ty5MBMAYw8TIAJpa7HxvSXLb/gsnjOIPa46Q\nWVLHD5aO4bktOTy5Lounv+0IeluLnTuf38nKOXEdLpai1GCnV5hSQ16Evxe/uGoii8f27EN8EeG5\n22ZjtQiebta27RMi/YkK9KKgqoEbzosF4M9rM8gpqyM+1JfNWWV8mn6C2sZmDXo1pOgSCGpY+M6C\nRJLCe375wVA/T4JOG94REW5flMTKOXHEBPuwcm48VhH+tTkHgHd2O9bN2ZpdTlGVzs5RQ4cGvVLt\n3Lowkd9cOwVwXOj80smjeSUtl8r6Jj46UERqfDDGwHu6WJoaQjTolTqDW+YnUNNg46q/bKSmwcb3\nL0xhcnQA7+7RoFdDhwa9UmcwKz6Y31w7GR93N5LCfVmQHMrlUyLZk1dFaW2jq8tTqkf0w1ilzkDE\ncaWrlXPi27bNTnDMwtl1vJKLJo5yVWlK9Zj26JXqpcnRgbhZhF3HK7ps8/mhEyx89HO9lKEaFDTo\nleolL3crE6MC2HW8ku3Hyrnt2e002lra9meX1nHvS7vJqzjJZ4eKXVipUg4a9EqdhZlxwezJq+Q3\n76fz2aFiDhQ41vmz2w0/eHEXbhYhzM+TTRmlLq5UKQ16pc7KjLgg6pta2J1bCcDu445/P9hfyL78\nKn525UQuGBfO5qwyXQdfuZwGvVJnYWZcMADBPu6E+3uyO7cSW4udx9YcISXCj+XTo1kwJpTK+mYO\ndnJNW6UGkga9UmchJtib2Qkh3HthCrOcwzjv7S0kq6SO+y8ei9UizE92LJm88bThm4bmFp5cl8nT\nG7IA2JtXyZu79Lq2qv/o9EqlzoKI8Mqd8wA42WznowNFPP7ZUcZE+HHxxNGA48zacaP8eWJtBo3N\ndu5ekkzVyWaue3ITOWX1WC3C1dOi+Olb+zlYUM0FYyMI9vU4049V6qxoj16pc9S6zHF2aR23L0zE\nYjm12Oufb5zB7MRQ/vjpEV5Oy+Wl7bnklNXz62sm02I3/PztA+zNq8JmN3zkvOqVMYaH3tzHdue1\nb5U6Vxr0Sp2jqTGBWATC/Dy4ZkZ0h31jR/nzt2/NYkp0IM9syObFbceZlxTKTXPjmZcUykcHivDx\nsBIT7N22fk5BVQMvbD3Oc87F1AAOF9Xwy3cOdLjWrVI9pUGv1Dny9XTjW/MSePCyCW3XqW3PsSpm\nIlmldeRVnOTGOXEArJjtWAp5+fQorp0RzebMMkpqGjnk/PB2S1YZxpi2Hv6zm46xN69ywJ6XGj40\n6JXqA7+8ehJfmxXT5f7Lp0QSFehFiK8HF09yLJtw6eTRfHdRIncvGcOVU6OwG/jkYBGHimoAx0XO\ns0rrWHu4mLQcx1m4m7P0Oraq9/TDWKUGgLvVwpM3zaKpxd52sRNPNysPXTERcIzLRwd5s/FoKVaL\n4ONhpb6phfVHSnhpWy4JoT5YLcLmzDK+d8EYVz4VNQRp0Cs1QKY5P7TtjIgwPzmUTw6eINTXg/nJ\noezPr+bRjw7R0Gzn6W+lsuFoCa+k5dFks7ddGL0zxTUNtNgNkYGdX1JRjTw6dKPUILFgTBhVJ5vJ\nKq1jQmQA85JDaWi286158SybOIp5yaGcbG7pdpz+nn/v4p5/7xqYotWQoD16pQaJ+cmhbbfHjw7g\n/LHhiMBPLp8AwJzEUETgy4wyUp1LJZ+uvK6J7cfKCfbR+fjqFO3RKzVIRAR4kRLhuO7t+Eh/zksI\n4bEbprfN5An29WB2Qgj/3pZDQ3NLp4+x7nAxxjgCv77JNmC1q8FNg16pQWTp+AiCfNxJCPXtdP99\nF43lRHUjz2/J6XR/+2WRdS181UqHbpQaRO67aCy3LEjA2u7s2vbmJoWyKCWMv67N4MY5cTQ22/nj\np0corGpgdIAX6w+XkBzuS2ZJHfmVDYyJ8G+7b2HVSQqrGtoWZFMjh/bolRpEvNyt3c6WueuCZCrq\nm1l7qIRX0nL51+YcjpfV8+qOXGoabXx7fgIA+RUde/Q/fm0v31y9hZIavdbtSKM9eqWGmDmJoYT6\nevDRgSKOl9czNSaQd+5ZSE1DMwcKqkmND+ZX7x4kv7K+7T45ZXVsOOpYRfPZTdn85yXjz/gz6hpt\nuFstZ5zGqYYOfRWVGmKsFuGiiaNYc7CIPbmVXD4lEgB/L3fmJoXiZrUwOsCrQ4/+39uOY7UIsxND\n+NfmHGoams/4M657YhO//+Rwvz4PNXC6DXoRiRWRtSKSLiIHROTeTtqsFJG9zq9NIjKtp/dVSvXe\nJZNG09BsB+CyyaO/sj862JuCygYASmoaeS0tjwvHR/DTKyZQ02Dj9R1dr39/sqmFwydq2OO8epYa\n+nrSo7cB9xtjJgBzgbtFZOJpbbKBxcaYqcDDwOpe3Fcp1Uvzx4Ti5+nGpKgA4juZoRMT5E1+5Um2\nZZdz2ePrqW208R+Lk5gaE8T40f68t7cQgDUHT1Bc09DhvsfK6gDHsstqeOg26I0xhcaYnc7bNUA6\nEH1am03GmArnt1uAmJ7eVynVe55uVv70jek8fM3kTvdHBXlTVN3AfS/vxs/TjXfuWciseMdJVpdP\niSQtp4L39hbw3X+lseql3RhzavnjY86AL65ppLZR5+IPB70aoxeRBGAGsPUMzW4DPuztfUXkDhFJ\nE5G0kpKS3pSl1Ii0bOKoLqdKRgd702I35Fee5FfLJzNu9Klplq1j+ve/sgeLwKbMMj5vN/8+q11P\n/lgve/UZxbV6MfRBqMdBLyJ+wOvAKmNMp1c7FpElOIL+gd7e1xiz2hiTaoxJDQ8P72lZSqlORAc5\npmjOTgjh/JSwDvvGRPgxbpQ/jTY7D142nqQwX/7ng/S2Xn37cO/N8E1ueT0X//ELXt2R2wfPQPWl\nHgW9iLjjCOoXjDFvdNFmKvA0sNwYU9ab+yql+tbk6EAmRwfw0BUTEPnqyVc3zY1jUlQA35qXwO2L\nksgsqWsL9WNldUyJDgS+GvT/+0E6v3n/YKc/c1duJXYDn6YXd7pfuU638+jF8S55Bkg3xjzWRZs4\n4A3gZmPMkd7cVynV90J8PXjv+4u63H/zvARunpcAwKx4x/DP7txKksL9yC6t58LxEZTVNnYI+iab\nnee25FDf1MKS8RHMT+74l8L+/CoANmWUdruUshpYPXklFgA3A0tFZLfz63IRuVNE7nS2+TkQCjzh\n3J92pvv2+bNQSp21MRF++HpY2XW8kpqGZkprG0kI8yUx3LdD0O88XkF9UwsebhZ++tZ+KuqaOjzO\nvrwq3CxCXVMLO49XnP5jMMboNW9dpNsevTFmI9D5whun2twO3H4291VKuZbVIkyLDWJ3biU5ZY6z\naRPDfMiv9OWd3QUYYxCRtqtf/fGG6Xz/xZ3Me+Qzlo6PYFJUIN9ZkMD+giqumBrJ+3sL+eJICXOT\nQjv8nJe25/KHT46w6cGl2tsfYHq0lVJMjw0ivbCaAwWO4ZeEMF8Sw/yobrBRUOWYZ7/haAnTY4O4\nYmokH606n2umR3OgoJrffXyYH726h5oGG3OTQpkZH8z6I1+dOfdZ+glKaxs5Xl7/lX2qf2nQK6WY\nHhuEzW54+L10Rgd4kRjmywXjwvF0s/DDl3dTVNXA3vwqFo5xjMuPHeXPI9dP5Yv/XMI106P4YF8R\nAFOiA1k8NpwDBdUdFk+z203bBc6zSmoH/gmOcBr0SimmxwUB0Ghr4YmbZuLpZiU53I9Hrp/C1uxy\n5j3yGeBYL/90q5aNxWoR3K1Cyig/Fo91TI/ecPRUrz6zpJbKesf6Oll6xu2A09UrlVJE+Hux4rxY\nx9BLu5Owrp0RQ2V9M0XVDSybMKrTC5wnhPly5+Ik8ipO4ulmZWJkAKG+HnxxpITrZsYAtPXm3a2i\nPXoX0KBXSgHwyPVTO93+nQWJ3d63/bLHFotw/thwvjhSgt1usFiE7cfKCfX1IDncj6wS7dEPNB26\nUUr1ufPHhlFe18R+54e7accqSE0IJjnC9ytDN5kltdz1/A69xm0/0qBXSvW5RSnhuFmE57fksPZw\nMcfL61mUEk5imC/ldU1U1p+ag//ungI+3F9E2rGvzr1XfUODXinV58L8PLltYSKvpOXxwGt7SQj1\n4YbUWJLC/ADIbDd8s9u57r2uf99/NOiVUv3i3mUpRAd5U1zTyENXTMTDzUJSuGPt/FfTclm9PhO7\n3bQF/J68yrb7PvDaXm7/ZxoFlSc7eWTVW/phrFKqX/h4uPHUTbPYml3GsgmOaZmxIT54WC28tN2x\nwqWfpzsV9c14uVvYnVuFMYYvjpTwclouIrD+aAmeVguLx4XzlxtnuvLpDGnao1dK9ZspMYHcviip\nbQVNd6uF//vWLP5+SyoBXm789uNDAFwzPbrtrNmH3ztIQqgPa+47nxtnxzE9Loj39hZyuKjGlU9l\nSNOgV0oNqCXjIlg6fhTXzIim0tmb/3pqLAB3Pr+TzJI6HrpiImMi/Pnl1ZP4fytm4OVu4dlN2S6u\nfOjSoFdKucQNznCf4lw738NqIb2wmjvOT+KiiaPa2gX7enDtjGje2Jn/lRUz22t/OUTVkQa9Usol\nJkcHcs30KK6fGYOnm5Urp0Zy45w4Hrx0/FfarpwTT6PNzpr0E50+1pPrMrn0Txv6u+QhSz+MVUq5\nzJ9WzGi7/dg3pnfZbmJkAP5ebuzOrWz7S2DNwRNsySrj7iVj+OvaDGobbTQ0t+Dlbu3vsoccDXql\n1KBnsQjTY4PYdbwSgPomG//1xl5Ka5v4LP0EtY2Os2pLaxuJCfZxYaWDkw7dKKWGhOmxQRwuqqa+\nycbzW3IorW1iUUoYx8rqCfJxB6Cs9tQYflV9M5syS9mSVdbVQ44YGvRKqSFhRlwQdgObM8v4vy+y\nWJQSxt9vOY+fXD6eX18zGYCyOsca+GW1jVz42Dpu/NtWvvm3LR3Wxgc4eqKGJb9fR3F1w4A/D1fQ\noFdKDQnTYoIA+PFre6mob+I/LxmHu9XCHecnt+0rdfboH/3oEJX1zXzvgmSMgZyyjgup7cqtJLu0\njj15VQP5FFxGg14pNSSE+nkSH+pDWV0Tty5IZKoz3B37PADH0M2u4xW8kpbHbYsS+dosx3r4p1++\nsLUnf/ovgOFKg14pNWQsHBNGcrgv9188rsN2Hw83fDyslNY28tGBItytwg+WphAd7I2II+g3ZZQy\n9ZcfU1bbyIlqx1DOSLl+rc66UUoNGQ8vn0yz3Y6n21enUIb6eVBW20ijzU5ssA++no54Gx3gxfHy\nehptdqobbBwtrqW4prVH333QP7Mxmw/2FfLid+fi4TY0+8ZDs2ql1IhksUinIQ8Q6utJWV0TOWX1\nxIWemmIZG+JDbnk96YXVAORVnOxVj37d4WJ25FTwr83Hzv0JuIgGvVJqWAjz86CkppHc8nriQ04F\nfVyID7nlJ9uCPr/iZNsYfV5FPS32My+dkFnsuMbt458epbS28YxtBysNeqXUsBDq68mxsjpqGm3E\nhfq2bY8L8aGouqFDL76ktpEQXw+aWwwFlSfJq6jvdK2cukYbBVUNXDsjmppGG2/tyh+w59OXNOiV\nUsNCmL8HDc12gA49+tgQ77bbbhZhf34VzS2G1PhgwDEGv/DRtTz5ReZXHjOzxNGbv2TSKAK93Tk2\nRGfpaNArpYaFUF/PttvxoR2HblrNTQrlSLFjXfvZiSEAbWPvv//4MBuOlnR4zNagHxPh1zYENBR1\nG/QiEisia0UkXUQOiMi9nbRZKSJ7nV+bRGRau32XishhEckQkQf7+gkopRScmksPjg9gT78d4e/J\n1JhAWkdopsYE4WG1YDfwo4vHkhLhz13P72Tn8VMXKc8orsXNIsSH+jqDfmhOx+xJj94G3G+MmQDM\nBe4WkYmntckGFhtjpgIPA6sBRMQK/BW4DJgIfLOT+yql1DkL83P06EcFeHZYwTLczxMvdwsTIgM6\nLHgWGehFTIg3fp5ufHt+As/eeh5hfh5865ltZDh7/RnFtcSF+uButRAT4k1excluP7wdjLoNemNM\noTFmp/N2DZAORJ/WZpMxpvXX4BYgxnl7NpBhjMkyxjQBLwHL+6p4pZRq1dqjjw/x7bBdRPj+0hRu\nmZ9AdPCp8fpwf0/uvmAMv75mMv5e7kQGevPiHXMR4PcfHwEgs6SOMeF+gGMIqKnFzokhuD5Or06Y\nEpEEYAaw9QzNbgM+dN6OBnLb7csD5nTx2HcAdwDExcX1piyllGrr0befQ9/q7iVjgFNj7oHe7ni5\nW7l+VkyHdpGB3ty2KJE/fXqUbdnlHCut42Ln1a5ax/qPl9cTFeTNUNLjD2NFxA94HVhljKnuos0S\nHEH/QOumTpp1+nePMWa1MSbVGJMaHh7e07KUUgqAYB8Pwv09mR4b1GWbaGdAjwrw7LLNrQsTCfBy\n44b/24zNbjgvwfGhbWvQn2mc/siJGv7y+VF++tY+GppbzuJZ9I8e9ehFxB1HyL9gjHmjizZTgaeB\ny4wxrQtA5wGx7ZrFAAVnX65SSnXOahG+fGApbpbO+pcOXu5Wwv09GRXg1WWbAC93fn3tFDZnlnLd\nzJi2oI8K8sYinQe9rcXOX9Zm8OfPM9rG8KdGB3HDebFfaesKPZl1I8AzQLox5rEu2sQBbwA3G2OO\ntNu1HUgRkUQR8QBWAO+ce9lKKfVVHm4WLGcIeoDbFp5a1bIrV0+L4n+vm9oW8gDuVgtRQd6dLpvw\ns7cP8KdPj3L1tCi2P7SMlAg/nt+ac3ZPoh/0pEe/ALgZ2Cciu53bfgLEARhjngJ+DoQCTzh+L2Bz\nDsPYROQe4GPACvzdGHOgb5+CUkr13J2Lk8/6vnEhPl8J+ue25PDituPcdUEyDzgvbH7T3Hh+8c4B\n9uZVdlhO2VW6DXpjzEY6H2tv3+Z24PYu9n0AfHBW1Sml1CASG+zDp+knMMZQUd/Mz97ez/t7C1ky\nLpwftVs6+dqZ0Tz60SFe2p7L1Jgg1h8p4R9fZrO/oJpnvp3aFv5PrMsgo7iWx26Y3q9165mxSinV\nQ7Pigymra2Jbdjm/eT+dTw4UsWpZCk/eNAtruyGjAC93UhNC2O28mPnD7x1kX34VTTY7v34/HWMM\nxhhe2HKcD/YVYu/nufka9Eop1UNXTYsi0NudP3xyhLd353PT3HhWLRvb4QStVuNG+ZFRUktdo43M\nklpunB3Hjy4Zx7bscj4/VMyxsnryK0/S0GynoKp/l1bQoFdKqR7y9rDyjfNi2XasHIDbFyV12Xbc\n6ACabHY+OViE3cCEyABWnBdLUpgvj350iHWHi9vaZpX072JpGvRKKdULN8+NxyJwzYzotnn5nRk/\n2h+AN3c5ZpRPiAzA3WrhPy8Zx5ETtfxxzRGCfdyBUydy9RcNeqWU6oXYEB9ev2s+P7/qzMt2jYnw\nwyKw8WgJvh7WthOuLp08mhlxQVQ32Lh0ciQBXm4a9EopNdjMiAsmwMv9jG283K0khPliNzA+MqBt\nfr+I8JPLJ2ARuHjSKJIj/Mgs1qEbpZQaksaNcgzfTIj077D9vIQQdvz0IpaMiyA53I+sUu3RK6XU\nkDRudGvQB3xlX7CvY7XN5HA/TlQ3UtPQ3G91aNArpVQ/meY8MWraGc6OTQp3LKvcnzNvNOiVUqqf\nXDAunE/uO5/J0YFdtkmJcKx3v2L1Fr7+1KZOL1J+rnq1Hr1SSqmeExHGjvI/Y5ukcD/+cuMMduZU\nUt9kw7leWJ/SoFdKKRe7cmoUV06N6rfH16EbpZQa5jTolVJqmNOgV0qpYU6DXimlhjkNeqWUGuY0\n6JVSapjToFdKqWFOg14ppYY56Y/Tbc+ViJQAOWd59zCgtA/L6StaV+8N1tq0rt7RunrvbGqLN8aE\nd7ZjUAb9uRCRNGNMqqvrOJ3W1XuDtTatq3e0rt7r69p06EYppYY5DXqllBrmhmPQr3Z1AV3Qunpv\nsNamdfWO1tV7fVrbsBujV0op1dFw7NErpZRqR4NeKaWGuWET9CJyqYgcFpEMEXnQhXXEishaEUkX\nkQMicq9z+y9FJF9Edju/LndRfcdEZJ+zhjTnthARWSMiR53/Bg9wTePaHZfdIlItIqtcccxE5O8i\nUiwi+9tt6/L4iMh/Od9zh0XkEhfU9jsROSQie0XkTREJcm5PEJGT7Y7dUwNcV5ev3UAdsy7qerld\nTcdEZLdz+0Aer64yov/eZ8aYIf8FWIFMIAnwAPYAE11USyQw03nbHzgCTAR+CfxoEByrY0DYadt+\nCzzovP0g8KiLX8siIN4Vxww4H5gJ7O/u+Dhf1z2AJ5DofA9aB7i2iwE35+1H29WW0L6dC45Zp6/d\nQB6zzuo6bf8fgJ+74Hh1lRH99j4bLj362UCGMSbLGNMEvAQsd0UhxphCY8xO5+0aIB2IdkUtvbAc\n+Kfz9j+Ba1xXChcCmcaYsz0z+pwYY9YD5adt7ur4LAdeMsY0GmOygQwc78UBq80Y84kxxub8dgsQ\n018/vzd1ncGAHbMz1SWOC7PeALzYHz/7TM6QEf32PhsuQR8N5Lb7Po9BEK4ikgDMALY6N93j/BP7\n7wM9PNKOAT4RkR0icodz2yhjTCE43oRAhItqA1hBx/98g+GYdXV8Btv77lbgw3bfJ4rILhH5QkQW\nuaCezl67wXLMFgEnjDFH220b8ON1Wkb02/tsuAR9Z5dNd+m8URHxA14HVhljqoEngWRgOlCI489G\nV1hgjJkJXAbcLSLnu6iOrxARD+Bq4FXnpsFyzLoyaN53IvIQYANecG4qBOKMMTOAHwL/FpGAASyp\nq9dusByzb9KxQzHgx6uTjOiyaSfbenXMhkvQ5wGx7b6PAQpcVAsi4o7jBXzBGPMGgDHmhDGmxRhj\nB/5GP/6JfybGmALnv8XAm846TohIpLP2SKDYFbXh+OWz0xhzwlnjoDhmdH18BsX7TkS+DVwJrDTO\nQV3nn/llzts7cIzrjh2oms7w2rn8mImIG3Ad8HLrtoE+Xp1lBP34PhsuQb8dSBGRRGevcAXwjisK\ncY79PQOkG2Mea7c9sl2za4H9p993AGrzFRH/1ts4Psjbj+NYfdvZ7NvA2wNdm1OHXtZgOGZOXR2f\nd4AVIuIpIolACrBtIAsTkUuBB4CrjTH17baHi4jVeTvJWVvWANbV1Wvn8mMGLAMOGWPyWjcM5PHq\nKiPoz/fZQHzKPECfZF+O49PrTOAhF9axEMefVXuB3c6vy4HngH3O7e8AkS6oLQnHp/d7gAOtxwkI\nBT4Djjr/DXFBbT5AGRDYbtuAHzMcv2gKgWYcPanbznR8gIec77nDwGUuqC0Dx/ht63vtKWfb652v\n8R5gJ3DVANfV5Ws3UMess7qc258F7jyt7UAer64yot/eZ7oEglJKDXPDZehGKaVUFzTolVJqmNOg\nV0qpYU6DXimlhjkNeqWUGuY06JVSapjToFdKqWHu/wNYqfbOkJ9wHwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(meta_losses)" ] }, { "cell_type": "markdown", "metadata": { "id": "64YicoxfJtZG" }, "source": [ "## Meta-training with Truncated backprop through time\n", "In the previous meta-training examples, in the meta-loss we always initialized the inner-problem and apply the optimizer for some fixed number of steps.\n", "\n", "This is fine for short inner-problem training times, it becomes costly for longer numbers of inner-iterations.\n", "Truncated backprop through time, and more generally truncated meta-training techniques are one solution to this. The core idea is to split up one longer sequence into smaller chunks and compute meta-gradients only within a chunk. This allows one to compute gradients faster -- each chunk we get a gradient estimate, but these methods are generally biased as we ignore how the chunks interact with each other.\n", "\n", "The code for this is a bit more involved. First, we need to keep track of each inner problem. In our case, this means keeping track of the inner problems optimizer state, as well as the current training iteration. Next, we must check if we are at the end of an inner-training. We fix the length of the inner training to be 100 for this example. We can then define a function (`short_segment_unroll`) which both progresses training by some number of steps,\n", "and return the loss from that segment." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "executionInfo": { "elapsed": 693, "status": "ok", "timestamp": 1647716674268, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "ri9wWMPizb-G", "outputId": "6025ce7a-98e5-459a-e1a7-6e19792056d0" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray(10, dtype=int32, weak_type=True)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def short_segment_unroll(meta_params,\n", " key,\n", " inner_opt_state,\n", " on_iteration,\n", " seq_of_batches,\n", " inner_problem_length=100):\n", "\n", " def step(scan_state, batch):\n", " opt_state, i, key = scan_state\n", "\n", " # If we have trained more than 100 steps, reset the inner problem.\n", " key1, key = jax.random.split(key)\n", " opt_state, i = jax.lax.cond(\n", " i >= inner_problem_length, lambda k:\n", " (lopt.initial_inner_opt_state(meta_params, task.init(k)), 0), lambda k:\n", " (opt_state, i + 1), key)\n", "\n", " loss, grads = value_grad_fn(opt_state[0], batch)\n", " opt_state = lopt.update_inner_opt_state(meta_params, opt_state, grads)\n", "\n", " # clip the loss to prevent diverging inner models\n", " loss = jax.lax.cond(\n", " jnp.isnan(loss), lambda loss: 3.0, lambda loss: jnp.minimum(loss, 3.0),\n", " loss)\n", "\n", " return (opt_state, i, key), loss\n", "\n", " (inner_opt_state, on_iteration,\n", " _), losses = jax.lax.scan(step, (inner_opt_state, on_iteration, key),\n", " seq_of_batches)\n", "\n", " return losses, inner_opt_state, on_iteration\n", "\n", "\n", "inner_opt_state = lopt.initial_inner_opt_state(meta_params, task.init(key))\n", "batch = get_batch_seq(10)\n", "\n", "loss, inner_opt_state, on_iteration = short_segment_unroll(\n", " meta_params, key, inner_opt_state, 0, batch)\n", "on_iteration" ] }, { "cell_type": "markdown", "metadata": { "id": "EBYgsPmn0-6E" }, "source": [ "Now with this function, we are free to estimate gradients over just this one short unroll rather than the full inner-training. We can use whatever gradient estimator we want -- either ES, or with backprop gradients -- but for now I will show an example with backprop gradients.\n", "\n", "As before, we construct a vectorized version of this unroll function, and compute gradients with `jax.value_and_grad`." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "executionInfo": { "elapsed": 3, "status": "ok", "timestamp": 1647716674380, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "xJRkyAX_1Oge" }, "outputs": [], "source": [ "def vec_short_segment_unroll(meta_params,\n", " keys,\n", " inner_opt_states,\n", " on_iterations,\n", " vec_seq_of_batches,\n", " inner_problem_length=100):\n", " losses, inner_opt_states, on_iterations = jax.vmap(\n", " short_segment_unroll,\n", " in_axes=(None, 0, 0, 0, 0, None))(meta_params, keys, inner_opt_states,\n", " on_iterations, vec_seq_of_batches,\n", " inner_problem_length)\n", " return jnp.mean(losses), (inner_opt_states, on_iterations)\n", "\n", "\n", "vec_short_segment_grad = jax.jit(\n", " jax.value_and_grad(vec_short_segment_unroll, has_aux=True))" ] }, { "cell_type": "markdown", "metadata": { "id": "QYrO5_ik1vbm" }, "source": [ "We can then use this function to compute meta-gradients. Before doing that though, we must setup the initial state (parameter values and optimizer state) of the problems being trained." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "executionInfo": { "elapsed": 826, "status": "ok", "timestamp": 1647716675325, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "iw2i_2kLsFRf" }, "outputs": [], "source": [ "#num_tasks = 32\n", "num_tasks = 16\n", "\n", "key = jax.random.PRNGKey(1)\n", "meta_params = lopt.init_meta_params(key)\n", "\n", "\n", "def init_single_inner_opt_state(key):\n", " return lopt.initial_inner_opt_state(meta_params, task.init(key))\n", "\n", "\n", "keys = jax.random.split(key, num_tasks)\n", "inner_opt_states = jax.vmap(init_single_inner_opt_state)(keys)\n", "\n", "# Randomly set the initial iteration to prevent the tasks from running in lock step.\n", "on_iterations = jax.random.randint(key, [num_tasks], 0, 100)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "executionInfo": { "elapsed": 24333, "status": "ok", "timestamp": 1647716699782, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "Kgltlb7b1n1M", "outputId": "9cf9de7b-5171-4d8b-cea1-a4df07b61075" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 2.3025281\n", "20 2.302613\n", "40 2.302506\n", "60 2.2939959\n", "80 2.228974\n", "100 2.146706\n", "120 2.062469\n", "140 1.9681495\n", "160 1.8850164\n", "180 1.8478796\n", "200 1.8156666\n", "220 1.7869465\n", "240 1.7575241\n", "260 1.7495174\n", "280 1.7104518\n", "300 1.7048458\n", "320 1.6748368\n", "340 1.6512482\n", "360 1.6221116\n", "380 1.5998225\n" ] } ], "source": [ "meta_opt = Adam(0.0001)\n", "meta_opt_state = meta_opt.init(meta_params)\n", "\n", "meta_losses = []\n", "\n", "num_steps = int(os.environ.get(\"LOPT_TRAIN_LENGTH\", 400))\n", "for i in range(num_steps):\n", " data = get_vec_batch_seq(num_tasks, 10)\n", " key1, key = jax.random.split(key)\n", " keys = jax.random.split(key1, num_tasks)\n", " (loss, (inner_opt_states, on_iterations)), meta_grad = vec_short_segment_grad(\n", " meta_opt_state[0], keys, inner_opt_states, on_iterations, data)\n", " meta_losses.append(loss)\n", " if i % 20 == 0:\n", " print(i, onp.mean(meta_losses[-20:]))\n", " meta_opt_state = meta_opt.update(meta_opt_state, meta_grad)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": { "height": 296 }, "executionInfo": { "elapsed": 191, "status": "ok", "timestamp": 1647716700105, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "6Br4uOBt66yf", "outputId": "c4f6530f-42b9-4910-aa24-b6ca56f830c3" }, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'meta-loss')" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAABJwklEQVR4nO29d5hcZ3mwf7/T6/ZdrXqzZLk3YRsw3QFsiilJgAAJJEAKJSSE\nBJJ8CYF8SSC/EBISIHwEDKGEACb0YrDBgDFgGVtykSxbXVppe5te3t8f57xnzjkzK+1KOzvanee+\nrr08M+85Z549Xr3PebrSWiMIgiC0L4FWCyAIgiC0FlEEgiAIbY4oAkEQhDZHFIEgCEKbI4pAEASh\nzQm1WoCF0tfXpzdt2tRqMQRBEJYVu3btGtVa9zdaW3aKYNOmTdx7772tFkMQBGFZoZQ6PNeauIYE\nQRDaHFEEgiAIbY4oAkEQhDZHFIEgCEKbI4pAEAShzRFFIAiC0OaIIhAEQWhzll0dwdny2PAstz98\n6qzP12i0Bq3t/4L9X+s9QCigiIWDAFS1JhhQRMNBosEA0XCAzniYrkSERCRIOhZisCOGUmoRfjtB\nEISzp20Uwb6TM7z323tbLYaH/nSU6zb3cN2WXp5yQR+b+pKtFkkQhDZELbfBNDt37tRnU1lcrlQp\nV8/td1UKFAqlIGA/ySvzuVKUKlXypQpgrZermkKpQqFcpVCuMpUrMpktkStVGM8U2XV4gp8dGOfk\ndJ6AgnfctIPXP2WLWAmCICw6SqldWuudjdbaxiIIBQOEgs39jnAwQDjoC7vEw3Me/5tP3ITWmiPj\nWd737X383Tf3svvYFO964SX0paLNFVYQBMGmbRTB+YpSio29Sf7tN67i0h928o/f2ct3HjrJTZeu\n5vefvpWLVne0WkRBEFY4kjV0nqCU4vefvpXv/tHTeNX1G/nBvmFu+fef8N8/P9Jq0QRBWOGIIjjP\nuGAgxV+/4BLu/JOnc93mHt5x2x7e/Llfcv/RyVaLJgjCCkUUwXlKbyrKra+9ludeMsjXHjjB6z75\nC6ZypVaLJQjCCkQUwXlMMKD48Kuu5jOvu46xTJGP//hgq0USBGEFIorgPEcpxZMv6OOGC/r44q5j\nVM8xBVYQBMGPKIJlwq9es47jkznuOTDWalEEQVhhiCJYJjznkkHS0RBf2HWs1aIIgrDCEEWwTIiF\ng7zgyjV8c88QJ6fyrRZHEIQVhCiCZcTvPXUrWsPvf2YX+0/NtFocQRBWCKIIlhEbehO8+5ZLODia\n4XWfupfZQrnVIgmCsAIQRbDMePm1G/iPV13D4bEsn77ncKvFEQRhBSCKYBly3ZZenrS1l1t/cohi\nudpqcQRBWOaIIlimvP4pWzg5nefru0/wlfuPc+tPDrLcWooLgnB+IN1HlylP297PBQMp/vh/HnA+\nu3pjN5ev62qdUIIgLEvEIlimBAKKf3n5lWzpT3LLlWsA+MG+kRZLJQjCckQsgmXMJWs6ueNtTwfg\n0GiGO/YO85ZnbWutUIIgLDvEIlghPPfS1dx/dJIDI7OtFkUQhGWGKIIVwkuvXkswoPiitKAQBGGB\nNE0RKKXWK6XuVEo9opR6SCn1hw2OeaVSarf9c7dS6opmybPSGeiIce2mHu7YO9xqUQRBWGY00yIo\nA2/TWl8EXA+8USl1se+Yg8DTtNaXA+8BPtpEeVY8T93ez96TMwxPSy8iQRDmT9MUgdZ6SGt9n/16\nBngEWOs75m6t9YT99h5gXbPkaQeeur0PgBd/6G4+efchHjoxxRs+dS/P+qcf8P1HTrVYOkEQzlfU\nUhQhKaU2AXcBl2qtp+c45k+AHVrr1zVYewPwBoANGzZcc/iwtFaYiy/tOsYH79jPobEsAB2xEPly\nlR2Dab76phtaLJ0gCK1CKbVLa72z0VrTg8VKqRTwJeCtp1ECzwB+B/izRuta649qrXdqrXf29/c3\nT9gVwEuvWcdXXBv+D97+DF553QYeG56V6WaCIDSkqXUESqkwlhL4jNb6tjmOuRz4GHCT1lrGby0C\nnfEw337rU0jHwvQkI2wbSJMtVjgxlWNdd6LV4gmCcJ7RzKwhBfwn8IjW+v1zHLMBuA14tdb60WbJ\n0o7sGOxgbVccgG2rUgDsH5YaA0EQ6mmmRfBk4NXAHqXU/fZnfw5sANBafwT4K6AX+JClNyjP5cMS\nzp5tA5YiePTkDM+4cKDF0giCcL7RNEWgtf4xoM5wzOuAuuCwsLh0JSKs647zwLHJVosiCMJ5iFQW\ntwnXbOxm1+GJulbV+VKFf/jWXqk9EIQ2RhRBm3D1hm5OTRc4PpnzfP6fPz7IR374OP97//EWSSYI\nQqsRRdAmXLelB4C7Hh11PhubLfDhHzwOwH2HJ1shliAI5wGiCNqEC1el2dSb4FsPDjnuoQ/e8Ri5\nUoUr13ex60i920gQhPZAFEGboJTieZev5kf7R9n8zm/yZ1/czWd+dpiXPWE9L75qLSMzBU5KnEAQ\n2hJRBG3E7z1tK1v6kwB8/t6jaA1vvXEbF9jppQdHMq0UTxCEFiGKoI1Ix8J87U038I233MAV67u4\n9bXXMpCOsbnPUg4HRkURCEI7IqMq24xkNMQlazr5yhuf7Hw22BEjGgpwSBSBILQlYhEIBAKKTb1J\nDo2JIhCEdkQUgQDApr6EuIYEoU0RRSAAcOFgB4dGM2SLZc/nR8ezfGP3UIukEgRhKRBFIABw2dpO\nqhoePuEdGfGcD9zFGz97H6VKtUWSCYLQbEQRCABcvq4TgD3HpzyfZ4sVAEZnC0sukyAIS4MoAgGA\nVR0xVnfG+JuvPcyT/v77DM/k2XV4wlk/NS2KQBBWKqIIBId/+42r6E9HOTGV56v3n+A1H/+5s3ZK\nqo4FYcUiikBwuGZjDz9757PoTUZ433f2MVMo8w8vuQwQRSAIKxlRBIKHQEBx/dZeiuUqAQUvvWYd\nwYASRSAIKxipLBbq+KMbt3N4LMOTt/YRDgYYSEclRiAIKxhRBEIdFwyk+Pqbn+K8H+iI1VkEDx6f\n4n3f2cdV67t44ZVr2NqfWmoxBUFYJMQ1JJyRVekowz6L4Lb7jnPXoyP8y/f388Hv72+RZIIgLAai\nCIQzsqojxqkZr0Xw2Mgs3YkwAQVjmWKLJBMEYTEQ15BwRlZ1RJnMlsiXKsTCQQD2nZzmGRcOMFMo\nc2Qs22IJBUE4F8QiEM7IQEcMwHEPTWaLnJoucOFgmoF0lBGpOhaEZY1YBMIZWWUrglMzeT724wN8\n6qeHAbhqQze5UoXxTJFiuUokJM8VgrAckX+5whkZNIpgOu8ogXQ0xBM2dTOQttbGMvVWQb5U4WM/\nOiAN6wThPEcUgXBGVnVEATg+kSMUUCQiQf77d69HKUV/2lrzZxUBfOB7+/nbbzzCN/dIG2tBOJ8R\n15BwRjrjYTrjYW5/+BTlqua9t1zKJWusbqUDtiIYmalXBPcdsZrWFctiEQjC+YxYBMIZUUrxhE09\n3Gt3I93Sn3TWVndZrqGjE97MIa21MwNZgsmCcH4jikCYF9dt7nFeb3FVEfenonTGw+wfnvUc/95v\n72PYthIauY0EQTh/EEUgzItnXTTAQDrKXz3/YjrjYedzpRTbV6XYf2rGc/w39wzxtO39XDCQkoZ1\ngnCeIzECYV5s6U/x87+4seHatlVpvrF7CK01SimqVc3QVI6bL1tNpao5KYpAEM5rxCIQzpntAymm\nciWnQ+nobIFSRbO2K8ZAR32fIkEQzi9EEQjnzM5NVvzg7sdHATg+mQNgTVecVR0xhmfyVKu67rxd\nh8f5+u4TSyeoIAgNEUUgnDMXr+6gNxnhrkdHADgxabmC1nTFWdMVp1TRDTOHXvrhn/Kmz/5ySWUV\nBKEeUQTCORMIKG7Y1sdPD4zxrT1D7Ds5DViKYF1XHIBjvvTSRhaCIAitQYLFwqJw2dpOvnL/CX7/\nM/cBtSK0dd1GEeS4ZmPt+ENjGed1uVIlFJRnEkFoFfKvT1gUdgx2eN7feNEqANbaisDEDQx7jk85\nr2cL5SZLJwjC6RBFICwKFw6mPe9vvmwQgEQkRHcizLEJryJ41FV3MJUrNV9AQRDmpGmKQCm1Xil1\np1LqEaXUQ0qpP2xwjFJK/atS6jGl1G6l1NXNkkdoLqb5HMBnXncdz9wx4Lxf2x2vUwQHRmquoemc\nWASC0EqaGSMoA2/TWt+nlEoDu5RSt2utH3YdcxOwzf65Dviw/V9hGfKJ1z6B/lSUS9d2ej7fMdjB\nHXuHnYIzsBRBRyzEdL7MdF4sAkFoJU2zCLTWQ1rr++zXM8AjwFrfYbcAn9IW9wBdSqnVzZJJaC7P\nuHCgTgkAXLuph/FMkcdHrH5Elarm4FiGKzd0AzAtriFBaClLEiNQSm0CrgJ+5ltaCxx1vT9GvbJA\nKfUGpdS9Sql7R0ZGmian0ByeYDes+/lBq3vpickcxXKVK9d3AYhFIAgtpumKQCmVAr4EvFVrPe1f\nbnBKXYK51vqjWuudWuud/f39zRBTaCKbehOkoiEnQGya0O2wA8wSIxCE1tJURaCUCmMpgc9orW9r\ncMgxYL3r/TpAeg6sMJRSDHbGHAUwlikCsL47QUCJRSAIraaZWUMK+E/gEa31++c47KvAb9rZQ9cD\nU1prmWu4AlndGWNoylIE47Yi6EtH6IiHJUYgCC2mmVlDTwZeDexRSt1vf/bnwAYArfVHgG8CNwOP\nAVngtU2UR2ghqzpiPDZsNaUbs/sO9SQjdMTCTOfFNSQIrWTBikAp1Q2s11rvPt1xWusf0zgG4D5G\nA29cqAzC8mOwI8bwTIFKVTOWKZKKhoiGgnTEQ3MWlM3kSyQiIYKB0/4ZCYJwjszLNaSU+oFSqkMp\n1QM8AHxCKTWXu0cQ6hjsjFGpakZnC4xnivQkIwCWRdBAEZQqVS5713d511cfWmpRBaHtmG+MoNPO\n+HkJ8Amt9TVA43FVgtCAwQ5ryP3B0Uy9ImgQLDZxhP+65/DSCSkIbcp8FUHILvT6deDrTZRHWKFc\nuraTRCTI6z91L7sOT9BrFEE81DB9dGRGppoJwlIxX0XwbuA7wGNa618opbYA+5snlrDSGOyM8Y23\nPIVgQJEtVug+g0UgikAQlo55BYu11l8AvuB6fwB4abOEElYmm/uSfOq3r+UTPznEi660Csg74mGy\nxQqlSpWwayaBWxFUq5qABIwFoWnMN1j8PjtYHFZKfV8pNaqUelWzhRNWHpev6+KfX3YlN2zrA6wB\nNgAzvhRS92jL0Uy9dXBwNMPbv/AAhXKlidIKQnswX9fQs+1g8fOxqoG3A29vmlRC29ARt4xSf+aQ\n2yI47mthDfB/v/EwX9h1jLsfH2uugILQBsxXEYTt/94MfE5rPd4keYQ2oyNm/Wn5awncimAiW6w7\nLx6xFEgjJSEIwsKYryL4mlJqL7AT+L5Sqh/IN08soV3osF1D/oDx0FTOSTltlFUUDVl/uu5JZ27u\nOzLBvpON1wRB8DIvRaC1fgfwRGCn1roEZLBmCQjCOWEsAvdmXyhXeOjENE/c2mutNcgqMm0q9g7V\nb/bf2jPESz50N2/87H3NEFkQVhzzDRaHsfoGfV4p9UXgdwBxzgrnTF/KSiM9Ppl1PttzbIpCucqz\nLrLGXU5lG6SX2org8Himbs3EDTIF6WEkCPNhvq6hDwPXAB+yf662PxOEc6I3FWVzX5KfH6yFnX5x\nyBpg88QtvcTDwYYWweiMFTeYbdCwzsQXZkURCMK8mG/TuSdora9wvb9DKfVAMwQS2o/rt/Tw9d1D\nVKqaYEBxbCJLbzJCbypKZzxcFyOo2j2LADLFSl2dgbEWZvJlCuUK0VBwQfL88sgE337wJO+4aYcz\nY1kQVjLztQgqSqmt5o1dWSwJ3MKicM3GHmbyZQ6OWm6e8UyxVnkcD9VZBFO5EuWqZk2nFUzOFL2K\nYtRVg2B6Fi2EF3/obv7jrgNzdkUVhJXGfBXB24E77S6kPwTuAN7WPLGEdmJNl7WhD0/XBtf0JGot\nKPwbsskU2m6PuvS7gEZmCqzviQMwNrswRZB1KZUj49nTHCkIK4f5Zg19H9gGvMX+uVBrfWczBRPa\nh4G0pQiMS2ciW6Q7aWUTdcTrexF968GTREIBnnvJIOCNE2QKZbLFCjsGOwCvdTAf7js86bwWRSC0\nC6eNESilXjLH0lalFHPMIRaEBTHQEQVgeNratMczJa7ZaH3WEQvx2LD3if97j5ziadv7WWW7hmZc\nFoHZ+HcMprn94VMLtgiGpmoFaqIIhHbhTMHiF5xmTQOiCIRzJh0NEQ0FGJ7JU61qJrJFemyLoNNn\nEWitOTmV5wVXrCEdtf583WmiJmPIWARjDfoUAeSKFWLhQF0weNg+PxEJclQUgdAmnFYRaK1lhrDQ\ndJRSDHREGZkpMJMvU6lquk2MwB5ubzKKZgtlylVNdyJM0lYEbtfQhF1zsK47TjioGM/UB3yPTWS5\n4b138r6XXs6vP2G9Z21kpkA6FmJrf0osAqFtmG+w2EEpJYNphEVnIG3NNDadRs0Es819Saoap13E\nhL2xdycipGxF4HYNGesgFQuRjoWZLdQrAlON/JUHjtetnZrOM5CO0p+OOrUKgrDSWbAiANYuuhRC\n2xMNBbj78TH+4NNWWwijCJ6wqQeAXYetgjPTgK47ESEdq7cITCppKhoiFQ01LDgzcQD3/APD8EyB\ngXSM7kS4YbO7M7Hv5Az5kmRWC8uLs1EEv1x0KYS254lbrL5C++zU0FV2w7l13XEG0lGn2thRBMlI\nzTXUwCJIRIKWImhQXXzArlcoVap1a8MzeQY6onQnI0xmS2it5/07lCtVnvOBu3jev/5o3ucIwvnA\nghWB1vq3myGI0N686ZkX8I233ADA5es62WHXCCiluGxtJ/uHZwG3RRAmHAwQCwc8weLZgvU0noyE\nSMdCTDewCA7ZiuDEpLeBrtaa4ekCA+ko3YkIxUqVbHH+T/cZ+9jHRzIcn5T22MLyYV4tJpRS24C/\nBy4GYuZzrfWWJskltBlKKS5Z08l//c61XLm+y5PNM9AR44Fjk0AtRmBcR6loyBMjyBbKJCJBAgFF\nOhZquNkbpXJ8IudpTzFbKFMoV+lLRelKWFlLE9miY3mcCbdC2n10krVd8YXcAkFoGfO1CD6B1WSu\nDDwD+BTwX80SSmhfnrKtn3Qs7PlsIB1lLFOkXKkykS0SULX21b3JqPOED1aMwGzcjVxDD52Y5thE\nju2rUhQrVU/B2aSdcdSdjNBlZy1NNMg6mgu3IjgmA3OEZcR8FUHcri5WWuvDWut3Ac9snliCUGOg\nI4rWMDpbZCJbpCsRcZ7in33JKu45MMYpuz3FbKFCMmI1mUvHwsz4qpK/cv9xwkHFa560GYAxVy8i\nowi64mHH4lhIwNitdMQ1JCwn5qsI8kqpALBfKfUmpdSLgYEmyiUIDqYFxfBMnrHZouO2AXjhFWuo\narhj7zBguYYciyBmWQTugO/jIxm2r0qzsTcBeEdkugPR3S7X0HzJFGrxBLEIhOXEfBXBW4EEVp+h\na4BXAb/ZJJkEwUN/utaCYv/wLFv6Us7a5r4kAVWbXTxb8LqGShVNoVzLDhrPFOlJRui0R2ROZhso\ngkTYcQ1NNhiKMzyT57kfuMtJaTWY1NU1nTGOTUgxmrB8mK8i2KS1ntVaH9Nav1Zr/VJgQzMFEwTD\ngK0IvvzL4zw2PMvFq9POWigYoD8d5aTtGsoWK06hWYddZzDjqTwu0p2oKYJpl0VgrIPOeISu+NwW\nwT/fvp+9J2f45p6Tns9NjGD7YPqsXENaa972Pw/wo/0jCz53Jl/iz7+8h/9314HTXl8QGjFfRfDO\neX4mCItOX8pSBN/YMwTAxWs6POuDHTEnRpCxs4bAcg2B13c/YSwC2/XjcQ3ZgeGuRJhQMEAqGqob\niqO15rsPWQogEvL+8zGKYEtfipl8ecGFZccmcnzpvmP89q2/WNB5AJ++5wif/dkRPvC9RxuuT2SK\nbH7nN/n8L44s+NrCyue0ikApdZNS6oPAWqXUv7p+bsXKIBKEphMJBXjbr2x33l+02qsIVnXEODll\ngsVlxyJIRa3N3lQXlytVpvNluhJhUpEQAVUfI0hHQ07FcToWqgs2j8wUnADzuK+zqalhWG26ojao\nYTgde45PARALL2yiGtS6rhYrVSrV+id/0zfpP344t8UgtC9nSpA+AdwLvBDY5fp8BvijZgklCH7e\n/KxtvOSadXznwZNs6El41gY7Y9xzwBpYny1WnBiBaUFhNvvJXK0GIRBQdMbDTOZqm/lUrkRXshaI\nbpR++ojd8whg3Oc2yhTKBBT0pa34wmyh7MQ35oOplTAWzUIwyq5U0ZycztfVMJiurAspkBPah9Na\nBFrrB7TWnwQuAP4HuEdr/Umt9W1a64klkVAQbNZ2xfntGzbXtY5e1RFjOl8mWyx76giMwjg0ZtUZ\nTNhP8iYQ3BkPM5Xzxg+64hHnvWUR1Na11nzlfqtR3cWrO5zrGUygOm1bIn5r4mcHxnjq++6cM37w\n0PFpwNq0/e0vSpUqf/bF3c44Tz+z7slqY/WBahND8Y/1FASYf4zgucD9wLcBlFJXKqW+2iyhBGEh\nGFfMwyem0RpnTsHqzhjJSJDH7EpiM7+4x6MIapv16GzBmZUMkIqFPVXLX33gBLfdd5xVHVE29yfr\n5iFnbLdUo2Z4ALfdd5wj41n+v+/sa/h7mEyjqq5lQRn2nZzh8/cenTN+MJsvO7UPjeYoDLuC6Y1c\nR0J7M19F8C7gWmASQGt9P7CpGQIJwkK5Yn0XAB//yUEALnT1KbpgIOUoggmnctgeepOIOIogX6qw\n7+QMF7kykvwxAlMb8F+/cx09iUi9a6hoBapNkHrG51Y6NWNtxl/+5XF+vH/Us1atak5M5Z3fxV+H\nYJTOnBZBocy2gRQBBUcbpK4ai6BS1ZyQYjfBx3wVQVlrPdVUSQThLNnSl2R9T9xJ57xsbaezdsFA\nmv3Dll/fbKbdbovA3swfOjFNqaK5an23c27a18Z6IlMkHg6yfVXa6U5adrlwMgUrdbXmGvLFF4am\nufmyQfpSUT5/71HP2limSLFc5SJbifnnKJjJaUDDjXzWBMGjoYZB6pPTtfPHMjJnQfAyX0XwoFLq\nN4CgUmqbnUl0dxPlEoR5o5TixotWOe/d7p1tq1Kcmi4wlSvx8NAUqWjIaXG9tT/J4fEsB0cz/PKI\nFfK6emOXc64/RjCRLTkVx732d0y6XEsZO0bgpK26rImx2QKnpgtctb6b1Z0xzxrUZiRcMJCyr+UN\n6g7P1JrnNZrDbOITiUiIbIM4wKmpvCOzu3ZCEGD+iuDNwCVAAfgsMAX8YbOEEoSF8ke/sp1NvQle\nca23znGbvbE+NjzLvYcmuGpDF0G7T9Err9tIOBjgk3cf4pGhGQbSUaedBVjpp7lSxXnqn7T7HEGt\n++mI60ndbMapBnMSTMB660CSRCTotKw2mKd8owj8m/mw64neH4Q235WOhkhE668NVkbUhgZtNQQB\n5q8ILrZ/QlhtqG8BFl71IghNoiMW5vtvezp/9+JLPZ+bjfXf73yMvSdn2Lmxx1nrT0e5Yl0njwxN\nc2gsw5b+pOdcf0HaeLboKIC13VZ6pttNkylaweJIKEA0FPBYEyaPf0NPklS0/qn9uN0ue9uqtH2t\nuS2CaZ8i0Fpb9ROxEIlIkFwDRZApllnTGW94/rkyNJXj2v/7PfaenF7U6wpLx3wVwWeAjwMvAZ5v\n/7zgdCcopT6ulBpWSj04x3qnUuprSqkHlFIPKaVeuxDBBcFPMKDqUkvXdVtPwXfsHSag4MaLvb0S\n13TFOT6Z4+Bohs19XkWQ9rWomMyWnIZ36+w8fXcqaKZQIRkNOue6g8VHxnK2PHES0RBZv+tnOk8k\nFGB1RwylrOZ53vUC63vMRu5dK5StIrJUNEwiXK9ktNZkixUG7eyqxbYI7tg7zPBMQYrVljHzVQQj\nWuuvaa0P2m2oD2utD5/hnFux0k7n4o3Aw1rrK4CnA/+klIqc5nhBWDDGDQRw158+g0vWdHrW13bF\nOTaRYzxTrFcEUa8iGM8UnUBzXypKJBTwZPe4G96lY2FPoPnoRJbBjhixcJBkJFiXzz+WKdJrF7ol\nwvXunZHZAlv7LevG7+M38qWiQRLRYF3RmKk27klGiAQDdW0zzhWTjupPpxWWD/MbvQR/rZT6GPB9\nrDgBAFrr2+Y6QWt9l1Jq02muqYG0sh7hUsA40rZCaALvedGljM4UHOvAjXHxAGzq9VsEdouKQtlu\nT1FyAtGBgGJtV9zJ9y9VqhTLVZKRWudTty//yHjWeaJPREJ1wWDTAwmwLAafopjJl1ljWyH+rCDj\nujKuoWMT3msb6yMRCdIRDy26a8hMgTNtPoTlx3wVwWuBHUAYMPlyGphTEcyDfwO+itXGIg28TGtd\nP00cUEq9AXgDwIYN0vRUWBivvn7jnGvuVgz+HkbGDTSesbKOtMbJGgLLzWOKwMxm626B7d6wj45n\neeKWXvsYyyLQWjuurDGXIkhGgnWKIlMok46FSEfrN3LT7C4VDRMPh+piBNlSbY5zh6+IbjEwcZLH\nRmYplCtEQwtvkSG0lvkqgiu01pct8nc/B6ta+ZnAVuB2pdSPtNZ1ESet9UeBjwLs3LlTyiKFRWOd\nyyJY7+thZIbXHBjNOGs9rtTUdd1xbh+y/lxNi4eUHSPoSUac4GmmUGZoKu8EoxOREFpDvlQlbvcV\nmsgWne/zp4CWK1UK5SqJsFW17HftGMWQjAYdJePGxBsS0SAdsfCip4+a1NdKVXNqquBkJwnLh/nG\nCO5RSl28yN/9WuA2bfEYcBDL6hCEJWNDT5Inbe3ls6+7rm4tHQszkI5yYCTD9x8ZRim43n6qB8ua\nGJ0tkitWnKdyYxH0pSJOaumBESt11GQwGWXh3rDHZ2vxh2TUaxE4T/TRIB3x+vGbpq6gLxUlHqmP\nEZh4QyISpDPeWBHkSxWK5YYGOQ8en+Jfvrd/TtfPicm8MzPixJRULS9H5qsIbgDuV0rtU0rtVkrt\nUUrtPsfvPgI8C0AptQq4EJC0A2FJiYQCfPb11/OkC/oarm/pT3JgZJZv7hli58ZupxgNahlJxydz\njp/eKIL+dJRpeybB4yNWiwsT7E3YcQTjTiqWq8wUyk7Bl98iyDkbech6ovcpAtOCui8VJRkJUSxX\nPRXP5loJ2zXkzzp6fGSWy9/1Xa74m+9ypz3y083Hf3KQf/7eo7z9iw/UrWmtOTWd56oNXYDECZYr\n83UNnS77pyFKqc9hZQP1KaWOAX+NFWNAa/0R4D3ArUqpPYAC/kxrPTrH5QShJWzpT/HZn1nDXP7q\n+V6j2ASaj01kneyklEsRgLVJPz4ySzCg2GgHo5M+i8A9K9msH5+sPdXXrI0g6VjI6RtkGJkpEAoo\nuuJhp4V1tlShw56r4A4Wd8ZDdTGCh05MU6xUoQLvv/1RnrHDm2JrXFE/fXyMqVzJme4GVhO7clWz\nY7CD7zx0iiFRBMuSeSmCeaSKNjrnFWdYPwE8e6HXFYSlxN236KbLBj1rJr5wfDLnepq3NuKaIiha\nMYbuuDPRzFgEZoM3aZcei8BVR2BcPfGw5Rrad6o2EwEsRdCXihIIKCfmkCtW6LCznjIuiyDZYMaC\nCXjfdOkgu4/VtxSbyZcIBRTlquZH+0d4/uVrnDVzrVUdMTpiISdesBAePD7Fpr6ko0SFpWe+riFB\naEt+7Zp1vOjKNbzgijWs7vQOexlIxwgFFMcmcs50spQTI7AUwchMgbHZgqd1Rc0isM5xmuG5s4aK\njSyCEFv7kxybyPHYcE0ZjM4WnGE4Jn3VHScwrqVkNEjKdh255x0cm8jRk4ywpivOZIMZzTP5Mtdt\nsSqyj/haXJt4RSoWYnVnfMEWwQ8fHeH5H/wxb/9CvdtJWDpEEQjCaQgFA3zg5VfxwVdcVbcWDCir\nMnkiVxcsNhbByEyBiUzJaX3tPsY89Y/5LQJfHYEJFiciQV5x7QZi4QCfvLtmpI/MFui3FY+xCDKu\np34nWBwOOd/tXj82kWNdd5zuRJhMsT5oPFMoMZCOEQsH6obxmHhDOhZisLM2O3q+fOxHVljwWw+e\ndNqFC0uPKAJBOAcG0lFGZgqO+8VYBL1Jl0WQKdKTrI2sNE/tZoOeaGARlCqaQtlad9co9Kai7Bjs\n8MwlGJkpOIrHuKZyJbdFYMkWjwQbNsQ7NpFlXXfcaajntwpm8lYNQ08i4sx0cK8BdMSsrqt+txNA\noVyZs+p40nW9G9//Q4+CMlSrum5im7C4iCIQhHOgJxlhPFNkNl8mGFBE7ThAJBSgOxFmZDbPRLZI\nTwOLwLhVxjJFlIIuOwjrbNb2JmuUTNweat+XijhWRLWqGZ0tOq4of/zBOr9CJBggEgq4LAJLUWit\nOT6RY21X3Cmgc7fW1lo7iqArEamzCMzvkI6FSYTrG95NZUu86N/v5ob33sHP7LnSbqbzJZ5/+Wqu\n3dzjvPfzJ198gG1/8a26z4XFQxSBIJwDvakoY5kCB0czbOhJeJre9aWsGoRKVTs1AmA9PUOt+dtE\npkhnPEzIzvJJuVpbQM2FZDZxS/lYKaPT+ZLTRwhqSsRTh1AoOy4jE58wg29mC2UK5SoD6Zgjo3uz\nN6Mt07Gw9b0NrAWwXEOJSLDuif5ru0/wyNA0iUiQf/ruo3X3b8YeqPPK6zY43+fntvuO27+TdKBp\nFqIIBOEc6LUtgoeHptkxmPas9aejPGpn+LgrkkPBAB2xkOMWGXe1l4D6rqfuGAFYymc8U0Rr7bhq\nzPm1c2tP1tlixTm35hqyrmlk6EyEnbRQt0Xg3ujNVDY3TrA4GiIeCXlcUgCPnpohHQ3x0mvW8cuj\nE57Yh9aa6VzJsiZ8tRUGt4XRaASnsDiIIhCEc6A3FaGq4fBYlh2D3l5F/ekoo3bVr3ujB+hKRBxf\n/HimSE/iNIqgUPG4nXqTEUoVzXS+XKtBSPgVQW3Ddef++4PFZmPvioedGIU7RuB2/XQnwnW+/pl8\nGaWsuEfCjm24/fmPnpph26oUN1zQR6mi+cWhCWctX6pSrmo6YmGSpv7B1x5j97FJ5/WRsbNTBEfH\ns9y5r75QTqghikAQzgH3Bn+h3yJIRRseB1ZDu4m5LAJn5rG1nimWSYSDjtvJHDueKTqbtvHvJyMh\nlPJaBG5F4A8WT+ZqgWoTo3A/9buzgroTEaZy3jnNM3lrGE8goGrFbK6n+EdPzbJ9VZrL13VZ70/O\nuK5tfU9HPOS4rvyuoaOuNt/u1/OlVKnylPfdyWs/IXO0TocoAkE4B/pcm/3l67yzDvrStTV3jABs\ni8B2wbgnn0Htqb4WI6iQiNY6evba3zk2a6Wmuq8fCChS0ZCnjcRUrjZQJ+WzCCZcFkEiEiQSDHgy\ng4xC6YiFnM6rUz7XkSlccxezgVXfMJ4psn1V2pnt4M4qclsbTkqtP9js+q6j4wu3CNwtM+bqpSSI\nIhCEc8K9ga/p8hacuS2CgY6oZ60rHmYya/v5fRZByufemc6XPFW3pt5gLFOscw2BNbbT7RqazJbo\nipv2FV5FMOVYFBGUUqRiIU9QtmYR1FxHEz7XkVFcCZ97xzz9b1+VdiwGt+tnKldLPTXn+junmgZ5\n2wZSnmlw88WdZtto1rNgIYpAEM4BsykPpKN1a+YpfG1XvK5Hf3cizGS2xImpPOWqdsZIQr1FcGQ8\n6xmqY5TG2KylCIIB5ZxjzndvepO5Ip22LJFQgEgw4ASLzdO/cR35M3/MRtwRCztP/m4lM+1SBPGw\n96neBMq3D9aa7c26gsGe1FMnWFzfYjsdtdxSc23kpUqV9393X8M5C+5KaP9AH0O1qtG6vbvbiyIQ\nhHOgPx3l7c+5kC/+3pPq1rbbg+j/8nkX1a11JiJM50vsOmwFT69c3+WsRUOWi2Y6X0JrzeGxLJt6\n6xXBZK7IRLZEVzxMwDWS01IE1qaXL1XIl6qeRnFWm+tasDgVDTl9kJKRkOep3GyuXYlwwxiA6XME\n9cVs+07N0pUIO5ZRKuq1CIy10RkPeZrluZnKleiIh+csVgOrKvlf73iM9393X93afBTBlj//Jm//\n4rk2U17eiCIQhHNAKcUbn3FBw2Es63sSPP53N3PTZavr1roTYbSGH+4bIRoK1GUcmc18PFNktlBm\ng2uMZiwcJBoKMJktMZktOpZH7dwwM3adgHsjNySjNffPZK5YpyTcG/1UrkQsHLBmLTdoTzE8U3Cs\nIb+i2H9qhu0DaSfIbY3obBwjiIYCBFR9+uh0rkxHPGzLXF9jYB1jXafQIAZwbCJHXyri+T43JsPp\ni7uONbx2uyCKQBCaSND1pO7GbMxfuu8Yl67tdJ7IDalYiNl8mcP2E+1G3/S0bjv9dCJTqgtEu8dk\n1tJD3cHo2kyDyWx9HyT3Zj2ZLTrn+jf6XLHCTL7MgD2joRYsts4/OpFlU19N7pSv86lpb52OhVBK\nkYyE6oLF07kSnfFQ3ehPN3nbivDfw0pVc2wi64wgbVS1bIb6tDuiCAShBbgnnb3q+vo53MbPb3Ln\nN/osji47xjCRLTpBXO+5RhF400sBBjuizsD5sUzRo0gSvglnk9kGNQj2Rj88Y12jZhF4YwTTubLH\n2kj4rI3JXJFIKOC0zoj7gslgbd4dsTCpaH3VssG02/Cr3KGpHKWK5lK7lbh/II/7d2h3RBEIQgtY\n3Rnn46/Zye8+dQsvunJt3Xo6Gma2UHYyZdZ2ezOSOuNhJnO2ImjkGrLjCyZF1b0hr+2OOyMlhyZz\nrHYFqv0xgslcyQk0OxaB7aIZtkdxGovAbTGUKlVypdpMBKBuFsL4rFVIZ1xHyWi9ReDECKJhcqWK\np4bBMDxdcI51Y0aEXmHXMDSyKMy57Y4oAkFoEc/csYp33nyRpz+RoTsZZmy2yKnpvJ1e6R3aYlkE\nVrDY7xpKx0J299KqYxF4FEFXwokvjMwWPGmviWjQ46efsoPR4GpoZysK03J6VYe3BbZxGRlZDMmI\n99p+ayYebmAR2MVw/hkObsxTvb8z6gF7ROgV6y2LoFGMwCizSLDxVvjj/aNsesc3eGRouuH6SkEU\ngSCch6zvTnBsIseJybwntdTQFY8wNJWnWK467aMNZq7y0FSeo+M5ggHlucaaLuv1fUcm0BrWuAbu\n1FsEtWB0MKCIhQPOU7t5mjZDdxLhmkXgpJ3G544/jGWKTvqtte51HZUrVTL2pDV/Sq0bo5Am/RbB\naIZUNMRgR4xEJNjQIhixFUE01HgrvP3hkwDc/Xh959SVhCgCQTgPWd+ToFipsvvYpLOxu+lK1IrG\n/K6hzXaA9uDoLIfGMqzrjhN2PfGaEZum78/qrtr1E5GQ1QPIdsFMZkseRZN0Zf6Mzlqzks33h4JW\njUK2VHZZBDXZUlFLyZic/YmMzyKIhDxP/JOOMmk8UAesQPFxu/WEf47CwdEMW/qTKKXqaivAshA+\n8zNrwE+uVGlYSxC1lVujqmTT+G8lIIpAEM5D1ttZQsMzBQYbKIJO1+bvDxZvslNND45mOTSWYaMr\n9RQs1xDAzw+OA96KaOOCyZYq5EsVCmV/DULI1Z6i6FQkG4xradrVmsJZi4SoaqvZHJhme65rR4Ke\ngrL7j0wCcOGqdMOBOgBf/uVxMsUKFwyk6jqjHhjJsKXP+t3TvmprgM//4qjjGipXtSOXG2MpmCFB\nhkOjGa5+z+382kd+2nDOwmLxqZ8e4vBY5swHniOiCAThPGSDK110LteQwR8j6ElGSMdCHBrNcHg0\ny2ZfxtGqjihdibBTzOZ2DbnbQTeqQUi45ilbqatea6QzHmYqV/LUCBhSziyEMqVKlel82TO5Le7L\nWPrR/hFi4QDXbOquG9Zj+NRPD3PJmg5uvmy1M5sBrDjF8ckcW/pT9nfXF6SdtOcr/97TtgKNYwjm\nen5LxATx7z08wZ9/eU/deYvBeKbIX33lIV7/qXubcn03oggE4TxkrespvZFryL0B+zdjpRSb+5Ls\nOjzBTKFcZxEopbh6QzcAm/uSTpAXcAVly06OvbtFdtI1T9nKWPI107OzmUyNQEfcaxGA1YvI9Cvy\nTG6LeGc1/+ixUa7b3Es0FHT6L7k35EeGpnlkaJqXPWG9U6BnYhOmx9CWfut3j4eDTr2BYSxTZH1P\nnItWWxXgjdJLjRXhrzcw8l+zsbsuSL1YnLCVTbnSfPeTKAJBOA+JhAK8+vqNbOhJcM3G7rr1K1wt\nKfzBYrA2+IftTBd3UZfBWBxP2dbn+dxtEYzMWm4TdxdVqxdRbaiNv6q5MxFhKlt0XENpX/ooWBbB\nuG9Os7m2sQiOT+Y4MJJx5DNznmdciuC7D51CKXj+5WuI2b78vO3COTBqZQxt6UvVXdswOlugNxlt\nOMzHYD4z98Jg3FCb+5JM5UpNiRUcs2Mfvan6/7+LTejMhwiC0Are86JL51xz+/X9mzHU4gRAnUUA\n8PJr1/P13UP81pM2eT53WwSjtv/c3Wo7GQk5WToT2aKnRxJYFsHhsQzT9sCadNRtEdTSS0v2U663\nmC1EoVylUtX8eP8IAE/Z1g+4GvG5ntofHppic2+SnmTESf80QV1TQ7DZjhHEIsG66Wljs0XWdMUc\nZdUoq2gui8C4zTb2JKhUNZlixdMhdjEw7qfeZH1Dw8VGLAJBWKa85ZkXsLE34ckIMpgNMKCsVFQ/\nOwY7uPcvb2Sr7UM3mCfvTKHMqP0U3O+2CKKWRaC1tiyCpFcJmYrnmXyJVCTkaYYXdzWlMy6gVANF\nkS2Wue/wJD3JCNtXperONTwyNOO0j4iGvYrg2ESWgXTUOS8eDpJ3WQQPn5jm8FiGvpTbIphbEYzW\nWQRF4uGg0168UefTc+WYPZozFm7+Ni2KQBCWKX/87Av54duf0XBtk60I1nbH63rwnA6TSnpoLMvI\nTIFYOOCMkYRanUG2WKFYqTaMEUznS0xlS55iMsBpJZErVpyAc8J1bTN8J1u0so56krWMpEgwgFK1\nvkKzhTJHxrOOf99YBKbx3KSv0C7hsggqVc3N//ojMsUKvamII5c/Mwjw9GRyM2G7xUxG1dRZxgkO\njMyy6/B4wzWTFtsom2mxEdeQIKxANtvuoE0N3EKnYyAdYyAd5aHjU1S1pi8V9aSH9qWiTGZLnLTd\nQ3VZQ4kIWlvtn/2xC/dTvdnQ3YFqd4uKbLHiURJKKWKhWsB3vz3r4EK7a6tRdkW7/sE9nhNM1bJp\njVHrL9SbjDqzIhp1LzUWQbFiuaxME0HTg8kUzJmRnwvlmf/0QwAO/cPz6tbMPc43UFCLjVgEgrAC\n6UyE2dCTcBquLYTL1nay5/gUo7NFT3wAas3vHjw+BdQHqk07igOjmbogp9siyDoWQX1WUaZQJles\nOMcbYuGA83RsUj9NdpWjCMo1ReCuao6FgxTKVapV7TTcA0spObUCvhjCO2/b7ZmK5rYYpuyKa5PG\nO+1zDX3+F0d47gfu4l1ffYhfef8PGza3O1OA2QTU/dlOzUAsAkFYoXztzTeclX/5krWd3LlvmFyp\n4vjgDabQ7T67BqHHV8xmAtfjmfrU0oQnRtDANeReL5U9oz7BmwJaa3hnHVMr/LIUwUze2/nUWB75\ncsVJywS4+bLVTnzBbRForfncz496vj9fqmJ+pclsia39Kaewzx8j+D//+xDFSpW99rjOAyMZpxWH\nwWQFzcWErQhyS+AaEotAEFYonfFw3YjM+XDDBX1UtbVRrfXNYTYWgem94193ZzD5lUTM1YsoV6wQ\nUN4eP26LwHINherOz9ub9fBMnmBAOTUOkaC3FYTfNeTOWBqyO6/uftez2dyXrIsvgNWnyWBiJO4n\n8wl7joMTI/ApAn82VSO30wPHJp3XfuugUK7FUfyWSjMQRSAIgoedrrqFX9+53rPWm4yQjATZPzxL\nOKjqit3cqY69PkUQDdUCvplimUQk5G1PEfEqCre1AFbfn5yr4V1fKuJkJbmzhsqVKrMFr0XgVkIn\nJvOkoiGnRXYoGCAUUB7Xj5m3/NYbt/GOm61Ro0YRlCtVJrJFeux7EQyoumDyTKFMLBxwUoAbuXce\nta0FqFcU7usthWtIFIEgCB4CAcW7b7mE33riRi5e43UNKaWcuoQ1XfG6CWzuuQn+HkhKKRL2Zt5o\no0+6Btv4g8VgxQjMZm2NyKwpIaeOoFLxzEI2mHhDvmS5htZ0eRVYNBSg4HLBPDZsFaT91hM30W/H\nOkx84vhkjkpVs7HHamhn2mq4GZ0tcMsVa3miPYCokUXw+Gith1DOV+xmKpfT0ZBkDQmC0Bp+84mb\n5lzbuambh4emGzbDc9c0+C0CsPsJlRpv9HFXHUGuWCHudw2FvDGCNa4eTO5g8VSDFtju+MPobKEu\nCB61g8lguWm+/eBJVnfG6E5GnA6kJnvnkD01zqToxsPeYrVqVTOeKdKXjswZiIZa0RtYTf7c9eMm\nULy6K8apJRieIxaBIAgL4klbradcs1n5MZuu3yIA289vP/H7N3pT1TyTL1OsVBtaBObpeGQm7wSK\nwasIphtMZXNnLE37AslgWwT2Rv/jx0a59/AEb3nWNut7Q94YwSH7SX6THS+Jhb3WxES2SKVqpd7W\nWl94n+qPjGV5ZGjaGfNpLIJ8qcJzP3AXn77Hao+9ujMuriFBEM4/zLzlG3x9igwmhbRR6wtT2JUt\nlus3envDNVW8jSyGXKlCtaoZy3hTWyOurKGpBoogZqyNUsWZeubGUgTWZn1k3Hrif+aOAetck1Vk\nb/aHxjIkIkGn4jrma2g3arej6EtFaxlJrnWtNU/9xzsBnPReowgOjGTYe3KGb+6xBuKs6Yo5aa/N\nRBSBIAgLoisR4SfveCbvvOmihuvmSXp1Z7xuzRR2NXINBQKKRCTobKTxBorCBJq19o7BdGf+NFIE\nToygWKmrMQCIhoLORu90TrWDybGw1yI4MpZlQ0/CCXTHw0FP0ZdRZH2paF1aK9QKxQBuuXINgNNy\nwz97YLAjXnd+M5AYgSAIC8afNurm5ddu4OXXbmi4Zp7qc8WKM+vYTTIaYtjeKBtlDeVLVaf7aSpa\n28zdTefM03WiQR+jyVyJQrnqGZhjXbvmGprKlYgEA44l4O9sOpopevovxXyzlk0R2mBnzGmN4bYI\nTGzgs6+7zmOpABy0FcGOwTQjMwWnjXe+VKlTjIuJKAJBEJaMeNh64s+WynV1AmAFmE2hVTzsryMI\nUChVnAEzJqYAljURCQYoVqoU7DYT7oH0xiIwnVNP5xqazpfoiNdSW41CMPGJqWyRja7BQbFwgPFM\n7Yl979AMsXDAsRrc1959bJJXfuxnAGwdSDnZQUZ5HR7N0peK8q0/fAr5UpX/vf+49d1NbjPRNNeQ\nUurjSqlhpdSDpznm6Uqp+5VSDymlftgsWQRBOD8wFkG2UO8aAqv3vnmirg8WWy4YM5zG3/Y5EgpQ\nLFedp293sz3z5G0UQUPX0FztKXzB4gnfHIaozzW09+Q0F65KO6m17hjCrXcfco4bSEdJhGsps2BZ\nBJv7LAUSjwTrlFCzaGaM4FbguXMtKqW6gA8BL9RaXwL8WhNlEQThPCAeDjm9hhopAvfoykbB5FJF\nO0Ptk3MoAtN4zl21nIyECAcVjw9brpd6RVBzDU3nSk58ANwxAqvx3HS+5OmxFA/X4gtaax4Zmva0\n5nBbBKYAb8dg2mqkF7FkNOmnw9N5Bl2xFXe2UzNpmiLQWt8FNO6vavEbwG1a6yP28cPNkkUQhPOD\neCRg1QmU6tNHwVt74PeJx+1NczxjBWPrLIKgrQjK9a6hYECxtT/Fzw9ZW5J7owc7RuAEi71ZRUah\n5O2MI629XVdj4YCzkY/MFpjIlrhwMO06v2YR5IpWk7v/feOTgVpbjZwdY5gtlD1BcH8NQ7NoZdbQ\ndqBbKfUDpdQupdRvznWgUuoNSql7lVL3joyMLKGIgiAsJolIyJnx629hDdDn6liabNBrCGB0xvKr\nN3INFcoViuUq4aDyDMUBPE/p7qpj8LqGpvNlj8UQCCgioQD5csXx6btdQ+5CtyO+YjNL7ppFkC9V\n6IyHnd8l7mp9AVYNhXuqm98t1SxaqQhCwDXA84DnAP9HKbW90YFa649qrXdqrXf29/cvpYyCICwi\n7o3enXljcLuG/G2s/XUGDV1DFcsiiDSY2rbD9ZR+OteQ1bDOX9VsWQzGLeV2DZkYgNbaqUHY4Aom\neyyCktclFgxYwWRrfGeVQrnq+b2MVVRocoyglVlDx4BRrXUGyCil7gKuAB5toUyCIDQRd5M6f5tp\n8G7+6QbuG6gVbPktgqgJFperDaeyXbWh1sShzjVk+/G11nUxAqht9pPGIvC1uK5qa3jNkfEsSnnT\na90xgmyx4lgDhkTEqq1oFAQ3wWL/vOXFppWK4CvAvymlQkAEuA745xbKIwhCk/EoggYWgdlgTbtr\nN2YDHcsUCAZU3ayFiL3hFudQBNdu7uH9v34FQ1P5us04agd8s8UK5aquSy81imAiY9xaNYVViyFU\nOTqeY7Aj5rl+zNWLqFE9QMIe/znbSBEskWuoaYpAKfU54OlAn1LqGPDXQBhAa/0RrfUjSqlvA7uB\nKvAxrfWcqaaCICx/Bs+gCMzgmz94+ta6NRMzODmVJxkJelpYgytYXGmsCABecvW6hp8b11CjhnVQ\n63NkXEPdPtcQWEVjR8ezzu/gvra7XsCfDZWOhZjJuxRBzG0R1DKWmknTFIHW+hXzOOYfgX9slgyC\nIJxfuC0Cv2sHrNbWe9/z3LondoDupLU5H5vI1VUGg2URzBbKc8YITkc0FKCqa430GloEZcs1FFDe\n9hbuzfrkdJ6rNnR5r+0KFudKlboeTB2xMDP5ErP5eovASR9dwcFiQRDaDLdbxP9Eb2ikBKA28Wy2\nUK4LFENtpkChXF3wZDZz/Ig9AtMfI0hEgmQKZWdovTsjye3Ht4LBc7fPzjWIEfgtAvfvFg3XUleb\niSgCQRCWBW53TKqBRRANBa2sodO4hubCbLhmyHyHL2vIemovM5Gtn8XsHnqTL1ac9+5ruy2CRq6h\n6XzJUQSeOoJQfa+iZiC9hgRBWFLe+9LL5nzqPx2xcNAZAuPfjMFuX1GsUChVFqwIjDvm6LjV3sLv\nGuqIh5nOlSyLIFHvNgJLEViFct7vtjqb1tJH/YqiI24pmUwDi0ApZVkU0n1UEISVxMue0Lgz6Xzo\nSVq9iNZ113c/TUVDzORLFCvRhvGH0zFoTzvbZ88q9ruGOmJhpvNlJnNFz4hMqLmGZgtlylXd0CLI\nu9NH5wgWzzSIEZjrL9sWE4IgCItNwN6x1nfXp5emYyFmC2UKpYUHi9fY/X322QPl0z7XU0fcuvbY\nbLEu2GviCybQ7Ld24uGg0/qiWK46jeZqcoepVLUTn6hXBEGJEQiCIBjMLIK5LIKqtucJLNA1ZCyC\nI+NZUtEQIZ8iMRbC0FSerrgvRmA/4ZsUUX+dgEl7ra17r22UzompPPFw0OlaarAylpZv91FBEIRF\nxfjR/bn6UAsgj2UKns6j8yEWDjq9j/zxAfDWFfh7JJneQM5kNX/lcNTbGqMuRmArmROTuYZBcLEI\nBEEQXJjsm0YWgWlJkS8tPGsIcNo/+91CgKduoSvptQjM95pZB/6N3lgEtRGcfteQ9f7gaMYzh9lg\nFbOJIhAEQQDgQ6+8midu6W341O7u2nk2imC17R46k0XQ1aDqOBxUDE9bT/z+YLBJFx2bwyIwimQ8\nU/Q05XOuHxKLQBAEweHmy1bzuTdc37AYLeUZZr/w9NSr7Yrgx0dm69bcWUQ9PotAKUVHLMypmTks\nAltBjc02jhG4rY1GjfjikeCynlAmCIKwZKTO0SJ47ZM3A/DMHQN1a+4Cs8vXdTZYDzsWQV2MwLYI\nTLGav/LY3XOpUf8l9+CbZiF1BIIgrAjOVREkoyEe/JvnNAw0u91F/vbY1mchDo5aYzDrsoZsuYbn\nSA/tSkScMZsNYwRL4BoSRSAIworA35rhbJirEC0ZCdGXivCGp25puO52Hc3lGjIWQ6M+SZ3xMCMz\nhYYWQTTcfNeQKAJBEFYEnmZtZ6kI5iIQUNz7l78y57rbdeQvKEv6XEPJaH38ostWBI0sgng42PRe\nQxIjEARhRRB2FYGdjWvoXPBYBA0Gz8DcriGALf1J+9x6uSVGIAiCcBYstMXEueKZT+BTQpGQlV46\nky8TUPWuI4B/eMnlXL7uCFet765bS8VClKuaXLF+utliIRaBIAgrhg+98mp2DKa5rEFmTzNxWwT+\n9hRQswqSkVDD1NfuZIQ3PuMCz5wDg2lyZ1xLzUAsAkEQVgw3X7aamy9bveTf6x9t6ScZCTKVKzUM\nFJ8JM97z5FSejb3Js5LvTIhFIAiCcI5curbjtOsJWwE0ChSfiVUdVgD5lB1jaAZiEQiCIJwj12zs\n4XOvv56R2cabtckcWuicBIAB2yIYnhbXkCAIwnnNE7f2zrnWZU9UOxvXUEcsRCwccJraNQNxDQmC\nIDSZrf0pgLp5xfNBKcVAOsZdj45ydDy72KIBoggEQRCazrZVliKYypXO6vy1XXH2nZrh0z87vJhi\nOYhrSBAEoclsG7AUwYnJs3PvvPuWS3h4aLphncFiIIpAEAShyVxgK4JKVZ/V+dtWpdm2Kr2YInkQ\nRSAIgtBkuhIR3nnTDp66vb/VojREFIEgCMIS8LtP29pqEeZEgsWCIAhtjigCQRCENkcUgSAIQpsj\nikAQBKHNEUUgCILQ5ogiEARBaHNEEQiCILQ5oggEQRDaHKX12ZU8twql1Ahwtp2X+oDRRRRnMTlf\nZRO5FobItTBEroVztrJt1Fo3LG1edorgXFBK3au13tlqORpxvsomci0MkWthiFwLpxmyiWtIEASh\nzRFFIAiC0Oa0myL4aKsFOA3nq2wi18IQuRaGyLVwFl22tooRCIIgCPW0m0UgCIIg+BBFIAiC0Oa0\njSJQSj1XKbVPKfWYUuodLZblkFJqj1LqfqXUvfZnPUqp25VS++3/Nmc4qVeOjyulhpVSD7o+m1MO\npdQ77fu3Tyn1nCWW611KqeP2PbtfKXVzC+Rar5S6Uyn1iFLqIaXUH9qft/SenUault4zpVRMKfVz\npdQDtlx/Y39+PvyNzSXb+fB3FlRK/VIp9XX7ffPvl9Z6xf8AQeBxYAsQAR4ALm6hPIeAPt9n7wPe\nYb9+B/DeJZDjqcDVwINnkgO42L5vUWCzfT+DSyjXu4A/aXDsUsq1Grjafp0GHrW/v6X37DRytfSe\nAQpI2a/DwM+A61t9v84g2/nwd/bHwGeBr9vvm36/2sUiuBZ4TGt9QGtdBP4buKXFMvm5Bfik/fqT\nwIua/YVa67uA8XnKcQvw31rrgtb6IPAY1n1dKrnmYinlGtJa32e/ngEeAdbS4nt2GrnmYqnk0lrr\nWftt2P7RnB9/Y3PJNhdLIptSah3wPOBjvu9u6v1qF0WwFjjqen+M0/9DaTYa+K5SapdS6g32Z6u0\n1kNg/cMGBlok21xynA/38E1Kqd2268iYxy2RSym1CbgK60nyvLlnPrmgxffMdnPcDwwDt2utz5v7\nNYds0Np79gHgT4Gq67Om3692UQSqwWetzJt9stb6auAm4I1Kqae2UJb50up7+GFgK3AlMAT8k/35\nksullEoBXwLeqrWePt2hDT5rmmwN5Gr5PdNaV7TWVwLrgGuVUpee5vAlvV9zyNaye6aUej4wrLXe\nNd9TGnx2VjK1iyI4Bqx3vV8HnGiRLGitT9j/HQa+jGXOnVJKrQaw/zvcIvHmkqOl91Brfcr+h1sF\n/h81E3hJ5VJKhbE2289orW+zP275PWsk1/lyz2xZJoEfAM/lPLhfc8nW4nv2ZOCFSqlDWO7rZyql\nPs0S3K92UQS/ALYppTYrpSLAy4GvtkIQpVRSKZU2r4FnAw/a8vyWfdhvAV9phXynkeOrwMuVUlGl\n1GZgG/DzpRLK/EOweTHWPVtSuZRSCvhP4BGt9ftdSy29Z3PJ1ep7ppTqV0p12a/jwI3AXs6Dv7G5\nZGvlPdNav1NrvU5rvQlrj7pDa/0qluJ+NSPqfT7+ADdjZVM8DvxFC+XYghXpfwB4yMgC9ALfB/bb\n/+1ZAlk+h2X+lrCeLn7ndHIAf2Hfv33ATUss138Be4Dd9j+A1S2Q6wYs03s3cL/9c3Or79lp5Grp\nPQMuB35pf/+DwF+d6W99Cf9fziVby//O7O96OrWsoabfL2kxIQiC0Oa0i2tIEARBmANRBIIgCG2O\nKAJBEIQ2RxSBIAhCmyOKQBAEoc0RRSC0HUqpK91dJRdw3ruVUjfar9+qlEosokwvUkpd3Oi7BKHZ\nSPqo0HYopV4D7NRav+kcrnHIvsboAs4Jaq0rc6zdipU3/sWzlUkQzhaxCIRliVJqk1Jqr1LqY0qp\nB5VSn1FK3aiU+ondt/1au4r740qpX9j93W+xK8vfDbzM7jf/MvvYu+1j7lZKXTjHd96qlPpVpdRb\ngDXAnUqpO+21ZyulfqqUuk8p9QW774+ZPfFXSqkfA7+mlHq9Lc8DSqkvKaUSSqknAS8E/tGWaav5\nLvsaz7Jl22P/PlHXtf/G/s49Sqkd9udPU7V++r80leyCMCfNrI6TH/lp1g+wCSgDl2E90OwCPo7V\niOsW4H+BvwNeZR/fhVVZngReA/yb61odQMh+fSPwpTm+81bgV+3Xh7BnSgB9wF1A0n7/Z9QqVQ8B\nf+q6Rq/r9d8Cb/Zf2/0eiGF1mNxuf/4prKZy5trm/D8APma//hpWY0OAlPnd5Ed+5voJnYXuEITz\nhYNa6z0ASqmHgO9rrbVSag+WoliH1cTrT+zjY8CGBtfpBD6plNqG1aohvEA5rscaEvITq+0PEeCn\nrvXPu15fqpT6WyzFlAK+c4ZrX4j1ez5qv/8k8EasdsUApvHdLuAl9uufAO9XSn0GuE1rfWyBv4/Q\nZogiEJYzBdfrqut9FetvuwK8VGu9z32SUuo633XeA9yptX6xsvr5/8A+7hNYvf1PaK1PF1xWWP3s\nXzHHesb1+lbgRVrrB+xYxdNPc11z7dNhfucK9r9nrfU/KKW+gdVv6B6l1I1a671nuI7QxkiMQFjJ\nfAd4s92dE6XUVfbnM1gjHQ2dwHH79WvMh1rr12qtr5xDCbivcQ/wZKXUBfb3JJRS2+eQKQ0MKatt\n9CvnuJ6bvcAmc23g1cAP57g29vdv1Vrv0Vq/F7gX2HG64wVBFIGwknkPlptnt1LqQfs9wJ3AxSZY\njDUT9u+VUj/Bmm89Hz4KfEspdafWegRLgXxOKbUbSzHMtfn+H6zpYbdjbfKG/wbebgd3t5oPtdZ5\n4LXAF2yXVxX4yBlke6sdQH8AyAHfmufvJLQpkj4qCILQ5ohFIAiC0OaIIhAEQWhzRBEIgiC0OaII\nBEEQ2hxRBIIgCG2OKAJBEIQ2RxSBIAhCm/P/A12fv5T4wUoOAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(meta_losses)\n", "plt.xlabel(\"meta-iterations\")\n", "plt.ylabel(\"meta-loss\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4b_tUUU-LpIG" }, "source": [ "Our meta-loss is going down which is great! There is a periodic behavior to the loss as we are averaging over different positions in inner-training.\n", "For example, if we are averaging more samples from earlier in training, we will have higher loss.\n", "\n", "\n", "We can now apply our optimizer for 100 steps. We can see that the resulting optimizer optimizes for ~50 steps, and then diverages. This is an indication that meta-training could have been more successful. One can improve this by meta-training for longer, or with different hparams to improve this!" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "colab": { "height": 269 }, "executionInfo": { "elapsed": 713, "status": "ok", "timestamp": 1647716700957, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "np2hiv0m4S1F", "outputId": "871d47a3-ec57-4acd-cdcd-f1a5eeacdfeb" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAACGXklEQVR4nOydd5gcV5W331udc5jpyVkTNcrZcpBzwgYnjE1YbJNM2GUJu7Cw\nfLCwYFhgl7RkzAI2wThnOcuWrZzDaHJO3dM5p6rvjxpJli3ZMrYsW9T7PPPMTNftqlvV3b8+de4J\nQlEUNDQ0NDTe/kgnewIaGhoaGm8MmqBraGhonCJogq6hoaFxiqAJuoaGhsYpgiboGhoaGqcImqBr\naGhonCK8qqALIWqFEE8LIbqEEPuEEJ8+ypizhRBRIcTO2Z//d2Kmq6GhoaFxLPTHMaYAfE5RlO1C\nCAewTQjxuKIo+18y7jlFUS5746eooaGhoXE8vKqFrijKpKIo22f/jgNdQPWJnpiGhoaGxmvjeCz0\nQwghGoDFwKajbD5NCLELmAA+ryjKvqM8/6PARwFsNtvS9vb21zxhDQ0Njb9ntm3bNqMoiu9o28Tx\npv4LIezAOuCbiqLc/ZJtTkBWFCUhhLgU+KGiKC2vtL9ly5YpW7duPa5ja2hoaGioCCG2KYqy7Gjb\njivKRQhhAO4Cbn+pmAMoihJTFCUx+/fDgEEIUfo65qyhoaGh8Ro5nigXAfwG6FIU5b+PMaZidhxC\niBWz+w2+kRPV0NDQ0HhljseHfjrwAWCPEGLn7GNfAuoAFEX5OXAN8HEhRAFIA9cpWhlHDQ0NjTeV\nVxV0RVHWA+JVxvwE+MkbNSkNDQ0NjdeOlimqoaGhcYqgCbqGhobGKYIm6BoaGhqnCJqga2hoaJwi\naIKuoaGhcYLJF2XejMA/TdA1NDQ0TiDpXJEV33yCO7aOnvBjaYKuoaGhcQLpDyQIp/I8sGvyhB9L\nE3QNDQ2NE0h/IAHA5sEQqVzhhB5LE3QNDQ2NE0i/XxX0XFFmQ/+JrYiiCbqGhobGCaQvkKDabcFq\n1PFMd+CEHus11UN/K/CXex7lqZkE9kwaZzFFWTFNhcWIw2nH4nJTUlNNRU0trlIfs/XCNDQ0NE4a\n/f4k7RUOhIBnevwoinLCtOltJ+hbw3u4r/mClz1ultPY5SSLR5/mrN51uEWEaLiKwEwLhaKJ1tZW\nzj77bMrLy0/CrDU0NE4Vtg6FuPG3W3jq82fjc5hecWyhKDM4k2RNm49ar5UnuvwMziRp8tlPyNze\ndoL+7gYHi3v+g4jRR9RYQtTkIm7SkTDomTR4ecJ+Ls/Y13BafgPnlDzKkoY7SU+2Eoju49777qGy\nsob5895LQ8NizYLX0NB4zWwfCRPPFuidjr+qoI+F0+SKMs0+O6fNKQHgme6AJugHWXHuzaw49+aX\nPR5L+tnc/SCbe9azXWlho+8MnhVn4REJltY8z2niOdroQrCFA92P8Ojad9HUuIYVK1bg8XhOwplo\naGi8HRkNpQEYj6RfdWzf7ILonDIbtV4rTT4bz/QEuOmMxhMyt7edoB8Lp62M85fcxPlL1P9jhSK/\n3P4CT/Yc4Nmyc3jCcBGV6QIrJntor7iThsZH6eoKsW3Ldj7wD++ntq720L6y/f2gKJiam0/S2Who\naLxVGQmlAJiIZF517MGQxTmzFvnZrWXcvmmYTL6I2aB7w+d2yka5OPU6Pr/iTB65/kb2uffzqfE7\nKEnkeKCxg+9Y/x8fN97KLxdcQ2+DiVtv/S17duwGQE6nGbnpQ4x86MMoudxJPgsNDY23GqNhVdAn\no8dnoZfajbitRgDWtPnIFmQ2DJyY8MVTxkI/Jjo9jhU38O8r4N9Sce79xZOEijamnSaebIjySPW5\nfEDcxl333smzD93LQpuT/7rmBnRykc///g8MWy2YTCbe8Y53YDabT/bZaGhonERkWWEsfPwul/5A\n4pB1DrCy0YvZILGuO8A5bWVv+PxOWQv9aOisDq7+zBVc/4ml1EiP8I31RVqyk9xReTXOjn7GhMLH\nmxayuXMhGxYs5b+EidGhcfbu2cutt/6WWDQKWmc9DY2/WwKJLLmCDMBk9JVdLoqi0OdP0Fx2WNDN\nBh2/v2kln7mg9YTM7+9K0A9it3q54cO3oFxh4NvPF3EUZH5TdgOPrVpCxGHjw1O/ZnFqP1sb2jE1\njNHZuBub9THu/uvHGf7V+Sd7+hoaGieJg/7zOT4bE5H0K1ZQnEnkiGUKR1joACsavbgshhMyv79L\nQT/ImUuuxHdNLd/fnCche5hRyvh0+FesKX2Cj5m/g1NE+b+yS7BWh6mu2Utt0yZ6GkcIj6072VPX\n0NA4CYzOCvrKphJSuSKx9LFrsxyMcHmxhX6i+bsWdIDWhQuZ16nj1k1Zfvv4OB3r3sfjz7+P0XWX\ncuED9zIllXO39C2sfV+n5/F/pVA0sn3/FygWsyd76hoaGm8yB0MWlzeooc6v5Ec/FOGiCfqbS/WV\nZ7KgPEGz1cVCq45PZc5mYfI8KrJlnLHlSR4UZj7rNJGJbGVg95mgD9Dd/c2TPW0NDY03mdFwinKn\niYYSG/DKkS59/gRWo45K55sXTHHqR7kcJ74PX4SiKET2B9jz6300mUzo7R1kRn9I69gYf7j8vfzs\n8n/g0ifvwjHWDtxOqe9synznnuypa2hovEmMhlLUea1Uuy0ATLyKhd7ksyFJb15GumahvwghBJ7O\nMire08aUfiN1WTfvL/0qSyZ28b/f+wouq427L7ye/v6FJBIedu/8HNnsia2epqGh8dZhNJSi1mOl\n1G7CoBOMv0JyUb8/QfMJSvE/FpqgH4XWVVXM/8KNSKb/xprJsKj1X6jMyXzrf78NOh1PNi3kwIEz\nUEiyc9u/vim9AjU0NE4uuYLMZCxDjdeKJAkqXOZjulyS2QIT0czLIlxONJqgHwOToxTdP3wTh/HL\nmMjgWvU56kZjvPPxh+mra2M4W8Hg4GISmWcZn7jj0PNkWSaVSp3EmWtoaJwI1DBFqPWo7pZKl+WY\nLpeBQBJ4cxdEQRP0V6R8zmIi1/6UroVZcmYdhrM+y3Vb91AVGOOxxauJ9nUQCVfQtf+r9HatY/36\n9fz4xz/me9/7HuFw+GRPX0ND4w3kYMp/ndcKQLXbcsx6Ln2BOHA4ZHGqv5d85tVrv7xeNEF/Ferm\nnc6F113PnH9cRdiSxrvqQ5w5dCdJi529p69mbMNqZAX2HPgiTzzxOOSyyLJMT0/PyZ66hobGG8jB\npKLaWUGvcpuZimUoyi93ufb5E+gkQUOJjWQkzB+//Dl2PvbQCZ+jJujHiaXEgf3qBpxFG069k86e\n51nrLKNXyTOxtRWXy0+LZYDC1meRcll2btl8sqesoaHxGtn/7FPc853/OOq62GgojUEnKJ8NQ6x0\nWSjKCoH4y3NS+vwJ5totGCTBWNc+FEUmNDF+wuevCfproG3BQoyNTj4cv4KO2B9wxCe596IrmQnd\ngJwz0ejo5vpPfxGXQWLS7+eFu/6sLZhqaLyNGNm7m4HtWwiNj71s22g4RbXbgm42DPFg6OLRkovG\nJuJcNKLQs2Wa8e59AESnJ0/gzFU0QX+NOM+uRZc2cEuinuuGH6Ogs/P7Mw1ERlaRbI6x8UePM7ei\nBSQdzz10P4/94scosnyyp62hoXEcpONRAAZ3bn3ZtrFQ6pC7BaDSrVrqL10YzRdlsv4UApjcP8V4\n134AIv6pEzTrw2iC/hoxtXowVNpIKNfzb2Ibl/U+RcBZy3d9c/mj/r185YMXcGPzQsJWJ6ULFzA2\n8lcGd2052dPW0NA4DtKxGABDu7a/bNtoOH2EoFfNWugvDV0cDiZZNqNa+BMv7CYwPIjeZCIenKFY\nyJ+oqQOaoL9mhBA4zq6lUPAhpq38fPrbLArsY8R9Po9xCXW6IcwKbG5YiKvicRrOn2Bw+Icne9oa\nGhrHQTquCvrY/j1HRKUksgVCyRy1nsOC7jQbsJv0hyJdFEVhePdOeseCLIyqCYeRbAxFkWldeToo\nClG//4TOXxP0vwHL/FJ0bj0x2+fhH/djCjdi2rufj03/O5/Xf5V/GPcz6CthwmcnG7WjODYRDD7P\nxMSE5lPX0HgLk47H8FbVUCwUGN2/59Djo4ciXCxHjK9ymw+5XALDg9z5zX9n6g8/RegdAOREGKFA\n27KVAERPsNtFE/S/ASEJXJc2k4+aCP5pkp9dPo+vtizlqu6PgqxjjeublCgBfl/4CHvXXU02YmT7\n9k9x660/Ztu2bSd7+hoaGkehWCiQTSVpWbkavcnE00+s48qfPs8nb9/OT57uAw7HoB+k0mVhYtbl\nEg+qVnlyZC9RWzlmmwG5MIY1B8HHn0IREpETvDD6qoIuhKgVQjwthOgSQuwTQnz6KGOEEOJHQog+\nIcRuIcSSEzPdtw7WBT5K3t9BYSqJ+P0BLhvM4Mj76E6aKbjHeV96LWOGGvY2NNCzZzmIBG3tW3j2\n2XUUCseuoayhoXFyyCTUZCC7txRnYwfDu7cTSubYOxHlkT2TWAw6GkptRzynym1hctblkpxNJszq\nJTK6CHXz3CiFKUy+OTycL5AvryE6fWIt9OOptlgAPqcoynYhhAPYJoR4XFGU/S8acwnQMvuzEvjZ\n7O9TGktnKb6PmZj53T5ywzEc76rniVEDblspV26dx18Xx3ihrZnaLcPMDCQoa97F9NRetm/fzooV\nK0729DU0/m55rjdAMlvk4nkVhx5Lx9QIl6zOzGMxD8vyMX5/dSP1TQ1kC0VyBRmH+chOQ9VuM8Fk\njky+SCKkNn7WF6CY3YHVNhcoUmxsoKXxIULTNYSmTrKFrijKpKIo22f/jgNdQPVLhr0L+L2ishFw\nCyEq3/DZvgUx1jgo/9RivO9tx7OqjqV1V3HLeIqHrX18ojtJ0mThwc7TGXjBiNXaTHPLPp57bh35\nvLranUgkePjhh5mZmTnJZ6Kh8feBoij8+717+ac/72AsfLju0sEF0Z+8MMmQuRaAYLfqRzfpdS8T\nc1BdLqCGLsb6ezEUikyb21AK44wfeAYhKZS03Ud5+SDt89eTMu0+oef2mnzoQogGYDGw6SWbqoHR\nF/0/xstFHyHER4UQW4UQWwOBU6fsrM5lwrrAhxCCjyz4CDfOuxFd21IuCdr5yIHHCTpd3PauDxHv\nqcRgCGK17eT3/3M/O3f08ul77uQ7uhh/2nri04I1NDSg159gOJgiV5D57truQ4+nZkMWuyJFvvGB\nNXgqqxg6Sjz6i2kpV2u1rOsJEJ+awpQvMupaARgY3beJhvPDOD2jZCbPJRbyUTl/O6Njfzhh53bc\ngi6EsAN3Af+sKErspZuP8pSXhXMoivJLRVGWKYqyzOfzvbaZvk0os5bx2aWf5UMXXsYBkeMDE+0s\n3vcn0mYbnyz5MD+Vv8j6OYv4S6XClf4gD9UsotvYznolfrKnrqHxd8Fj+1Q/9rXLarhv5wQ7RlTf\n97rdgwC8b81czp9bTuOiZQzv2cXuJx49ZnTa/GoXq5q8/O/T/SQjYYwK6PVu9Jb5VCwP4GqcZmho\nITX1NxN8uIngTA09PV9jeOTXJ+TcjkvQhRAGVDG/XVGUu48yZAyofdH/NcDE65/e2xeH2cBorQNL\nwcN1+QV84MH7KQkG6Ml2sEF3Gnuqm5ij9PNv8n/QnO9mzFRGKHTiM8k0NP6e+M36QR7fP33EY4/t\nn2ZxnZv/d3knpXYT//lQFztGwjy1UxX0T128AICVV15LTUcnj//qJ9zznf8gEQ69bP9CCD53YRsz\niSyJVJKiXodPJ1O5wEPFkiCx0ALGhhZjs1YzR9HR1bWGQmIpNkPLCTnfV10UFUII4DdAl6Io/32M\nYfcDnxJC/Bl1MTSqKMqJL1zwFmfpO1q475c7eFewkYVWL3OeuJ+07mzmnvFNcg05JKPMxIY6yqr9\nbKlbyY6df+K8cz9zsqetoXFKEE7muOXhLlwWA6vnnIPNpGcikmb3WJQvXNyO3aTn8xe28sW79/D+\nX2/iLF0eg8WKwWgEwOpyc82Xv8GOtQ/y3O3/x68+eSNGixWD2YzV6eHdX/lPTFYryxu8nDnHQ75f\nJm3L4hR+LK1/oiB56RtcjjHrIjKVpmHFSrYOT7Bn2xKWFyrhqjf+nI/HQj8d+ABwrhBi5+zPpUKI\nm4UQN8+OeRgYAPqAXwGfeOOn+vZjQb2HT37zXEo/PB+DxcS5le/DxDrMD+SRCjKBXe34d9s5z1pF\nXhjZE9nz6jvV0NA4Lh7dN0VBVggmc/zfC0MAPNGlWusXzC0H4N3LammvcJAvKpxRY8HqdB6xDyFJ\nLLnknbz/Oz9k2WVX0rb6LGzuSqYHehjafQCAbDbAp9uMpH1VJJvPoqZlHVZrkDtGBclkDpexguB4\nkpIF57Fct4CEyBCae2LO+VUtdEVR1nN0H/mLxyjAJ9+oSZ1qmJvd+D6zmMh/7aLZuoC/mvuQHw9S\nPSphMHTQ9qMfwzd+wITDRioewepwA7Bz3dNUNTZRVld/ck9AQ+NtyIO7J2gstdFUauMX6/p5/6p6\nHt8/TZPPdqjxhE4S/P6mFQQSWbp++wLC4TrqvkqqaznzvTcA8NxfNjHVt5vpgVGs1bvp6f0GuWgL\nSvUidEqAqubnSSc7aGQVOXLY3E7yozESIwolaQmdrkjPQB/N7a1v+DlrmaJvEl5PKZ5ldTS651Mb\ndlI+XoUiKUjWs7CXz8WeTzGsa2Djkz9GURQe/92vWfvCPfzxtz/RygVoaLxGAvEsG/qDXLagks9e\n2EosU+C/H+tmQ3+QC+dWHDG2zGmms8pFOh7DqNMRffCVI85yGTW5KJbYQE/vN5Fz1egd/Sxdci/N\nnc8hgD372igOFEnak+zMbWZutoDkMtGVeoKy/du5YNmyE3LemqC/iViXliGKggsuuhm9LDHcGMag\nkwjVr6Z9oJ9+WpiJPM2Tv/kZvf61LF3xMFUdGxneu+tkT11D423Fo3snkRW4bEEVnVUu3jG/kt9t\nGKYgK1zYWX7U56TjMcTwCJP/9m8vK3mdy80wMvIbEsleEuE8JrceU/196Kllw+YzCW98D/monjJX\nkmDXJSxecBGKonDG4jMwiDhGIZBWVWBvrCKvFEk8/cwJOW9N0N9EjPVO9KUWSnOVmD5xNs+0xKky\nbWMi7aKjr5dxajCUhxmcuI/WJTuQZTMlpRM8/+SvTvbUNTTeVjywe5KWMjttFWqRrM9c0IIkwOcw\nsajGfdTnpOMxpGk/Sj5P4SWJfkO9P6W371ts2nQxpqbP0nRJL7IMucAHkYsGliXOo2zDZ+jpeQ+R\nvsu56OIL+NjHPsa7znkX53mXoygKm4I9eFvayBr05OMvj5h5I9AE/U1ECIF1aRm5wShLSleApHCP\nxUxWsWGX61GExLiphqYz+kilqjht1aPk8yZM5VvxDw2c7OlraLwtCO+8H//Qfi5bUHXoseYyB1+4\nuJ3PXdCKJL18STCfzVDIZjGk1EJbhckjg/QCow9h7BM0lX+GQtaAwZJn5MkWJqcLuBUrk0MPU21p\nZfnYebhLbAhJUFlZiSRJLBLNRBWZDZu7KEqqkOvMJyZjVBP0Nxnr4nIQ0DRSBkBXVRyncRh7Xn3z\n9cmtRCLlzO34GQ5HDYX8uXjKZ9j82E9O5rQ1NN5SBMdG+eUnb2Rwx0syOeUijvs/zI8MP+Yd8490\nrXxszRyuW1F31P0dTPs3FooA5CcOp9Gk0yNkjDOYd0jYo2cy/OQX2f+nDxAdVZiKBfAJF6PBTWwM\nPUxFwUi95XCsiSIrFMaSKOVmaoId3N59LwCRxhMQs4gm6G86ercJU7MbZVeMWlsNVVXT9Du2c8PM\nhygLzrA1vIYDW66lqbEDgJaWj5LPmchZ1xE5wYV9NDTeDmSSCe773jeIzwSY7Os5cmN0FL2cZYE0\nSLP/sePe58FORebZKJf8xCQTfRGeuf0AgcCT6rY9EpFRtQCX01eOoteTKmTw5Yrk9Hp2FbaSlGVK\ns4c7GOWnkijZIjWr6xBSEX1slTrNaPJvPv9XQhP0k4BtaTnFSJZ/nH4vrkSSLcbTsFdk6RjqY9ri\nxBWex0iX+sZpbu5kamohrqoImx5VOx8lEglSqdQrHUJD45RElos8/OPvEfVPI0t6hkaOTEhXZtS6\n5RmdDZ78OhSyx7Xf1GwWqHvxYiSHg/zEBLueHGXfcxNMTz2BYcaIfkYQm0oAUNFUS9GsRruUprJk\njDpSZomJnII5KSOn1RLZuSH1i8LhC9NhfprW8GryOoXt/Rtf/8U4CpqgnwQsnaVYF/noHK/n610f\n4zu6OvbYz6Zuqp+4zU3EkmPHU108985r2fXVW3C7r+CJwsVsNw3yzNpH+eEPf8gdd9xxsk9DQ+NN\n54U7/sjgjq20XvEBgno301NHtnRLT6sW+6a2L0JkGLbeetT9pGJRkpHwof+je/cC4Fm2DENlJdmJ\nKUa7Qkj6NLHEVkw71dDheCiDEFA7tx7ZYkMoAncyT8ZkQOhtTOYVBBLpbeoXS3Yois5lRP/0P7HI\n8zSSYiS77AJqLlvzhl8b0AT9pCAMEt7r2gl/zMFPKv6ERSriy/8jpQa1PGdX5RhTe5I4+nuw3Hk7\ne0aj/NbwER4sv5pNO9ei0+kYHh4mnU6/ypE0NE4dYgE/m+75C/POuYB000qSOiv5eOSIMbnpHuKK\nhWjr1dB0Dqz7L8hEjxgz3t3Fbz/7ce773n+qDzz+VWIPqmtUnuRjGJw6poMCa75IW+1+oIBp16yg\nx2RsHhMVTTUUzTY8ig1iUfJCYNWXEi4qSARJb9qPoijkhmIY7X4Y34rnis/RsKCUmomlnF99wQm5\nRpqgn0TaqjtY693A2nP3MGib4grpEnRFmZHySkIOC7+96Et8+4Ib+HnHEgzFPAM0U1K6n9M6WlEU\nhf7+/pN9ChoafzszvfDMd+A4E+diAdUab1+9hr5AkoTehpI8UqyVmT7WTrehn+iGC/4D0iF47Cug\nKAxEBnjssdv56ze+TCYRZ3qgj0I2C3v+SjxtAUXBOvE4huh2pvI+Gsw6HNW7EXkbxkE1MiaR0eEs\nseAotSBbbPhkJ7lkGBBYKcPmMmItnSATcFEYD1OM5TDN3AXNF8C8q1l0fi2ZRJ4DG09MIT5N0E8i\nZr2ZVm8rz86s51/rvk9A2sZ/7M0yXlHBLy52kl/eyhPvvABnPM0Xfv9zFCExVV/BgSf+gtlkoq+v\nj8h0ipkxrfSuxtuQbf8Hz3wLUsHjGp6KRQCwulz0+RMkdTb0uSTFF7V01IUHGQ7bCW15CioXwumf\nhu2/g4f/he/f+nl2/+aPRFx5Si5fhVwsMrN7HeFgjJDOhclgRHr3bzBY88w4Wii1SCRLd2MPLEDv\naUNfVkayaMZZaiaZjqPodPgUJ9m8+vkzF8ooqbFjOW0+YCR692YAjLoD8I7vgxBUtbhpWuzDYNK9\nkVfyEJqgn2Tml85nX3AfaTKM123nnaN38ecNWZoTOe6uALNeR3J3kmX1c3CkE+yRFmCrGcahl+jr\n6+ORX+7hsd/sf/UDaWi81ZiczYBOHV+SzcEGFBaniz5/nKROXZQ85AsvZCnGZgBBeLBHzfY8/z9g\n9T8ib/4VlTvTJMsM7DvPwC9idwIwve0J1ubPYnDeYswOJ9SuImUrI2MtR7j2UzQm0E8vxLzswxjL\nrGQlG85SCxNdarhkqewgU1QXSnUZNyVVdoyrTkeS4mQmbAiSGM57P3jUekxCCC752HzaVh5ZfuCN\nQhP0k8z80vkAlFnKuGTNRyhz/Ipe77P8fEuSf9ud5ha/gXq9gf/1rWRZbzd75CWUL06SGj5AIpFg\nenqayHSKYkF+lSNpaLyFUBSYmk2uSR1f+8VUNAKAMNsYCaVIGVRBj4dmnx8aJJ43AZBLp5gZHQYh\n4IJv0Nd5E5asjtp6A3dceReX1J+LLMmMdG3AsqyP5WfeQ+WFB+gavIXBs0ppuvgrjC39L4RioWu0\nA2GyE6u6AgCnQ2biwGaEAh7FTkKZjThT7JRU2xA6CUuD6kYyWsYRq27mzUIT9JPMfJ8q6OfVn4ez\n+XSUioWclvgfPlf/XS6fzNOwKcB1UxK2viKtFg8hnYcxfSneEjVzNGcKocgKEb8WxqjxNiIyfHix\n8hVcLslsga/dv49IKkcqFsVsszMcziIrUFmpJucF/bPtLEP9RPPmQ88dPzB75yoEXTQD0JDZhkgG\nOO+B3bSfPYL1XRMYnQlCoRpkScLvf4Rc1Ri6hBtf93vA/CUiGSuDoX14rHOxSmDf9HUmpqZxygZ0\nSATss4eRbHir1X8sZy5lr0sisGQR6F61qO0bhiboJ5lGZyNfXPFFbpp3E0gS4saHqZ9zCdOmUYbd\n26my62lZ4qMtryN7QE162BVbSuXyEIZCDNmu3m6GJpLE4/uYmXn6ZJ6OhsbxMfmi1PdXEPRNg0H+\n74UhHtw9SToWxeJy0+tXfdbzW1U3xqHQxWAfsbwZBYHN7WHswL5D+5nYv4+4pUAbUaY3f4nsOwcx\n1WeY2OZjy+Yr6e0+jV27L2bF0g303vtjStddj2voXP6599vo8zESippB6pAEYmojE3k7pZKTTDFJ\n0KynKElIOgPeCvWuYbjGxQ2rbJxr17Fiw34+e2CEZ0Ix5BNcOVUT9JOMEIL3dbyPCtusT81kR3ft\n7+iwVrHVsBVyMmef4eH0f2ijNiUoj2fpDi/AYE4zZ84mMiKKIhWYmdrHtu3vZfeeT5DLnZjCPxoa\nbxhTu0HMys8rCPpYWA3Nfb5vhlQsgtWpLojqJMGK9hqKSIQOWujBPgJ5BwWzk5qOeYwf2IeiKMhy\nkdTgBH5fhuSCDvbyNNKkgZ0P1DO5tx5ZNmCIhSjICjue348sS7hMBhLpYTqczTjjY8QkDwBOSWZt\n3WpymCg3VZAuJMgIAToL7nIbOoN6Tk+F1S+df2uspNNu4cFAhOt2DbBqYxc/Gp7Gn82fkMuqCfpb\nESGY23g+a70jAEzd9RP+3/hNZGr9NE0p7K2ZS2jnanz1U5T6hjBWTJI0fBkh9ChKjqnp+07q9Hu3\nbGBk74kpPqRxijC5G0rbwGB7xUXRg4L+Qn+QVDSK1emidzpBfYmVGq+dlM5KLDT7/GA/wZwNHF6q\nOzpJhILEAn4CQ4OQKbB0VZRR+yTlA1kmHv8kusJ1yEbV524ITaOTJPbu24fJrMNg9ZGODXGecT72\nxBhhXSU5AW7y7J/JAFCpKyF9cEFUOCipth2a9zOhOK1WM59uKOe38xvZc/o8fj63nhqzkW8NTPLD\n4SP7nL5RaIL+FqWztJNpKUzamaQr28R4cpK7Sn/OnKk8eb2ejYkbyE57aW7eSMXCn4EUYfGi3+J0\nLmRi4i8ntSnGc7f/H5vv++tJO77G24Cp3VC5AKwlkDz2ouhYWF0biqbzxCMRrKU5ev1RWsrslDlM\nJPQ2UhFV0JVgH/G8EU9jjrI5qn99vHs/I3t3YSnJ0OCJUV93MzV7aklaajDRRNFsBhSkbJqKEg/T\nkREaGxwIyYAUHGbZRARHYgwZHdGCgqJPY9K1MDanDE/GTKaYA0BSnJTM+s9TRZmN0QTnlDgOnYdJ\nkrii3MPdi5tZv7KdT9SVnYirqgn6W5W5JWrTwYf1G2jNNHPrdJqcaQaP+Vn0BYX+Sislt5qQhIzV\nHuDA/jO594lxhhNrSCZ7Gexaz7ZHh970eSuKQjw0c6jYkYbGy0j4IT6pxolbvS9zuRQLBYZ370RR\nFMbCaeZWOhGKjM44g6j9A17dRlrKHLitBtJ6G/lYGDIx5Lgfb2eQ5sVPMB76JiablfGufYzs3YV3\nuZ8iRurrP0bcuISc0YkQgoLFisjnEYpCTWklRZHDZFOtbsvkMHUbH8WeGAMgWiiyzTZIT2UDD9as\n5vYyiYysxsALyXpI0DdEEmRlhXO8Do5Gs9VMtdl4Qi6tJuhvURqcDbhNbkZK/BgUPfPPuYerjZXc\nU3oncwJJ+qsluoxl8Dsb/c+2EplsZGD3JjY/G0WSrPQe+AMb7x0gPHViqrodi2wqSSGbPVSOVEPj\nZRxcEK1YALbSlwn6C3+9nTu/+e/4B/sZC6dZWOtmXqkeW4VqrVfaxmkusyOEQLY4UVIxUoE+DtQ4\nqF7lJ5+rJp7YS8MZgtH9uwlMbqO0LknWvhqDwUkoergRtKK3IeVUF4pFKQVFMJ0cJUcGEgGyYxms\n2QB6g8SgfgK/LsZMaS0AP2o1sd+tRtUIYaOkSnW5PB2KYZEEq1z2E3oZj4Ym6G9RJCFx26W38ZVr\nvwU6QSZg54Z3/p6iXqHGvI6gzciKswZpHEmR36fDOjLIgWQzVlkmLy9Hsq9H0mfo2XxifHXHIhFU\nb581Qdc4JlNqQlHUU89zRukIQQ+MDLH1gbsB8I+PE0rmqPFYWFVhwlqu+tPLrQGaPSYmv/Y1nJJE\nUS6wpC/DbXMuJ9TtQm/4BiXes7DU7yYRG6Zk7iQFGUoqrgUg4lcFXO9LIHRGpJxakTE2JbDKpQzE\nRhkRw4BCKmDC6PPgrNYzahnAJfk44Cnlgy4nFRmFX67uIG2yIBltOEpUcX8mFOc0tx2z7s2XV03Q\n38LUO+txOzwY65xk+8JU2iu5bM7ldGWeQgBfb/h3igvMNE6HUIoT1MsyQdnKgX0lSPosnjnb6Nk8\nQTI5SDS6A1k+vlKir4dESP1wFnJZ8tnMCT+extuQyd3grueOkbV8In2A0WwEAEWWefxXP8FgVoVx\nbFQtjVvjsTC/VIet7LCgl+/aQOTPf6EuEWSqpoYIVkaSTYw8U0lVTR1tbV9HCBAX6Sm2GNgQN9Dg\nnY+cShHLqa7KiuVOFJ2CyKufi+mhHLXljSTlDPtNo4yUV/HH1e/E0NhC2jKOLIoUPfOQheAanYVb\ndqWJWYw8cs7VOMt8CCEYTmfpS2U5x+t8+Xm/CWiC/jbA3OImP5GkmMxz07ybKOQDVOmi9FW08k3n\ne2iX/QjM1Ce2s0+uIhS0komVUbbwTsrPuJmNm85n67ZrWPfsYrbveD+jo/+HopyYzNJ46LC1pVnp\nGkdldkF0bNY3/ayhCIUsu554lMmeA5x7w8cwWqwEptQCVjUeK7WWJCZ3jqIsUWkLkPzrXwBw5zOM\nN6rx6FOFKmQkamsrsFhqaWj4R37u+2d+JT7Bs2kLFbYKsgODpKzlOByC6gY3ALpiAUmnI5uWWNDc\ngqQIAtY89591Pj+75v28sGgFwcwYpoKXjWUWKhIZ2tMKc2My12/pob+hnUdWtlFUFJ4JqeGKL14Q\nfTPRBP1tgKnZDUC2P0Kjq5Hz688nE3mKMbnIC+YFWNvzOOQa0pE+fCYDomBiZLidQsJGbKST2OR1\nJOIfwOu9knwuRE/vN5iefuCEzDXxYkHXFkY1XkomBqEBqFhIf2hW0C0WEhP9PPfH/6Nu3kI6zjwH\nZ6mP6Gx1xVqPhXx2P0LA/plWzPo0iQNqLRVDYoCRyiYApg1lpPQ2HFYLAOW1HyIofOwX87HY5yIJ\niVx/HylrOZ4KGxadWq7aZ2hgpe+dtFt0lEYLVMsliKKJ4doGAH7UuoBEOs6ZV5/FplID54wnkGM5\nFEXhXWMTnLEvzTMOiU93jfBEMEaN2cAci+nNvKqH0AT9bYCx2oEw68h0q1mh17dfjxx/HoCK9jI2\n1S6hIq8DYWL5/j9h8nczHa5n9xML6OqpYlevgZ07BfffZyaZ/Bx2eyd9/f9FsfjGlwtIaBa6xksI\njo8ysGOL+s+02kiCygUMRVVB32Ixs/f5p8ilU5x7480IIXD6ysiEZzDqJUrtJtK5bgA2TC4HIFsJ\nPVUgJaOMOuvRK3lCFg9pi+fQcUdz6l2oIiQUx+kApHv7SZt9lDSWEA6rn6c241JqzC20mXXkd8xQ\nqTiRswoDVTV4o2EGzTYGKuuZLK8iLwnODeoozCRQsjG8ZSbO2ZvmH71e7pwO83gwxrleNYLmZKAJ\n+tsAoRNY5pWS3hNAzhRYVr6MdrsDazGAp8nFqss/TJVnBL3lLIyWKsz5xVitdhLeChSHB/N0AGv3\nDkqtZp599jn27Z1HNjvF8PCv3vC5xkMzGMyqhaQJ+qnPgReeZXTfKyeRvXDH7dz/vW+STaUORbgo\nFQuIFwKY0yXkhaB7aB9GixVvdQ2ZTIak2YySiFDjtiBJgjxDZKNmxsNqM/WeBQpSmZ5ASSuy0LM0\ntwVZ0pEoKefTT32a+/vvZyChruHoigX8ktqjN9TvR5F0eKochEIhFAmckpOZgsLI0goeOG8HP63/\nI3mDkZDNwZXPrMWXjLFjTicPhRKUFGQWxPVkhyIomQgVjU5u+PbpfHlhHV+bU4UA3uFzn7Dr/Wpo\ngv42wb6qEiUnk9o2jRCC69qug8QLbE+kMXaeS02dH71pPnr95aTN82hcegadnZ188pOfxGe7lpKK\npaQ2r6PZYWZkxIjBcBrDI78kk5l49YO/BhKhIL66BkAT9L8Hnvndr9h4zyu3QwwMD1IsFFQr3b8P\nrCUEDQZk8uhjHTiKMhNTk3iqqtm1axff/cF36QunQSdR65DUJDnjBIWYl/8u00ERpjv0LJVDdDfM\nRacUWJw4AMCkL85To0/xnc3fYcPIMABzJ4cYVzwMx5OEZ3uCeiqshMNh9A49JkmQkaGk1cz/+W9n\nbtsCrC3tALhTCZYN7GNGZ+TRmRgXoKAD5JiCkolgqKrC5lbdKzfXldFz5nzWHCP+/M1AE/S3CcYa\nB8ZaB4mNkyiKwuVzLsed348MPLBhnGfbygFQJBPDxjgBXSnvfve7KSktoXV5JZns6Sy+5FqmNq/H\nhEJ/30JApq//u2/oPOOhIKW19QghaYJ+ipOOx0hGwsT8xw6NzWczhKdUo6F30/MQ6AZfB1NJdcFT\nUMXp6TT5WJFJq4d7772XiBwBoGi2UGPIkMmMIQwZlEwFjsfvRglDQ7kTm7XIjrmdNNODY2K2TIZT\n5t2t7yaWi7FpZBBjPsf56+8B4JtPrCOSVisfusuthEIhnB4HZgFZFJ4rPkaqkOKmeTch1TWq5zin\nntZUlBVOKwCX+g6n98uZKPrKyiPO16E/MY0rjhdN0N9G2FZVUgikyfZHsRqsfMJ6FmVpmd9Fo2wL\nLkIWagy4MG9n/8RhMV18TgVms47JwVbWvP8jMD3G8HCSEu/1TE/fTzD4HACFxOvzqRdyOTLxGI6S\nUsx2u7YoeooTHFVFNDbjRy4Wjz1GUXCU+hjcuY38VDf42phIqCJfV9rGGYkMenM5yYKM3CnzVM1T\nIIFsskJ+F9HoDgDkCTMD2Q5CBTtlZEiVeOitbqSTPVhG46DISO5OvrLqK1xUeRFhrJTEEyw/oKc+\nFWOjZKa/2YvVrCDpIRqNUu4uQxICqVnm9t4/sLpqNe3eduIuD/pigWI6SefcuXynrZYPVpVwVlM1\n8mzZXyUTxVBV9SZc6eNHE/S3EdYFPiSrnuTGCYrJPBftWMS7hwP02uF3nZfw1BnlTLkE7sQT7Js8\nXPAo/dc/0P7CfxP1pwiMN1BqMYCiMDa2AJutlX37P0v3nXfzm39+igP3b/+b53dwQdReUorF4dQs\n9FOUDf1BRkMpZsZUQZeLReLBwFHHBkaGAFh11XsoZLMMBnXga2cyOQmAz1pNa9KObLZSlPLck7qH\nD8z9AF6vD9lsYTK0jmhsJ3JBEN5dZKjhUoivJE2SLW3nI0sSc9lLNuXEkYzhKFmMEIIzx5uJmm24\nkiYOtF7FRU89ht/pZajGjK40SSQSAaDWUQ3A05ZHCGaCahlrYDBXpLyYRwDz58+nw27hO221mBx2\n5Nm7C6WYQud2n5iL/DeiCfrbCGGQsC6vIL0/SPC2LkRagcr1tEa+yY1R2O1z8quLPaxf9jFSxq2k\ncmqdieTzL+AJ99BR2MHgziBVjWvQx8Ls3LGPjvYfUCymGcj+iILBxLq1ISLTf5ulfkjQvSVYnJqg\nn6p84vZtfO+xboJjw4cei0wdvelxYHgQg8lM55rzsdgs9MZLwdfKRGICpWjGZ/MgKKVothLSzeCz\n+PjYwo9hcvkomiwUY1HGhtaSCpiJO9qI+cZwuppAV+R286UY5BwNhQGivgU4clniBiPFVIreA4Mk\nzFZK4hKKzkHnsJqUNOCtZ4JupqdVN1GZqQSAfdlu5pbMZUXFCgC6kxnmeZzMnz+f+vr6Q+cjhICC\naqHrbLqTFs1yLDRBf5thX1EBCuQGo7jf0cQXLv8Kj1z5Z7552ULWDgo+0Z1kytfE8LyLufSZZ/jM\nngF+X1LF/hWrKXvu19TXCob2l2LPpcnl8wwPZ3HuOQ1j5SiN9T9FyAUe/eVe8rmj30K/EgdbgTm8\nmoV+qpLOFQmn8nRNxgiOjmD3qoIY9R9D0EcGKa2rR6fX09zooz/hpeCew1hiAjnvwmszECy6UYwm\nZixh/mnxP8PUCzS7v05Lx2bOsApgipTfAjo7H/r4O1l+mtp9qKvEQFNxgGLcitE/jj2bYSqdYsuP\nfsSEpwRZkvAkihRMGQp1FzJ3oI/uikbShSSPPvooAC7UiKyQIcpN825CCEG8UGQim2dZeSlXX301\nknSkTAq9+uWgc1tOwBV+fWiC/jZDX2LBtrIS28oKbKdVYtQZMeqMCINE9Zo6bhqSuesHP8c7/Qfs\nyVHWTvv56VXv45M3/iP/+O/fITj9KFanBauuHZHLsP6xtWzd3YB/vB3j8t002G8jOJHg2T/3vOa5\nHWGha4J+SjIVU0MB+wNJZsZGqF+wGJ1eT+Qlgr53Zi/PjT3HzPAQvtkFxtbyInlZz/DABOPxCZS8\nB7fVyHjeA0IwlVxINryIyYHfgCVNSckoTaePgx6S0xbq5tVSVuXGFk+TxIbfWso89iKmHJiCU5iD\nM8QNRh7KZEma1UXMkozM/EsriKQMnD2ZI+wwUuxcSTwex2g0Ysqoi5jeEh/n150PQG9SPcdWm5mj\noTOHyO69C33Z0befTDRBfxviuaIZz5UtL7vdMze7EQZwlC1kyf5HWTn5C/76i/dx1xdv5p3GXiJV\n1Xzhyuv545l6kvkOjNEw/kSSyeo8ExNryKQdKGdsZ8kFXg68MMnwvmN3kgHIZqcpFg/Xh4mHZjBa\nLJisVlXQY7GTWpdd441nKqqKnTGfJB2L4qtrwOkrJzo1ecS4H2z/AV9d+0UyyQSl9Q0A1OrGMOkV\nejY9z1RyEjnvxmszEiiqkSNu+0p++dwAjyXhlsxXeXrr++h+bhWRniuJDjtYcHYLAKM9IwwV1S+J\nZn0XiVQnAFI0DUKQMpiwlanRJ61uG2eduxCz3UDJdBXOVJHddY04nU5KSkqQE3kkq54/vvNP6CRV\n3A+kZgXdegxBL3WR61uLsaryqNtPJq8q6EKIW4UQfiHE3mNsP1sIERVC7Jz9+X9v/DQ1jgehlzB3\neNFXLGTeRBl32JIMBoxMOaNs6Ps6H/Js5VOb17HHZmbaOYMlXoZpchhPoJ2PffSfyM9chs6exVr6\nc0w2PT2bj34bDaqYb9h4IT09Xzv0WCIUxO5Rb8EtThdysUAunT7Rp63xJjI9a6F7c2qWZUlNHe7y\nipdZ6IPRQYwhtfmDr14VX13wAM21dvq2biKVi2PMlCFNZ4hjwlDMcuOaNialMLeYP8Ne2wLGvGWE\nIl7GN9WDLHCXl1IsFll7IMlQURX3WobIRetRkNAXVDnr3LkbS3Mb+qJCc4UdvUHHvLOqIS+zciDH\n5lSGle95L1dddRXFWA7JYcQwWwYAoCeZwSwJ6ixHr1muLy0FwFD91opwgeOz0P8PuPhVxjynKMqi\n2Z+vv/5pafytWJdUIow2zkwvxjR5Gk3TsNSd4FJ7E7/c81POP72WsmSCntNqsNoXYYwE8JUEsXvM\nLFxwFePjHcwUHqdx5TiDu2Yo5I/uS+/v/z7FYoLJqbtIp0cBSASD2EvUN7vFoVab09wupxYHXS6+\ngiropbX1uMorjvChp/Ip/Ck/3rgqiL66BkgGITXDnM5Wcqkk5WET8yOt9N2xn5zBgrsYwldtJLqg\nlnKmKYmHVD+4yYLJGgHA6nKxdetWZnJG/FI9diWGmwhVe/djcrwHp0EV+ZnScoYw4k7IlNep78N5\na6qRdILzUnosksRfYll8Ph/FeA6d80jh7klmaLaa0R1jwVNf6gPAUPk2tNAVRXkW0LoOv00wN7uB\nPJX6Glr2z0EoULZ8Gf9v/wvUWcr49OAPCfmzvOD0UHPtMvSmajIxtTt6w6JFzOxrJZdyoC/7CcVC\ngpF9L3/pY7HdTE7dRUXFlYCOoeGfARAPB3F4D1ros4Iei74Zp31SiGfy3L/rjc20faszFc1gN+lp\n0ico6M3YPF5cZRVkk0nSCbXS4FBsCIDatJeEpUCUJMyotVgalqxG6HXUTlvxZJ3IcgjZZEaxyHzw\nwCQ+EeErka8xv6eLQW8tBbMFiyOCkCQUScfTTz9NE8NM6BupYxi9YqH6wAaS+lI6Igb0xSIzza0M\npnJ4EkV8dWrWps1lYvVVzZy2pobrKr3cOx3Bn80jx7LoHEcKencyc0z/OYBt1Upsa87C3NFxAq7w\n6+ON8qGfJoTYJYR4RAjReaxBQoiPCiG2CiG2BgJHj1vVeH0IvYTek8dQMpd3z+wipzOgnPddcvKn\n+e54ilQxhlf3VxCCj3cPw7KVhCdHCIwMIUkSLVmF/d2nU5D9VC7/K33b/EfsX1EUenr/E4OhhLbW\nr1JV9W4mJ+8mlRolGQ5h9/79WOj37Zzgn/60g9HQG1/k7K3KdCzD8qpB6qzDhI1e9gX3kXOovufo\ntGqlD0WHAKjNegg5cjw88DAE1NR8Q/U8TI0V1PotePIWCroZEBLDvjoSMnyJ/6B8b4yyGQMZgxG/\n08tMYBKLw8m27dvJZDKcx3oGRQW1jGDV1yEBgUIQT1zGFwkzVVPHuFygJCXjqbAemvvC82ppW1XJ\nR2p85BWFW8cCFON5dM7DlREThSLj2Txtx/CfAxjr66n7xS+QbLZjjjlZvBGCvh2oVxRlIfBj4N5j\nDVQU5ZeKoixTFGWZz+d7Aw6tcTSsi8sRRhvzEnki864h9NdJkuk1OKc+z0cn5qHwPF5lBlOjnZ+N\nOkFIdK1/BoDOmhrisVKQL8JR+xyBwKMUXhTC6Pc/TDS6jTlzPode76Ch/mYA+vt+giLLh8LYLA4X\ncGoLeiCuLghPRv9+GnlMx1K8Z86P6Tz3BTxz/dz8+M38aepe4HDo4lBsCF1RkAtEMJR7uK//PhT/\nATDawVVDtsGOM2XAk4+h96hrLBOmOThFnhImkHvKqUk2IRSFUU85OUmPxeVm06ZNzCk1kTfrSaOn\njmFs9lYAsqlhBOCNZthTUkZWQK3BgHSUrkFNVhOX+lz8bDTAHodA5zjsP+9Nqa9pq+3klL99vbxu\nQVcUJaYoSmL274cBgxCi9HXPTONvxn5mB0o+jXnJDcxpPIcelw7PR+cTF0aujH2I/8l9jFzoIWIG\niarWcqYc9XStfwZlbBu+eZ1UTE2xa0c1Jn0HvoW/o3f9ZhRZJhTeQG/vN7HbO6iqvAYAs7mKyoor\nCczci8Gax1FyUNBPfZdLMDnb6SZ24gQ9dGcPkUcGT9j+s9ksa9euJZU6vruMYnYYg5Qhn9CzbOke\n3u2cpqfQBUBkNtJlKDpEi1yFIst0dqykL9LHgcBuKG0BIfBXqmVt9bkBCvosQpYZMTVQoQQRssQ+\ny6dxojDfKBjz+JBNFnJ2N4lEgtWeIPs98wGoFeO4fMtBCFa7kthsBWw5M5OzC5zNTgvRx4ZIHmVx\n/7tttZTrdXx+sYUp2+H6K91J9QvmlVwub2Vet6ALISrEbPycEGLF7D5fOd5N44QiWUwo2REks5MD\njhSfiEf503iIqxUZnaWb9v6F3CKtBKVAqn6YXeZ6EsEZxn54JWbdMI2Dg8SSWUp9n0fo8gyP/Svb\nH72CHTvej5D0dLTfghCHPwQNDR9HQca3MHQoysVosSDp9Ke0hR5MqFEcJ1LQc0MxcgMn7ktx59bt\nbNiwge6d+wFIbtiAnMu9bNx9ffcxEB7ApusHYPCxGh7ta6LdLHOhK4jZ6SS69V748TKGJrfQkFHf\nB+cuvRyDZOD+9Cj41AqG42KGuM1IMTdARsljFhC0mXAXx0lNt6AUTVz5gRrOqyrH7/SQtLsICQPl\n5eU0pXbS5VuBAK5a9hOqat6DvrSUs70KzdZpbLnDLpYOn434s2OE7+0jNxo/4ny8Bj2/dvvISIKP\nJkMEcnn+MDHDj4f9mCVBvfkUtdCFEH8CNgBtQogxIcSHhBA3CyFunh1yDbBXCLEL+BFwnaIFH590\nzM05Mrt+Su0NK8jKMt96uIuqcidl75uLSdrF6u1VLJCSDEs1RBY+gSJBV8yHoftWauIJJAWGh9Lk\nNq3GUDNBVNdFU+NnWLXyMZzO+Uccy2Kpw1hcQklHGLNLtY6EEC9L/w+OjTDWddTo17clBwXdHz9x\nvVrldJ5i/OUC+0axY/tOAGaGpsl09zBy403E7r8fCll48DMwuZtcMcdXnv8KX9vwDWrtI8iynkzY\nxLpCls1JB3PNRSSvieC0n3x4mKG0nzmxAzScP05g5qv8R2WOqmrwe9Qb94nkBHFHOXJxgpxOh9vl\nJGYXlEkjJMbnsWj3T/AtmsM5XgeKkBipmUNGVli9eC7Cv5/9zlYaLEZKHY1Ikh59ZSWFySnKprbg\nTqjuQSErzLWaoKAACsE/H0DOFo449zlphW/tTnMgl2Ph8/v4l+4xTJLgZ3Pr0UtvrZT+4+V4olyu\nVxSlUlEUg6IoNYqi/EZRlJ8rivLz2e0/URSlU1GUhYqirFIU5YUTP22NV8P3yY/QeMcvaav2sKze\ng6zAx86ag65uKV7jDxAiz839dhSdm5naj7K/XrA7XkbO3429poSyZJLuLVtwPTzD+PM3I/1oAbWO\n69DpDt+KjuwLEg+p1qkcXoLOoBCM3n9ou8XhRApBdlC1MB/5359y73dveXMvxAlk5gS7XBRFQU4X\nKCZyb2iCVjYbYPeeTzIx0ctUUK1pEolESG3apG4fGIRnvwdbb4Xdf2EyOYmCwo7AVurd/eRTXgxW\nG2lrmP7w6RgloDJILC34hXIWaSGom5vEWRNHhAaIByspMxXZVryLUHg7gVSAvLEJ2WgGSYe5soqC\nJOFT/LTv3IjXmkIyGlnitGEu5Bj1lmMqppj36BWQT3LAUE6HzUJi4yTFeA5DRQX5qSnYv52G2YbT\nrrSMc7ZjkefqVoqhDJH7B464DsVYjtNniny/pYZrK7w8sKSFJ5e3cclJbFDxetEyRU9RhMGAzqGG\nbP3z+a1cMq+Cdy6qAqMVXXUd3vJ7WNGb4ntJE7J5DuvO/gwjFS182vQxPn3ZtXzzomvZ4q2g7KxO\nsjPLCBqWEn/iyUP7TydyPPi/u9k8+yFJTutITXkZG/89xaIqcBaHk6pYPZEHB1BkmcBwH9lklFQ0\n8qZfjxPBQQt96gQtiirZIshAQUHJvPbaOsdiZuZJAoFH2bPnXiQhcMlWIokoqa1qn85c9x5Y/9/q\nYP9+xhPjAAgU6hwTZENWCl4TQkhs7j+TnCIhlUyQLBjplQwsTxowOfN4o4tY+kIvg/tOZ9fOi0kW\niuzY+T4WWPLoi3XIVq96vEq1auPiySFcQ+MYGxoA0EuCznyaMU8ZzSUWdJd9n9RlP2ZANtEmdETu\n7SP+3BiGygpyw8MUg0GWeNTrVJ4XFCYSSFY91iVlOM6pJbVtmtTuw9F1xXgOyarn+ppSftBRx3KX\n7S1XbOu1ogn63wFntJTys/cvxXBwxb92FebIX7GfUcnZ62e421jEK2e44/KbuHf19ez31OCORVjX\ntpiBd15Ow8Iygr4FRNY+Rv+2TexY+yBP/fZP5JOb6Nu6jsm+HqIBP9nJeeTzISYn7wLAYndikW0U\nI1lmxkaRZ4V+sq//ZF2KN4xcQSaazgOvzeUiy/Lxj00ddhEct9slFYLhV75JjidUf/nE5G4aPbX4\nFCfRTPyQoOcPbAeLB1ovhul9hwR9uWsJFl2R0IyFabNEo20BCbOX3YX5VPqiKIDJbeYMPcgFgft5\nDz+RPkReglTKzaM9KxnLCW4szXHasj9gqjJRWjLImNwLwNKxYXJxA8bm9kNzPdtlIWG2Yl1+ISy7\niZ7Wa1CAltnv0Mz+ELrySpitx75sRTXGvEK9wUB+PIGhxqH2KD2vDkOFjfjTo4ev6WyW6KmEJuh/\nj9SthEIG17wwOq+ZqqdDrN12M/Xxh/H1/oL/uv9HfGtrEAX4eqxApsJIlhxPhqa597++wVO3/pwD\n6/9KIbOeROAh/vjlzzLRvR+zrg2ncxEjI79GlgvYbR4MwoSczNO/Zc+hw491vfbCX/F4nHvuuYf0\nW6SUQDilCqzVqGM6ljkul8iePXu45ZZbOHDgwHEdQ06/dkEvPPcDYr+6Us3MPAbxuCroOilEmcmJ\nQzGTlNPkw2F0DjO5SAHlom9Dw5mQmGYi1Ide6FlsVtdOQnIjFn0jC0fnUljq5bbc+7AZZSylGdyJ\nLFU1SZQuI7nH12PYFTl03Lr0Mv5nSuKesBGnZ5j5ZzxCW8cLxEyLkBSZymgAOS8wNhwuV/ueuS2Y\nigV+by5FURT2z0ahzAmr16Ywk0bnOpyx6Vw4l9811PCN5Q3kp5MYa+wACJ2EZZGP/GSSYjR76Jq+\nNEv07Y4m6H+P1K4EQExswnVJI4WYCSMXck2lAZNhGw2lN7BYqebKniEGzA5+uP9pcrHfEzNJlC9Y\nyTu+/G0srpupUC7E6LyBRRd8iNPf8wFWXHEt9fUfJZ0ZIRBYi8PgPXTI6b39IEwgbEwPDhxrZsdk\n79697Nq1iz179rz6YFT/8+9+9zv27j0xi7AzCVUU2iscpHJFEi9ZcHvpXJ577jnuuusu8vk8g4PH\nF4Yoz94BAMjHKejbX9jJb/uWkOl67BhzKZKY7b9ps2ew5CUcigUFSNkcuGrCKEVBoXQ1+6fnkix6\nGA91U2GrwFScoihLpNJuxnUHGHPVIBv1JKRyZAVcDQlKwpMYLEXkbTZ+tfJ6TFYLKKDLW9GlodbZ\nyNZgCf2P/CcT4x0I0UTGdRFVSgpm188PulwA6jwe/r2tnnWxNPcHInQl0lgkiYqpNJJNbScnZ9zq\n8+rr0dltnNPqoyKjgAzG6sP9PS3t6vsx062WLZBjuZdlib7d0QT97xFHBbjrYXQjlnklGI29xDLX\nssS1iI9Mv4tSvOj1v+Ad4xGaAuPsX7yMp1e9i+F519BVhLvuvpu8UWFBYx5TUU9yVE/Y4sSfSOEr\nPR+rtYn+ge9jFoc/LOmpAJK+AklfRmh8+BUmdxhFUfDPLjgODKhfAvv27Tuu52YyGQYHB+nq6nqN\nF+f4OOg/n1ulxttPx47udskUi9yw9jkeeHY98+fPp6KiAr/ff9SxL+VIl0v+FUYeJhQIU1B0DL5w\ndEFPpQaR5TSKInC58uSTGRyyWtc71VBH4JIihRKF4L4hnl4L+9IXMBEfo9pejVnpJZH0IJJpRJVM\nZNEqALJmK4NZPc76OCVzosgpwS/t/8DdlUsJzZmDrmDFmBIUinG+fca3udLzMXKKwsDAUqoq/4ex\nrKBOZMnF1VBY44saSgDcWF3KAruFr/SOsyWaos1mRplKYWp0YaixUwioMmaaezgV/2CYorHWfugx\nfbkVnctEujuEIiuaha5xClG3CkY2IZIzuPkJctFM+1o374icRW9biA2Vtaw0/Ddn9eykJZpi28JF\n/OXMBfz29HfQ7y4n6tlN5b/cgF2Ms6ewj02bNnHfffcRicRob/sG6fQwKfeDhw6ny+ZwlTWiN5aT\nikxROEqs80t5ZO8Uq7/9FCPBBENDQ+j1eoaHh4nH46/63INjDnameaM5mFQ0t1LNiPUfI9LlnoFR\n1pqcRJefzlVXXfXaBP21ulwSAWJp1fXT1zUER3EDHXS3RKNlCCmESMsYZ2UgP8/JULOTxDlF/N3q\nHKM0MpENU2WrxGscJJXwYojOcNbCD/JCNEm1BLIkMZppxlqaxdUUR5kq4fqbruAz57eQEhmsspXq\n6VEUqYA97KGhuJS8QX19ampqGMlkqdcXycX1oNNhqK4+Ys56SfCdtloCuQI74ynaLSYKoQyGChuW\njhLyUxkMje3Yzzjz0HPyY3Ekp/GItH4hBOZ2D9neiHo9ZUWz0DVOEWpXQtIPu/+CUerH2i6hjGfo\nsY/wYP2zXPmRr+B2FKkjznt3buOru3byjl0bKIkl2F21AEkPt932BwZqQxR0Bc5uW4AQggceeAC3\neyW1tTcRdz1FyLUFAIveTmVtAnedDkWRCc72o3wl1vfNUJAVtuztI5/Pc+aZ6gf2eKzu2GyD6mAw\nSD5/pHWbz0cpFl9f/ZWDFnrnQQs9fnRBf2pGXVDcYzEghKCsrIxkMkkymXzVYxwUdMlmOD6Xy/Re\nYnlVwIYiJgr+l69VxBP7kGUdU8E6UNI4hCBpTiApgukKHZ/gVr5y1lfYG1TPJ0IdASVHndWGWZ8i\nEfdiiIUZdC5AJ+BfGsoBUHrb1LnqFJRl1/PuZbV8cGkZBSVHWWk55VOqm6l77yCJQIacIYbN7kBv\nszOdK1BfWkNO14ixtgah179s3oudVm6sVuPY22QJFFRB71QTmCq+9lPcV191aHxuLIGxxvGy/Zjb\nvCi5IundanctSbPQNU4J6tTbZTb+DCQDrisWYj+jmvUrutk5sxPFYIEzPkNzYS+TUpip6DDzw046\nhwVjPguL1lxOMpnE7fHgCS7GsHGMCy64gIGBAf6wZQdKxacwpiqZmf97Es4uLKevx7jwD1Su+gNC\nJxMYfnU/ct9AmI9iYqC/HyEEK1asoLS0lP3797/qcw9a6IqivMwi3rnrRnp6vvHar9mLCCZzGHSC\n5jL1ln4qenSXy5YM6JQCAwULA6ksZWVlABxPcTo5XQCdQO81U0y8uqDLk3uJ5034KsvIyXpG1931\nsjGx2D6SSTfhhLqQaDNkMRVM2BUzg147OWGiSz+XL5zewgvtZoJ51e9ciirwmUQpRruDOwJRzi9x\n0vKj/wFg2NxMJioI5wR1tRcB0NulvsYNHS3YpCIoMDw0RiacIWeMUeX1sucPtwNQV1JJLl+CcbYZ\nxtH4t6ZKPlhVwvkpVbb0FTbVjeI1k9l/eBFYThcozKQPLYi+GFOzG3SC5Ba1HMCLLfhTAU3Q/17x\ndYDJBbExqFqEzu3AfVkTrXVz8af9alf2pTcwx6IKo0e20d7cxPwRNTzsm+lnWHqmkY++92Jsio7J\nsSzNlZVsnbeCf01KfHTjASp3fwjZnGB81XegYpSSbD3CmMXVkML/KoIez+RpCWR5v5BIT49RVVWF\nxWKhs7OT4eFhEonEKz4/Gj2cofpit4uiKCQS3SSSrz3S5sUEE1lKbCZsJj0Ok/6oyUXD6SxT2LiY\nhwB4OBA5JOjH43ZRUgUkqx7JYTwuCz05vBsZic4L3olBkunfvvXI/SkKsdheEgkP3pDqpxaWMMGY\nB1vRwLRFFcCvy//G3OQQTy608ny9G0PRhBToQ5YFRKyMtS0ikCtwfYkD82OPYiwWmPaW8Zu4i1+H\nTNQ71X0P9A2DImid10Bx1VkYsxLBsJ98LInQZXHv3sW+Bx8BoM6oJzcycsSC6Etx6HV8p60Wnz+D\nMEjovWY1I3luCZm+yKFM0Nz4rP/8KBa6ZNRhanJR8Kt3aJrLRePUQJKgdrn6d82KQw8vKlsEwA7/\nDjBYqF/zPuYrI6zJdxLL9eJOyZSEkozpWvj84O/5xRP/RHWbhz2Nczn/+b1s81RQHg3SrTcxVWgh\nvWk5pu6zKHnq/xGt+ALP5C6ipD3B9MDRI13y+TwTExPsHJlh1dzb6Tv3k5ikEZqamgCYO3cuiqK8\nqtslMBVCyHokdEy9qCN9Ph9CljNkMq+vjnkwkaPEropBmdOE/ygul2dCqrCczRM0KX086A/icDgw\nm83HJehyOo9k0aNzGI7Lhx4bU7+kvFU1NFTb6RtNohRm3U0TO8nedR2yHCeR8KJEGgDIWWYoFrOY\nZCMBkwu7kqIjWuCfxXeoS8jsqzPizJRCZoRE0oMxEWP7nHlUGA2snhxFyuepRSZqtmGMNiIbq7Ea\n1HoqU1NTGIo2fDUuHBdegCuWIkecYj4CgGPrNmYWLALAeMu3UNLpI0IWj0V+OoW+3IqYTc+3zPVC\nUSG5ZRpFUciNqV/2R7PQQXW7HEQTdI1Th9pZt0vtYUFvdjdj1VvZ6d8JgH7ZB7nK8QI+SaD4o0Q7\n+2mRLeQNlSwp+PhTdpTcghJuvbCKosPFbVKSrxpyoCjcVZUnOlCCvHcZZqmczyYr+ZXpowxV1xKZ\nOXDU2O3Nmzfz29/+iODAjVhrnkOR8lRWHqCxUW1jVlZWRmlp6atGu4QCEaSiCV3edoSgZzKqTzuX\n8yPLf3sNlplkjhK7erte7jQfNcrl6VCMUiWAN1tgORvZmcgxkc3j8/mOT9BTBSSLAZ3DiJwsoBRf\nISmpmCfmn6a7cS6fzJhpWrKUZMHA1OaHIRGAP7+X+ORTAEQiNRSFk3RRIm8JYIiNY5QM+PVl1Iko\npsxKjPYg87MBJkr0uIt1GM0BJuJVhAt59rnKua7SS37XLgBa7EaSFjMt2QaWlS8D1LuBaCqI01KC\nJAlql8zDGI0i63LkTUFQFEplmfg112IuFtHffy/w8giXo5GfSmIoP1yL3FjvwlBtJ/rgADO/2kNm\nfxBdiRnJajjq882z4YuSVY8wnFoSeGqdjcZro/MKaDwLmtYcekgv6Znvm8+uwK7ZB0yI1Z8iqN/G\nvEQz5101n/9YWYleLuAwryYmmfgKAQxFhR/YqjnvnDN55/IWaiMzPFZhIF2MkS4k2FRmZjwvYy1m\n+I24Gd3cDLu2bnnZlEZHt7No8cPY9AN4dn+IrL+dsvIBJPtsgogQdMwtw+X+GTt3foqxsdtIJI6y\n+BePIclGVdAnpw59ebzYMs9kjt0z9dVQXS6qdacK+pEWel5WeD4cZz47KeZXslJsA+CRmShlZWX4\n/f5XTUaS0wUki/5QNqOceIXQxWAf8ayOwbo2tmcKuM66BoFC37OPwl9vgFSQkN2CokAhUkshv45U\nzkzeEsAcHcJoTjFBNXVSgXz2LBRZR3lQrctT3eRFpyuwfmwFO2taUITgukovqV27CbaeQ0U+T8Ti\nQEra+cqyrwAQCMwgk6e8vIJNkQQ7MhkmJdVyz5j9OKNJyv7hHxgtyjQ4bbivuRp0OozNza94TYqJ\nHHIij+FFjSuETlD2iYW43zWH/HSS3EgcY/XRrXMAQ6kFfanllMsSBU3Q/74pbYEPPqCmeb+IRb5F\ndIe7SeVnI0GW3sBTzl5cRTsLix0snH6YNeEtrPVejN71AaYVE+/dlyW+S21Xp9/2K97lHyFos9NX\nVkm6mOC2Oh2ubJrz92xhBh/rWs/l/kceJvei8MVisYjJ9BeMxiI7d17MxpCFsckWDIYcQ+MPMhAd\nQFEUbNYHsdnChEIb6O75Kps2X3Ko3MBB0tkUdpsdfd5GLp8jGlULhB0h6Nm/3e0STOQOCXqZ04Q/\nlj1CoHfEksSLCvPZicnUQZuznHrJz0OzfvRMJvOq4ZdyWvWhH3QLvKLbZUqNcIl61cYxEVclte4C\nffu7UYbWw+U/wu/2kU47MU7vRgltQ4o5yFtmsEV7ydniRISX0oJENF1GemoOK8rWU56bocvSSX/P\nCkaiTexpX8JC8tTo9Gz117Gr6hriOzMUJD0Jo4XeXjWNv2fPIGm9kfvmVPOuHX18aN8Qw81L1LlK\nMp5wGM/738dwOke9xUjlN75B81NPYphdYzgW+Sk1OshQcWS3IKGTsJ9WRcW/LMd1SSOOs2tfcT/u\nd87BdVHDK455O6IJusbLWFS2CFmRuatXFcnRbJg73aoVnNvZC5t+wTszPSTMdqaca7BG72VhVZaR\n/SEi41HY9SeuibmQFIU9HYvZUl3CHq+JxdOjfKoixOnTPTylP59oBQwPH04y6uv7M07XJIr+veRC\np+GXokzEnBQSPsLB/+Nd976L32z6HMnU84wOdTLY9UFWn/Y0TscC+gf+55ALJRXLUiSLSZfAplcF\n4ODC6EGXC0D2b/Sjp3IF0vniYZeLw0yuKBNJzdZ2GRrgv//4RyQUOpW92O1tuNzLWCo/y6ZIEv1s\nI+0XR7r0b9tEZPrIO4aDFvpBQf/RX+7kSz/47dEt++m9xAoWQm5135PZPG2djYSyVv4auYhw2Rmk\nLDmScTeZYorHzr4SEamgYJnBWaZjxqvO3RIrMBMokO+pweFKsMrwDAfoRA658PgEEVcpl+gEd96y\niQnXQqo9KSwBdT45TwmPPfYYf/7zn/lrfw9/WXEe65C53OcmKyvMzGvHOFtkLFdSg2SzMZzJUWc2\nIoTAUF7+qtc+P6UaGS8V9INIZj2ONTUYq45toQOYWz1Y5pa86vHebmiCrvEyVlSsYGXFSv5ry3/x\n5fVf5o8H/khcnwLdAJmte2Gmm0ubV2MrKCwKDlEfuY+n7Xch6QS7H9gEmSiurI/lWegrr+X5tkYM\nxSLfnLuSBYuWcOWWACXM8GDT2dwzOMqesSjffXQH4xM/IBH3Ms3lLJAraDXWAQLX2Bqc0jgNxiLW\n6P1k4yamDlQyFQxjNtfSNOdzZLOTjE/8BYCh7ikQEB7aSXJMjTDp3q26kDKZcSyWutm/J4iHZkiE\nXl73ZO1MlO8OTh71+hyMQT+4KFrhUksKH4xFf/6O2+j2VtIkjyJSOlyuUtyupSxVNiAD+/Squ+Cg\nHz0RDnHfd7/JC3fcdugYSlFGyRZnXS4GZkSMRHyYYnic2zeNICeTFEIvauA9vQ+/zkfCpO57IpuH\nVYto/4cx8nOD/Pm7H0ZvTJEKWtm24oPsal/KHuNSFKmAfmE1IbcaIaLvHWK651eMDo1TzEo0Tg4h\nC4mZEheJCgvGXAbro9PEZ9Is2PMzLnh3FY1qwzKqliynsrKSYDDI45V1mPMyjy9v42dz6/EadMRc\nBpS8mg2aWnU+M/kCqaJMveX4QwfzU0kkmx7JfnT/+N87mqBrvAyjzsgvLvgFn1j4CR7of4Dbum7j\n9OrTcVRDLldPSn8p8Sdquee5JLdsV3hfJMRz4WeoWGDhwB6FrLGeYlLiwqyRmMVGd3kdl08UKcvq\noWoxZXKAG/puwyol+W+Dl5v399A/eRuKEmJicg17/UU60HFhx+m896yrqJw4l7wscXM5lBsUxtf7\n0CVT5GU1xtzrOR23ewVDg/9L8ZdnMLZjOwBSIcecxechcjl2b1jPjrUPkslMYLU2YjSWkslM8PCP\nvsddt3xVjZwJdrF9Wn3unVNhfjA8Taoos3XrVn75y18euj4H67iU2g/60FVBmopmmB7oY9/ePUz5\nqpknbyOZ9OB0OnG5llLHKAZkhosKNpvtkKAfWP8MiiIzun/PIev7YFJRMJhhYiTOVr0aFWQRBb73\n0G4OfPtL9H7+PYfmpEztZdhy2MKdyuYZmn4YoylJefMwrZepd1iGpI0dLapb44BNTQYqtLqYspiR\nlCKOcAwhTBhSPnpvm0NysxV7JsXm8nkM+Cpp799LLqhnnmeU0lg3ls5O1vh2oisqBBU773vf+5jT\n/g5CdifXmrx02C3oJcH5FisjFFlp7qSq6KG0pILRtPrFWGc+fl/2wQXRt3uZ2xOFJugaR0Un6fj4\noo/z6wt/TWdJJzfNuwnTmWcAekKJTyDZDNQurcAnl3Jl1IdR6NngfYh80UCv6VMAVGwLo5MVFEnw\n7qE0O/b72RAw4zDEMW5cxif3/ohr5D8xoldYW38VzwfPQ6TrWP2nn2BGYG300NDaiK5gZSRaipks\nxVgd4SknQy7VXdLX14cQgqbGz5DLzzAm+pgeVV0ZeuCcG27CIFegWB0M7dxGJjuB2VSF2VRFJjvJ\nzNgIMyND9O3eyqee+hRfWv8lAGbyeYoK7Imn6OvrY2Ji4lClx4MWusesY3Tfbnyzrhd/LMuGu/7E\nZFMHiiSxQLeFRMKL0+nEYHDitLdQIUUYSmcPLYwC7H/2KRCCRChIZHoSEgHk2Tj6rm1+7r99I2O6\nIF6D6k7x6lKMtT7B1CUDFBNJpgJdbEpH8dtnXQiKwlMDkxiNveh0p9Hacg+p1EVMBhrY7L6QnEFg\nLgbpdaop9r3xbiYMPkryUYp6PUbHNTRXNeMO5dDrjDQFptjq6SSnN7KoextCGHB1rcXc3o5kMtFh\n20hJMs+eyRhTg1H+uFt9bT60Wo1YSWycZNUz0yR1gofc5VyaX0JDtMhwZlbQLccn6IqsUJhOHdPd\noqEJusarsKJyBX++7M8sq1iGqaMRY50V+xmVlH9qMa4L6kEnIH8R58UqeSR1L5K5l/FACwCpYA77\nRIq6viHm+GcY75nis3/dhS4pIxk6KW5wUrs1x+fj38FBnF96Psr46AjVQv3AGmsd6L2qOyM0ORdz\n2kzfo2YWrLmQGW8GKZumezZr1ONZgVfXwHCthXB2tslBTQ3uMgdWvZuiTk8oMEY+H8ZsrsZkriKd\nGiMz2yLvobt+gT/lZzwxTjQbZSanWsjbYqlDwntwYfVgHZforvXc8fUvsfMPP0EnF5gc6KN/6yYS\nK8/BIueZQx/JhAf7bISOy7UMnzzMYCpLqc/H6MQUjzyzlcDIEIsuvBSAsf174XeXI6/7GQDJVIGo\n1IdJMXK6ThXIGxYFuK36vXy15D9J9e7iV1v+l29bmwm7S0FREPE8k4kZDIYcpf1WairaSFf8G1/e\n+VkerZ1H9UyKD1aVMm6xEMZNoriLCaopzyTISHkyQmbesiZciTRYHLRMqNfIF/RTG0niMIQQB3Zi\nWbAAAF1ylCY5wgRFHv7pbnrrTbRbTNRb1dcucyDEasWAWQjuKJMIIFPqz3CvP4xFEtQdR/9ORVFI\nbphAycsYKjVBPxaaoGscN0IvUfaJpbgva0YYJCSrAev8UtLyOVw3HeT0vI3Ha55HSqu3w89bp7EW\nv0C79GvkTJhFuiLT4SRyzzhCCJaseAeZmIXIs81c/XgYZzrJrZdexzNLTyNdTDHQux1h0ZPQpamI\ntmNZK5MO61hyzoXMnbsCXTLG+OQkhYIqvk3xSvJGCXvNHlAU6trU6nsVFRUgBKmiarmbzdWYzVWz\nUS4Krqoq8t2TzBGqxdoV6jok6FsjCcJhtdxqJBIBYGbWQk/P7KF6tZ++zeu41n8/6efvw2Sz0e0q\nY77kR0+RYqES/WxtErdrKWXKGIPpDAWjA0kpMvjYBiSdjtOueS9Wl5vRfbtgpgc5rNYaiehC5E1R\nmtNupKk4oljEZ3mcDZxBj+hgbGgLxW1lLIqdR9BThjlbwIsEVglFEeh++gzf+sf/4l/v3I27PUXc\nYmb1YC9XWVVf9oHcUuy1aaappDxZpChkiq4CzqYGPNVnoBhM1E5EaE/mOH3bMxSLDipK85AHy5wK\ntQBYbIIFtghhu45QsciwR8el5e5D75tCKIO9zMa5JU7kcgsvUOD5Qoa1MzH+tbESq+6VZUjOFQnf\n0UPkgQFMrR4sC31/83v4VEcTdI3XhW1FBSgWWqjhlrF+wo4DmPR5knKRTXO/R4XDwQ79MDl9BmtO\nYfXkXowBNdJklyOLK5ZAWE1UimnetX0dvuA031vdweXnevnw3kE+8b8/J6AEaE6WsidSQbk5juXu\nb3LN3eOIdJSiLDM2NgaAc3wYJa/H5BxFFPLcm3maJ4efpKFe9RnrnGpijtlchdlchaJkeX7Vuaxr\nVTvkfLBwAQB7Z7oIF1Qrf2s0ccivfVDYg4kcNqOOvLQB3/wgZ368A082hGG6j5p3XMNQNs8i0xDZ\njBk9h6Mt3O7llDNFWoaBiCqo5RmFxsXLsTpd1Mydz+i+3ShyESWZJSZSTNn3IPI5ytNmXI5yrJKf\nEX2GuFCrPG6OzyByehAQKq1ElypQpmSJ6e1k/U6klOD8zQ/wn5dUE62yUhIrcNX0DjonnseWl9kv\nLyRsLCEvjHhny9fayuNkeqwk5yxWr1c6zteTbjpHDoBw4ZgttGWxTEE6DIUMcyw6CjpB/Lo6ZOCS\nUnV+iqJQDGfQe0xc6nOhmHTc55P4bouJxUYjH619ZXEuJvP4/3cnqZ1+nBfUU3pDJ5JR99rfqH8n\naIKu8bowNrrQu2TShQvwkOOrZf+C11ygzz7IJ5d/nPuvvJ+za8+mx+pH0tu4cf/DyA6JosizeXwM\njAuRdRJ+Q5zGvIPz+/7Ml/ammTcVZqp8Dvd0nsY3FtRgkn3MZG3UxA2M/WYLhg07sSbCoChqrfRC\njmhwlJv1v2GPdw6ikGNXdAfB7/8Phm9/CSELDC5VmM1m1YeexcjmeWfS62hEavYx+cI2asxV7Aiq\ngjXHIOEvyCSMqutgbFANfQwm1aQiYVbbmcULD+E/61LGK5cQXXoGAPPEFlIJN1LxcAnc6EQS12xk\nyvhERH3QZKZh5VkA1HbMIxEOE82bGUjpuM+4BVkq4A4nkXWNGIomquu62clSJFlGKDLbDFYUXQ5d\nWiFqc+JKRijNTJEVZmy9dorVtbjSMeI77iejr+a07iw+W4L81qdZHC7QpW9jAvXOJBZWhTIT7ae7\nb5ARKYBVNqBXZohMx8gVcgjJjuX5+9CZBYbQcxBXI4EabWp0zV25FNUmA/Psao11OZFHycvovGbO\nL3EiFIXdC12kdfCfYT26V1ncjD0+TCGQovTGeTjPqzuU7q9xdDRB13hdCCGwragip3QSFl+keXMF\nLruLhe85iw/P/zAGycAXV3yRfeUxhM5AdT7LXYvTpE0hOqdXY0mrYlIQRXxFC/8c/yhXjRf4yj2/\np3rya1y2Oc4+n4NPnFGJo2imbDRK+dl2PO15avwJRCZJd283BA7wuHs5CcnOVutiZCXLf/wxT+fj\n/ZSvnIspZ8fgVkAWmEzlmIzl7GYRBb2RqN3DRedcQzoeY+VAGakXVJEq3fQ0ANNONyKXYaJfFfBg\nIseZwa1YSlPI20xIsp4VTQ/xQtlZbEhmKTPo8KZeIJEspZg6XESsZ+N6TH1qaKNJGqREthMwpIhb\nGwCo7VRbvK0rLOUupRSrYsI2cIDVQ4Po3WZkXRxf3QRbC6fRNtxPdWGMLqeaQFNMFsnpDTRmAvgK\nakeiWJee8quvwH7+eYyOqgu6HaM53NUlpPccYElYZtJQygHmApArmFiWn0NBkXncuJsxXZha2YfI\nBglPqusIVpsDUy6GubkKMboJJncD0OBS0+ln8gUuLnUdikIphNXz1XvMuA166hQd6CQ+EdVR1RV9\nxWzZvD9FcvMktpWVmFs9xxyncRhN0DVeN9aV9SBkkukzsC4tp/ozK6hpm3Noe7W9mo75s91tnB6e\nnmek0ZxGL5uwOIOUyGpVvNvK7iFclwVRwDjYRUJM4A3v5DMvREnqBd+56bPY7rwb73/+iZK2IK5M\nDkMyxvTkNJmR7dxXorpMDkhzSVnSlMTh1gslvF/6KFWFBCZrmkJcT7ann+xj29mCOqeIw4k3IeOt\nrsWxM4wpolqb169cgU6WCdh86PJp4qkY4YfWoh8Z5NypPyN0YOmyU7q/nRLDHhrsz/NsKEJHag9C\nFEkkPWTDQRRZRlEUejY9T43Sjo4CM9WC6qKHaRFlar/q2/dW12KymthpWUqZbOCSzHxqAlGskyP4\nWj1Eap4irrMzaqylfaSP6uQo/ZZGFEUQNqnCqEvlqLKodxjhbCnm1lZKPv1PjPiqcKVSmPMKntYW\n0n6JJWHVrbSOc7AVE1x9eiOL5AY+0HoZ4fYwUXOIucUavPpSYv4+ACrbGyn7l3+h5MMfARTYeqv6\nGnvKMc6K+CU+16HXvhiaFfTZxe2vz6vjihIXN1eVUAxmKASO3SM2+vAgwqDDeV7da39T/p2iCbrG\n60ZnN+K9fi6lN3bivaYVyfzyBgXnd6o1su+8sAyR+ig1c2owSEmqTUV8igmkIkOeGWpvWobjrDzk\nU1xpXMH+8heonoKfbk2TtFh5SDZAaTOG5sXUVCvoEmoExqbdg6zzLqMiEaEgDMTKVYtuuEzQs/s2\nOkQPJlOSTNbK5Bf/Bf8Pf8MOlmKQsyiSxMb127j8M1+k+YNX8PQyddFz5ZIllMdDxKoES9Y8gWzM\nsOeW3/Lx+75PtkVBKUJFySoMf54gJ9o5b+5zhArQJj1NeGgFwWAtSjpFeGqCwPAgkalJOlaswUuS\nKadEudeFIhRis80+hBA4yr0oQkdjqgSlWKB5agopnaW83ki8fAu7Mup1rIoGsMfjpISNuNlOaDYW\nfnfUxxyvKugzrhJMra0MevLsaaqh2h9Hr5OxLLyQXLqU9piMRc4TE27Kc0GuXNNE+WeWUnn9fL53\n7fdYc8U5eCUHZeY6Chk1pb+2s56SD92E7aJrwV0HY5vV94Cziloh4USw0nV43eCgha7zqIJ+Ubmb\nny9oxN6hhlhmDrwoOepFZPrCZA6EcJ5bi85+6tVcOVFogq7xhmBd4DuiLOlLMXlUq3du1U1M+stZ\n8L5zeO+/nkZbtpE+5xAvVDzI8vL5uM1uzG1q2OO1hlXMX9KIbMjQnJDxxZKs9auhg3S8k8baPgzJ\nOBQL/D7eQl4vsSawBauSwF/aAMCEF7qGnqKpMoXZmCRdcJHp6mFzVR0pYWdpShWkqZFRXFW1nH7O\nFRT0ahci3cQuOq1bGTVUo7PmKSkfxj//YoyFHMUOHZmwA9+Z5yMHw1jT7+eZ9DkA2B67kshzK1EU\nHaKQxz/YT++m5xFCorF9HuaMYFr4qFgcQY+eQvpwoa6M2YYo5HDnzeSUArbZaJtCapCsY4zt+eVY\ns2lMcg5LThXxkEdH1OFGKsrUGUfxKOoicdBXhqGmhucnNuAvqcaXkbBExhj956+Qz7rQK7Ai3K2+\nLiYTOklgKFPL0hokA1e0X4mpzkmZpQ65oJZoaFqsXleEgPbL1L9tPhSdgav7Mtw8WsDwIj93MZRF\nshuQTEcuZOrdZgwVNtJdLxd0RVaIPjSIzm3Cvrr6Zds1jo0m6BpvCpLdCBK0mHwoCgxHUuQG1Loc\nz7l3MmkqcJFe/UIw1tUijEYcYxFuWfMtGi3bKSgK7YEcO1JpcrIMc9+J2VHErZMwJLL0VLrRFVLI\nXS+wQNnJPvc8MmYTKbONrtQ47kUrMJrTpPMuRLtgwyXnYpIzrMmrgp5ylHPf8z2UWkoxmyrQUaC3\n/wZaOEBemBjMdFBZ2cuksZQPveOLWNwR8ulafjfST85opGUgxLrMeZQk8jgiOprqVQGTCkUm+/ro\n3vg81U3NjF96GRUDIaapIiQ9hNfiIyqFGJ2Mk81m8eNAH4tgkswUixlSJlUcxwKPUxASe611NE1P\nkDFbqI3rsCoJgk61hosnnuHsigMYkHFnkgTrGhCSxDNTB1AkI2WyE29zGZn9+xGzNctX16hfnotq\nj17l0DzHjdtYhkEyIOnMOLwvahrR/g71t6OS/FSKa3szXNuVQs4VDw0phDPoZ63zl+27w0tuOIqc\nOrKKZOZAiPxkEtfFDadcedsTjXa1NN4UhCTQOUyUyKpA9fkTxHf42U+R6oolmBQ41z/retDrMc6Z\nQ3a2cl/d9AMkiworInpyqMk+eJugfD5ul5v5+jaGS8vpzE5QmXHTmdtP1OBicEEH3rSZ/SYrN9rP\n4w5xPemsjYlzz2NdVQsdiT1UGVXL80DzuTx959MUZQW7pQaXEqaYhuzeGgD2R1twufwIaxcr3BMI\nSWE64SM0E2Vi2WKSGzeieE00TuSYV5wilXcDIMlO9j+7jvDEGBXRJEo2y8rJSdLCwpg8RJt3kpTI\nsmtDN709PRTRYUnHMekcmJJhcssXUZAgktxGr9JGWq+naSqGbDJj84eZQy8TzioirhJK4nk6Snfh\ndi/HFwwyXVbNur900TujhmvaJzKULWun6cEHcF6u9t88u0mN1V9YevTCWKYmNwKBz1yLxfmSO7Da\nVWDxgquGbK8a0olyuCIiqDHoOu8xBL3dCzJk+iJHPJ7pCiFMOizzS1/pLaVxFDRB13jT0LmMWFIF\n7IB/MIIukOZx8nxywee503sGrqHnIT8bBdLcTLavDxIBxPgkhVyMNSEFFIV1s52A6Lgcj6GDVFk7\nOb2Bpq5hREGiKT6CpBTZPr+TZYUEO8o+yGNJI/eLq/lL2RU8bJxPzOSmM7ALr3UCHRB22GiImHhw\n9wRmnRk3Efy7l6CPGjEWIjxlbkaWJdxzNjLf3Y1ShEBcFZze2gq2xDMUJUGjX6ampIykyGBSJKxF\nI5lECITA9fxGZpadSW1O9QlPK+WY6u6ktHSI4P7d7N2xFTtJ6qrOwqCzY0qEMK15B3ecfy4Plszn\nQfk69LJMfcAMko60yNGQG2bCWk3IVYI3H8Npm6DUcjq+YIAxs5u9T08yJ6CKty9axF1hxVBRgblj\nIQhYUOpg62lzOdP78nZtoGbrKpJCmbkWX33VS15QPVz/Zzj/a2T6IkhWde0kP6kKuiIrFCPZY1ro\nxloHwqIn0x0+9JiiKGS6Q5hb3IhXSTjSeDnaFdN409B7zRSG4zyKkws3h1CAp8jTWu6hof1KyKdg\neD0AppYWCpOTFAc2kwoYMYkg5UWwhvM8OaMuhCodl2MzzGNduRFjoUhFKkHDitPRp020coDnGjux\nWeeQdF7IxaZR3q/cSre7mT81L0WSCzRNjyCh4ChmKZjDWK0N3PngATJFBXcmykRmPgbFDLlegs5F\n/LPyc77Q9G98v+0GJoKV5LMmZGRCwsQd512CJVvkjOExjJKdlC6HDQuVfrWolt1Ygqkgc/eKK/Ga\nVGEU3SXoChk65j5H6fxfUBB/paN5A1LD0xRK9xEjz/tr5vKrKz/CbZ5/YKe+k/NCKTwx1a9eNBqo\nSYyiCAlZp8dnUcMV7cFavNEYQaMeuTlCf3kEZ0rGVADPbKcfebZfqZAENa9QHEsYJPRVVnzmOpwl\nR7GY61aiuFvIDUaxLipDmHSHBL0YzYKsoPMePbVfSAJzq4dMdwhFVtcQ8lMpirHcK67HaBwbTdA1\n3jRclzXhfU8bj5fpedKs8FytGWE34DAboPFM0Fug5zEATC2qTzf55KMUUnrctRZ0QtA2lWdvMk0o\nXyBSrMNtcPNMmcTSmRQoCpPhCEq+hMVso7ukkd/V/wv6bB/twb9yCQ/xzmfvQaBQOzNJ3loBgDUX\nY9xjIGXbydnGp4lhwzdjQdYXSYg080OPo88NUR4LcDEPkhVG7tFdi6To6CzUErS52Dh/Cct7s3SK\nPGmhkHEpWGUztahilleasVx6OU8FJVqLdiRZYXCskfJvKSS3nM6z5tOIlQt0ngT26h0EOv7Ib5c1\nE0bwr9t+xS+Vf+DJfT/kW7uK1I0cbv5cEj5c4rfBuZN8qIHscxLmgpW0SWJH5bNEfEYWltk574Md\nlNaqEShyMn/MFm0vxdrmw2Mqp2n+MhRFIT+TPsJPnh2OouRlTC1uDJW2Q4JeCB2OQT8W5lYPciJ/\n6DmZbnWR1NymxZ3/LWiCrvGmobMbsS4uwz/Xw7eyCe4UORpKZgstGSxqO7zetaAomFrUxbrw42qb\nOutStdTrBdNFFOC5cJyJjZP8ttVE2KTnsp5pJASxWAwp42EJqujpFT3nDK8jnx1Cr/eyfCzCBzc+\nw7k9u1FCqqC1lpiRbXYythnKOh8khpuyfCMZKYcwZrht8HFqQ//Dmh3PcWXybs7lcTaVnEHEYmeO\nUsGe2jYMhQJndcdxeet5UMkRy6ewKSZ8886jNSCBdSU/lVbhzMrYZKhMRxlxV5KP6fGbzuP3phv5\nNl9j99Of59GAm2GbjjsXd/C+yhIWl89gI4kuPIpSkJiTC6MrqF8UfcYIVYqa8NRmPYB5egly1IXL\novZg3SS6SEsldHpstJ9WeSjhR07mkWzHJ+jmOS4EAnefk+nvb2P6e1sJ/mH/ocicbG8EJIGpyTUr\n6AnV3fKSGPSj7ntWuA+GL2a6Qxgqbeicx18jXeMwmqBrvOm0lNkpyAo7RyM0lr6ocl7rhRAegmAf\nhqoqhNVKaiCKZJSwn9EJbhPnpRSMBdWP/vBUmN/OMXHd6FrOXP8w1UU1tlkXcVKpTHBt5Dm+3x3k\ny6OXU2aMk8yXcEbJRVybmMt7U0uYH1VrvJTo4iQtNs7O3Y9sUSgKHZmMmuF5pr4Ej6WEjtJOEsYs\n46PtvCt9N4aizObGuRx4l5HusirmTgyyIBdACIl7yJJOp3DYHejLO5kz0YsuF8YRt3Bxs9pbtDE7\nzsziRVRflOTu2mZM+TxGWeEPp5dzIFPJ7cpNmJQsX6h1Y65UkOIwmpwCqYC5ehFeJYZehgO6OO10\nYZcTuIlgn16qXkrUKJassYUCOtpsR4qqnDp+C91Y60SYdKS7Qug8JqxLysj2RkjvUhOiMn0RjPUO\nJJMeY5UdJSdTDGXUGHQBOvexxVlnN2KosZPpCSNnCuSGY5q75XWgCbrGm05z2extvwINLxb0lgvV\n3z1rEZKEaU4TAJamMiSDAdeaGkr0Ek3+PI8Eovyk1UBbqMBnlp2FyI/QWFSFwJwqwZD28JH6PVx8\n9XLMspEmSY/R78QqbOydeoR8OswcyxkYcmBJ7yYqKyQuLJIMq30vQ6gRN83xKpTKJbR6Opi0xpn0\nd9D85OdYNTLMgK+KP+kakBSZBeN9FEqNmFrcmD3qx8pR7qGQMLKtZgl73Dq8RR2+SSsyRZqIMKi3\n8yPjz9jutXHNcJ7vvtBHzOqku/LT7JQWcQV/xZXsQzZNYRwQBDMmkMfRVy7i3S1BrvK8gF9WuFa5\nna+If2cmbcSVqSKvKMwpqJZ40aoW2HqpoBeTBXTHaaELg0T5Z5ZS9e8r8X1oPp5rWjHUOog8OEB+\nJk1+IoG5WbW0D5a2zf3/9u48yK6yzOP497n71nv37eX2mqSzdiAJTSAhIBBA0EiIMiMUDopajJaK\nC6OFhaPF1IwzNWU5jhsMMiAoSokLRATiKGhEgZAACQ0kIUsnadL7vt/tnT/OTac76TXpTleffj5V\nt/rec849931v9/31ue953/fU95JoG8CZ4Z3w5KZvSTbRo1307W6GpDa3nA0NdHXOLcw7OZJwxBF6\nZinkLYN9z0DrQbw5Vl9u//lVAATX5BMTuKI5Tns8QWbUcH1NL2XL1+FbvoRISyMe4yKHHLzdAfri\n9biLMji6fieOQBu5HQt5yPcoDf2H2dH8O3C4cPcVEIhZw9rdb91BYP+NAHijg7gcLvzxYuIZG2hr\ny6Xd1wMITm8mlXVv448OsKurn2WtreQMwGFnI60Lk6x31wLw09oeHAaS19/M3XdvpM/bxfKepXQ5\nWulpWUo3hm2rsggmBvlcUT4XRou4+5Ef0edLp6Q/ynsdT3Pk2IMMxN7FWetgoCNJonUP4nATXPY1\nPB/+JjEj0J+kmGMUvOjA7xAODSbJSl27M+6z3rslgZOBboyxjtCDp4/oHYsr0zt0RC8OIWvLIpJ9\nMVoeqgED3spM6z3MD4ADYvU9xNsHh0aIjse3JAuMNRGX+Jx4StMnXS41kga6OueCXheRTGs2vqE2\n9BMWX2P1dPneGrwd2wEIbLDmaHF4nfRXZrC5NUFVc4xvvd6PlFsffv/KKtj9DLcMXkrpu/X4HLn0\n9R0iHu8lmfEr+gfS+J/EUf7qs7o89sY7ObQig/Cbn2PxvisBkOXXIB9bB8DuvkLSM7IRkjS1LmHb\na078WL1BdsdqcJk4lx6sJ+h0cGWNg+6kcMzZymN/+g2eaBevxiNQlAuOKNc3/o03fn8rrxf8kbAJ\n0T2QS7DT2te+vBCfqH+CwqpcwMniRJxvP/Yk39kZJbs9zrHOZwFINoG7NUrsyG4wPfS90Ukz1oCc\n2gGrD7n/QCEAnQmDozBBKGaIiZdin5ug6+RITTOYgISZdJPLaDxFIUIbIiRaBxCfC0+x1e1R3E5c\nuQFi9b3E2wbGbT8f2ldxGo6Ai2RPDF9lFuLUGRXP1ISBLiIPikiTiNSMsV5E5LsickBE9ojImukv\nprKbE80u5bmBkSvWfQ6u+Te44T7S7/wh2bfcSOA91w6tLnv/QnKihgd39VPSkWD5GivMfFVVJDuP\nEShvY/DVhwmlVZJMDvD23rtIxJp4qCvB61lHeLfd2lev00/uZWVkli1mYafVDbC7LE5LwjqyvXjx\nArqaavA6aji2101nZxZbOqyh/Udd1jS61W+5+UNehIwmB4fTDpNWmMamTZv48p1f4tFPX8aPknfj\nZycDjg38pqOBqx35BBzCkcBb3HqbNbOinyS3H30Uj7eOXkcbpvQCLjzeRlmfIdIwgCGJiAtnV4Jw\nW5JYQz0OfweDBzvoaLEubr27ZSnvdFTywqobAAg27uKtC4+Tm7py05KAf8RbnOy1/hGcTaADpF9V\nhjPHh39p1ohpbd2FQaLHukl2R3FlTXxyUxyCNzWborafn53JfOf6MfB94JEx1l8HVKZuFwH3pn4q\nNab3LM5jMJ4g4DnlTzCUB+uta5K6gfwLbxixOpgfZIcTKpNCTTLO1YusCyT4qqymhd4/Pw6xPtLy\nV0Hn4zQ1PU1R4d9zY9F5LMpcxDc7Bhk84qHTncmSwnR8t65hWb0f9g7S+NKDdFxwOw7gvz54Hr19\nXXDkDYr7zufx6krCO9N5x/c8UV+QuMRxxP28+IsDiAMaig7RsSBMdXU1e/Y+wdbnvsLVgwmq1l9I\n/wtuvnr0P3EiHPC/yE9KnuBX2f9I0Ong1nRDbqwT3tnGK6FuLum+FBNqx5gk4bYo75h0Ap4iQr4D\nhFo9QAxPsTBwADx7rUnEnj96Jc/XXsc3ixIkOpIU/lMVD/Q8jTtZAORQ6Rr5Hif7rH9gU2lyGY3D\n4yT/jtWIY+RxoacoOHTCdKxRoqcKrAozuL8d31JtPz8bEx6hG2O2A6NPiWbZDDxiLC8BmSJSOF0F\nVPb08Q0VPHb7ujN67t6F1tF9Ta6LDL91lOnKysIdidD/2msApC9YD4DbncOiRXdx89KbubDgQu64\najH7QpV05S/F57aaITIKluITQ4Nx0/L2s+S4nXicDrI6akhfkACBgl3N9PmjDLhSF4r2tuFxx2mv\n7yWyOIuqyDJern+Zb7xwN7e8/M/8IuDhk9kBvi5baQi28XzOK2Rd38RbWfdyzNmNUxL8Ze1SvrZ8\nKYiT+v2/45nMPbiNG++SK3D4nbhCBZzXs5IlzsvJC0SHPqzeijzcxSFyD/nxOwOQtI6CSxMZNLpb\nOJabYEfDDnwhq6zlxwdGvH/ROqvZyZlx9l0DHV7XafOtuAtPniOZTJMLgH9pNkVfX6czK56l6WhD\njwDHhj2uSy07jYjcLiI7RWRnc3PzNLy0mo/yVoW5kW7S14ycf8S30mrGcIXD+LNKiURuYcXyb+F2\nn5yf+/LFebD+Qyy9atPQMhEh3+ulYdH7aY4bcvuOQ08ztB/GWbYc74IMMFB/0SDdLqs744CvgzKn\nNbHXwtV5VOdXU99bz9aDv+W2ji7+fNG/cucFd/JS1yvcVvo1eq/yE1r/d4Sv+XcM0NrfSpHPg9Pj\nh9xKXmzfS03gAMZnMIPgDPkgLZ+sjgHSuwbwZ578qLoKCvAvyyG7I0iZp5gTrR15gx7qPE1sq91G\nb6yXRdlW80Wkph0Ts+ZzMQlD9/Y6PCVpM3ax5eH7HW9QkZp+Z/edyzLaGYxRL0NijLkfuB+gurp6\n7EuVKDWOSytzeaAgxAfOG/lF0F+1gu5nn8VTYQ2qWbrkX057rojw8MfXnra8wOumQXKJhVeR17oX\nfvYf1orIGtIjpfQXtxOuzqKvxpohMpSbzuKWlzkaq6ZimZ9I+gdoPP4KH9rxMxas/gQsu4GPAVsq\nt/Dc0ed4b7k1j3k4ZA37b+htoCBojVQlfwV/a2wmJxEjuCJM365mHH4XpBVAyzvQkY4nUgBYr92f\nEyIYsAb7XBBdwfEMP8c7+gn0JGjKbGN73Xac4mRLZCEHe1upaO6mb08zwQvy6dvTTKJ9kMwPLBwa\nZDTdnGkeHCE3yYE4jjQ94j6XpiPQ64CSYY+LgePTsF+lRpWf7uPZL1x22nJflXWE7llQMeV9Fnjd\nvNHdT8KTzQUZ+bDnVWtF4Sq8/ky8CzIJJKIMuKzmi5LSBVSULeaTf7gJuddJsGgVX+44CpmL4Op7\nhvab4c1gS+WWk2UPWN8qmvqahpaZ8Ap2dL3EpTHBvzIV6AEXpBXC4e3g8uKKlGP8B4nGBvlrz242\nlW8iKjFW9Cxgd3aAeMcAEjf0pA+SMAlW5a3imnA+V+eFadyxi54XjxNYHab7+WO4CwLWTIczyBMJ\nEe8Y1GuAnmPT0eSyFbg11dvlYqDTGFM/0ZOUmm6+qhU4MjIIrJl6R6sCj5uGaIzmWJy8ktWw+h+g\n/FLwZw5t43F6oBh25e5iWWQZbPgi8rHfwqVfAqcHxAkf/JE1jcEYTgR6Y1/j0LK2nAranU6WuTPw\nLcpEvE5rWH4oHwY6oeUAkl2Or7yC9kwnz9f9CXE52Buopaw9n/LcAItTJz4TWdZ5gfUR6xyCiBBa\nV0SsroeubbXEm/pIu7xkxoM2c8sicm5ZNqOvoU434RG6iPwcuBzIFZE64BtYHRAwxtwHPA28DziA\n9Z3wtpkqrFLjcYZCVP5lO+Keene8fK+bvoTVzpzrdcPm74+63cKChTzZ+ySLsxZbV+2puMy6TVKG\nNwOv0zviCL3Wn+rCGShEXA7yPrnSCvQjqSalWC9klpH14bW8vG8rL7z7Am0Dbbzu38vKlkV8Zl0F\nbS4f/K0RV54PjsD6ovVD+w+syafz2Vq6/1xndTNcmTfVt2fKXJnadj4bJgx0Y8zNE6w3wGemrURK\nnQWH58zabAu8J/8J5J7alXKYK0qvYH/7fhZkLDij1xERwoHwiCP02rg1HXB5xLpotackNTd527CT\nvlllZF2ymZJ3I/T/4dM8dfAp9gT3Iy1CTluUoHHQ53WyqqKatYP7qcqpGnqqw+skWJ1Pz1+Pk/ae\nYh24Y2PT0Yau1JyXPyzEc91jfyw2lm5kY+nGs3qtcCBMY+/JQD/cWYvH4aHwki+N3DBt2EnfzDIA\n1hasJegO8vj+x3nXV0fSCYOHOoi39OPK87OxbD0by04vX9oVJTgCboJrRr8ykbIHHfqvFCOP0PM8\nZzeCciL5gfyRTS5dtZRllOF0jLyQ8ohAz7IC3eP0sCGygdquWmKOOBR7GDzUSbypH3feKaNuh3GG\nPKRvLEVc+pG3M/3tKoV1UvSE8ZpcpsOJQD8xn3htVy3l6eWnb+jPsk62ejOs+ylXllw5dN+3IItY\nQy+JzkFcuWOfjFXzgwa6UkDQ5SQtNc3reE0u0yEcCBNNRukY7CCWiFHXXTd6oItAqACySkcs3lC8\nAZe48Dg8pFeGh0Z9uMIa6POdtqErlVLgdSPRGL4ZvjhxfvBkX/R2RzsJk6AiY4y+8+UbIDCyz3i6\nJ52LCi+irqcOb0k6uBwQT47b5KLmBw10pVLyPW4S52D8cjhgXSmpsa+RWNKa+XDMQN9y76iL71l/\nD13RLsTtwFuaxuDhTlw52lVwvtNAVyrlU6Vh2mPxGX+d4YOLugZTXRZHa3IZbx/B/KEj/eC6Ilzh\nAOJ2TvAsZXca6EqlXJVzbq6Uk+PPwSEOGnsbaehtIM+fR8gTmviJYwiszCWwMncaS6jmKj0pqtQ5\n5na4yfHl0NTXZPVwySif7SIpm9BAV2oWnOi6OGaXRaXOgAa6UrMgHAizr30fnYOdY58QVWqKNNCV\nmgXhQJiW/hZg6idElRqLBrpSs+BEDxVA29DVtNFAV2oWnOi66HF4KAoWzXJplF1ooCs1C04MLipN\nLz19Ui6lzpAGulKz4MQRup4QVdNJA12pWRAOhHGK84wvlKHUaHSkqFKzIOAOcP/V97Mke8lsF0XZ\niAa6UrNkbeHa2S6CshltclFKKZvQQFdKKZvQQFdKKZvQQFdKKZvQQFdKKZvQQFdKKZvQQFdKKZvQ\nQFdKKZvQQFdKKZvQQFdKKZvQQFdKKZvQQFdKKZvQQFdKKZvQQFdKKZvQQFdKKZuYVKCLyLUisk9E\nDojIXaOsv1xEOkXk9dTt69NfVKWUUuOZ8AIXIuIEfgBcDdQBr4jIVmPMW6ds+hdjzKYZKKNSSqlJ\nmMwR+lrggDHmkDEmCjwGbJ7ZYimllJqqyQR6BDg27HFdatmp1onIbhF5RkRWjLYjEbldRHaKyM7m\n5uYzKK5SSqmxTCbQZZRl5pTHrwJlxpjzge8BT4y2I2PM/caYamNMdV5e3pQKqpRSanyTCfQ6oGTY\n42Lg+PANjDFdxpie1P2nAbeI5E5bKZVSSk1oMoH+ClApIhUi4gFuArYO30BECkREUvfXpvbbOt2F\nVUopNbYJe7kYY+Ii8llgG+AEHjTGvCkin0qtvw+4Efi0iMSBfuAmY8ypzTJKKaVmkMxW7lZXV5ud\nO3fOymsrpdRcJSK7jDHVo63TkaJKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhK\nKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUT\nGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhK\nKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUTGuhKKWUT\nkwp0EblWRPaJyAERuWuU9SIi302t3yMia6a/qEoppcYzYaCLiBP4AXAdsBy4WUSWn7LZdUBl6nY7\ncO80l1MppdQEJnOEvhY4YIw5ZIyJAo8Bm0/ZZjPwiLG8BGSKSOE0l1UppdQ4XJPYJgIcG/a4Drho\nEttEgPrhG4nI7VhH8AA9IrJvSqU9KRdoOcPnzmXzsd7zsc4wP+s9H+sMU6932VgrJhPoMsoycwbb\nYIy5H7h/Eq85foFEdhpjqs92P3PNfKz3fKwzzM96z8c6w/TWezJNLnVAybDHxcDxM9hGKaXUDJpM\noL8CVIpIhYh4gJuAradssxW4NdXb5WKg0xhTf+qOlFJKzZwJm1yMMXER+SywDXACDxpj3hSRT6XW\n3wc8DbwPOAD0AbfNXJGBaWi2maPmY73nY51hftZ7PtYZprHeYsxpTd1KKaXmIB0pqpRSNqGBrpRS\nNjHnAn2iaQjsQERKROR5EXlbRN4Ukc+nlmeLyP+JyDupn1mzXdbpJiJOEXlNRJ5KPZ4Pdc4UkV+K\nyN7U73zdPKn3F1N/3zUi8nMR8dmt3iLyoIg0iUjNsGVj1lFEvprKtn0i8t6pvt6cCvRJTkNgB3Hg\nTmPMMuBi4DOpet4F/NEYUwn8MfXYbj4PvD3s8Xyo838DzxpjlgLnY9Xf1vUWkQhwB1BtjKnC6nBx\nE/ar94+Ba09ZNmodU5/xm4AVqef8MJV5kzanAp3JTUMw5xlj6o0xr6bud2N9wCNYdX04tdnDwA2z\nUsAZIiLFwPuBB4Yttnud04HLgP8FMMZEjTEd2LzeKS7ALyIuIIA1dsVW9TbGbAfaTlk8Vh03A48Z\nYwaNMYexeg2uncrrzbVAH2uKAdsSkXJgNfAykH+if3/qZ3gWizYTvgN8BUgOW2b3Oi8AmoGHUk1N\nD4hIEJvX2xjzLvAt4CjWFCGdxpjfY/N6p4xVx7POt7kW6JOaYsAuRCQE/Ar4gjGma7bLM5NEZBPQ\nZIzZNdtlOcdcwBrgXmPMaqCXud/MMKFUu/FmoAIoAoIi8pHZLdWsO+t8m2uBPm+mGBARN1aYP2qM\n+XVqceOJWSxTP5tmq3wz4BLgehGpxWpKu1JEfoq96wzW33SdMebl1ONfYgW83et9FXDYGNNsjIkB\nvwbWY/96w9h1POt8m2uBPplpCOY8ERGsNtW3jTHfHrZqK/DR1P2PAk+e67LNFGPMV40xxcaYcqzf\n63PGmI9g4zoDGGMagGMisiS1aCPwFjavN1ZTy8UiEkj9vW/EOldk93rD2HXcCtwkIl4RqcC6vsSO\nKe3ZGDOnblhTDOwHDgJ3z3Z5ZqiOG7C+au0BXk/d3gfkYJ0Vfyf1M3u2yzpD9b8ceCp13/Z1BlYB\nO1O/7yeArHlS73uAvUAN8BPAa7d6Az/HOkcQwzoC/8R4dQTuTmXbPuC6qb6eDv1XSimbmGtNLkop\npcagga6UUjahga6UUjahga6UUjahga6UUjahga6UUjahga6UUjbx/67A1vF9WfwcAAAAAElFTkSu\nQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "meta_params = meta_opt_state[0]\n", "\n", "for j in range(10):\n", " losses = []\n", " key = jax.random.PRNGKey(j)\n", " params = task.init(key)\n", " opt_state = lopt.initial_inner_opt_state(meta_params, params)\n", "\n", " num_steps = int(os.environ.get(\"LOPT_TRAIN_LENGTH\", 100))\n", " for i in range(num_steps):\n", " batch = next(data_iterator)\n", " loss, grads = value_grad_fn(opt_state[0], batch)\n", " opt_state = lopt.update_inner_opt_state(meta_params, opt_state, grads)\n", " losses.append(loss)\n", " plt.plot(losses)\n", " plt.ylim(0.0, 2.5)" ] }, { "cell_type": "markdown", "metadata": { "id": "tCyfmmax2_rp" }, "source": [ "## Meta-training truncated ES\n", "Next, instead of meta-training with truncated gradients, we will meta-train with truncated evolution strategies." ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "executionInfo": { "elapsed": 53, "status": "ok", "timestamp": 1647716701170, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "rxQbMHDf3Rfn" }, "outputs": [], "source": [ "@jax.jit\n", "def vec_short_segment_es(meta_param,\n", " keys,\n", " inner_opt_state,\n", " on_iterations,\n", " vec_seq_of_batches,\n", " std=0.01):\n", " # Compute an es estimate on a single inner-problem\n", " def do_one(meta_param, key, inner_opt_state, on_iteration, seq_of_batches):\n", " # Sample random noise of the same shape as meta-parameters\n", " flat_params, struct = jax.tree_util.tree_flatten(meta_param)\n", " keys = [jax.random.fold_in(key, i) for i in range(len(flat_params))]\n", " keys = jax.tree_util.tree_unflatten(struct, keys)\n", " perturbs = jax.tree_util.tree_map(lambda k, v: jax.random.normal(k, v.shape) * std,\n", " keys, meta_param)\n", "\n", " # compute positive and negative antithetic samples\n", " pos_theta = jax.tree_util.tree_map(lambda eps, v: v + eps, perturbs, meta_param)\n", " neg_theta = jax.tree_util.tree_map(lambda eps, v: v - eps, perturbs, meta_param)\n", "\n", " # Apply both of the antithetic samples\n", " p_losses, p_opt_state, p_on_iteration = short_segment_unroll(\n", " pos_theta,\n", " key,\n", " inner_opt_state,\n", " on_iteration,\n", " seq_of_batches,\n", " inner_problem_length=30)\n", " n_losses, n_opt_state, n_on_iteration = short_segment_unroll(\n", " neg_theta,\n", " key,\n", " inner_opt_state,\n", " on_iteration,\n", " seq_of_batches,\n", " inner_problem_length=30)\n", " p_loss = jnp.mean(p_losses)\n", " n_loss = jnp.mean(n_losses)\n", "\n", " # estimate gradient\n", " es_grad = jax.tree_util.tree_map(lambda p: (p_loss - n_loss) * 1 / (2. * std) * p,\n", " perturbs)\n", "\n", " return ((p_loss + n_loss) / 2.0, (p_opt_state, p_on_iteration)), es_grad\n", "\n", " (loss, inner_opt_state), es_grad = jax.vmap(\n", " do_one, in_axes=(None, 0, 0, 0, 0))(meta_param, keys, inner_opt_state,\n", " on_iterations, vec_seq_of_batches)\n", "\n", " # Gradient has an extra batch dimension here from the vmap -- reduce over this.\n", " return (jnp.mean(loss),\n", " inner_opt_state), jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), es_grad)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "executionInfo": { "elapsed": 686, "status": "ok", "timestamp": 1647716701983, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "Hn1LFTrV3Rfo" }, "outputs": [], "source": [ "num_tasks = 32\n", "\n", "key = jax.random.PRNGKey(1)\n", "inner_opt_state = lopt.initial_inner_opt_state(meta_params, task.init(key))\n", "batch = get_batch_seq(10)\n", "\n", "meta_params = lopt.init_meta_params(key)\n", "\n", "\n", "def init_single_inner_opt_state(key):\n", " return lopt.initial_inner_opt_state(meta_params, task.init(key))\n", "\n", "\n", "keys = jax.random.split(key, num_tasks)\n", "inner_opt_states = jax.vmap(init_single_inner_opt_state)(keys)\n", "\n", "# Randomly set the initial iteration to prevent the tasks from running in lock step.\n", "on_iterations = jax.random.randint(key, [num_tasks], 0, 30)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "executionInfo": { "elapsed": 32729, "status": "ok", "timestamp": 1647716734817, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "YmdqBmi83Rfo", "outputId": "a420fdd1-8d5d-4eab-f64e-aa0f99051eb9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 2.3028803\n", "20 2.301912\n", "40 2.2743235\n", "60 2.2011428\n", "80 2.0796442\n", "100 1.9650596\n", "120 1.8779399\n", "140 1.8373306\n", "160 1.8046211\n", "180 1.7904768\n", "200 1.7764364\n", "220 1.7700386\n", "240 1.7758758\n", "260 1.773428\n", "280 1.7594522\n", "300 1.7603681\n", "320 1.7564806\n", "340 1.7545398\n", "360 1.744331\n", "380 1.7509981\n" ] } ], "source": [ "meta_opt = Adam(0.001)\n", "meta_opt_state = meta_opt.init(meta_params)\n", "\n", "meta_losses = []\n", "\n", "num_steps = int(os.environ.get(\"LOPT_TRAIN_LENGTH\", 400))\n", "for i in range(num_steps):\n", " data = get_vec_batch_seq(num_tasks, 10)\n", " key1, key = jax.random.split(key)\n", " keys = jax.random.split(key1, num_tasks)\n", " (loss, (inner_opt_states, on_iterations)), meta_grad = vec_short_segment_es(\n", " meta_opt_state[0], keys, inner_opt_states, on_iterations, data)\n", " meta_losses.append(loss)\n", " if i % 20 == 0:\n", " print(i, onp.mean(meta_losses[-20:]))\n", " meta_opt_state = meta_opt.update(meta_opt_state, meta_grad)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "colab": { "height": 296 }, "executionInfo": { "elapsed": 203, "status": "ok", "timestamp": 1647716735126, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "3ZxRdgKg3Rfo", "outputId": "cd1e4eb0-8eed-4c6f-f778-52e517a1e60f" }, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'meta-loss')" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAABCQ0lEQVR4nO2dd5hcZdn/P/fM7mwvyaYnhE0lBBJCCL33JkRRFBBUBJEf8iKv\nWEBRUVFQXrBheRER9QUsgErvgdAhCek9IaRn07O9Pr8/znnOnpk5MztbZmc3c3+ua6+dOXPmnHvO\n7jzfc5fnfsQYg6IoipK9hDJtgKIoipJZVAgURVGyHBUCRVGULEeFQFEUJctRIVAURclycjJtQGcZ\nNGiQqayszLQZiqIo/Yq5c+fuMMYMDnqt3wlBZWUlc+bMybQZiqIo/QoR+SjRaxoaUhRFyXJUCBRF\nUbIcFQJFUZQsR4VAURQly1EhUBRFyXJUCBRFUbKctAmBiBwgIrNEZJmILBGRrwbsM1NEForIfBGZ\nIyInpMseRVEUJZh0ziNoAW4yxswTkRJgroi8aIxZ6tvnZeAJY4wRkanAP4BJ6TBm5bZqnlqwmXAo\nRE5YiIRD5IaF3JwQueGQ+zxEJCdEaX4O5YURBhTmUl4YIZKjjpOiKPsvaRMCY8wWYIv7uFpElgEj\ngaW+fWp8bykC0rY4wspt1fzqldWdfl9uWDh23CA+e/Rozpo8FBFJg3WKoiiZQ3pjYRoRqQRmA4ca\nY/bFvPYJ4A5gCHC+MebtgPdfA1wDMHr06CM++ijhBLkOaW0ztLS10dxqaG5po7m1jaZW93lrG43N\nbexraGZ3XRN76ppZt6OW55duZcOues44eCi/uGQaxXn9bkK2oihZjojMNcbMCHwt3UIgIsXAa8CP\njTGPJ9nvJOB7xpgzkh1vxowZprdbTLS0tvHAmx9y57PLOefQYfzmsunqGSiK0q9IJgRpDX6LSC7w\nGPBQMhEAMMbMBsaJyKB02tQVcsIhrjlpHDefO4lnFm3lt6+uybRJiqIoPUY6q4YE+COwzBhzT4J9\nxrv7ISLTgQiwM102dZcvnTiWCw8bwV3Pr2DCd55hxdbqTJukKIrSbdLpERwPXAGc5paHzheR80Tk\nWhG51t3nk8BiEZkP/Ab4jOmNpEUXERHuungqVx5fSXOr4bF5GzNtkqIoSrdJZ9XQG0DSQLox5qfA\nT9NlQzrIywnz/QsOYd2OWp5csJlLjjyAUQMKtcRUUZR+i45eXeSLJ4xhy94GTrv7Ne54dlmmzVEU\nRekyKgRd5MQJg7nn04cB8I/3N7B+Zx0bd9dl2CpFUZTOo0LQDS6aPorH/t9x1Da1ctJdszjjntdo\naG6lqaUt06YpiqKkjApBNzniwAF892OTAWhobmPSd59jxu0vsmlPfYYtUxRFSQ0Vgh7gqhPGsO7O\n87ngsBEA7Gto4W/vrc+wVYqiKKmhQtCD/PIz0/jLF4/i+PEV/P39DazdXkMfroZVFEUBVAh6lFBI\nOGniYK48bgxV1Y2cdvdr/P61tTS1tNHapoKgKErfRIUgDZw6aQhjBxUB8NPnlnPo95/nW48tzLBV\niqIowagQpIFwSHjqhhP461VHEQ4JTa1tPDp3I23qFSiK0gdRIUgThZEcTpwwmLduPo3/Om08ACu2\naW8iRVH6HioEaWZoaT6XHjUagNufXkpDc2uGLVIURYlGhaAXGFFewA8uPIQ3V+/k2cVbMm2OoihK\nFCoEvcQVxxzI8LJ8nl64NdOmKIqiRKFC0EuEQsLHpg7ntZVV2pNIUZQ+hQpBL/LFE8YgItz2xBIa\nWzRXoChK30CFoBcZXlbAzedM4qVlVfzP8ysybY6iKAqgQtDrfPGEMVx0+Ej++s5H7KhpzLQ5iqIo\nKgSZ4LpTx9PY0sb9r3+YaVMURVFUCDLB+CHFnD9lOH99e52uXaAoSsZRIcgQpx88hNqmVj7aWZtp\nUxRFyXJUCDLE+MElAKyuqsmwJYqiZDsqBBli3BCnO6kKgaIomUaFIEMURnIYWV7Aal28RlGUDJM2\nIRCRA0RklogsE5ElIvLVgH0+KyIL3Z+3ROSwdNnTF5k0rIT/zN/MmFue4YP1uzNtjqIoWUo6PYIW\n4CZjzMHAMcBXRGRyzD4fAicbY6YCPwLuS6M9fY6zDxnmPZ61vCqDliiKks2kTQiMMVuMMfPcx9XA\nMmBkzD5vGWPsrfA7wKh02dMXOWdKuxDsqW/OoCWKomQzvZIjEJFK4HDg3SS7XQU82xv29BVK83N5\n/saTqKwo5MMdtby1Zgfrd2pDOkVRepe0C4GIFAOPATcaY/Yl2OdUHCH4VoLXrxGROSIyZ/v27ekz\nNgMcNKyEKaPKWbezlsv+8C7n/nJ2pk1SFCXLSKsQiEgujgg8ZIx5PME+U4H7gZnGmJ1B+xhj7jPG\nzDDGzBg8eHD6DM4QYyoK2bCrHoDaJu1KqihK75LOqiEB/ggsM8bck2Cf0cDjwBXGmJXpsqWvc/Dw\n0kyboChKFpOTxmMfD1wBLBKR+e62bwOjAYwxvwe+B1QAv3V0gxZjzIw02tQnOeLAAd7jSDjEs4u2\nsGjTXr55zqQMWqUoSraQNiEwxrwBSAf7XA1cnS4b+gtDSvO9x7lh4alFW3h+8VZuOH0C+bnhDFqm\nKEo2oDOL+wgXTXcqa2ubWqna10BLm2HltuoMW6UoSjagQtBHuOfT0/jxJw4FYO12pyPp4k2BRVaK\noig9igpBH6KiKALAztomAGatqNL1jRVFSTvpTBYrnWRgUV7U8xeXbgPglIMGc8pBQzJhkqIoWYB6\nBH2IiuKI9zg33J5nb9RVzBRFSSMqBH2IUQMKvMf+klJd5F5RlHSiQtCHyMtpLxU9Yfwg7/H2ahUC\nRVHShwpBH6M030nbTD9wAL+//Agi4ZAKgaIoaUWTxX2MwSV57GtooSiSw3HjBjFmUJEKgaIoaUU9\ngj7G1886CIADKwoBRxi21zSyt76Z1jZd0lJRlJ5HhaCPce6U4ay783zKC50KosEleazbUcthP3iB\nu19YkWHrFEXZH1Eh6OOMKM9nd52zetmTCzdn2BpFUfZHVAj6OFNGlnmPBxRGkuypKIrSNVQI+jhT\nRpV7jyNh/XMpitLz6MjSxxlR1t6iWieWKYqSDlQI+jgiwk8/OYWR5QVUaRmpoihpQIWgH/CZI0dz\nxbEHUtfUSk1jS6bNURRlP0OFoJ8wtNTpTFq1ryHDliiKsr+hQtBPGFHmNKR7auEWFm3cm2FrFEXZ\nn1Ah6CccMrIMEbjnxZVccO8bmTZHUZT9CBWCfkJxXg65Wj6qKEoa0JGlH2HzBAANzbp8paIoPYMK\nQT/igc8fyWGjnJnGm/fUZ9gaRVH2F1QI+hEThpZwy3kHA7BJhUBRlB4ibUIgIgeIyCwRWSYiS0Tk\nqwH7TBKRt0WkUUS+ni5b9idGljvVQ5t2qxAoitIzpHNhmhbgJmPMPBEpAeaKyIvGmKW+fXYBNwAf\nT6Md+xXDyvIJh4TXV+9gT30zXz5pLCLS8RsVRVESkDaPwBizxRgzz31cDSwDRsbsU2WMeR9oTpcd\n+xu54RAHDizk6YVbuPPZ5eyqbcq0SYqi9HN6JUcgIpXA4cC7XXz/NSIyR0TmbN++vUdt649MGFrs\nPdb+Q4qidJe0C4GIFAOPATcaY/Z15RjGmPuMMTOMMTMGDx7cswb2QyYMKfEeb9OWE4qidJO0CoGI\n5OKIwEPGmMfTea5soji/PbVTtU89AkVRukc6q4YE+COwzBhzT7rOk43MnDaCUQOc6iH1CBRF6S7p\n9AiOB64AThOR+e7PeSJyrYhcCyAiw0RkI/A14FYR2SgipWm0ab9geFkBb3zrNMoLc9lW3cBzi7ew\ndHOXom6KoijpKx81xrwBJK1rNMZsBUaly4b9naEl+by0tIr/e2c9AHNvPYOK4rwO3qUoihKNzizu\nxwwpzWOrLzT00a66DFqjKEp/RYWgH3OYb2F70NnGiqJ0DRWCfswJEwZFPd+oQqAoShdQIejHTB89\nAIDjx1dQXpjLw+99xOqqmgxbpShKf0OFoB8TyQkx+xun8r9XzKC11bBhVz1n3PMajS26VoGiKKnT\naSEQkQEiMjUdxiidZ3RFIcV5OQwsjnjbnl+yLYMWKYrS30hJCETkVREpFZGBwALgTyKik8T6EA9e\neRQPXnkkAFt0rQJFUTpBqh5Bmdsn6CLgT8aYI4Az0meW0lnGDCri5ImDiYRD2pFUUZROkaoQ5IjI\ncODTwFNptEfpBiLCwKKICoGiKJ0iVSH4IfA8sNoY876IjAVWpc8spasMcIXgsbkbeWPVjkyboyhK\nPyClFhPGmH8C//Q9Xwt8Ml1GKV2noijCrrombvrnAgDW3Xl+hi1SFKWvk2qy+GdusjhXRF4WkR0i\ncnm6jVM6z4CiCBu01YSiKJ0g1dDQWW6y+GPARmAi8I20WaV0mYqiCDtqNEegKErqpCoEue7v84BH\njDG70mSP0k0GFkWinre1GZpa2gA0iawoSiCpCsGTIrIcmAG8LCKDAV0RpQ8ywCcEIYGvP7qAibc+\ny6srqpj+oxd5fZWu+awoSjQpCYEx5mbgWGCGMaYZqAVmptMwpWscWTmA8kLHgWsz8Pi8TQC8v85x\n4uZ9tIe31uygSlc2UxTFJdVkcS7OamN/F5FHgauAnek0TOkak4aVMu/WM/nexyZHba9tdPoPGQyX\n/eFdPvHbtzJhnqIofZBUVyj7HU6e4Lfu8yvcbVenwyile4RC4nkFlprGFgBaWg0Am7QNhaIoLqkK\nwZHGmMN8z18RkQXpMEjpGUrzo4Wg1hWChmbtTKooSjSpJotbRWScfeLOLNYRpQ9TWhDsEdSrECiK\nEkOqHsE3gFkishZnQfoDgSvTZpXSbcoKgj2C+iYVAkVRokm1xcTLIjIBOAhHCJYbYxrTapnSLUoL\nov+0dg7BvobmTJijKEofJqkQiMhFCV4aJyIYYx5Pg01KD1BRlBf13M421klliqLE0pFHcEGS1wyQ\nUAhE5ADgL8AwoA24zxjzy5h9BPglzozlOuALxph5KditdEAkJ0RxXo6XG7C/VQgURYklqRAYY7qT\nB2gBbjLGzBOREmCuiLxojFnq2+dcYIL7czROSerR3Tin4qNyUCGLN+2L2qZCoChKLF1ZszilhWmM\nMVvs3b0xphpYBoyM2W0m8Bfj8A5Q7i6Ao/QAp0wcErdtX0OL97i5ta03zVEUpY/SaSEgfjDvEBGp\nBA4H3g041gbf841BxxeRa0RkjojM2b5de+Wkyn+fOZE/XXkkg0vyAl9/78NdvLVGF69RlGynK0Lw\nQWd2FpFi4DHgRreVddTLAW8xcRuMuc8YM8MYM2Pw4MGdOX1WEw4Jpx40hJL84AjgZ+9/l8v+EKvN\niqJkG50WAmPMF1Pd1+1R9BjwUIIKo43AAb7no4DNnbVJSU5JzCxjRVEUP6k2nZsgIo+KyFIRWWt/\nOniPAH8Elhlj7kmw2xPA58ThGGCvMWZLpz6B0iHDS/OTvt7WZrQbqaJkMal6BH/CqehpAU7FKQv9\nawfvOR6nOd1pIjLf/TlPRK4VkWvdfZ4B1gKrgT8A13X2AygdM7qiMOnrv35lNUf95GU2ayM6RclK\nUm0xUeDOLhZjzEfAbSLyOvD9RG8wxrxBcA7Av48BvpKytUqXOGBAQdLXn1nkOGHbqxsZUZ58X0VR\n9j9SFYIGEQkBq0TkemATEF+bqPRJRg2M9ggKcsNRzeeq3bYTIUmq24qi7KekGhq6ESgEbgCOAC4H\nPpcmm5QeZnSMEMR2ILVzC5p0XoGiZCWpCkGlMabGGLPRGHOlMeaTwOh0Gqb0HCPdcM8RBw4IfN22\nn2jUFtWKkpWkKgS3pLhN6YPk54aZ990z+b+r2rt3RMLxf3pdq0BRspOOuo+ei9MQbqSI/Mr3UilO\nBZHSTxhYFIl6np8bigsFqRAoSnbSUbJ4MzAHuBCY69teDfx3uoxS0se/rjuONmN48K2PeHJB9Nw9\nXbRGUbKTjrqPLgAWiMjD7r6jjTEresUyJS0cPtrJExwyooyzDxnK9Q+3dwxpaNFksaJkI6nmCM4B\n5gPPAYjINBF5Il1GKeknPzfMUZUDo7Zt39dA5c1P8+windytKNlEqkJwG3AUsAfAGDMfqEyHQUrv\nUZgX7RCu3l4DwC9eWpUJcxRFyRCpCkGLMWZvWi1Rep2C3HDUc5sjqGvWOgBFySZSFYLFInIZEHYb\n0P0aeCuNdim9QDgkRHLa/wXs6mX1TZorUJRsIlUh+C/gEKAReBjYC3w1XUYpvUexLzy0q84KgXoE\nipJNpCoEk92fHCAfZ4nJ99NllNJ7HD9+kPd4V40rBDqfQFGyilSbzj0EfB1YDGjcYD/i8qNHe/MJ\nat0cQVvcGnGKouzPpOoRbDfGPGmM+dAY85H9SatlSq9w9NgK3rr5NKaOKova7nQIVxQlG0jVI/i+\niNwPvIyTJwAgwfKTSj9jRHlBXAXRz55fwagBBXz26AMzZJWiKL1FqkJwJTAJyKU9NGQAFYL9hIJI\ntBD87tU1ACoEipIFpCoEhxljpqTVEiWjxHoElrY2QyikC9Yoyv5MqjmCd0RkclotUTJKIiHYvFfX\nMVaU/Z1UheAEYL6IrBCRhSKySEQWptMwpXcZXJIXuH3VtpqUEsc7ahpZta26p81SFKUXSDU0dE5a\nrVAyTqJF69dsr+H2p5cyrCyfh64+Ju71+qZWcsLCaf/zKvsaWlh35/npNlVRlB4mJSHQUtH9n0RC\nsL2mkTXba1mzvZbH523kN7NW88J/n0zYzRsc/L3nOHZshbfuMcDOmkYqioM9DEVR+h6phoaU/Zzh\nZfmB22ev3OE9/to/FrBmey21MS0o3l6703v8zzkbOOL2l1iyWXsUKkp/IW1CICIPiEiViCxO8PoA\nEfmXm3N4T0QOTZctSseMTOARLNuyL25bspXMXl/lCMdKzRcoSr8hnR7BgyTPLXwbmG+MmQp8Dvhl\nGm1ROqC8MDflfWsbHY+gNaAXRUubM81EJyYrSv8hbUJgjJkN7Eqyy2ScmcoYY5YDlSIyNF32KMkR\nESYMKeaLx4/xtg0qjgTuW+d6BA0Bzekam9uF4LG5G9m6tyEN1iqK0pNkMkewALgIQESOAg4ERmXQ\nnqznxa+dzPcuaJ8uMrwsOFy0dPM+Km9+mpeWbYt7rdFd97i2qYWb/rmAy+5/Jz3GKorSY2RSCO4E\nBojIfJz1Dj4AAhvhi8g1IjJHROZs3769F03Mboa5CeTS/OjisleWVwHwzzkb495jvYQmVxA27Krj\n8vvf5dL7VBAUpa+S6jyCHscYsw+nhxEiIsCH7k/QvvcB9wHMmDFDo8+9xAA3bzB2cDHzN+zxtu+t\nbwba8wF+rEdgfze3Gt5YvSNuP0VR+g4Z8whEpFxEbBD6amC2Kw5KhvnNZdO59KjRFOe5QjCoKOr1\nPa4QBOiA5wnU6SpnitJvSGf56CPA28BBIrJRRK4SkWtF5Fp3l4OBJSKyHDgXXfqyz3D+1OHccdEU\n8nOdf49RAwsBOH3SEAC2VzudyIM9Aic0VN2gQqAo/YW0hYaMMZd28PrbwIR0nV/pPq1uDWheToiF\nt50FwNTbXmBHjSMEu+ua495jQ0IqBIrSf9CZxUpCWlsdIcgNC6X5uRRHou8brGfgp10I4kWiI+at\n383qqpouWKooSndQIVAS0uJOGAuHnH+TUEi8cBFATWP8Xb+dbLavCx7BRb99izPueS3pPj94cgmf\nf+C9wNfeXL2DBb6ktqIoqaFCoCTkzMnO/L5jxg70thVGkkcTk4WGgmYid5Y/vbmO11YGlxB/9v53\nmfmbN7t9DkXJNjJWPqr0fY4fP4gP7zgPp7rXoTASZldtx+8NCg01trR2KCSKovQ+6hEoSfGLADhL\nV6ZCkEewp645YTM6W22kKErvo0KgdIrNbu+giUOLk+4XlD/46t8+4Kyfz/YmpPnRKqOus3Jbtc7b\nULqFCoHSJU6aMBiASE6IgUXxzemC8gHvr9sNwL4AIQja1lss2riXg259lqp9/a9BXlNLG2f9fDbX\nPTQv06Yo/RgVAqVT3PWpqdx6/sEMcAf/sAj/uu44Ljt6NDkh6eDdDkF3/0HbGppbueeFFTQ0t7K3\nrtmbtdzT/OnND2lsaWP2qv7XCsNO6ntr9c4O9lSUxKgQKJ3i4hkHcPWJY73F7uubWzmwooiffGIK\nRXmpJYLXbK9h8veeY9763d62fQHJ5TnrdvOrV1YzZ91uDvvhCwnvehdt3NsjOYbUZKxv0ezO9WjV\nBSCUbqBCoHSJ86cMj9uWl5Pav9OSzfuoa2qNmjwW5BHYuLf9HdT2+qOdtVxw7xv8+OllKZ17f8OG\n4HqiNFfJXrSWT+kSRXk5/O8VR3jtJgDyc8MpvXfznnoAanyDvz9HYIxBRKhvTty3qLXNEA6Jt/BN\n0JKaqdKfh9CW1vSEy5TsQj0CpcucfcgwPnv0gd7zATHLXYYEbj3/4Lj3bXKFwD/A+0NDdlKaXRt5\nT0AiudkdAOtcsUhVhIIwblhF+mFsqEU9AaUHUCFQeozYFc0iOSGuPnEsr3/z1Kjtm3a7HkFj+wDv\nFwW73KX1CPbWNcWdywpBgysW/kR1V++SReAfczZ4QgVOWGrhxj1dOl5v0NKqQqB0HxUCpccYXp4f\n9TwSdv69Yu/Wt7plmjWNLfzsueX8+4NNUaGhBjfxa4Ug2CMw3jEAcsPt/8pNnRQCO5TWNbXyzUcX\n8tk/tK+m9t1/L+HCe99kWw+Ulv79/fVU3vx0j5apNgctCqEonUSFQOkxhpfFCEGOIwAFkeCwTXVD\nC799dQ03/n0+NY3tVT81jS3UNLa0h4YC2l03t7axdW+DNznNLwTWo4jFLqNpue6huVTe/LT33N5d\nb9vXnvdYt9Ppp7F2ewp9NTrgsbmbnGPt6P6xLJokVnoCFQKlx7AlpZZI2AnX5CeoJvKHg/wzY7/y\n0DwO/f7zSXMEv3hpJcfc8TJLNztJ4pAvNNQYMN9g+dZ9TPruczy3eIu37ZlFWwGwlZfWAzG+9PGo\nAU64a30qDZY6wjWxJys9m1P0ft5ft4sPe1CAlP0LFQKlxxhYFC0E1W7YJiccIhIgBv4Esb8lxfKt\nTj8ie7cflCN4aVkVAMvcfa1oQHDfovc/3AXAaysTTxqzLbT9A3WF+5nW9IBHYKXKYPjD7LU9EiJK\nNUdw8e/f5tT/ebXb51P2T1QIlB5jZLlz91zsTizz3/EPLs6L29+/sM2WvQ1xVTs2lxDkEdiQiBUJ\nv0cR5BHY0FNJfnzFtB1KaxvjBcTO3F3TAwvm2M+3alsNP35mGdc//EG3j6lVQ0pPoEKg9BjjhxTz\nx8/P4I+fnxH32pDSeCGo8sXiV1fVUBHTs8jOEQjKEdjKILtcZq3fIwjIEdgKpaKANthtrgtgxcQ/\ntNpj+SuJOktdUwt765sJuUpgwzl76uM9nc6SSoWUSfOs47qmFm55fCF7Ajw3pX+gQqD0KKcfPJRJ\nw0rjtpcVOHMMJg0r8bbFVvdUxISWrEcQ1K3UegQ2rl/X6PcInG1tbYanFm6mtc14k9cKIiFmLa9i\nla8dtu1hVGfFxDduWhvrm7vewuL2p5dx9Z/fj/N4/OPzwo17uOeFFZ0+dirJ4iAPKVXa2gwbd9cl\n3efhd9fzyHsb+O2ra7p8HiWzqBAoPU5pQfxdd0l+vBDEEtvFNFlr6tjeOnVROQJn4Hty4Wauf/gD\nHnjjQy801NxquPLB9znz57O9/W01UW1A62wrKrEVR52hal8D2/Y1IliPwLG9zfcZnl28lV+9sjrl\n5K+lOQUhCGoJniq/mbWaE346K2miuVnnMvR7VAiUHid2MRtoj82PGZR4HYOBxfHtrBMReyccnSOI\nbk2xZnuN51UEDeg20VzrhYbaj+1NbmvquhA0NLfR2NLqeQRBSWnrlQSJkZ/qhuaoMJU/NPTUws08\ntXBz3HtqurHWw5trnOT6liShMSto4RS7z/pZunlfn56wly2oEChp4ZkbTuTlm072npe6HkFOOPFg\nMShgXYNExCZJg3IERXnO/IWaxhZ21Tr5iKAwk71jtl6Ff4C23kVDN8IrjS2tNLa0eQJpz+f/BE1J\n1nr2M/PeNzn+zle85/7rcP3DH3D9wx8wf8MerntorieW3fEILAa4+4UVbNgVHyay5wl3oUfHeb96\nnQvv7dl1pmsaW/jHnA2BuZFt+xp4YcnWHj3f/oAKgZIWJo8oZdzg9rt/Gy6qbWxh3nfP5DeXTY97\nz4jygrhtiYj9jvvXKrCDd3OL8c65s9ZJZPqb5FnsQFmbZIBuamlj7ke7aW0z7Kpt6tSKYI0tbTQ2\nt3nlo/Y8/tCQPY9/0P7JM8uYtbwq6lixk9Fiy0cjOSG+8tA8nlm01Wvu15GXsWFXHTNuf4n1O+MH\neRvOWrO9hl+/sjqwFbgVo1AXPIJ08L3/LOabjy7kgw174l77zP++zTV/ndvjE/G27K1Pe1I+naRN\nCETkARGpEpHFCV4vE5EnRWSBiCwRkSvTZYuSea445kAuPGwEV50whoFFEW+ilp/SglzGDS7q9rls\naMgO1jWNLeyqcYTAX7JqqY3zCHyhId+chE/+7i1eXraN6T96kfN/9Ubq9jS30eAPDTUFCEFrfGjo\nvtlrufLB9wOPaW1siWkxccCAAnJdr6uqupHpP3qRN1YnX3DnkffWs6OmkX/P35RwH7tWdV1TC4s2\n7o1KttvXUl2YKIgte+v505sfdvn9Ucfa4xQZNASE89a5YtcT61c8t3grU297niWb93LsHa9w/+vx\n9n+4o5ZjfvIyW/Z2veqsrc2kbVEmSzo9ggeBc5K8/hVgqTHmMOAU4G4RST02oPQrSvJz+dWlh1Ph\nzifwl5OWu11LCyNhfn3pdKaOKqMoQVuKVLAegQ0XVTe0UOMOvkFCEOQRXHjvG5z189fiKm52uIKS\nKHm6YMMeKm9+Oiru3djSijHtd+82cW2ME1Zp9X3R7SS8jspCrV2xHkFTa5vXbuOdtTvZVdvEr19Z\n7b0edNdqw1FBcyws9jQGuODeN6KS7a0BOYLH5m7ka/+Y36E3Yrn6z3P4wZNLAwfM255YQuXNT1Pf\n1MqvX17VYUK9zesmK1RVN0Td/VsTGxK0IekMtz+9lH0NLcxxl2D1C+5vZq3mpJ/NYtW2arbua+jW\nrO4fPrWUibc+6wluOkibEBhjZgO7ku0ClIgTOC1299UVuLOEQb4JZgMLHf0vzsth8ohSnrj+BA6s\n6Lpn8D/Pr6Dy5qc9j6CqutELJQUJga168ecIFm7cy8ptNXFC4O+YaozhFy+t5MkFm1mzvYbH523k\n6UVOC4s3V+/kd6+uYf6GPd4x7F1/nS9ZfM1f5vDd/yz29rGJXX/OI6g+3yavYz2CxuZ2IQgaOOqb\nW+MS31YIi5OsMFcdsIKcxQ60xhgWb9rLjppG/v7+Bh6ft4mfv7gy4fv87HJDd1v2NlB589M8t7g9\njv/gW+sA+N/Za7j7xZX87b31SY9lP/Wu2iaO+vHL3PV8e1muncvREx6B/Z+ynp4/RXLX8ytYv6vO\nKzvuTtWZ/fydbabYGTK5MM29wBPAZqAE+IwxJvCTisg1wDUAo0eP7jUDlfThbxJX6s4x8Den+/LJ\nY3l07kYAXu/kWsL73MHUTlizgwxED7CxBM3SbWxuJS8nFDdQAzz83np+8dIqhpXme3MeLjnyAPcz\n5fCdfzlRUVsWG5sHMMbw4c5a6ptbyXGvR01jCxt21UXdXS/dvI/SglwqfFVVdc2tDAiwuam1zWvn\nETT34e4XVvLqiipevukUb5sd5IOqveymZElsfx7lY79+g8EleUw7oByA9QHJZYvfO7FismjjXsBp\nB37OocOi9rfez+6ACYZBx920xzn3qyuq2FvfzIDCXFcITI94BNbz8Hsc5/xiNjMqB3jPrejWN3X/\nfE2tbd1adyMZmRSCs4H5wGnAOOBFEXndGBO31JQx5j7gPoAZM2b034yMEkh+rjNw+aMWM6eNZOa0\nkTw6d2OHQlCan+MN/n4+Ckh+dpam1jbKCnKpcj0J/3nsoDW0rF0IbDI6uhuqMxjECkGbcQaKuqZW\nXC3k7TU7ueXxRVx2dPsNT1V1I5fd/y6DfEJQ32RDSNFfh8bmNq/9d9CAuXJbNWu219LY0kqe2x12\nX71zrGR3yZ5HEPDts+9rdO9Yt1c3euGbHTWNvLN2J8eMrQh4X/vgaAdV+z5jDCu2VnsTEaH9mnZ0\nd23HZRuCywkLj7hehBXJ7tyhx9ps/66C0yfL9sqCdgH1Fxf8Z/4m9tY387ljKzt1vnTmCTJZNXQl\n8LhxWA18CEzKoD1Khpg+2rmD8n/pLZUVhR2+v6ww/n0AH/VAx9DG5rYou/xVR/aL6e9DtMVti7G3\nLn7FtQYvie3OfDaGuqZW6ppaPLff3kE//G57+MMKms1PgNMXacnmvXHx8qbWNnJznNt4WzLrZ6d7\njC17Gnhx6TaMMV5eorG5jXnrd7O9upFXlm/j/975yHufFYuguzD7efwDlX08b/0eLrnvHRZv2hv1\nnrXba6Kupb2rtmE6A5z9i9kcc8fL3j629Liugzkd1iOw3ls41D7MWUersaWNj3bWdqvSx7412czt\n3W5Yzy88X/3bfL73nyUdHr+huTVKnNMpBJn0CNYDpwOvi8hQ4CBgbQbtUXqZmdNGMGfdbr525kRO\nP3goh44si9unclDHuYKyglw24CQZIzkh7wvjX1cAnKqWzjZpa2yJFoKdvsG4MaDk0/ZH2uWL69tz\n2sSpfx6B9Qjs3bk/jGWZv2F33Lbfv7aGZxdv5bRJQ6K2t7YZLw6+uzbeI7DHv+PZZTy/ZBu/++x0\n725/T30zF/32LQ4fXc4H6/cAcPx4506+2s2NBA2cNvxR15h40PL3i2pubeO0u1/joKHts8y9kl/P\nI4g7jUeqHoHNrUR8c1fstXl1RRW/eGkVd198GJ88YhTQnrgP6pSb7Dy2CiwotGYbJnalRcmk7z4X\n5QX2S49ARB4B3gYOEpGNInKViFwrIte6u/wIOE5EFgEvA98yxnQuGKz0a355yeG8efNp5IRDHHHg\ngMB9YhvRBeEfqMsDvArLgE5MWLM0tbZ5VU0AO3132UGhFDtfISjBa7ubNvkGvabWNuqbWr1tQcns\n+W49fKmvquddt632Ol81yvlThgPtg+iuABusELznvr+qutELX7yxajsAG3fHV+5YjyAIO8hV+xLp\nsYlNv+diBXGFrwTVHqN9TYh4bHilo0HVipX9W+REeQTOYL3QDet94BPZS//wDhNvfTbpscGZlLin\nrsk7j71+bQHqZf8PgnIEqQzsfi+wXyaLjTGXdvD6ZuCsdJ1f2T/w32WVF+YGdiItL4hE7VPlG0z9\n7xlYGAkcaDuizHf8II8giKA7+9gZvvYOuq6pNWlzOxvr93szNuxhSzfX/uQ8HnxrHU8v2uINMLt9\nNgwqzmNHTaN3HntMY4y3LsQ81wuYOLQ47jrZfYIGaHuH7k8oxw5y/hndQTOd7RhqvZMgz8N+5vqm\nVr781znMnDaS81zxizqW+9uGZfyz2e2/kxWmNgOVNz/NZUeP9sSxI779r0VUN7R4A7+1K8hTsf97\ndc0tvLV6BxN9vbZ21zUxtDQ/7j2J2F9DQ4qSEr+//AiaW9sozs/hp88uj0rGQXvVEUSLAsCw0nzv\nyzigKLG3kIyo0JBvcA1qV2EJCsvE4h/8U0le+mPjTb4QSkicWb15btLdxvz9tk4eUcrsldvjjtnQ\n0pZ0lratTrWDvH/fpxZu5v7XP/QGVf8AH3v3ure+mRVbq7n8j+9y98WHJfyMye6u7eeqbmjh7bU7\neX7JNm7/+KHsrGniq2dM4KZ/LGDJ5r3ewG9FOxKO9wjsZ7SC48/JdMSm3fU0trT5ktJWCOIHaiu4\nNQ0tXHb/u1HhsB01jVFCYIwJDC9ZeqLkNRHaYkLp85xz6DAuOGwEpx40hOduPCluBqt/oC6NCQ35\nv2ix3U1TJSiJDe13nJ19LYh9CUQl0bktrW3GC33YAS+o5v/gBF1fdwa03PALnPUE7DH9IR7b18ju\nXxPjEdhqMHA+3/Kt+9he3cjSLXGFgR7VDYkHVXt8f3ju1n8v5ucvrWRPXROPzdvI8q3VWBPt38Bf\nimvHWStUQa0mOmo/sa+hmcbm1pQ8Artw0ja3qswfDttZ08Qdzyzj8w+8x78+2MiYW57x9guiO+3E\nO0I9AqXfkROOTvr6B8vYgbPYF1e3IZ7hZfledU8qlAW01QbYU9ucsHS1o1r3WBLNb6gojiT1PKob\nmr07YJvkDAq9jBoYXH1VFRAq8+cDrDdlq3mCwhO2+ifKI2hpY0hJvlcFtbe+2Rvkk4XnrOAEDar2\n+FsD/nZ2Ip//ve3VWe37WY+gPqDSyVLb1OI1SQy2scURXRvOakycu7DJ4k1u24twSDyh2VnbyP/O\ndupjrGht3F2fMFzUL5PFipIu/DX6AMV57ZNsbJuEEycMAmCbb9CwbRvGpFCJ5CdReWp1Y0vCL21n\nPQI/eb6qlaAV1fzUNLZ4HpKtPIpt6zB9dDmnx1QXWYIGZf/KabGfI6h00969+z2C2saWqHUp/EIQ\nJD6W6iR311YIgoR3yeZ2LyM2P1Pf3L6/dQ6sBxb0ed7/cBdn3vNaVAkwOP8/xhj21TfT4PcIGhMf\ny26zrTNyQuLN4Pbnm7x8RpJ+TSoEiuIjEiME/nI/23r6hPGDuOH0Cdx24SHea/bLNjrB3XEiCnIT\nD8Z2tm+sTd3pblnom2FtZ1vnJShpbDN4s5LtPv6FYvJzQzx+3fEM8wmWPwwdNCj7wzKxg1uyip1q\nnwDVNrUyoqy9seDe+mZvwNxendgbq/YlhBO9FoS/eirWg/ILlI3BJyvrfG/dLlZV1bB2hzPXobXN\n0NLaxvjvPMsdzy53usn6cwQNNneR2HOznlVOSLy/03ZfWG5PbbQnFJQsT2fVkAqB0u+IjfX7PYRC\n9w66KC+Hr505MWpugi3FGxnQ7jpZ40x/rDuWkeWFVBRFuOviqd620iTN25JhB+gCXxsB21LAb3N5\njIfS7hHE22m9hFBIvM/h7/xaFROTLuxGs79Yxg0pZvY3TmXGgQM64RHYnET8oJ9MCFZuq0n4mv99\ndoC1AhcURrN36ht21zPj9pe4/eml3szx+9xQTrRHkDivEUtOOOQN6P4Qlzepr6WNZVv2RYmqRT0C\nRfExuMRpWDd5eClfO3Oitwzm2YcM9dzuoAZqN501kWGl+UwPmLNw3LhBCc83oryAv11zDLeef3Dc\nayX5Ocz97pnMnDbS2za8LF5oUmGA23wv3zcY24HLv1ZDbLjICmHQRCi/OFiR9K8SFxtmGV6Wejlj\nR0TCIUZXFFJe6OQ57J1z1b7EQmAH1aCBMFkn06B1Jiz+Y8UmXHcHlPnaY33kehkPvbOeTe7cCtvi\nu6XNeInzzizVmRsW731BHUk37K7j3F++zrceXRj3mgqBoviwE8POnzqcG06fwCkHDebx647j95cf\n4d3RBt3ZnjhhMO98+/S4O2qACUMTL6E5sryAY8ZWMHVUedxreQHewqCSrlUnWU+nIDfM3FvP4N1v\nn+4NGnaALoyE49pF26qYQCHw2Wc9jbFJciSdWRzIT5BHZe0pK8hlX32zJzr+u/DYv1OyiFqisFRH\nkw79oaFYIdgZIATWI7AT65pa27zHfu+zK9E/R0CcNwYJgfUSgtaQ6JcTyhQlXQxwB3J7d5YbDnn9\nioI8gq+dOTEq6WnLLUeWFzCwKMKwsnxGDUicNyhyjzUoYE3lvHD84Juf03F4xd/R1OIXArtug801\nDHOFoKI4wsjygqi5FLZqKK+D89pBd2ySxX+GdWKCk58hJe2N9yzWGynJz6G2qTUwhj60NL9bvfoB\nDqwoDBzQLX4Bib2rDgo3WY9g4572poV2neiurMscez77Nw06t/0/DWohrhPKFMWHDaEEVWlMGFrM\noOJIVI+iG06fELXPmEFFXDR9JP912gSvgmje+vh+PrFU+NZQsOQFtAVOpVXwoSPLmPtR9Dntna2/\nHbe9exziDtBHVVbEeQS5ofjQUMSNRfvj1lYIxgwq4rKjR7Nxd703yUzEmSx2yIhS/jm3Q/PjGFKa\nFycE9u45PzdMfXNrYDy+o3kSyfjNZdOZvXI7DS2tzFu/J1Bcu4IVgg272ltt2LWau7v+c0dFBDZs\n1hqQLE7nPAINDSn9jnJXCILaTYwfUsKcW89MGuKI5IS459PTospIDxlR2uF5g5LAgQnagHBRbjj6\nTjIodGU9Ar+Q2IVnpows409fOJIff+LQuGU+wwHJYuu9+MswrcAIwk8+MYVjxg70XrMJ5wM6WVFl\nGVISL5JWmApywzS1tAXOhwgK0yXDX5117LgKfvqpqZ7351/1blKCCXSpYMXXrvkMsMjtntqVZqWd\ncSJsIj1IMDRHoCg+bGO57tTqx5KXEyYSDnkLqlgm+nIH/un/9svtH3wf+dIxPHrtsYEeQWXMimuR\ncIjDYs7lDw1Z7HoDOSHh1ElDyM8Nx4WxcsPxOQKbUG/0eQTWBtuDzT9pyobLCiM5HDYqvgtsRwwJ\nCCnZQbsg4pZLBlQLddYj8LcJsWJqQ3el+bmU5udwVOVALxyUrOKrI/yTFlduq06yZ3JuPGMit5w7\nids/fmiH+9oqrqAuuVo+qig+bKfS2BbM3WXxD87msf93nPf8gS/M4J9fPi5wXzvY+wffY8dVMKNy\nYKCXYJfmtAPTgo17+Ps1x/Ded073RCVICH5w4SFMGlbC+CHtguS/84Vgj8AKgX/w+N4Fk/nJJ6Zw\nrLtIjL8dh80zFOWF+ctVR3PvZYeTpO1NHEEJW79HAMGhjWTdYoMYWNT+2e3n9QvC+7eewcNfOtpb\nw+GoMfEL4nSF7iwXfPz4Cr588jhGlHecf7EeQZDnoR6BovioHFTEitvPiSrZ7AkiOaGoZOCMyoEJ\nZxXbQSgoQWtLDW1b6NEDC70736tPGAvAyROdu/shJfleiGpgQI7g6LEVPHfjSVFeRmzVT05A+ejg\ngFBNYSSHy44e7Xk2/lCXDQ0V5eVQVpDLx6aOoCSmBHdoafwxLbET6vz2BHlINrx15JiBca8lwy84\n9nPYctrcsJCXEyYnHPLWURiXJDHeW9hy4uK89v8lr9Irag5MOOl6GSoEihJDRxUyPXOO+K/HzGkj\nGFKS5w1uQfvYAfBId+3aykFFXruF4vwclv7wbH76ySne/rEtLzpKNpcXRlh5+7me0HihobA/R5B4\n0LZYj8Avfv45CtbjGlySx4iy/KRLKxYE5DzaQ0Pxr335pLE8f+NJfGzqCG/bgamsRhfgQRS6s8n9\nK5Hd/7kjmXvrGYEltamcpyex+RN/kt+KsF+wCzr4u6sQKEovcsu5zoqpQXe5v7zkcN77zhmBoSHL\nj2YeysNfOpqZ00YydnARN58zKWoAK4zkeHfxgLc+8YQhJVxz0ljOnNxxyCuSE/KS0naimD+HkZIQ\nuDmCvJyQdyfqH4w+d1wl4FyPt245nU/POCDhsYJ6LsWGhvxUDiriIDehe+v5B/PE9cfH9ZAKIkh4\nbbI41ydoBRGnBDfobzhhSOI5I6mQLGQWVFBg/9b+kmYrwv6S5I5uAHQegaL0Il8+eRxfPnlc0n2S\nhYYGFEW8mcqv3HQKAC8vc774QbX0p00ayqofn0tuOMTkFKqXLDaBeebBQ71tFUURdtY2Bd6Fx2K9\nlPzcMLecO4nbnlji9WoCZy3p975zOoPcuHzQ3fg5hwzj6hPHBPbRTyYEfi/o6hOdcJkdtA8bVcaC\njXvj3uM/ph8rIEE1/kFCMG5IMS8tqwo8fnFeDjWNLV75bRADCyPevIXYktXRFYUs3hTcZtufnLfe\nQZFPHBL1k7KoR6AofQz7pU11fVt7B5hoycdU7oZjsVVA50wZ5m2zA2wkHOKHMw+JSn7H2eQOTPk5\nIS6ecQBLfnhOlKcCzkSxUJKZy8eNdxLk/kSoHXztb3/LDBvGGhHQhsMe3y+G9jrbQT7oOllvJiec\nWIz8TBgSX1pqj29ncNvP4w/n2M/jD+fE9q3yNzS87YLJ3PPp9kV4iqNCQ/FhuaA5KZaC3LDOI1CU\nvob90gatpBWEvfvel6RDZWe585NTefXrp0TdadpZw/samvncsZUJ14IGxxOI5IRSmgCX8BiuRzSk\npF0IbGWUHbT9HsFDVx/Dq18/xRMXP3bOwwEDC70BecH3z+K5G0/0EuRBA7udhevPEViChMMO8l86\ncYy3rdC1cbg7sNsE78CiiBcKsqWrfiGwC9dbRg9s93S+cPwYLpo+ynseDolX4VTmy8+cPHEw3zzn\noKSlrsPL8rV8VFH6GvZONZUlJqH97j12Mlh3KM7LiZpBDTDtAGfgb0mxEVppfm7KXo0fe2dv8xTh\nmPg8BIeGhpXmx9lssat3TRlZxktfO5n3vnM6+blhJg0r9XIekZwQ//7K8Txx/fHe+84+xFnB7tvn\nTYo7ZtBnKyvI5cM7zuM750+Os3l4abQnUBjJ8RLotnR1sC//ss2dCfzlk53wlq1Wip1AaLHHtb9D\nIvz5i0dx3SnjCbuKE5SDqCiO0JTGpSo1R6AoXeCqE8bw+qodKcf0Dx89gL9fcwyHj058h94TXHLk\nAURyQlx42IiOd8bxVLriEYweWMia7bWBORI78Ftt8OcrYttj+LHO1aEjyijKy4mKn8+oHMDba3fS\n3NIWN+mvIBLm15ceHnjMICEYXJwXl9Owd+r2bj83J0ROSCiKhCmIhKlpbPFKV4NKc284bQK3nHuw\n1zcpqPut3b6NRs+LC/nssHMIpowsY2FMjqQ0P9frd5QO1CNQlC5wykFDWHfn+VEhkY44emxFl+6+\nO0MoJHzqiFEpn6eiKJJ0cE7ESHd2s388/eUl0/jBhYcwxe3S6u81ZClOci4bgx8QMDnt8NHOMRdv\nDk4iJ8IOtP4WGP71LOzj8W7ewF6LhqZWBhRFKIiEPZGwdlX4Kn1yvYl40Z8r0ee0LdOLPY+g/TXb\n0uKQEfEzu8cNKWbtjtpuLXiUDPUIFCWL+eHMQ7vUUfNANynqX87RTvC7eMYoPj1jlNe3yMa+83ND\nSZPis75xCnUJmrrNqHQmnp176PBO2bnVXSLyuHEV/Hv+ZoCohPgb3zrVSzbPWl7lCUN9cyvjBhcx\nakCB1xpjoDu50N7NjxlUxENXHx21ToItB73htOhGhxYrNPY6+D0Ca4edBBeS9hnN44cU09TSxvpd\ndZ1eajUV0iYEIvIA8DGgyhgT12RDRL4BfNZnx8HAYGPMrnTZpChKNAcPT71cFZy7/peWVfH1sw8C\n4GOHxQ/MhZEcTpww2HseCYcISfTM2iCCVo6zlObnsuYn53VatD5z5Gg27WngyuMrPSGItdUyc9pI\nVlc5eYqJQ0v4xtkHEQ4Jl/3hHaA9R1CSn8vTN5zAyPICrwGipSQ/l3V3np/QHisE9nP4P8+xYyt4\ne+1OTpwwmJHl67jzk1O44o/vAe1zH1Ztq+5fQgA8CNwL/CXoRWPMXcBdACJyAfDfKgKK0reZOW2k\nd+f/oxSaqIEz0a0gN9zlJTwtXfFcBpfkccdFU+KW5EzE+CElPPb/juPQkaVe/sOKhe3xVF6YGxi+\nSYUSVwy9Rep9H+lPVx5JvRuSevPm06IqzCYMdUJXq6pqOKt9Ge4eI21CYIyZLSKVKe5+KfBIumxR\nFCWzFETCSfMD6aYz8zRiS25tjuCE8YO461NTOWZs1xvZ2WtwwgRnwuEX3Nnb4ORS/PkUW5k2dlAR\nxXk5fHzaiB6tOvOT8RyBiBQC5wDXJ9nnGuAagNGjR/eSZYqi9BT5ufFLbPYmud1I0lshKMnP4eIk\nbTZSwV6DUQMKk4aQwJm1fu9lh3Okmx/5xSXBlVE9QcaFALgAeDNZWMgYcx9wH8CMGTPSkzZXFCVt\nVBTndarCqqexFUkXHzGqgz3jKYi0t+LoLmccPJQdNY0UpdACBIhqypdO+oIQXIKGhRRlv+Y3lx3e\nIwNpV4nkhJhz6xmdXv8AHI8gJB33AkqFQ0eWcfvIKR3v2MtkVAhEpAw4Gbg8k3YoipJeYldVywSp\ndGQN4qLpIxlZXhDYWG9/IZ3lo48ApwCDRGQj8H0gF8AY83t3t08ALxhjatNlh6IoSnc4ZERZl6uE\n+gvprBq6NIV9HsQpM1UURVEyhLaYUBRFyXJUCBRFUbIcFQJFUZQsR4VAURQly1EhUBRFyXJUCBRF\nUbIcFQJFUZQsR0yKi2/3FURkO/BRF98+CNjRg+b0JH3VNrWrc6hdnUPt6jxdte1AY8zgoBf6nRB0\nBxGZY4yZkWk7guirtqldnUPt6hxqV+dJh20aGlIURclyVAgURVGynGwTgvsybUAS+qptalfnULs6\nh9rVeXrctqzKESiKoijxZJtHoCiKosSgQqAoipLlZI0QiMg5IrJCRFaLyM0ZtmWdiCwSkfkiMsfd\nNlBEXhSRVe7vAb1gxwMiUiUii33bEtohIre412+FiJzdy3bdJiKb3Gs2X0TOy4BdB4jILBFZJiJL\nROSr7vaMXrMkdmX0molIvoi8JyILXLt+4G7vC/9jiWzrC/9nYRH5QESecp+n/3oZY/b7HyAMrAHG\nAhFgATA5g/asAwbFbPsZcLP7+Gbgp71gx0nAdGBxR3YAk93rlgeMca9nuBftug34esC+vWnXcGC6\n+7gEWOmeP6PXLIldGb1mgADF7uNc4F3gmExfrw5s6wv/Z18DHgaecp+n/Xpli0dwFLDaGLPWGNME\n/A2YmWGbYpkJ/Nl9/Gfg4+k+oTFmNrArRTtmAn8zxjQaYz4EVuNc196yKxG9adcWY8w893E1sAwY\nSYavWRK7EtFbdhljTI37NNf9MfSN/7FEtiWiV2wTkVHA+cD9MedO6/XKFiEYCWzwPd9I8i9KujHA\nCyIyV0SucbcNNcZsAeeLDQzJkG2J7OgL1/B6EVnoho6se5wRu0SkEjgc506yz1yzGLsgw9fMDXPM\nB6qAF40xfeZ6JbANMnvNfgF8E2jzbUv79coWIZCAbZmsmz3eGDMdOBf4ioiclEFbUiXT1/B3wDhg\nGrAFuNvd3ut2iUgx8BhwozFmX7JdA7alzbYAuzJ+zYwxrcaYacAo4CgROTTJ7r16vRLYlrFrJiIf\nA6qMMXNTfUvAti7ZlC1CsBE4wPd8FLA5Q7ZgjNns/q4C/oXjzm0TkeEA7u+qDJmXyI6MXkNjzDb3\ni9sG/IF2F7hX7RKRXJzB9iFjzOPu5oxfsyC7+so1c23ZA7wKnEMfuF6JbMvwNTseuFBE1uGEr08T\nkf+jF65XtgjB+8AEERkjIhHgEuCJTBgiIkUiUmIfA2cBi117Pu/u9nngP5mwL4kdTwCXiEieiIwB\nJgDv9ZZR9ovg8gmca9ardomIAH8Elhlj7vG9lNFrlsiuTF8zERksIuXu4wLgDGA5feB/LJFtmbxm\nxphbjDGjjDGVOGPUK8aYy+mN65WOrHdf/AHOw6mmWAN8J4N2jMXJ9C8AllhbgArgZWCV+3tgL9jy\nCI7724xzd3FVMjuA77jXbwVwbi/b9VdgEbDQ/QIMz4BdJ+C43guB+e7PeZm+Zknsyug1A6YCH7jn\nXwx8r6P/9V78WyayLeP/Z+65TqG9aijt10tbTCiKomQ52RIaUhRFURKgQqAoipLlqBAoiqJkOSoE\niqIoWY4KgaIoSpajQqBkHSIyzd9VshPv+6GInOE+vlFECnvQpo+LyOSgcylKutHyUSXrEJEvADOM\nMdd34xjr3GPs6MR7wsaY1gSvPYhTN/5oV21SlK6iHoHSLxGRShFZLiL3i8hiEXlIRM4QkTfdvu1H\nubO4HxCR993+7jPdmeU/BD7j9pv/jLvvW+4+b4nIQQnO+aCIfEpEbgBGALNEZJb72lki8raIzBOR\nf7p9f+zaE98TkTeAi0XkS649C0TkMREpFJHjgAuBu1ybxtlzucc43bVtkft58nzH/oF7zkUiMsnd\nfrK099P/wM5kV5SEpHN2nP7oT7p+gEqgBZiCc0MzF3gApxHXTODfwE+Ay939y3FmlhcBXwDu9R2r\nFMhxH58BPJbgnA8Cn3Ifr8NdUwIYBMwGitzn36J9puo64Ju+Y1T4Ht8O/Ffssf3PgXycDpMT3e1/\nwWkqZ49t338dcL/7+EmcxoYAxfaz6Y/+JPrJ6YJ2KEpf4UNjzCIAEVkCvGyMMSKyCEcoRuE08fq6\nu38+MDrgOGXAn0VkAk6rhtxO2nEMziIhbzptf4gAb/te/7vv8aEicjuOMBUDz3dw7INwPudK9/mf\nga/gtCsGsI3v5gIXuY/fBO4RkYeAx40xGzv5eZQsQ4VA6c80+h63+Z634fxvtwKfNMas8L9JRI6O\nOc6PgFnGmE+I08//VXe/P+H09t9sjEmWXBacfvaXJni91vf4QeDjxpgFbq7ilCTHtcdOhv3Mrbjf\nZ2PMnSLyNE6/oXdE5AxjzPIOjqNkMZojUPZnngf+y+3OiYgc7m6vxlnS0VIGbHIff8FuNMZcaYyZ\nlkAE/Md4BzheRMa75ykUkYkJbCoBtojTNvqzCY7nZzlQaY8NXAG8luDYuOcfZ4xZZIz5KTAHmJRs\nf0VRIVD2Z36EE+ZZKCKL3ecAs4DJNlmMsybsHSLyJs761qlwH/CsiMwyxmzHEZBHRGQhjjAkGny/\ni7N62Is4g7zlb8A33OTuOLvRGNMAXAn80w15tQG/78C2G90E+gKgHng2xc+kZClaPqooipLlqEeg\nKIqS5agQKIqiZDkqBIqiKFmOCoGiKEqWo0KgKIqS5agQKIqiZDkqBIqiKFnO/weEEPIaoRH8CQAA\nAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(meta_losses)\n", "plt.xlabel(\"meta-iterations\")\n", "plt.ylabel(\"meta-loss\")" ] }, { "cell_type": "markdown", "metadata": { "id": "D-8DhN_o2R7L" }, "source": [ "## Meta-training with truncations with less bias: Persistent Evolution Strategies (PES)\n", "When training with truncated evolutionary strategies, as well as truncated backprop through time and truncated evolutionary strategies one cannot compute the effect of one truncated segment, on other truncated segments. This introduces bias when working with longer sequences.\n", "\n", "PES is one ES based algorithm to prevent such bias." ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "executionInfo": { "elapsed": 53, "status": "ok", "timestamp": 1647716735333, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "WEfssK2R9-Gs" }, "outputs": [], "source": [ "@jax.jit\n", "def vec_short_segment_pes(meta_param,\n", " keys,\n", " pes_state,\n", " vec_seq_of_batches,\n", " std=0.01):\n", " # Compute a pes estimate on a single inner-problem\n", " def do_one(meta_param, key, pes_state, seq_of_batches):\n", " accumulator, pos_opt_state, neg_opt_state, on_iteration = pes_state\n", "\n", " # Sample random noise of the same shape as meta-parameters\n", " flat_params, struct = jax.tree_util.tree_flatten(meta_param)\n", " keys = [jax.random.fold_in(key, i) for i in range(len(flat_params))]\n", " keys = jax.tree_util.tree_unflatten(struct, keys)\n", " perturbs = jax.tree_util.tree_map(lambda k, v: jax.random.normal(k, v.shape) * std,\n", " keys, meta_param)\n", "\n", " # compute positive and negative antithetic samples\n", " pos_theta = jax.tree_util.tree_map(lambda eps, v: v + eps, perturbs, meta_param)\n", " neg_theta = jax.tree_util.tree_map(lambda eps, v: v - eps, perturbs, meta_param)\n", "\n", " # Apply both of the antithetic samples\n", " p_losses, pos_opt_state, _ = short_segment_unroll(\n", " pos_theta,\n", " key,\n", " pos_opt_state,\n", " on_iteration,\n", " seq_of_batches,\n", " inner_problem_length=30)\n", " n_losses, neg_opt_state, next_on_iteration = short_segment_unroll(\n", " neg_theta,\n", " key,\n", " neg_opt_state,\n", " on_iteration,\n", " seq_of_batches,\n", " inner_problem_length=30)\n", "\n", " # estimate gradient. PES works by multipliying loss difference by the sum\n", " # of previous perturbations.\n", " new_accum = jax.tree_util.tree_map(lambda a, b: a + b, accumulator, perturbs)\n", " delta_losses = p_losses - n_losses\n", " unroll_length = p_losses.shape[0]\n", "\n", " # one unroll could span 2 problems, so we compute 2 different gradients --\n", " # one as if it was the previous trajectory, and one as if it was a previous\n", " # unroll and sum them.\n", " has_finished = (jnp.arange(unroll_length) + on_iteration) > 30\n", "\n", " last_unroll_losses = jnp.mean(delta_losses * (1.0 - has_finished), axis=0)\n", " new_unroll = jnp.mean(delta_losses * has_finished)\n", "\n", " es_grad_from_accum = jax.tree_util.tree_map(\n", " lambda p: last_unroll_losses * 1 / (2. * std) * p, new_accum)\n", " es_grad_from_new_perturb = jax.tree_util.tree_map(\n", " lambda p: new_unroll * 1 / (2. * std) * p, perturbs)\n", " es_grad = jax.tree_util.tree_map(lambda a, b: a + b, es_grad_from_accum,\n", " es_grad_from_new_perturb)\n", "\n", " # finally, we potentially reset the accumulator to the current perturbation\n", " # if we finished one trajectory.\n", " def _switch_one_accum(a, b):\n", " return jnp.where(has_finished[-1], a, b)\n", "\n", " new_accum = jax.tree_util.tree_map(_switch_one_accum, perturbs, new_accum)\n", "\n", " next_pes_state = (new_accum, pos_opt_state, neg_opt_state,\n", " next_on_iteration)\n", "\n", " return ((jnp.mean(p_losses) + jnp.mean(n_losses)) / 2.0,\n", " next_pes_state), es_grad\n", "\n", " (loss, pes_state), es_grad = jax.vmap(\n", " do_one, in_axes=(None, 0, 0, 0))(meta_param, keys, pes_state,\n", " vec_seq_of_batches)\n", "\n", " # Gradient has an extra batch dimension here from the vmap -- reduce over this.\n", " return (jnp.mean(loss),\n", " pes_state), jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), es_grad)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "executionInfo": { "elapsed": 60, "status": "ok", "timestamp": 1647716735518, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "f-DFEv7vbNjQ" }, "outputs": [], "source": [ "num_tasks = 32\n", "\n", "key = jax.random.PRNGKey(1)\n", "inner_opt_state = lopt.initial_inner_opt_state(meta_params, task.init(key))\n", "batch = get_batch_seq(10)\n", "\n", "meta_params = lopt.init_meta_params(key)\n", "\n", "# construct the initial PES state which is passed from iteration to iteration\n", "def init_single_inner_opt_state(key):\n", " return lopt.initial_inner_opt_state(meta_params, task.init(key))\n", "keys = jax.random.split(key, num_tasks)\n", "inner_opt_states = jax.vmap(init_single_inner_opt_state)(keys)\n", "accumulator = jax.tree_util.tree_map(lambda x: jnp.zeros([num_tasks] + list(x.shape)),\n", " meta_params)\n", "# Randomly set the initial iteration to prevent the tasks from running in lock step.\n", "on_iterations = jax.random.randint(key, [num_tasks], 0, 30)\n", "pes_state = (accumulator, inner_opt_states, inner_opt_states, on_iterations)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "executionInfo": { "elapsed": 34013, "status": "ok", "timestamp": 1647716769658, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "BRdSQI6dbNjQ", "outputId": "1dbf6073-f748-4aa7-de21-2fce80c0e188" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 2.302895\n", "20 2.3025956\n", "40 2.3024507\n", "60 2.3007588\n", "80 2.2814593\n", "100 2.2528963\n", "120 2.221389\n", "140 2.1904716\n", "160 2.1542785\n", "180 2.1154919\n", "200 2.0792098\n", "220 2.0493665\n", "240 2.0263429\n", "260 2.000341\n", "280 1.9719956\n", "300 1.946686\n", "320 1.9295752\n", "340 1.9169009\n", "360 1.8944466\n", "380 1.880933\n" ] } ], "source": [ "meta_opt = Adam(0.0003)\n", "meta_opt_state = meta_opt.init(meta_params)\n", "\n", "meta_losses = []\n", "\n", "num_steps = int(os.environ.get(\"LOPT_TRAIN_LENGTH\", 400))\n", "for i in range(num_steps):\n", " data = get_vec_batch_seq(num_tasks, 10)\n", " key1, key = jax.random.split(key)\n", " keys = jax.random.split(key1, num_tasks)\n", " (loss, pes_state), meta_grad = vec_short_segment_pes(meta_opt_state[0], keys,\n", " pes_state, data)\n", " meta_losses.append(loss)\n", " if i % 20 == 0:\n", " print(i, onp.mean(meta_losses[-20:]))\n", " meta_opt_state = meta_opt.update(meta_opt_state, meta_grad)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "colab": { "height": 296 }, "executionInfo": { "elapsed": 226, "status": "ok", "timestamp": 1647716770004, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "noPvV56CjOrd", "outputId": "526cd7cf-1962-45b3-dcdb-1d956e9ace64" }, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'meta-loss')" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAAA5eklEQVR4nO3dd3hb5dn48e/tvVfsTCcx2QkhCYkTKGHvWSgt9KVltpSXFgq0\nlL4tHfCjtFDa0kXLKKsUSlsolFEgrLAhe+9JEjuOHSfeS+P5/XHOkWX7SJZNZNnW/bkuXdY4R7qt\nxLr1rPsRYwxKKaXiV0KsA1BKKRVbmgiUUirOaSJQSqk4p4lAKaXinCYCpZSKc0mxDqCnCgsLTUlJ\nSazDUEqpAWXZsmX7jTFFbo8NuERQUlLC0qVLYx2GUkoNKCLyaajHtGtIKaXinCYCpZSKc5oIlFIq\nzmkiUEqpOKeJQCml4pwmAqWUinNRSwQiMlpEForIBhFZJyI3uhxzvoisFpGVIrJURI6NVjxKKaXc\nRXMdgRe42RizXESygWUi8oYxZn3QMW8BLxpjjIjMAP4FTIlGMJv31fPyqnKSExNISQq62LcTE4Q2\nrx+Pz09GShJ+Y/DbJbqdSt2Bn52eOzEBEkQCF+d2cmICSYlCUkICPr+hzecjIyWJrNQkctKSyUpL\nIjstCQHqW7zkZ6ZE41dXSqmwopYIjDF7gb329XoR2QCMAtYHHdMQdEomXT9jD5kt+xr4w9tbo/X0\nn0mCgN/AlOHZnH3ECKYMz+bEyUNJSdKeO6VU9ElfbEwjIiXAe8B0Y0xdp8e+ANwFDAXOMcZ87HL+\nNcA1AGPGjJnz6achF8iFZYzB4zN4fH7avH7a7J+tXj9+Y0i1WwZNbT4SE4REkaAY7J9Ih9vGgN8Y\nfMZgjMHnt2/7rdfx2j+TEqyWR1Orl7oWLw2tXupbPDS0eGn2+MhOS+bFVeVs2Gu9PcX56Zw8ZSg+\nv+GmUydRlJ3aq99ZKaUARGSZMabU9bFoJwIRyQLeBX5ujHkuzHHHAz81xpwa7vlKS0vNYC4x0dzm\n45Pt1dz16gYqalto8foZlZfOk1cfxai89FiHp5QaoGKWCEQkGXgZWGCMuTeC43cAc40x+0MdM9gT\nQWdLdx7gqseWIAL3XzqH+RMKYx2SUmoACpcIojlrSIBHgA2hkoCITLCPQ0RmAylAdbRiGohKSwr4\nz/XzGZ6bxrVPLuPH/1nDqt01sQ5LKTWIRHM0cj5wGXCyPT10pYicLSLXisi19jFfBNaKyErgT8CX\nTV8MWgww44uyePjyucwancfzy8u48rHF1DS1xTospdQg0SeDxYdSvHUNdba+vI5z//g+Q7PT+PLc\n0Xz75AkkJersIqVUeDHpGlLRMW1kDtefPJGKuhZ+/9YW/vDWlliHpJQa4DQRDEA3njKRv319HufN\nHMkf3t7KSb9+h4+2hRxfV0qpsDQRDECJCcJxE4v47cUzufGUidS3ePjfvy1j4aZKvD5/rMNTSg0w\nmggGsKTEBL5z2iReuP5Y0pMTueqxJVzw5w91IFkp1SOaCAaBUXnp/Oe6+dx82iQ27q3nF69sYG9t\nMwNtIoBSKjYG3Ob1yt3IvHS+fcpE9hxs5p9Ld/OvpXu4ZN4Y7rrwiFiHppTq5zQRDDK3nDmZ4blp\nbN5Xz9OLd3HatKEUZqUyozgv1qEppfopTQSDTGFWKt85bRKNrV4W7TjA1x631lws+/GpLNl5kNOm\nDSMxQbp5FqVUPNExgkEqMzWJH58zNXB7zp1vcu2Ty1iwroL3t1Sxr64lhtEppfoTTQSD2IWzi1l1\n2+mMHZIRuO+VNXu57JHFXPHo4hhGppTqT7RraJDLTU/mkStKWf5pDX/9eCcvr94LwMaKeowxiGg3\nkVLxTlsEcWDC0GwunjuaKz5X0uH++9/dxjNLd8cmKKVUv6Etgjhy8dzRTB6eTYvHx2WPLOae1zYB\ncN7MkTS2ehmSpbugKRWPtEUQZ2aOzuOocUO4qLQ4cN8X7/+IOXe+yd7a5hhGppSKFU0EceqHZ0/l\nqauPYkhmCuvKrX2SP9yqewIpFY80EcSprNQk5k8o5JErrQ1vAD7aup8XVpZR2+yJbXBKqT6liSDO\nzRqdx3+um8+5M0bw3IoybvzHSu56ZQO1zR4aW72xDk8p1Qc0ESgArj1hfOD6h9v2c/w9Cznvvg9i\nGJFSqq9oIlAATB+Vy/1fnc1XjxrD7gPN1DZ72F7VyIpdB2MdmlIqynT6qAo464gRnDx1KE8t2hW4\n79W1FTy7bA9ZaUn88KypYc5WSg1UmghUB6lJibz//ZPYfaCJX7++iYfe2x547NKjxlKcn66rkZUa\nZLRrSHUxuiCDYyYUMmt0PgAZKYkkCBx3z0KO/eVCWr2+GEeolDqUNBGokGaOzgXgJ+dO476vzAag\nrKaZzRUNsQxLKXWIadeQCum8GSMpzs9g9pg8RIR3bzmRE371DmvLazmiOLfDsevL69h1oIkzpw+P\nUbRKqd7SRKBCSkgQ5ozND9weU5BBdloSS3YewOvz86U5o/H6/aQmJXL2H94HYOfd58QqXKVUL2ki\nUBETEWYW5/Hc8jKeW16Gz2+4/aX1HD+pKHBMXYuHnLTkGEaplOopHSNQPfLzL0znwtmjAPj38jIA\n3ttcFXh8V3UT72yqjElsSqne0USgemTskEzuvXgWF80pZk1ZbZfHb31+DVc+toQPtuyPQXRKqd7Q\nRKB65dyZI13vX73HSg6V9S388a0tbK2s78uwlFK9oIlA9coJk4q4/bxp3HjKRC49ekyXx7dXNfKb\nNzbz1YcXxSA6pVRP6GCx6rUr5x8GgN9vuPKYw/jWU8vYvM9aY7DcrlG0r66VHz2/BhG484IjYhar\nUio0bRGozywhQZgwNCuwrwHAsk/bi9U9tWgXT36yy+VMpVR/ELVEICKjRWShiGwQkXUicqPLMV8V\nkdX25SMRmRmteFT0lY4tCFxv9fq7PN7mcp9SKvai2SLwAjcbY6YCRwPXici0TsfsAE4wxswAfgY8\nFMV4VJTNDGoRuNl9sKlvAlFK9UjUEoExZq8xZrl9vR7YAIzqdMxHxhinD+EToBg1YE0ens2jV5Zy\n82mTXB9fW1bL0p0H+jgqpVR3+mSMQERKgCOBcFNIvg68GuL8a0RkqYgsraqqcjtE9RMnTxnG5ceU\nBG4/+fWjOHfGCADueGk9X3rgY3bsb4xRdEopN1GfNSQiWcC/gZuMMXUhjjkJKxEc6/a4MeYh7G6j\n0tJSE6VQ1SGSm57MA5fOJictmWMmFDJ/whA+3lZNdWMbAC+sLOOmU91bDUqpvhfVFoGIJGMlgaeM\nMc+FOGYG8DBwvjGmOprxqL5z5vQRHDOhELBqFAVXK33i40+59m/L+LS6Y8vg8Q938OqavX0ap1Iq\nurOGBHgE2GCMuTfEMWOA54DLjDGboxWLir0pw3MAyM9I5kBjG6+tq+Dvi3ZR3+Kh1evDGKuA3Tef\nWh7jSJWKP9FsEcwHLgNOFpGV9uVsEblWRK61j/kpMAT4s/340ijGo2JoTEEGABeXjg7cV17bwhG3\nv87XH1/KnoPNsQpNqbgXtTECY8wHQNjNbY0xVwNXRysG1X9cVFpMU5uXS48ey+mHD+dnL6/nzfX7\nAPhg637WBhWwa/H4SEtOjFWoSsUdXVms+kRyYgJXHzeOtORE5ozN5+QpQ2n2tO99/NG29uGh3Qd0\nvYFSfUkTgYqJz3eqXvq3Tz4NXN9Z3TURNLf58Pl1wphS0aCJQMVESWEmf/rKbG48ZSIJdgfihUda\n6w23VlqF62qbPHzvmVXUNnuY+tPX+N4zq9haWR8oaKeUOjS0+qiKmXNmjOAcRvDv5XvYc7CZEyYX\nsWJ3TaBg3Sc7qnl22R5OmjwUgOdXlPH8CmtXNN0bWalDR1sEKuZ+9+VZTBqWxfETi5gzNp83N+zj\n4fe3U1HbAsCmCtd1iEqpQ0QTgYq50pICXv/OCeRnplA6Nh+AO/+7IdBFtLGi6y5nxuh4gVKHiiYC\n1a9ccOQoTplidQW9v8WqK7VpX9dEEDzjSCn12WgiUP1KWnIit513ONA+e+hTl1lEtc2ePo1LqcFM\nE4Hqd4rz00nvZkHZyl01fPefK3WzG6UOAU0Eqt9JSOhYpM7NN59aznMryli5u6ZvglJqENNEoPql\ney+eySXzxnBxafi9ijw+P9UNrfh1sZlSvaaJQPVLxfkZ3HXhEXx57uiwx22qqOeYu9/mv2v2cv87\n2/jXkt19FKFSg4cuKFP9mlO+OpT/rtlLq9fP9qpGfvumVcn84m6Sh1KqI20RqH4tMzX8dxVnFXJl\nfUtfhKPUoKSJQPV793xxBj+7YHrYY7ZVNQSu62IzpXpGE4Hq9y6eO5rLjh7L5GHZAIzKS+9yzLqy\n9jIUB5t0jYFSPaGJQA0Yj39tLj89dxrzJwzpcP/ognTqW72B23trdbczpXpCE4EaMEbkpvO1Yw9j\n7JDMwH1pyQmB6qSOfXU6XqBUT+isITXgDM1OBWBcUSbfP2MKO/Y3dni8vEYTgVI9oS0CNeCcNm0Y\nk4Zlcd8lszlz+nCK7MSQlCCkJCYEtrpsbvNx16sbqNUxA6XC0kSgBpy8jBRe/84JTBtprTE4depQ\nvnnieF676ThKCjPYWtlAi8fHgnUVPPjudn7xyoYYR6xU/6aJQA14eRkp/N+ZU5gwNJvxRVm8tbGS\nKT95jbIaa9D4g637AbQMhVIhaCJQg8q4ovaB5L8v2gVAWU0zi3ccYNytr/De5qrA48YYTQ5KoYlA\nDTI5acmB606LAOD+d7YCdEgE33xyOeNufaXD+W+s38f+htYoR6lU/6KJQA0qF5dai8/OPHw4ACNy\n0wB4104AOenJ7DnYRG2zh9fWVQDwzqZKSn7wX9aV1/KNJ5Zy1WNLYhO8UjGi00fVoJKfmcLPLphO\nTVMbO6sb+cZx47jl2VU4PUAHGts49pcLGTskI3DOPxZbFUs/3lYNdCxXoVQ80BaBGpTyMlJ47abj\n+eKcYoKHAarqrW6f4O0vDzS1AdBi74OspYpUvNFEoAa9aSNyAj8376vv8vjBRisRNLVZicCvmUDF\nGU0EatB78uqjWHzrKRTnp7Olsmu3zwE7EVQ3WD+D04AxBq9P90VWg1uPE4GI5IvIjGgEo1Q0FGSm\nMDQnjSFZqa6PO11DgT0NgjLBn9/ZxoQfvUpjUFE7pQabiBKBiLwjIjkiUgCsAh4TkXujG5pSh1Zh\nVorr/U5PUKU9fmCCMsHfPv4UQKeUqkEt0llDucaYOhG5GnjMGHObiKyOZmBKHWpDc9LCPu4kAr+B\nO19ej8fnJylRAKv7KLjqqVKDSaRdQ0kiMgK4GHg5khNEZLSILBSRDSKyTkRudDlmioh8LCKtIvK9\nHsStVI+dPm1Y2MedGUXGGJZ8epAlOw+SnGj9iey3xw+UGowiTQR3AAuArcaYJSIyDtjSzTle4GZj\nzFTgaOA6EZnW6ZgDwA3Ar3sQs1K9MiwnjYlDs7o9zm+gvtlDXYuHxASrRVCtXUNqEIsoERhjnjHG\nzDDGfMu+vd0Y88VuztlrjFluX68HNgCjOh1TaYxZAmidYNUnXr7hWP79zc91e1xdi4f6Fi+JYiWC\n/Q2tvLy6PLDWQKnBJNLB4nvsweJkEXlLRPaLyKWRvoiIlABHAot6E6SIXCMiS0VkaVVVVfcnKBVC\nalIi4wrbWwVzS/JdjzvQ2EZ9i4cmjzVb6OXVe7n+7yu4943NfRKnUn0p0q6h040xdcC5wB5gEnBL\nJCeKSBbwb+Am+zl6zBjzkDGm1BhTWlRU1JunUCogP7N99tApU93HDfzGuuy1dzvbWGEtRHM2vVFq\nMIk0ETglHc8GnjbGHIjkJBFJxkoCTxljnutFfEpFRYGdDI6dUMhdFx7BaSEGkr2dylQnJAgLN1Wy\nsaJX32mU6pciTQQvichGoBR4S0SKgLAbw4qIAI8AG4wxuuZA9Ss/PGsKYFUnvWTeGKYOz4743Kse\nW8KZv3s/WqEp1eciWkdgjPmBiPwSqDPG+ESkETi/m9PmA5cBa0RkpX3frcAY+zkfEJHhwFIgB/CL\nyE3AtN52ISkVqYtKR/PF2cUk2LOCnG0v3RTnp7PnoLW3wb7asN9/lBqQIkoEdhfPZcDx1hd93gUe\nCHeOMeYDQLo5pgIojihSpQ4xJwkAHD4yN+RxR48bwrPL9gCwtry2x6+zr66FxlYv44q6n7qqVCxE\n2jV0PzAH+LN9mW3fp9SgUJyfHvKxC2a1z3pu8fS8AN1Rv3iLk3/zbq/iUqovRFpiYq4xZmbQ7bdF\nZFU0AlIqFkSErxw1hlF56fxqwaYOjx0zfojrOa1eH6lJiX0RnlJRFWmLwCci450b9spiXVmjBpVf\nfOEIrjtpQpf7ExKEl799LN8/c3KH+93KTrR5/Tz03jZavV3/PLw+P+vLdfhL9T+RtghuARaKyHas\nfv+xwFVRi0qpGHrhuvksWFfBgnUVXDJvDADTR+UydkgG97zW3loor2nm/c1VTBmRw6zReQA89uEO\n7np1I6lJidz24jquPKYkcPyvXt/Eg+9u583vnsCECEpdKNVXIp019JaITAQmYyWCjcYYLb6iBqWZ\no/OYOTqP7585pcP92WnJ3PeVI3lpVTkL1u3j9hfXsa68jqPHFfCPa6yyFWU11uwij72ZzeMf7Qyc\nv2SHtfymsr5FE4HqV8ImAhG5MMRD40UEXSSm4s25M0Zy5Jh8Fqzbxzq7m6chaNOaumarbFbnhWgA\nrV4rOfj8hn8t3c1pU4d1WOWsVKx01yI4L8xjBtBEoOJO5w1u9te30er1Udfspa7FSgoVLusNnIJ1\nW/Y1cMfL6zl+UhFPfG1e9ANWqhthE4ExRscBlOokNSmRnLSkwId+dWMr33pyOW9trOSIUdZ6hHK7\niyiYM/XU2fZyr8sxSsVCb/YsjmhjGqUGs8Ls9v2PPT7DWxsrAVhTZi042+vSInBmEtXbicBvunYf\nKRULPU4EdNpTQKl4VJRlJYKxQzJcHw/XItgf2BtZqf6hN4lgxSGPQqkBxmkRTBrmXqyuurHrGgOn\nRVDV4GyJGaXglOqhHicCY8zXohGIUgOJ0yKY7JIIQm2H6fFZn/zOQrTgrqHymmYWrKvocRzGGG55\nZhVLdkZUGV4pV5EWnZsI3AVMA9Kc+40x46IUl1L9WpHdIpgywkoEM4pzWb3HGh+YOCyLLZUNIc91\n9j/2G8PFD37MiNw0lu48SFlNM1t/fhZJiZF/P2v2+Hhm2R5eWFnO5p+f1dtfR8W5SFcWPwbcBvwW\nOAlrVXHYyqJKDWZOi2BcYRZ//do8Zo/J48g73iA5MYHc9OSw5+53EoEfFu/o+E2+rsUb2DQnEh6v\n1apo8/W8GJ5SjkgTQbq9uliMMZ8Ct4vI+1jJQam4c8b04dQ2e5gyPJtpCdZeBst+choi8Ke3t4Y9\n11lrZlwGCWqa2shNTyYxwf171rryWrZXNXLezJEAtPq05Jf67CJtg7aISAKwRUSuF5EvAEOjGJdS\n/VpuejLfOH5chz0NctOTyUlLDnQbdcfjsvp49Z5axt/6Ci+uKnc95/EPd/KTF9YGbrd5tSWgPrtI\nE8FNQAZwA9a+BJcCl0cpJqUGtNKSgoiOq6rvWq7LGfS94ekVHVoMmyrqqaxvoabZQ12zB7+dRDQR\nqEMh0q6hEmPMEqABu+qoiFwELIpWYEoNVNPDbHvZneD1B9WNbRTaYxFn/O49CrNSGV+Uid9AY5uX\n7LRkHRtQh0SkLYIfRnifUnEvKTGBMQUZlAQtNhuRmxbmjHa7DjQFrjuDyk5Ru/0NrdTaRe2cn8Et\ngvoWD3Utns8WvIpL3VUfPQs4GxglIn8IeigH8LqfpZR66+YTSBDhskcW8dG2aobnprmWnegsOBE8\nv7yMyz6XRF1z+59avV3fqK7ZS2OGt0MimPH/XscY2Hn3OYfwN1HxoLsWQTmwFGgBlgVdXgTOiG5o\nSg1cyYkJJCYID11eymNXzeWwIZkRnecsOgN48L3tHPvLhew60Bi4z2kJ/HPJLg6/bQErd9cEHtOV\nyqq3wiYCY8wqY8xfgQnAv4BPjDF/NcY8Z4w52CcRKjWAZaUmcdLkoeRmWGsLultjAJCe3HEf5H8t\n3RO47nQTObOK3FYU7z7QxIPvbnOdnqqUm0jHCM4EVgKvAYjILBF5MVpBKTXYOAng2AmF3R7beTzh\nnU2VXY5xuoicQnbBvvHEUu56dSPlEXRFKQWRJ4LbgXlADYAxZiVQEo2AlBqMslKt4bhAaYrh7TWK\nHrxsDgDTR1mzjXI6tRpclhsEdkBzNrsJ5rQamtt83P7iOirrNCGo8CKdPuo1xtSKaFUJpXojPcXq\n7mnx+Fh06ymkpySycW89y3cd5IzDh7PjrrO5943NrC2r69A1lJeRTE1T6JlAzS6JwFmV/M6mSh7/\naCd7DjazaHs1x08q4k9fnX2IfzM1GETaIlgrIl8BEkVkooj8EfgoinEpNaikJbUngmE5aeSkJTPv\nsAKuPWE8ACIS2ND+YFMbF84exdyS/EBNo0nD3CuaHmzqWu46wf7C5gw81zS1Ud/q5b9r9h7aX0oN\nGpEmgm8DhwOtwN+BWuDGaAWl1GDjtAjcvsE7xhdZH/aV9a3ce/Esnrn2GLLSrEb7tBHui9SqG7om\nAq+/45aYTleRUqFE2jU0zb4k2Zfzgc8DM6IUl1KDypgCa3GZ2/4FjnFF1hTTkXntg8VOKYmjxw0h\nMSEBj8/Ppop6Nu2rB6CprWticSqSVjdaC9LcWg1KBYs0ETwFfA9YC+iadqV6aPqoXF64bj6Hhyk/\nkZGSxONXze3w7d8ZFB5dkMH/zBsDwOWPLg4kAjdO2QlnA5x9de01jV5bW8E7myq5+4v6HU61izQR\nVBljXopqJEoNcjNH53V7zImTOxb19dmJIHj9wZBu9itwVhs7G+AEu/7vy/H6DV8/9jAmhmmdqPgS\n6RjBbSLysIhcIiIXOpeoRqaU4rdfnsV5M0cyOWi6aeeNa1KSOv4ZOy0Ct32TnS6ql1d3HDj2hChe\n19jq5cVV5Xh9fq55Yilr7F3Y1OASaYvgKmAKkEx715ABnotGUEopy9QROfzxkiM73Nc5EaQmJnSo\nOdTeIuiaCJzB6kU7qgP3ba9q4IzfvccrNxzXpZVw24vreHbZHtoumsnr6/exraqBt24+8TP9Tqr/\niTQRzDTGHNGTJxaR0cATwHCs5PGQMeb3nY4R4PdYhe2agCuNMct78jpKxRvXFkHXXiDX2UJO4bvt\nVY34/YaEBGFrZQMen+HT6iZEoCg7LdAV5ZTFbmqznsttcZsa+CLtGvpERKb18Lm9wM3GmKnA0cB1\nLs9xFjDRvlwD3N/D11Aq7rglgseumsvogvSIn6OyvpVxt77C9qqGwKyiJo+PU+99j6/85ROu/utS\nZv/sjcCaBGf2kk8zwaAUaSI4FlgpIptEZLWIrBGR1eFOMMbsdb7dG2PqgQ3AqE6HnQ88YSyfAHki\nMqKHv4NScaXzYLHfGE6aPJTHrpwb0fmj8toTxrryusBYgjO4vK68jjc37ONAY1tgK85Wu7vJbwwV\ntS2BvRLU4NCTonMTgdOB84Bz7Z8REZES4Ei67mg2CtgddHsPXZMFInKNiCwVkaVVVVWRvqxSg1Jm\nasceXWd6aFqnqqUA+Rldq53OGpMXuH6wqY2DdiLYc7C5y7FOURmnuJ3fbzj6rrcovfPNwDG1TR6u\neHQx+7Sm0YAVUSIwxnzqdonkXBHJAv4N3GSMqev8sNvLubz+Q8aYUmNMaVFRUSQvq9SgNTzHWnB2\n4uSOfwudy1cDZKe1J4LkROvPbWZxLidPsaap7q1t4UCjVcuozC0R2H+hjfYYgc+ltPUzy3bz7uYq\n7n9nW09/FdVPRNoi6BURScZKAk8ZY9xmGO0BRgfdLsbaDEcpFUJ+ZgqrbjudX180s8P9ThmLDvcF\nJQdnALggM5VHr5zLqLx09tY0B8YI9tQ0dTnfGROoszfE0S2SB6eoJQJ7RtAjwAZjzL0hDnsRuFws\nRwO1xhitjKVUN3LTk7uMFTiF7Trcl5LIA5fO4ZYzJgcSQY5dv2hknrV9ZnWYriGnhIWzF7LbZjda\nlXjgi3T6aG/MBy4D1ojISvu+W4ExAMaYB4BXsKaObsWaPnpVFONRalDp/AGckCAkJUigLAVASqJw\n5vThALy1YR/QvjfC8Nx0Vu+pCWxx6ZS7Tk6UQOXSZicRNIfuGtI0MPBFLREYYz6gm/8jxvp6cV20\nYlBqsPvg/04iKaG9YT80O5Xy2hYmD8tm0756DgbtZTBlRA7Ld9UE9isYkZvGgnUtpCSG7hhwFqDV\ntzhdQ6Gnj+rWmANXVMcIlFLRVZyfwfCgrS2d605Jin1B21X+5Jxp/Pqimcw7rACA/IwU2rz+LgvP\nnNYAtC8kq7O3xvS7JAKnPEVv0kB5TTO7DzTx8upyJv/4Vdcd11T0RbNrSCnVx0bkpgM1gU1u6oM+\n5NNTEvnSnOLAbbeppZ05LQpnsNgTlAjavH7qWjyBVoO/Fy2CY+5+G4DRBem0ev1U1bcy2q6HpPqO\ntgiUGkScFkGy3d3jDAy7ycsIX8UU2usWOYPFwTWNvv30ckrvfDMwoNzi8fPrBZt4b3PP1/qI3Yus\nvUuxoS0CpQYRp0soOVF44mvzGDsk9Lfr4BZBfkYyB5s8jMxNo7y268Kw4O4ix4J11uBzVb21oK25\nzcd9C7fCQth59zk9itsetnAdjFbRp4lAqUHkS7OLSRDh8zNHdilP3Vl+0PTT/IwUDjZ5mDoixzUR\nhLO31pp26iw6i8Seg00UZacGbjszoIJbHKrvaNeQUoNIQoLwpTnF3SYBgLzgMQL7G/mM4rwev6ZT\n4iJ4hlI4Xp+fY3+5kO/8c2X7y9uv3+rVweJY0ESgVJzKS29vETizgT43fkiPn8dpEXTeEa2yrsW1\nFHaL/a3/dbtrCSAxSi2CtWW1OhMpApoIlIpTKUkJZKUmkZueHFiEVhJmTCEUpyDdgaAd0Vo8Pub9\n4i2u/duyLsc7i9SCZxm1twgOXSKobfZw/p8+5MWVWrWmO5oIlIpjuenJFGSm8OBlc/jf48d16Lfv\nKWf2EMDzK8oA+Hh7dZfjnG/owcPCzqzU+hYP1/19ObsPdK171PN4vPj8htrmyLqs4pkOFisVx/Iz\nk0lJTODwkbkcPjIXgOzUJOpbvQzNTqWyvnf7Djz5iVWceM6Y/MB9Pr/hk+3VDMvpmmycRWlvbqjk\nv6v30tzm49EI91cIxelmatNKed3SFoFScez6kybyrRMndLjvjgsOBwgsSuuNDXutivPNQf3zj36w\ng68+vIhX11QAHdcMOB/aTkLozeI0R1lNM7c+vybQQjmU3U2DlbYIlIpjTkG6YF84spgvHFnM955Z\n1evnDe7q2X2giT0Hm9lR3QjAtqqGLsc7CaA9EcCZv3uPuSUF3HjqRCpqW5g+Kjei1/7R82t4Z1NV\nYLzDoy2CbmkiUEq56rw3cjgpSQmuM37qW7wcd89CAC47eiwADa1dZ/E439rbvFYGMcawsaKejRX1\nvLlhH3trW3q8SK3Rfp02r5+Vu2tISUxg2sicHj1HJFq9Pr7/7Gq+d/rkAVseQ7uGlFKusu1y1Vmp\n3X9fPMouZNdZddBMIqfqaaPLlFIniTR77OJ2QV1De+0Fbt1VN/3laxu5/NHFgamozqB0m9fP7S+u\n49evb+r29+iNj7ZV88LKcn70n7VRef6+oIlAKeXK2Rs5NczitCnDs/nfE8Zx9XHjIn5eZ0e0YM6A\nbn2gymnX87ob9N1UUc/mivpAwnHGCDw+P81tvsC01UPNqbU/kMtwayJQSrlyWgJpLnshD7Wnmd5w\nykR+eNbUDnWLnP0Nsju1JJxpnBUum9w7n6HOMW79+t19kDe0eGls9QYSQX1QoTyPz7qU1TTz0qpD\nu65gMOzQpolAKeUqK81JBF0/JiYNs4rbOR/Y+UGVTHPtpDAiL63DOZX1VgKoCVOKwil33ejyod/U\nTSKob/XS0OYlwU4ETsmLNp+fVq+fNp+fix/4mG8/vYKK2hZO+NVCtrsMXEfq5dXlTLj1lUGxclkT\ngVLK1fgia/royVOGdnns8s9ZA7/jCq1jcoNaBBkpVgvC2huhnVOTKByna6jJpYCd230Af1+0i/Xl\ndTS0ejAGWu2Vzk4XlNMiaPP6KbfLYTy3Yg+fVjfx2Ic7u40plHte24TXb6isc8Ywev1UMaezhpRS\nriYPz+btm09gTEEGf3l/R4fHTj98OBvuOJN0+0M/uBvI6ZoZ2blF4NIl1JlT6sJtQNmtRdDQ6uXW\n59d0KKDnJIxAIvD5A5dEEbzGBAanP0vZ60CPkH3F9GqPtv5BWwRKqZDGFWWRFGJPYycJQMd+cufa\n8JyOLQJnu8tINLpMMW1s9QW6YQ42tvHdf67kfXsTHMEaI4D2mkc1je3jDR67VeB0GwW21/wsicC5\nYkzwjwFJWwRKqW794KwpjC3IIDcjmU+2Hwh7bIKdFJy6RaPy0imrae7R6zW79Lu/vXEfl/zlE569\n9nP8Y8lunltRxordNYDVDXWwyVrNvL/BSgTONp1t9vhAm9ffpcqpz2UP5kg5yW8wrFzWRKCU6ta1\nJ4wPXD9mfGHYY51EMKM4l8eunEtmahIXP/jxZ47BSUAvr94b2A5zx35rtXJL0D4G1Y0dxyJavX48\nPoPHZwLdVu2JoP245jYfrV5fRFt4QnuLIFBEbwC3CLRrSCl1SOSmW/30Ti+RCJw0ZSgFme399z1Z\nrdyZs55hW1VDl2J4e2vaxx86fyA3BrUMnNic1c3GGFo8Prw+P/cs2MhljyyOPCD7uZyxC4PhpVXl\nLN0ZvsXUH2mLQCl1SHz0g5PxGcOaPbX88Lk1gRlF6SntHzOHj8zh/S37e/X8++2Nb9aU1XZ5zK0r\nyREoNeHzk2mPazjrFfzGMOUnr3Hq1KEkJyZQ3oMuLKdFEEgEBr799Aqg53s2x5q2CJRSh0RmahI5\nacnMn1DIe98/KTCYnBG0IO3IoLLUPVVltwLCrUNw47QIPD5/oNvKWa/gs1sPb26opNnjc91RLRTn\nuUJNax1ItEWglIqq4NlFpWN7nwjcFpkBjMxNo7w29NTUBvuD2pj2zXBqmq0BZV9QLQtrjMCP1+cP\nOVMK4MI/f0hNk4ekxI6lLEINERhj+v3qY20RKKWiKrhW0RiX6pyZKV1LWPTEpOHZYR8PHjPw2qPD\nTtdQ8O5lzl7KblNXgy3fVcN2e5Aa2geL3TLBpop6Jv/4tUOy41o0aSJQSkVV8LfhEXlpPHDpHB6/\nqn33sYKs3g8gg7XwzZESpkAetO+T4HQvVTcE7bNsf7NviLCrR+jcIuiaCXZWN9Lm83+mRPD6ugou\nfvDjqBa100SglOozqUmJnDl9OCdObi9bMSTTWm/g1loI55J5owGYOLQ9EUwoCr+rmtfuCnLm/geX\nyXamoDZEuPBNOs0a8vjaP6jLa5r5aNv+QKG8cIPZ3fnmU8tZvONAVLfc1ESglIqpIfaU0rFDuiaC\nnLSOw5jB009/fM40tvz8LIbYLYrCrNRAMnHbFxm6LiA7EJQInA/tngwYB58XvLDsjN++x1f+siiQ\nAD5LIrCXPrhu/HOoaCJQSkXdcRMLA9/gHZ8bNwQg8EE+NNuqTZSSmMDUETnkZSRTnG99sDtrFEbn\nt5etyEhJJDloUPewwgyG56Z1eK7Ogr+1Q8fE4HxYu9U5cuN0eTnntQZ92Durmp3WwmfZC0E6rYaO\nBp01pJSKur99/agu9z1yZSmfVjfx/IoyALLtb/9tPj8vXj8fn9/wyAc7WL+3jsnDslm88wBDc9IA\nax2B8wF5+MgcUhITuOWMKSzfdRDoXQE4p2ppqETw+Ic7OH5SUeB253UEbuWonT0RPkup6kBZjIHY\nNSQij4pIpYi47t8mIvki8ryIrBaRxSIyPVqxKKX6n4yUJKaOyAlULk0N2vcgOTGBtOREvnXieN74\nzvGcOMX6AE50mYY5NDuNzT8/i3mHFTA8x2oJOB/qPeF80Lp1DTW2ern9pfX8x05aEDxGYB3v1v1z\n0O566m4vBTf1LR4aW70DvmvoceDMMI/fCqw0xswALgd+H8VYlFL91LCc9i6hi+YUc99Xjgw8JiJM\nHJbNqDyrSyh4uqebo+3upiuOKel1PG4tAmcxW33QY06vkvMh77bO4YA9O6k3YwRH3P46c+58I7Bw\nbUB2DRlj3hORkjCHTAPuso/dKCIlIjLMGLMvWjEppfqfC2eP4mBTG5d/rqTD4rNgziY3dS0eTp4y\nlORE9wVaw3PT2Hn3ObR5/fy4l5vJB3+g76tr4XdvbuHsI4YDHVc1d95O0+2D2mkRNHt83P7iOs4+\nYgTzDiuIOJYWj5+UNOv7ejSrnMZyjGAVcCHwgYjMA8YCxUCXRCAi1wDXAIwZM6YvY1RKRVlSYgL/\nG1Td1I2zR3Jts4f/3nBct8/Z3XqCcBpavby7uYrmNh/3LdzC2rK6QOKpdpllFI4zK6m+xcvfF+0i\nNSmhQyKoqm+lodXLYYWZIZ8jUDE1imMEsUwEdwO/F5GVwBpgBeA6SmOMeQh4CKC0tHQAF3tVSvVG\ncX46R48r4IZTJh6y50xNSnD9lr1wYyX3v7Otw31lB61idAeDEkFjBAvPnF3S9ttdS53HH+5+dSNr\ny2pZ8J3jO7z+nJL2UhxO11Bvxj0iFbNEYIypA64CEGv4f4d9UUqpDpISE/jHNZ/r0Tl/vORIdh1o\n4lcLNrk+XjIkk0376rvcv7Gi633OvgfB6w4imWbqHF9lV07tfE5FXTPVjW2sLavlR8+v4d4vz+Kq\nx5dwStA+0TKQZw11R0TyRMRZHXI18J6dHJRS6jM7b+ZIzp81ssv9+fb+xsH7HHdnl10iInjTm85r\nEtw4LY6qQIvAh99vOPKO13nkgx0caLRmBt3x8npW7anlk+3VgLXngsNZKjEgZw2JyNPAx8BkEdkj\nIl8XkWtF5Fr7kKnAOhHZCJwF3BitWJRS8cmZkRTM6aN3vqUHC1Uk1GtPEWrpZffM/qAWwdryWg42\nebjntY3UNLXR7PEFWgrJLlVPB/qsoUu6efxj4NB1+CmlVCduH6ylYwtYsG5fYLVysIKMlA4DwoeK\nk0Aa27y8ud6aD3P0uCEs2mG1AJwuJGefhOBCfYFE4Ov9orTuaIkJpVRcyc9M4YmvzeOBS+d0eaww\nq2uNopQwexP0VEOrl9X2DmuG9gThdB05SSi4YSJ9sKBMS0wopeJKSlJCoFTERXOKqahrYX9DGxv2\n1gXGDbJSkwIzfAqzUsJufNMTTa2+wLTT4G0xna4nZ1ZS8OhDYNaQJgKllOqdMw4fxoJ17cuTUoIW\no/3qopmAtXDs2WV7qKpvZdGOA4zITWNLpTVgm5WWBLXWBjqhdkmLVGOrN7ABjjMlNZjTRRRcDG+g\nl5hQSqmY+9NXZrPhjvZqN26LzYblpHHdSRPITLVWNmcHlb/OsmshjRkSetFXpBrbvIENcFxrE9nr\nDrwuU0Wj2SLQRKCUGtSSEhNIT0lkir2TWUKY/YMz7Q9952dWahLZaVZ30dgebpzjCE4qftO+X7Ib\nZ4zAE9QicKapaotAKaU+I2dcIFwlUOfbf0ZKIotvPYUP/u8kq2sIGFvYu0RwctDiMOi4PWZnTtdQ\ncB0jp4T1YC0xoZRSfebm0ycxKi+d06cNC3lMZkp7i2CovQbBKZM9tqB3XUPnzxrJCyvLA7e9/tAL\n0Zyidp6gb/9O4tIWgVJKfUapSYlccUwJSWGmgwa6hlLavyPnpieTnCjkpFv3zRmb3+W8/AzrmOCt\nNB1OaWw3TkG5zoJXLTtjCZoIlFKqDwS6hlLby2FfcUwJD11eyrETCjnniBH88ZL2/RKcD/K5JQVs\nvvMsnr22az2kjJSuHS/OecODVj4XZbevYXDrBmr1Rm9BmXYNKaWUzZk1FNwiGJmXzkh7Y5w/fXV2\nh+MnDs1iY0U9NU0eRKTDjKTfXDSTUruK6B8uOZJ1ZbU8+N52APIzUtjf0Ephdipl9nqCnLSkwMIy\nN9oiUEqpPpDVadZQdyYOs2Yi7T5oFaULLltxwuQixtpTTj8/cyT/M699L5WCTOu4/KDCd24th2A6\nWKyUUn1gSFYqyYnCiNyuxeqCFWalsL+hjYlDswDYa688zk5LZvGtp5CTnkxacsfd1oKnkeZnpAR+\n/uyC6YwvyuR3b24J+5raIlBKqT5QkJnCO7ecxBmHDw973PRRuQCBvZSDDc1J65IEoGMicAaV8zNS\nuOzosRwzvpB0l3OCaYkJpZTqI24f7p1NH5nLO5uq2H2wie+fOZkZo/K6PSc1qf2DPj+QCIK7hsIn\nAi06p5RS/cjXjz2MjRV1XFw6OjCQ3BMFdtdQXtB00/RuEoFbSYpDRROBUkr1UH5mCg9fMbfX5zvJ\nY2TQWMTo/PArl+tbut8as7d0jEAppfrY5OFZvHT9sZw0ub38xLUnjOewwkzGFbqvYHY2rYkGTQRK\nKdXHUpMSOaI4l4SglcXpKYm8ffMJ/PjcqV2Oz0hJpK7FgzHd75PcG5oIlFKqj7nNKgJri0q3XdIK\nMlPw+EzUZg5pIlBKqT7iNADSkkN/9A7N7rqGYYg9qByt7iFNBEop1UeS7YJ3oVoEAEOyuhauc9Yd\n1LVoIlBKqQHtMHsgONVllzRHskt1VGfdQW1zdGYO6fRRpZTqI098bR4fb68O7HoWKWfdgbYIlFJq\ngBuak8b5s0ZFfLwzNlBgdxdFay2BJgKllOqnRuRZA8c6WKyUUnHmhlMmkpOWxIhcawVyUoL1Ua1d\nQ0opFSe+e9okVt9+BiVDrLITLV4fyYlCnQ4WK6VUfLnp1EkAfHF2MaPy0hlTEL4eUW9pIlBKqX4q\nMzWJH50zDYATg+oSHWraNaSUUnFOE4FSSsU5TQRKKRXnopYIRORREakUkbUhHs8VkZdEZJWIrBOR\nq6IVi1JKqdCi2SJ4HDgzzOPXAeuNMTOBE4HfiEjXaktKKaWiKmqJwBjzHnAg3CFAtogIkGUfG729\n2JRSSrmK5RjBfcBUoBxYA9xojHHddUFErhGRpSKytKqqqi9jVEqpQS+WieAMYCUwEpgF3CciOW4H\nGmMeMsaUGmNKi4qK+i5CpZSKA7FcUHYVcLexNuHcKiI7gCnA4nAnLVu2bL+IfNrL1ywE9vfy3Gjr\nr7FpXD2jcfWMxtVzvY1tbKgHYpkIdgGnAO+LyDBgMrC9u5OMMb1uEojIUmNMaW/Pj6b+GpvG1TMa\nV89oXD0XjdiilghE5Gms2UCFIrIHuA1IBjDGPAD8DHhcRNYAAvyfMaa/ZmCllBq0opYIjDGXdPN4\nOXB6tF5fKaVUZOJtZfFDsQ4gjP4am8bVMxpXz2hcPXfIYxNrrFYppVS8ircWgVJKqU40ESilVJyL\nm0QgImeKyCYR2SoiP4hxLDtFZI2IrBSRpfZ9BSLyhohssX/m90EcXQoDhotDRH5ov3+bROSMPo7r\ndhEps9+zlSJydgziGi0iC0Vkg10o8Ub7/pi+Z2Hiiul7JiJpIrI4qLDk/7Pv7w//x0LF1h/+nyWK\nyAoRedm+Hf33yxgz6C9AIrANGAekAKuAaTGMZydQ2Om+e4Af2Nd/APyyD+I4HpgNrO0uDmCa/b6l\nAofZ72diH8Z1O/A9l2P7Mq4RwGz7ejaw2X79mL5nYeKK6XuGNS08y76eDCwCjo71+9VNbP3h/9l3\ngb8DL9u3o/5+xUuLYB6w1Riz3RjTBvwDOD/GMXV2PvBX+/pfgQui/YLGvTBgqDjOB/5hjGk1xuwA\ntmK9r30VVyh9GddeY8xy+3o9sAEYRYzfszBxhdJXcRljTIN9M9m+GPrH/7FQsYXSJ7GJSDFwDvBw\np9eO6vsVL4lgFLA76PYewv+hRJsBXheRZSJyjX3fMGPMXrD+sIHobVAaXqg4+sN7eL2IrLa7jpzm\ncUziEpES4Eisb5L95j3rFBfE+D2zuzlWApXAG8aYfvN+hYgNYvue/Q74PhBcgDPq71e8JAJxuS+W\n82bnG2NmA2cB14nI8TGMJVKxfg/vB8ZjFSjcC/zGvr/P4xKRLODfwE3GmLpwh7rcF7XYXOKK+Xtm\njPEZY2YBxcA8EZke5vA+fb9CxBaz90xEzgUqjTHLIj3F5b5exRQviWAPMDrodjFW+euYMNaqaowx\nlcDzWM25fSIyAsD+WRmj8ELFEdP30Bizz/7D9QN/ob0J3KdxiUgy1oftU8aY5+y7Y/6eucXVX94z\nO5Ya4B2szapi/n6Fii3G79l84PMishOr+/pkEXmSPni/4iURLAEmishhYu2C9j/Ai7EIREQyRSTb\nuY5VZmOtHc8V9mFXAC/EIr4wcbwI/I+IpIrIYcBEuqkUeyg5fwi2L2C9Z30al4gI8AiwwRhzb9BD\nMX3PQsUV6/dMRIpEJM++ng6cCmykH/wfCxVbLN8zY8wPjTHFxpgSrM+ot40xl9IX71c0Rr374wU4\nG2s2xTbgRzGMYxzWSP8qYJ0TCzAEeAvYYv8s6INYnsZq/nqwvl18PVwcwI/s928TcFYfx/U3rA2M\nVtt/ACNiENexWE3v1Vh7aay0/1/F9D0LE1dM3zNgBrDCfv21wE+7+7/eh/+WoWKL+f8z+7VOpH3W\nUNTfLy0xoZRScS5euoaUUkqFoIlAKaXinCYCpZSKc5oIlFIqzmkiUEqpOKeJQMUdEZkVXFWyB+fd\nISKn2tdvEpGMQxjTBSIyze21lIo2nT6q4o6IXAmUGmOu/wzPsdN+jv09OCfRGOML8djjWPPGn+1t\nTEr1lrYI1IAkIiUislFEHhaRtSLylIicKiIf2nXb59mruB8VkSV2fffz7ZXldwBftuvNf9k+9iP7\nmI9EZHKI13xcRL4kIjcAI4GFIrLQfux0EflYRJaLyDN23R9n74mfisgHwEUi8g07nlUi8m8RyRCR\nY4DPA7+yYxrvvJb9HKfYsa2xf5/UoOf+f/ZrrhGRKfb9J0h7Pf0Vzkp2pUKK5uo4veglWhegBPAC\nR2B9oVkGPIpViOt84D/AL4BL7ePzsFaWZwJXAvcFPVcOkGRfPxX4d4jXfBz4kn19J/aeEkAh8B6Q\nad/+P9pXqu4Evh/0HEOCrt8JfLvzcwffBtKwKkxOsu9/AquonPPczvnfAh62r7+EVdgQIMv53fSi\nl1CXpF7kDqX6ix3GmDUAIrIOeMsYY0RkDVaiKMYq4vU9+/g0YIzL8+QCfxWRiVilGpJ7GMfRWJuE\nfGiV/SEF+Djo8X8GXZ8uIndiJaYsYEE3zz0Z6/fcbN/+K3AdVrliAKfw3TLgQvv6h8C9IvIU8Jwx\nZk8Pfx8VZzQRqIGsNei6P+i2H+v/tg/4ojFmU/BJInJUp+f5GbDQGPMFser5v2Mf9xhWbf9yY0y4\nwWXBqmd/SYjHG4OuPw5cYIxZZY9VnBjmeZ3nDsf5nX3Yf8/GmLtF5L9Y9YY+EZFTjTEbu3keFcd0\njEANZguAb9vVORGRI+3767G2dHTkAmX29SudO40xVxljZoVIAsHP8QkwX0Qm2K+TISKTQsSUDewV\nq2z0V0M8X7CNQInz3MBlwLshnhv79ccbY9YYY34JLAWmhDteKU0EajD7GVY3z2oRWWvfBlgITHMG\ni7H2hL1LRD7E2t86Eg8Br4rIQmNMFVYCeVpEVmMlhlAfvj/B2j3sDawPecc/gFvswd3xzp3GmBbg\nKuAZu8vLDzzQTWw32QPoq4Bm4NUIfycVp3T6qFJKxTltESilVJzTRKCUUnFOE4FSSsU5TQRKKRXn\nNBEopVSc00SglFJxThOBUkrFuf8Phkl86nR/vksAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(meta_losses)\n", "plt.xlabel(\"meta-iterations\")\n", "plt.ylabel(\"meta-loss\")" ] }, { "cell_type": "markdown", "metadata": { "id": "-kz0dpBfJi5r" }, "source": [ "## Exercises\n", "For those curious in getting their feet wet, fork this notebook and try to implement any of the following!\n", "\n", "* Modify the meta-loss to meta-train targeting some validation loss rather than train loss.\n", "* Add other features to the learned optimizer such as rolling second moment features (such as adam) or momentums at different timescales.\n", "* Make the per parameter MLP a per parameter RNN." ] }, { "cell_type": "markdown", "metadata": { "id": "H39tBazN8FP1" }, "source": [ "## Conclusion and relations to the `learned_optimization` package\n", "We hope this notebook gives you an introduction to learned optimizers.\n", "This is an incredibly minimal implementation, and as a result suboptimal in a number of ways.\n", "The `learned_optimization` library was designed based on the patterns used above, but expanded upon to be more general, more scalable, and more fully featured.\n", "\n", "The core designs of how to implement these things in jax remain very consistent.\n", "We outline a few of the main differences.\n", "\n", "\n", "### PyTree all the things\n", "In the above examples, we made use of tuples and lists to store parameters, and opt states. This is simple, but gets unwieldy with more complex structures.\n", "With `learned_optimization` every piece of data is stored as some kind of jax pytree -- oftentimes a dataclass registered as a pytree. These pytree require the use of the [pytree library](https://jax.readthedocs.io/en/latest/pytrees.html).\n", "\n", "### Tasks\n", "The task interface is quite similar to what we have shown here. There is one other layer of abstraction in `learned_optimization`, namely `TaskFamily`. In this example, we meta-train on multiple tasks in parallel -- the only difference between these tasks is their random initialization. A `TaskFamily` let's one instead vectorize over other aspects of the problem. Common examples include vectorizing over different kinds of initializations, or different kinds of activation functions.\n", "\n", "### Optimizers\n", "The optimizer interface is also basically the same. The main differences being that the learned optimization optimizers can accept additional arguments such as `num_steps` (target number of steps for learning rate schedules and related), jax.PRNGKey for stochastic optimizers, and loss values.\n", "\n", "### LearnedOptimizers\n", "In this colab, the learned optimizers and optimizers here have different signatures.\n", "In `learned_optimization` a LearnedOptimizer contains a function of meta-parameters which itself returns an instance of an Optimizer. For example the update can be called as: `lopt.opt_fn(meta_params).update(opt_state, grads)`.\n", "\n", "The learned optimizer implemented in this notebook is designed to be simple and easy to follow as opposed to performant / easy to meta-train. `learned_optimization` comes packaged with a number of learned optimizer implementations which will give much better performance.\n", "\n", "### Meta-training\n", "This is the biggest divergence. Meta-training algorithms are implemented as subclasses of `GradientEstimator`. These operate internally like the truncated training in that they store state which is passed from iteration to iteration, but are much more general.\n", "They implement 2 functions, one to initialize the state of the inner-problems, and the second to perform updates. This mirrors the 2 functions we needed to write for truncated training.\n", "When applying the meta-gradient updates we make use of a `GradientLearner` class which can be either run on a single or multiple machines." ] } ], "metadata": { "colab": { "collapsed_sections": [], "last_runtime": { "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", "kind": "private" }, "name": "No dependency learned optimizer", "provenance": [ { "file_id": "1uJ2meEAIPuGPOQy4x7GmR__kJ4RG271y", "timestamp": 1644459557191 } ], "toc_visible": true }, "jupytext": { "formats": "ipynb,md:myst,py", "main_language": "python" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }