Agent

[source]

Agent

polyaxon.estimators.agents.Agent(model_fn, memory, model_dir=None, config=None, params=None)

Agent class is a reinforcement learning Q model trainer/evaluator.

Constructs an Agent instance.

  • Args:

    • model_fn: Model function. Follows the signature:

      • Args:
      • features: single Tensor or dict of Tensors (depending on data passed to fit),
      • labels: Tensor or dict of Tensors (for multi-head models). If mode is Modes.PREDICT, labels=None will be passed. If the model_fn's signature does not accept mode, the model_fn must still be able to handle labels=None.
      • mode: Specifies if this training, evaluation or prediction. See Modes.
      • params: Optional dict of hyperparameters. Will receive what is passed to Estimator in params parameter. This allows to configure Estimators from hyper parameter tuning.
      • config: Optional configuration object. Will receive what is passed to Estimator in config parameter, or the default config. Allows updating things in your model_fn based on configuration such as num_ps_replicas.
      • model_dir: Optional directory where model parameters, graph etc are saved. Will receive what is passed to Estimator in model_dir parameter, or the default model_dir. Allows updating things in your model_fn that expect model_dir, such as training hooks.

      • Returns: EstimatorSpec

      Supports next three signatures for the function:

      • (features, labels, mode)
      • (features, labels, mode, params)
      • (features, labels, mode, params, config)
      • (features, labels, mode, params, config, model_dir)
    • memory: An instance of a subclass of Memory.

    • model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model.
    • config: Configuration object.
    • params: dict of hyper parameters that will be passed into model_fn. Keys are names of parameters, values are basic python types.
    • Raises:
    • ValueError: parameters of model_fn don't match params.

train

train(self, env, first_update=35, update_frequency=10, episodes=None, steps=None, hooks=None, max_steps=None, max_episodes=None)

Trains a model given an environment.

  • Args:

    • env: Environment instance.
    • first_update: int. First timestep to calculate the loss and train_op for the model.
    • update_frequency: int. The frequecncy at which to calcualate the loss and train_op.
    • steps: Number of steps for which to train model. If None, train forever. 'steps' works incrementally. If you call two times fit(steps=10) then training occurs in total 20 steps. If you don't want to have incremental behaviour please set max_steps instead. If set, max_steps must be None.
    • hooks: List of BaseMonitor subclass instances. Used for callbacks inside the training loop.
    • max_steps: Number of total steps for which to train model. If None, train forever. If set, steps must be None.
    • max_episodes: Number of total episodes for which to train model. If None, train forever. If set, episodes must be None.

    Two calls to fit(steps=100) means 200 training iterations. On the other hand, two calls to fit(max_steps=100) means that the second call will not do any iteration since first call did all 100 steps.

  • Returns: self, for chaining.


run_episode

run_episode(self, env, sess, features, labels, no_run_hooks, global_step, update_episode_op, update_timestep_op, first_update, update_frequency, estimator_spec)

We need to differentiate between the global_timestep and global_step.

The global_step gets updated directly by the train_op and has an effect on the training learning rate, especially if it gets decayed.

The global_timestep on the other hand is related to the episode and how many times our agent acted. It has an effect on the exploration rate and how it's annealed.

  • Args:

    • env: Environment instance.
    • sess: MonitoredTrainingSession instance.
    • first_update: The first timestep we should invoke the train_op and update the model loss.
    • update_frequency: The frequency of calculating the loss of the model.
    • estimator_spec: EstimatorSpec instance.
  • Returns: statistics about episode.