Integrating with Stable Baselines3
You can integrate stable-baselines3 with this project. Here is an example of how to do it:
import os.path
import wandb
from craftground import craftground
from craftground.wrappers.action import ActionWrapper, Action
from craftground.wrappers.fast_reset import FastResetWrapper
from craftground.wrappers.time_limit import TimeLimitWrapper
from craftground.wrappers.vision import VisionWrapper
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
from wandb.integration.sb3 import WandbCallback
from craftground.craftground.screen_encoding_modes import ScreenEncodingMode
from check_vglrun import check_vglrun
from get_device import get_device
current_path = os.path.dirname(os.path.abspath(__file__))
map_path = os.path.join(current_path, "custom_structure.nbt")
def example_exploration():
group_name = "example-exploration"
run = wandb.init(
# set the wandb project where this run will be logged
project="craftground-sb3",
entity="your-name",
# track hyperparameters and run metadata
group=group_name,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True, # auto-upload the videos of agents playing the game
save_code=True, # optional
)
size_x = 114
size_y = 64
base_env, sound_list = (
craftground.make(
port=8001,
initialInventoryCommands=[],
verbose=False,
initialPosition=[5, 5, 5], # nullable
initialMobsCommands=[],
imageSizeX=size_x,
imageSizeY=size_y,
visibleSizeX=size_x,
visibleSizeY=size_y,
seed=12345, # nullable
allowMobSpawn=False,
alwaysDay=True,
alwaysNight=False,
initialWeather="clear", # nullable
isHardCore=False,
isWorldFlat=True, # superflat world
obs_keys=[], # No sound subtitles
miscStatKeys=[], # No stats
initialExtraCommands=[
"time set noon",
"place template minecraft:custom_structure 0 0 0",
"tp @p 3 1 1 -90 0",
], # x y z yaw pitch
isHudHidden=True,
render_action=False,
render_distance=5,
simulation_distance=5,
structure_paths=[
map_path,
],
no_pov_effect=True,
screen_encoding_mode=ScreenEncodingMode.RAW,
use_vglrun=check_vglrun(),
),
[],
)
env = FastResetWrapper(
TimeLimitWrapper(
ActionWrapper(
VisionWrapper(
base_env,
x_dim=size_x,
y_dim=size_y,
),
enabled_actions=[
Action.FORWARD,
Action.TURN_LEFT,
Action.TURN_RIGHT,
],
),
max_timesteps=20000,
),
)
env = DummyVecEnv([lambda: env])
env = Monitor(env)
env = VecVideoRecorder(
env,
f"videos/{run.id}",
record_video_trigger=lambda x: x % 20000 == 0,
video_length=20000,
)
model = PPO(
"CnnPolicy",
env,
verbose=1,
device=get_device(),
tensorboard_log=f"runs/{run.id}",
gae_lambda=0.99,
ent_coef=0.005,
n_steps=512,
)
try:
model.learn(
total_timesteps=6000000,
callback=[
WandbCallback(
gradient_save_freq=500,
model_save_path=f"models/{run.id}",
verbose=2,
),
],
)
model.save(f"{group_name}.ckpt")
run.finish()
finally:
base_env.terminate()
if __name__ == "__main__":
example_exploration()
The get_device
method is a utility method that returns the device to use for training. It is defined as follows:
import torch
def get_device() -> torch.device:
if torch.cuda.is_available():
device = torch.device("cuda:0")
torch.cuda.empty_cache()
elif torch.backends.mps.is_built():
device = torch.device("mps")
else:
device = torch.device("cpu")
return device
The check_vglrun
method is a utility method that returns whether vglrun
is installed on the system. It is defined as follows:
def check_vglrun():
from shutil import which
return which("vglrun") is not None