circular_replay_buffer.OutOfGraphReplayBuffer
Class OutOfGraphReplayBuffer
A simple out-of-graph Replay Buffer.
Stores transitions, state, action, reward, next_state, terminal (and any extra
contents specified) in a circular buffer and provides a uniform transition
sampling function.
When the states consist of stacks of observations storing the states is
inefficient. This class writes observations and constructs the stacked states at
sample time.
Attributes:
add_count
: int, counter of how many transitions have been added
(including the blank ones at the beginning of an episode).
Methods
init
__init__(
observation_shape,
stack_size,
replay_capacity,
batch_size,
update_horizon=1,
gamma=0.99,
max_sample_attempts=MAX_SAMPLE_ATTEMPTS,
extra_storage_types=None,
observation_dtype=np.uint8
)
Initializes OutOfGraphReplayBuffer.
Args:
observation_shape
: tuple or int. If int, the observation is assumed
to be a 2D square.stack_size
: int, number of frames to use in state stack.replay_capacity
: int, number of transitions to keep in memory.batch_size
: int.update_horizon
: int, length of update (‘n’ in n-step update).gamma
: int, the discount factor.max_sample_attempts
: int, the maximum number of attempts allowed to
get a sample.extra_storage_types
: list of ReplayElements defining the type of
the extra contents that will be stored and returned by
sample_transition_batch.observation_dtype
: np.dtype, type of the observations. Defaults to
np.uint8 for Atari 2600.
Raises:
ValueError
: If replay_capacity is too small to hold at least one
transition.
add
add(
observation,
action,
reward,
terminal,
*args
)
Adds a transition to the replay memory.
This function checks the types and handles the padding at the beginning of an
episode. Then it calls the _add function.
Since the next_observation in the transition will be the observation added next
there is no need to pass it.
If the replay memory is at capacity the oldest transition will be discarded.
Args:
observation
: np.array with shape observation_shape.action
: int, the action in the transition.reward
: float, the reward received in the transition.terminal
: A uint8 acting as a boolean indicating whether the
transition was terminal (1) or not (0).*args
: extra contents with shapes and dtypes according to
extra_storage_types.
cursor
cursor()
Index to the location where the next transition will be written.
get_add_args_signature
get_add_args_signature()
The signature of the add function.
Note - Derived classes may return a different signature.
Returns:
list of ReplayElements defining the type of the argument signature needed by the
add function.
get_observation_stack
get_observation_stack(index)
get_range
get_range(
array,
start_index,
end_index
)
Returns the range of array at the index handling wraparound if necessary.
Args:
array
: np.array, the array to get the stack from.start_index
: int, index to the start of the range to be returned.
Range will wraparound if start_index is smaller than 0.end_index
: int, exclusive end index. Range will wraparound if
end_index exceeds replay_capacity.
Returns:
np.array, with shape [end_index - start_index, array.shape[1:]].
get_storage_signature
get_storage_signature()
Returns a default list of elements to be stored in this replay memory.
Note - Derived classes may return a different signature.
Returns:
list of ReplayElements defining the type of the contents stored.
get_terminal_stack
get_terminal_stack(index)
get_transition_elements
get_transition_elements(batch_size=None)
Returns a ‘type signature’ for sample_transition_batch.
Args:
batch_size
: int, number of transitions returned. If None, the
default batch_size will be used.
Returns:
signature
: A namedtuple describing the method’s return type
signature.
is_empty
is_empty()
Is the Replay Buffer empty?
is_full
is_full()
Is the Replay Buffer full?
is_valid_transition
is_valid_transition(index)
Checks if the index contains a valid transition.
Checks for collisions with the end of episodes and the current position of the
cursor.
Args:
index
: int, the index to the state in the transition.
Returns:
Is the index valid: Boolean.
load
load(
checkpoint_dir,
suffix
)
Restores the object from bundle_dictionary and numpy checkpoints.
Args:
checkpoint_dir
: str, the directory where to read the numpy
checkpointed files from.suffix
: str, the suffix to use in numpy checkpoint files.
Raises:
NotFoundError
: If not all expected files are found in directory.
sample_index_batch
sample_index_batch(batch_size)
Returns a batch of valid indices sampled uniformly.
Args:
batch_size
: int, number of indices returned.
Returns:
list of ints, a batch of valid indices sampled uniformly.
Raises:
RuntimeError
: If the batch was not constructed after maximum number
of tries.
sample_transition_batch
sample_transition_batch(
batch_size=None,
indices=None
)
Returns a batch of transitions (including any extra contents).
If get_transition_elements has been overridden and defines elements not stored
in self._store, an empty array will be returned and it will be left to the child
class to fill it. For example, for the child class
OutOfGraphPrioritizedReplayBuffer, the contents of the sampling_probabilities
are stored separately in a sum tree.
When the transition is terminal next_state_batch has undefined contents.
NOTE: This transition contains the indices of the sampled elements. These are
only valid during the call to sample_transition_batch, i.e. they may be used by
subclasses of this replay buffer but may point to different data as soon as
sampling is done.
Args:
batch_size
: int, number of transitions returned. If None, the
default batch_size will be used.indices
: None or list of ints, the indices of every transition in
the batch. If None, sample the indices uniformly.
Returns:
transition_batch
: tuple of np.arrays with the shape and type as in
get_transition_elements().
Raises:
ValueError
: If an element to be sampled is missing from the replay
buffer.
save
save(
checkpoint_dir,
iteration_number
)
Save the OutOfGraphReplayBuffer attributes into a file.
This method will save all the replay buffer’s state in a single file.
Args:
checkpoint_dir
: str, the directory where numpy checkpoint files
should be saved.iteration_number
: int, iteration_number to use as a suffix in
naming numpy checkpoint files.