polyaxon.estimators.agents.PGAgent(model_fn, memory, optimizer_params=None, model_dir=None, config=None, params=None)
PGAgent class is the basic reinforcement learning policy gradient model trainer/evaluator.
Constructs an `PGAgent` instance. - __Args__: - __model_fn__: Model function. Follows the signature: * Args: * `features`: single `Tensor` or `dict` of `Tensor`s (depending on data passed to `fit`), * `labels`: `Tensor` or `dict` of `Tensor`s (for multi-head models). If mode is `Modes.PREDICT`, `labels=None` will be passed. If the `model_fn`'s signature does not accept `mode`, the `model_fn` must still be able to handle `labels=None`. * `mode`: Specifies if this training, evaluation or prediction. See `Modes`. * `params`: Optional `dict` of hyperparameters. Will receive what is passed to Estimator in `params` parameter. This allows to configure Estimators from hyper parameter tuning. * `config`: Optional configuration object. Will receive what is passed to Estimator in `config` parameter, or the default `config`. Allows updating things in your model_fn based on configuration such as `num_ps_replicas`. * `model_dir`: Optional directory where model parameters, graph etc are saved. Will receive what is passed to Estimator in `model_dir` parameter, or the default `model_dir`. Allows updating things in your model_fn that expect model_dir, such as training hooks. * Returns: `EstimatorSpec` Supports next three signatures for the function: * `(features, labels, mode)` * `(features, labels, mode, params)` * `(features, labels, mode, params, config)` * `(features, labels, mode, params, config, model_dir)` - __memory__: An instance of a subclass of `BatchMemory`. - __model_dir__: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. - __config__: Configuration object. - __params__: `dict` of hyper parameters that will be passed into `model_fn`. Keys are names of parameters, values are basic python types. - __Raises__: - __ValueError__: parameters of `model_fn` don't match `params`.
train(self, env, episodes=None, steps=None, hooks=None, max_steps=None, max_episodes=None)
Trains a model given an environment.
- steps: Number of steps for which to train model. If
None, train forever. 'steps' works incrementally. If you call two times fit(steps=10) then training occurs in total 20 steps. If you don't want to have incremental behaviour please set
max_stepsinstead. If set,
- hooks: List of
BaseMonitorsubclass instances. Used for callbacks inside the training loop.
- max_steps: Number of total steps for which to train model. If
None, train forever. If set,
- max_episodes: Number of total episodes for which to train model. If
None, train forever. If set,
Two calls to
fit(steps=100)means 200 training iterations. On the other hand, two calls to
fit(max_steps=100)means that the second call will not do any iteration since first call did all 100 steps.
self, for chaining.
run_episode(self, env, sess, features, labels, no_run_hooks, global_step, update_episode_op, update_timestep_op, estimator_spec)
We need to differentiate between the
global_step gets updated directly by the
train_op and has an effect
on the training learning rate, especially if it gets decayed.
global_timestep on the other hand is related to the episode and how many times
our agent acted. It has an effect on the exploration rate and how it's annealed.
Returns: statistics about episode.