polyaxon.estimators.agents.TRPOAgent(model_fn, memory, optimizer_params=None, model_dir=None, config=None, params=None)
TRPOAgent class is the reinforcement learning trust policy region optimizer model trainer/evaluator.
model_fn: Model function. Follows the signature:
Tensors (depending on data passed to
Tensors (for multi-head models). If mode is
labels=Nonewill be passed. If the
model_fn's signature does not accept
model_fnmust still be able to handle
mode: Specifies if this training, evaluation or prediction. See
dictof hyperparameters. Will receive what is passed to Estimator in
paramsparameter. This allows to configure Estimators from hyper parameter tuning.
config: Optional configuration object. Will receive what is passed to Estimator in
configparameter, or the default
config. Allows updating things in your model_fn based on configuration such as
model_dir: Optional directory where model parameters, graph etc are saved. Will receive what is passed to Estimator in
model_dirparameter, or the default
model_dir. Allows updating things in your model_fn that expect model_dir, such as training hooks.
Supports next three signatures for the function:
(features, labels, mode)
(features, labels, mode, params)
(features, labels, mode, params, config)
(features, labels, mode, params, config, model_dir)
memory: An instance of a subclass of
- model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model.
- config: Configuration object.
dictof hyper parameters that will be passed into
model_fn. Keys are names of parameters, values are basic python types.
- ValueError: parameters of
train(self, env, episodes=None, steps=None, hooks=None, max_steps=None, max_episodes=None)
Trains a model given an environment.
- steps: Number of steps for which to train model. If
None, train forever. 'steps' works incrementally. If you call two times fit(steps=10) then training occurs in total 20 steps. If you don't want to have incremental behaviour please set
max_stepsinstead. If set,
- hooks: List of
BaseMonitorsubclass instances. Used for callbacks inside the training loop.
- max_steps: Number of total steps for which to train model. If
None, train forever. If set,
- max_episodes: Number of total episodes for which to train model. If
None, train forever. If set,
Two calls to
fit(steps=100)means 200 training iterations. On the other hand, two calls to
fit(max_steps=100)means that the second call will not do any iteration since first call did all 100 steps.
self, for chaining.
run_episode(self, env, sess, features, labels, no_run_hooks, global_step, update_episode_op, update_timestep_op, estimator_spec)
We need to differentiate between the
global_step gets updated directly by the
train_op and has an effect
on the training learning rate, especially if it gets decayed.
global_timestep on the other hand is related to the episode and how many times
our agent acted. It has an effect on the exploration rate and how it's annealed.
Returns: statistics about episode.