prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer
Class OutOfGraphPrioritizedReplayBuffer
Inherits From:OutOfGraphReplayBuffer
An out-of-graph Replay Buffer for Prioritized Experience Replay.
See circular_replay_buffer.py for details.
Methods
init
__init__(
observation_shape,
stack_size,
replay_capacity,
batch_size,
update_horizon=1,
gamma=0.99,
max_sample_attempts=circular_replay_buffer.MAX_SAMPLE_ATTEMPTS,
extra_storage_types=None,
observation_dtype=np.uint8
)
Initializes OutOfGraphPrioritizedReplayBuffer.
Args:
observation_shape
: tuple or int. If int, the observation is assumed
to be a 2D square with sides equal to observation_shape.stack_size
: int, number of frames to use in state stack.replay_capacity
: int, number of transitions to keep in memory.batch_size
: int.update_horizon
: int, length of update (‘n’ in n-step update).gamma
: int, the discount factor.max_sample_attempts
: int, the maximum number of attempts allowed to
get a sample.extra_storage_types
: list of ReplayElements defining the type of
the extra contents that will be stored and returned by
sample_transition_batch.observation_dtype
: np.dtype, type of the observations. Defaults to
np.uint8 for Atari 2600.
add
add(
observation,
action,
reward,
terminal,
*args
)
Adds a transition to the replay memory.
This function checks the types and handles the padding at the beginning of an
episode. Then it calls the _add function.
Since the next_observation in the transition will be the observation added next
there is no need to pass it.
If the replay memory is at capacity the oldest transition will be discarded.
Args:
observation
: np.array with shape observation_shape.action
: int, the action in the transition.reward
: float, the reward received in the transition.terminal
: A uint8 acting as a boolean indicating whether the
transition was terminal (1) or not (0).*args
: extra contents with shapes and dtypes according to
extra_storage_types.
cursor
cursor()
Index to the location where the next transition will be written.
get_add_args_signature
get_add_args_signature()
The signature of the add function.
The signature is the same as the one for OutOfGraphReplayBuffer, with an added
priority.
Returns:
list of ReplayElements defining the type of the argument signature needed by the
add function.
get_observation_stack
get_observation_stack(index)
get_priority
get_priority(indices)
Fetches the priorities correspond to a batch of memory indices.
For any memory location not yet used, the corresponding priority is 0.
Args:
indices
: np.array with dtype int32, of indices in range [0,
replay_capacity).
Returns:
priorities
: float, the corresponding priorities.
get_range
get_range(
array,
start_index,
end_index
)
Returns the range of array at the index handling wraparound if necessary.
Args:
array
: np.array, the array to get the stack from.start_index
: int, index to the start of the range to be returned.
Range will wraparound if start_index is smaller than 0.end_index
: int, exclusive end index. Range will wraparound if
end_index exceeds replay_capacity.
Returns:
np.array, with shape [end_index - start_index, array.shape[1:]].
get_storage_signature
get_storage_signature()
Returns a default list of elements to be stored in this replay memory.
Note - Derived classes may return a different signature.
Returns:
list of ReplayElements defining the type of the contents stored.
get_terminal_stack
get_terminal_stack(index)
get_transition_elements
get_transition_elements(batch_size=None)
Returns a ‘type signature’ for sample_transition_batch.
Args:
batch_size
: int, number of transitions returned. If None, the
default batch_size will be used.
Returns:
signature
: A namedtuple describing the method’s return type
signature.
is_empty
is_empty()
Is the Replay Buffer empty?
is_full
is_full()
Is the Replay Buffer full?
is_valid_transition
is_valid_transition(index)
Checks if the index contains a valid transition.
Checks for collisions with the end of episodes and the current position of the
cursor.
Args:
index
: int, the index to the state in the transition.
Returns:
Is the index valid: Boolean.
load
load(
checkpoint_dir,
suffix
)
Restores the object from bundle_dictionary and numpy checkpoints.
Args:
checkpoint_dir
: str, the directory where to read the numpy
checkpointed files from.suffix
: str, the suffix to use in numpy checkpoint files.
Raises:
NotFoundError
: If not all expected files are found in directory.
sample_index_batch
sample_index_batch(batch_size)
Returns a batch of valid indices sampled as in Schaul et al. (2015).
Args:
batch_size
: int, number of indices returned.
Returns:
list of ints, a batch of valid indices sampled uniformly.
Raises:
Exception
: If the batch was not constructed after maximum number of
tries.
sample_transition_batch
sample_transition_batch(
batch_size=None,
indices=None
)
Returns a batch of transitions with extra storage and the priorities.
The extra storage are defined through the extra_storage_types constructor
argument.
When the transition is terminal next_state_batch has undefined contents.
Args:
batch_size
: int, number of transitions returned. If None, the
default batch_size will be used.indices
: None or list of ints, the indices of every transition in
the batch. If None, sample the indices uniformly.
Returns:
transition_batch
: tuple of np.arrays with the shape and type as in
get_transition_elements().
save
save(
checkpoint_dir,
iteration_number
)
Save the OutOfGraphReplayBuffer attributes into a file.
This method will save all the replay buffer’s state in a single file.
Args:
checkpoint_dir
: str, the directory where numpy checkpoint files
should be saved.iteration_number
: int, iteration_number to use as a suffix in
naming numpy checkpoint files.
set_priority
set_priority(
indices,
priorities
)
Sets the priority of the given elements according to Schaul et al.
Args:
indices
: np.array with dtype int32, of indices in range [0,
replay_capacity).priorities
: float, the corresponding priorities.