import ray.rllib.agents.ppo as ppo
|
|
from ray.tune.logger import pretty_print
|
|
import os
|
|
import ray
|
|
|
|
'''
|
|
The following switch allows the program to run locally or on the Agit distributed cluster without modifications.
|
|
'''
|
|
if 'CLOUD_PROVIDER' in os.environ and os.environ['CLOUD_PROVIDER'] == 'Agit':
|
|
from agit import ray_init
|
|
ray_init()
|
|
else:
|
|
ray.init()
|
|
|
|
config = ppo.DEFAULT_CONFIG.copy()
|
|
config["num_gpus"] = 0
|
|
config["num_workers"] = 1
|
|
config["eager"] = False
|
|
# Create a trainer that holds PPO policy for environment interaction.
|
|
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
|
|
|
checkpoint_path = ""
|
|
|
|
# Can optionally call trainer.restore(checkpoint_path) to load a checkpoint.
|
|
|
|
for i in range(1000):
|
|
# Perform one iteration of training the policy with PPO
|
|
result = trainer.train()
|
|
print(pretty_print(result))
|
|
|
|
if i % 100 == 0:
|
|
checkpoint = trainer.save(checkpoint_path)
|
|
print("checkpoint saved at", checkpoint)
|