Unity 2021.3.1f1 で「Unity ML-Agents実践ゲームプログラミング」のRollerAgents.csを書いた

Unity

概要

  • 「Unity ML-Agents実践ゲームプログラミング」では、Unity 2021.3.2f1に対応していない。
  • RollerAgentsのところでソースコードエラーがが出るので、出ないように改修した。

ソースコード

  • p.65のHeuristicのところまで直してみた。
  • エラーが出る場所は、OnActionReceivedとHeuristicの2か所なので、その部分を改修
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

public class RollerAgent : Agent
{
    public Transform target;
    Rigidbody rBody;

    public override void Initialize()
    {
        this.rBody = GetComponent<Rigidbody>();
    }

    public override void OnEpisodeBegin()
    {
        if (this.transform.position.y < 0)
        {
            this.rBody.angularVelocity = Vector3.zero;
            this.rBody.velocity = Vector3.zero;
            this.transform.position = new Vector3(0.0f, 0.5f, 0.0f);
        }

        target.position = new Vector3(Random.value * 8 - 4, 0.5f, Random.value * 8 - 4);
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(target.position);
        sensor.AddObservation(this.transform.position);
        sensor.AddObservation(rBody.velocity.x);
        sensor.AddObservation(rBody.velocity.z);
    }

    public override void OnActionReceived(ActionBuffers actions)
    {
        Vector3 controlSignal = Vector3.zero;
        controlSignal.x = actions.ContinuousActions[0];
        controlSignal.z = actions.ContinuousActions[1];
        rBody.AddForce(controlSignal * 10);

        float distanceToTarget = Vector3.Distance(
            this.transform.position, target.position
        );

        if (distanceToTarget < 1.42f){
            AddReward(1.0F);
            EndEpisode();
        }

        if (this.transform.position.y < 0){
            EndEpisode();
        }
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        // 手動実行時の行動
        var continuousActionsOut = actionsOut.ContinuousActions;
        continuousActionsOut[0] = Input.GetAxis("Horizontal");
        continuousActionsOut[1] = Input.GetAxis("Vertical");
    }

}

コメント

タイトルとURLをコピーしました