Minimal Example

import argparse
import sys
import time
from craftground.screen_encoding_modes import ScreenEncodingMode
import wandb
from craftground.wrappers.vision import VisionWrapper

try:
    from craftground.minecraft import no_op_v2
except ImportError:
    from craftground.environment.action_space import no_op_v2
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
from stable_baselines3 import PPO
from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common import on_policy_algorithm
import platform
from craftground.initial_environment_config import DaylightMode
import gymnasium as gym
from typing import SupportsFloat, Any
from gymnasium.core import WrapperActType, WrapperObsType
import craftground
from craftground import InitialEnvironmentConfig, ActionSpaceVersion
import torch

# Utility function for headless environment setup
def check_vglrun():
    from shutil import which

    return which("vglrun") is not None

# Utility function to get the accelerator device
def get_device(dev_num: int = 0) -> torch.device:
    if torch.cuda.is_available():
        device = torch.device(f"cuda:{dev_num}")
        torch.cuda.empty_cache()
    elif torch.backends.mps.is_built():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    return device

# Utility function to create the CraftGround environment
def make_craftground_env(
    port: int,
    width: int,
    height: int,
    screen_encoding_mode: ScreenEncodingMode,
    verbose_python: bool = False,
    verbose_gradle: bool = False,
    verbose_jvm: bool = False,
) -> gym.Env:
    return craftground.make(
        port=port,
        initial_env_config=InitialEnvironmentConfig(
            image_width=width,
            image_height=height,
            hud_hidden=False,
            render_distance=11,
            screen_encoding_mode=screen_encoding_mode,
        ).set_daylight_cycle_mode(DaylightMode.ALWAYS_DAY),
        action_space_version=ActionSpaceVersion.V2_MINERL_HUMAN,
        use_vglrun=check_vglrun(),
        verbose_python=verbose_python,
        verbose_gradle=verbose_gradle,
        verbose_jvm=verbose_jvm,
    )

# A wrapper that converts the dictionary action space to a multidiscrete action space, to be used with stable-baselines3's PPO.
class TreeWrapper(gym.Wrapper):
    def __init__(self, env):
        self.env = env
        super().__init__(env)
        self.action_space = gymnasium.spaces.MultiDiscrete(
            [
                2,  # forward
                2,  # back
                2,  # left
                2,  # right
                2,  # jump
                2,  # sneak
                2,  # sprint
                2,  # attack
                25,  # pitch
                25,  # yaw
            ]
        )

    def step(
        self, action: WrapperActType
    ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
        action_v2 = no_op_v2()
        # print(f"{action=}")
        action_v2["forward"] = action[0]
        action_v2["back"] = action[1]
        action_v2["left"] = action[2]
        action_v2["right"] = action[3]
        action_v2["jump"] = action[4]
        action_v2["sneak"] = action[5]
        action_v2["sprint"] = action[6]
        action_v2["attack"] = action[7]
        action_v2["camera_pitch"] = (action[8] - 12) * 15
        action_v2["camera_yaw"] = (action[9] - 12) * 15
        obs, reward, terminated, truncated, info = self.env.step(action_v2)
        return obs, reward, terminated, truncated, info

# The main function that runs the PPO algorithm with the CraftGround environment
def ppo_check(
    run: wandb.Run,
    screen_encoding_mode: ScreenEncodingMode,
    vision_width: int,
    vision_height: int,
    port: int,
    device: str,
    max_steps: int = 10000,
):
    env = make_craftground_env(
        port=port,
        width=vision_width,
        height=vision_height,
        screen_encoding_mode=screen_encoding_mode,
    )
    # VisionWrapper selects only the visual information from the observation space.
    env = VisionWrapper(env, x_dim=vision_width, y_dim=vision_height)
    # Task specific wrappers
    env = TreeWrapper(env)
    env = DummyVecEnv([lambda: env])
    # Record video every 2000 steps and save the video
    env = VecVideoRecorder(
        env,
        f"videos/{run.id}",
        record_video_trigger=lambda x: x % 2000 == 0,
        video_length=2000,
    )
    model = PPO(
        "CnnPolicy",
        env,
        verbose=1,
        device=device,
        tensorboard_log=f"runs/{run.id}",
        gae_lambda=0.99,
        ent_coef=0.005,
        n_steps=512,
    )

    try:
        env.reset()
        model.learn(
            total_timesteps=max_steps,
            callback=[
                # CustomWandbCallback(),
                WandbCallback(
                    gradient_save_freq=500,
                    model_save_path=f"models/{run.id}",
                    verbose=2,
                ),
            ],
        )
        model.save(f"ckpts/{run.group}-{run.name}.ckpt")
    finally:
        env.close()
        run.finish()

if __name__ == "__main__":
    # Initialize wandb
    run = wandb.init(project="craftground", entity="your_entity")
    # Run the PPO algorithm
    ppo_check(
        run=run,
        screen_encoding_mode=ScreenEncodingMode.RAW,
        vision_width=64,
        vision_height=64,
        port=8023,
        device=get_device(),
        max_steps=10000,
    )