Source code for mushroom_rl_benchmark.utils.utils

import numpy as np
from inspect import signature

from mushroom_rl.utils.frames import LazyFrames


[docs]def get_init_states(dataset): """ Get the initial states of a MushroomRL dataset Args: dataset (Dataset): a MushroomRL dataset. """ pick = True x_0 = list() for d in dataset: if pick: if isinstance(d[0], LazyFrames): x_0.append(np.array(d[0])) else: x_0.append(d[0]) pick = d[-1] return np.array(x_0)
[docs]def extract_arguments(args, method): """ Extract the arguments from a dictionary that fit to a methods parameters. Args: args (dict): dictionary of arguments; method (function): method for which the arguments should be extracted. """ intersection = lambda list1, list2: [x for x in list1 if x in list2] filterByKey = lambda keys, data: {x: data[x] for x in keys if x in data } keys = intersection(signature(method).parameters.keys(), args.keys()) params = filterByKey(keys, args) return params