nvkartik commited on
Commit
21beb18
·
1 Parent(s): 31a0783

added test for env loading

Browse files
Files changed (1) hide show
  1. tests/test_env_load_arena.py +44 -0
tests/test_env_load_arena.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import asdict, dataclass
3
+ from pprint import pformat
4
+ import torch
5
+ import tqdm
6
+ from lerobot import envs
7
+ from lerobot.configs import parser
8
+ from lerobot.configs.eval import EvalPipelineConfig
9
+
10
+
11
+ @parser.wrap()
12
+ def main(cfg: EvalPipelineConfig):
13
+ """Run zero action rollout for IsaacLab Arena environment."""
14
+ logging.info(pformat(asdict(cfg)))
15
+
16
+ from lerobot.envs.factory import make_env
17
+
18
+ # hub_path = cfg.env.hub_path
19
+ env_dict = make_env(
20
+ cfg.env,
21
+ n_envs=cfg.env.num_envs,
22
+ trust_remote_code=True,
23
+ )
24
+ env = next(iter(env_dict.values()))[0]
25
+
26
+ try:
27
+ env.reset()
28
+
29
+ for _ in tqdm.tqdm(range(cfg.env.episode_length)):
30
+ with torch.inference_mode():
31
+ action_dim = env.action_space.shape[-1]
32
+ actions = torch.zeros(
33
+ (env.num_envs, action_dim), device=env.device
34
+ )
35
+ obs, rewards, terminated, truncated, info = env.step(actions)
36
+ print(obs.keys())
37
+ print(obs["policy"].keys())
38
+ print(obs["camera_obs"].keys())
39
+ finally:
40
+ env.close()
41
+
42
+
43
+ if __name__ == "__main__":
44
+ main()