* In Linux, I can't use Korean Keyboard.. So I explain [how to do it] with English..
해당 코드 링크 : https://github.com/thu-ml/SRPO
1. Initial settings
new start !! lets go... :D
(1) install anaconda (깔려 있으면 skip 가능)
https://ieworld.tistory.com/12
(2) Open terminal board, and Write under code.
# first, git clone diffusion QL file
git clone https://github.com/thu-ml/SRPO
cd SRPO/
# conda activate
conda update -n base -c defaults conda
conda create -n python38 python=3.8
conda activate python38
# install Pytorch, Mujoco, D4RL
# 'pip list'를 수시로 써가며 깔렸는지 확인하기
# 오류 한두 개 뜨는건 오류 그대로 구글링하면 바로 나와요 ..!
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cpuonly -c pytorch
pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl
pip install gym==0.24.1
pip install mujoco==2.3.7
# diffusion QL coding
TASK="halfcheetah-medium-v2"; seed=0; python3 -u train_behavior.py --expid ${TASK}-baseline-seed${seed} --env $TASK --seed ${seed}
# conda deactivate
conda deactivate
(3) wandb API key
3a8a41d449090ec59d378031df37506c9dd2d12c
(4) error
[1] https://github.com/apple/turicreate/issues/3383
[2] https://github.com/openai/mujoco-py/issues/773
[3] https://stackoverflow.com/questions/76958656/how-can-i-solve-the-subprocess-exited-with-error
[4] https://blog.finxter.com/fixed-modulenotfounderror-no-module-named-wandb/ -> wandb key is (3)
2. Simulation result
(1) 코드 실행 결과
사진
(2) 코드 분석
model.py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# 'embed_dim' 차원의 가우시안 푸리에 투사를 수행
class GaussianFourierProjection(nn.Module):
def __init__(self, embed_dim, scale=30.):
super().__init__()
# Randomly sample weights during initialization. These weights are fixed
# during optimization and are not trainable.
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[..., None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
# 다층 퍼셈트론 MLP를 생성.
def mlp(dims, activation=nn.ReLU, output_activation=None):
n_dims = len(dims)
assert n_dims >= 2, 'MLP requires at least two dims (input and output)'
layers = []
for i in range(n_dims - 2):
layers.append(nn.Linear(dims[i], dims[i+1]))
layers.append(activation())
layers.append(nn.Linear(dims[-2], dims[-1]))
if output_activation is not None:
layers.append(output_activation())
net = nn.Sequential(*layers)
net.to(dtype=torch.float32)
return net
# 상태 및 행동을 입력 받아 Q 값을 예측하는 Twin Q 네트워크 정의.
class TwinQ(nn.Module):
def __init__(self, action_dim, state_dim, layers=2):
super().__init__()
dims = [state_dim + action_dim] +[256]*layers +[1]
# dims = [state_dim + action_dim, 256, 256, 1] # TODO
self.q1 = mlp(dims)
self.q2 = mlp(dims)
def both(self, action, condition=None):
as_ = torch.cat([action, condition], -1) if condition is not None else action
return self.q1(as_), self.q2(as_)
def forward(self, action, condition=None):
return torch.min(*self.both(action, condition))
# 주어진 상태에 대한 가치를 예측
class ValueFunction(nn.Module):
def __init__(self, state_dim):
super().__init__()
dims = [state_dim, 256, 256, 1]
self.v = mlp(dims)
def forward(self, state):
return self.v(state)
# 상태를 입력으로 받아 행동을 출력.
class Dirac_Policy(nn.Module):
def __init__(self, action_dim, state_dim, layer=2):
super().__init__()
self.net = mlp([state_dim] + [256]*layer + [action_dim], output_activation=nn.Tanh)
def forward(self, state):
return self.net(state)
def select_actions(self, state):
return self(state)
# Residual 블록과 이를 기반으로 한 MLP ResNet을 정의.
class MLPResNetBlock(nn.Module):
def __init__(self, features, act, dropout_rate=None, use_layer_norm=False):
super(MLPResNetBlock, self).__init__()
self.features = features
self.act = act
self.dropout_rate = dropout_rate
self.use_layer_norm = use_layer_norm
if self.use_layer_norm:
self.layer_norm = nn.LayerNorm(features)
self.fc1 = nn.Linear(features, features * 4)
self.fc2 = nn.Linear(features * 4, features)
self.residual = nn.Linear(features, features)
self.dropout = nn.Dropout(dropout_rate) if dropout_rate is not None and dropout_rate > 0.0 else None
def forward(self, x, training=False):
residual = x
if self.dropout is not None:
x = self.dropout(x)
if self.use_layer_norm:
x = self.layer_norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
if residual.shape != x.shape:
residual = self.residual(residual)
return residual + x
class MLPResNet(nn.Module):
def __init__(self, num_blocks, input_dim, out_dim, dropout_rate=None, use_layer_norm=False, hidden_dim=256, activations=F.relu):
super(MLPResNet, self).__init__()
self.num_blocks = num_blocks
self.out_dim = out_dim
self.dropout_rate = dropout_rate
self.use_layer_norm = use_layer_norm
self.hidden_dim = hidden_dim
self.activations = activations
self.fc = nn.Linear(input_dim+128, self.hidden_dim)
self.blocks = nn.ModuleList([MLPResNetBlock(self.hidden_dim, self.activations, self.dropout_rate, self.use_layer_norm)
for _ in range(self.num_blocks)])
self.out_fc = nn.Linear(self.hidden_dim, self.out_dim)
def forward(self, x, training=False):
x = self.fc(x)
for block in self.blocks:
x = block(x, training=training)
x = self.activations(x)
x = self.out_fc(x)
return x
# 주어진 입력에 대해 출력을 예측하는 네트워크.
# 주어진 입력과 조건에 대해 스코어를 에측하는 모델.
class ScoreNet_IDQL(nn.Module):
def __init__(self, input_dim, output_dim, marginal_prob_std, embed_dim=64, args=None):
super().__init__()
self.output_dim = output_dim
self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim))
self.device=args.device
self.marginal_prob_std = marginal_prob_std
self.args=args
self.main = MLPResNet(args.actor_blocks, input_dim, output_dim, dropout_rate=0.1, use_layer_norm=True, hidden_dim=256, activations=nn.Mish())
# 조건 모델, 입력 차원이 64, 은닉층이 128인 다층 퍼셉트론 모델.
self.cond_model = mlp([64, 128, 128], output_activation=None, activation=nn.Mish)
# The swish activation function
# self.act = lambda x: x * torch.sigmoid(x)
def forward(self, x, t, condition):
# 주어진 입력과 조건에 대해 score를 예측하는 함수.
embed = self.cond_model(self.embed(t))
all = torch.cat([x, condition, embed], dim=-1)
h = self.main(all)
return h
SRPO.py
import copy
import torch
import torch.nn as nn
from model import *
class SRPO(nn.Module):
def __init__(self, input_dim, output_dim, marginal_prob_std, args=None):
super().__init__()
self.diffusion_behavior = ScoreNet_IDQL(input_dim, output_dim, marginal_prob_std, embed_dim=64, args=args)
self.diffusion_optimizer = torch.optim.AdamW(self.diffusion_behavior.parameters(), lr=3e-4)
self.SRPO_policy = Dirac_Policy(output_dim, input_dim-output_dim, layer=args.policy_layer).to("cuda")
self.SRPO_policy_optimizer = torch.optim.Adam(self.SRPO_policy.parameters(), lr=3e-4)
self.SRPO_policy_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.SRPO_policy_optimizer, T_max=args.n_policy_epochs * 10000, eta_min=0.)
self.marginal_prob_std = marginal_prob_std
self.args = args
self.output_dim = output_dim
self.step = 0
self.q = []
self.q.append(IQL_Critic(adim=output_dim, sdim=input_dim-output_dim, args=args))
def update_SRPO_policy(self, data):
s = data['s']
self.diffusion_behavior.eval()
a = self.SRPO_policy(s) # 상태 s를 입력으로 하여 행동 a를 추출.
t = torch.rand(a.shape[0], device=s.device) * 0.96 + 0.02
alpha_t, std = self.marginal_prob_std(t)
z = torch.randn_like(a)
perturbed_a = a * alpha_t[..., None] + z * std[..., None]
with torch.no_grad():
episilon = self.diffusion_behavior(perturbed_a, t, s).detach()
# 변형된 행동 'perturbed_a'에 대한 입실론을 계산한다.
if "noise" in self.args.WT: # 노이즈가 포함되어 있으면, 다음 식 수행.
episilon = episilon - z
if "VDS" in self.args.WT:
wt = std ** 2
elif "stable" in self.args.WT:
wt = 1.0
elif "score" in self.args.WT:
wt = alpha_t / std
else:
assert False
detach_a = a.detach().requires_grad_(True)
qs = self.q[0].q0_target.both(detach_a , s)
q = (qs[0].squeeze() + qs[1].squeeze()) / 2.0
self.SRPO_policy.q = torch.mean(q)
# TODO be aware that there is a small std gap term here, this seem won't affect final performance though
# guidance = torch.autograd.grad(torch.sum(q), detach_a)[0].detach() * std[..., None]
guidance = torch.autograd.grad(torch.sum(q), detach_a)[0].detach()
if self.args.regq:
guidance_norm = torch.mean(guidance ** 2, dim=-1, keepdim=True).sqrt()
guidance = guidance / guidance_norm
loss = (episilon * a).sum(-1) * wt - (guidance * a).sum(-1) * self.args.beta
# 손실 계산.
loss = loss.mean()
self.SRPO_policy_optimizer.zero_grad(set_to_none=True)
loss.backward()
self.SRPO_policy_optimizer.step()
self.SRPO_policy_lr_scheduler.step()
self.diffusion_behavior.train() # 다시 학습 모드로 전환.
class SRPO_Behavior(nn.Module): # 행동 정책을 업데이트.
def __init__(self, input_dim, output_dim, marginal_prob_std, args=None):
super().__init__()
self.diffusion_behavior = ScoreNet_IDQL(input_dim, output_dim, marginal_prob_std, embed_dim=64, args=args)
self.diffusion_optimizer = torch.optim.AdamW(self.diffusion_behavior.parameters(), lr=3e-4)
self.marginal_prob_std = marginal_prob_std
self.args = args
self.output_dim = output_dim
self.step = 0
def update_behavior(self, data):
self.step += 1
all_a = data['a'] # 모든 상태와 행동 추출.
all_s = data['s']
# Update diffusion behavior, 학습 모드로 전환.
self.diffusion_behavior.train()
random_t = torch.rand(all_a.shape[0], device=all_a.device) * (1. - 1e-3) + 1e-3
z = torch.randn_like(all_a)
alpha_t, std = self.marginal_prob_std(random_t)
perturbed_x = all_a * alpha_t[:, None] + z * std[:, None]
# 변형된 행동을 얻는다.
episilon = self.diffusion_behavior(perturbed_x, random_t, all_s)
loss = torch.mean(torch.sum((episilon - z)**2, dim=(1,)))
# loss : 입실론과 z 간의 제곱 합의 평균.
self.loss =loss
self.diffusion_optimizer.zero_grad()
loss.backward()
self.diffusion_optimizer.step()
class SRPO_IQL(nn.Module):
def __init__(self, input_dim, output_dim, args=None):
super().__init__()
self.deter_policy = Dirac_Policy(output_dim, input_dim-output_dim).to("cuda")
self.deter_policy_optimizer = torch.optim.Adam(self.deter_policy.parameters(), lr=3e-4)
self.deter_policy_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.deter_policy_optimizer, T_max=1500000, eta_min=0.)
self.args = args
self.output_dim = output_dim
self.step = 0
self.q = []
self.q.append(IQL_Critic(adim=output_dim, sdim=input_dim-output_dim, args=args))
def update_iql(self, data):
a = data['a']
s = data['s']
self.q[0].update_q0(data) # 'update_q0' 메서드를 사용하여 Q 함수 업데이트.
# evaluate iql policy part, can be deleted
with torch.no_grad():
target_q = self.q[0].q0_target(a, s).detach()
v = self.q[0].vf(s).detach()
adv = target_q - v
temp = 10.0 if "maze" in self.args.env else 3.0
exp_adv = torch.exp(temp * adv.detach()).clamp(max=100.0)
# 어드밴티지에 대한 지수 함수 계산.
policy_out = self.deter_policy(s)
bc_losses = torch.sum((policy_out - a)**2, dim=1)
policy_loss = torch.mean(exp_adv.squeeze() * bc_losses)
self.deter_policy_optimizer.zero_grad(set_to_none=True)
policy_loss.backward()
self.deter_policy_optimizer.step()
self.deter_policy_lr_scheduler.step()
self.policy_loss = policy_loss
def update_target(new, target, tau):
# Update the frozen target models
for param, target_param in zip(new.parameters(), target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
def asymmetric_l2_loss(u, tau):
return torch.mean(torch.abs(tau - (u < 0).float()) * u**2)
class IQL_Critic(nn.Module):
def __init__(self, adim, sdim, args) -> None:
super().__init__()
self.q0 = TwinQ(adim, sdim, layers=args.q_layer).to(args.device)
print(args.q_layer)
self.q0_target = copy.deepcopy(self.q0).to(args.device)
self.vf = ValueFunction(sdim).to("cuda")
self.q_optimizer = torch.optim.Adam(self.q0.parameters(), lr=3e-4)
self.v_optimizer = torch.optim.Adam(self.vf.parameters(), lr=3e-4)
self.discount = 0.99
self.args = args
self.tau = 0.9 if "maze" in args.env else 0.7
print(self.tau)
def update_q0(self, data):
s = data["s"]
a = data["a"]
r = data["r"]
s_ = data["s_"]
d = data["d"]
with torch.no_grad():
target_q = self.q0_target(a, s).detach()
next_v = self.vf(s_).detach()
# Update value function
v = self.vf(s)
adv = target_q - v
v_loss = asymmetric_l2_loss(adv, self.tau)
self.v_optimizer.zero_grad(set_to_none=True)
v_loss.backward()
self.v_optimizer.step()
# Update Q function
targets = r + (1. - d.float()) * self.discount * next_v.detach()
qs = self.q0.both(a, s)
self.v = v.mean()
q_loss = sum(torch.nn.functional.mse_loss(q, targets) for q in qs) / len(qs)
self.q_optimizer.zero_grad(set_to_none=True)
q_loss.backward()
self.q_optimizer.step()
self.v_loss = v_loss
self.q_loss = q_loss
self.q = target_q.mean()
self.v = next_v.mean()
# Update target
update_target(self.q0, self.q0_target, 0.005)
utils.py
import argparse
import d4rl
import gym
import numpy as np
import torch
temperature_coefficients = {"antmaze-medium-play-v2": 0.08, "antmaze-umaze-v2": 0.02, "antmaze-umaze-diverse-v2": 0.04, "antmaze-medium-diverse-v2": 0.05, "antmaze-large-diverse-v2": 0.05, "antmaze-large-play-v2": 0.06, "hopper-medium-expert-v2": 0.01, "hopper-medium-v2": 0.05, "hopper-medium-replay-v2": 0.2, "walker2d-medium-expert-v2": 0.1, "walker2d-medium-v2": 0.05, "walker2d-medium-replay-v2": 0.5, "halfcheetah-medium-expert-v2": 0.01, "halfcheetah-medium-v2": 0.2, "halfcheetah-medium-replay-v2": 0.2}
def marginal_prob_std(t, device="cuda",beta_1=20.0,beta_0=0.1):
"""Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.
"""
t = torch.tensor(t, device=device)
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
alpha_t = torch.exp(log_mean_coeff)
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
return alpha_t, std
def simple_eval_policy(policy_fn, env_name, seed, eval_episodes=20):
env = gym.make(env_name)
env.seed(seed+561)
all_rewards = []
for _ in range(eval_episodes):
obs = env.reset()
total_reward = 0.
done = False
while not done:
with torch.no_grad():
action = policy_fn(torch.Tensor(obs).unsqueeze(0).to("cuda")).cpu().numpy().squeeze()
next_obs, reward, done, info = env.step(action)
total_reward += reward
if done:
break
else:
obs = next_obs
all_rewards.append(d4rl.get_normalized_score(env_name, total_reward))
return np.mean(all_rewards), np.std(all_rewards)
def pallaral_simple_eval_policy(policy_fn, env_name, seed, eval_episodes=20):
eval_envs = []
for i in range(eval_episodes):
env = gym.make(env_name)
eval_envs.append(env)
env.seed(seed + 1001 + i)
env.buffer_state = env.reset()
env.buffer_return = 0.0
ori_eval_envs = [env for env in eval_envs]
import time
t = time.time()
while len(eval_envs) > 0:
new_eval_envs = []
states = np.stack([env.buffer_state for env in eval_envs])
states = torch.Tensor(states).to("cuda")
with torch.no_grad():
actions = policy_fn(states).detach().cpu().numpy()
for i, env in enumerate(eval_envs):
state, reward, done, info = env.step(actions[i])
env.buffer_return += reward
env.buffer_state = state
if not done:
new_eval_envs.append(env)
eval_envs = new_eval_envs
for i in range(eval_episodes):
ori_eval_envs[i].buffer_return = d4rl.get_normalized_score(env_name, ori_eval_envs[i].buffer_return)
mean = np.mean([ori_eval_envs[i].buffer_return for i in range(eval_episodes)])
std = np.std([ori_eval_envs[i].buffer_return for i in range(eval_episodes)])
return mean, std
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--env", default="halfcheetah-medium-expert-v2") # OpenAI gym environment name
parser.add_argument("--seed", default=0, type=int) # Sets Gym, PyTorch and Numpy seeds
parser.add_argument("--expid", default="default", type=str)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--save_model", default=1, type=int)
parser.add_argument('--debug', type=int, default=0)
parser.add_argument('--beta', type=float, default=None)
parser.add_argument('--actor_load_path', type=str, default=None)
parser.add_argument('--critic_load_path', type=str, default=None)
parser.add_argument('--policy_batchsize', type=int, default=256)
parser.add_argument('--actor_blocks', type=int, default=3)
parser.add_argument('--z_noise', type=int, default=1)
parser.add_argument('--WT', type=str, default="VDS")
parser.add_argument('--q_layer', type=int, default=2)
parser.add_argument('--n_policy_epochs', type=int, default=100)
parser.add_argument('--policy_layer', type=int, default=None)
parser.add_argument('--critic_load_epochs', type=int, default=150)
parser.add_argument('--regq', type=int, default=0)
print("**************************")
args = parser.parse_known_args()[0]
if args.debug:
args.actor_epoch =1
args.critic_epoch =1
if args.policy_layer is None:
args.policy_layer = 4 if "maze" in args.env else 2
if "maze" in args.env:
args.regq = 1
if args.beta is None:
args.beta = temperature_coefficients[args.env]
print(args)
return args
if __name__ == "__main__":
args = get_args()
print(args)
train_policy.py
import functools
import os
import d4rl
import gym
import numpy as np
import torch
import tqdm
import wandb
from dataset import D4RL_dataset
from SRPO import SRPO
from utils import get_args, marginal_prob_std, pallaral_simple_eval_policy
def train_policy(args, score_model, data_loader, start_epoch=0):
n_epochs = args.n_policy_epochs
tqdm_epoch = tqdm.trange(start_epoch, n_epochs)
evaluation_inerval = 2
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
for _ in range(10000):
data = data_loader.sample(args.policy_batchsize)
loss2 = score_model.update_SRPO_policy(data)
avg_loss += 0.0
num_items += 1
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
if (epoch % evaluation_inerval == (evaluation_inerval -1)) or epoch==0:
mean, std = pallaral_simple_eval_policy(score_model.SRPO_policy.select_actions,args.env,00)
args.run.log({"eval/rew{}".format("deter"): mean}, step=epoch+1)
args.run.log({"info/policy_q": score_model.SRPO_policy.q.detach().cpu().numpy()}, step=epoch+1)
args.run.log({"info/lr": score_model.SRPO_policy_optimizer.state_dict()['param_groups'][0]['lr']}, step=epoch+1)
def critic(args):
for dir in ["./SRPO_policy_models"]:
if not os.path.exists(dir):
os.makedirs(dir)
if not os.path.exists(os.path.join("./SRPO_policy_models", str(args.expid))):
os.makedirs(os.path.join("./SRPO_policy_models", str(args.expid)))
run = wandb.init(project="SRPO_policy", name=str(args.expid))
wandb.config.update(args)
env = gym.make(args.env)
env.seed(args.seed)
env.action_space.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
args.run = run
marginal_prob_std_fn = functools.partial(marginal_prob_std, device=args.device,beta_1=20.0)
args.marginal_prob_std_fn = marginal_prob_std_fn
score_model= SRPO(input_dim=state_dim+action_dim, output_dim=action_dim, marginal_prob_std=marginal_prob_std_fn, args=args).to(args.device)
score_model.q[0].to(args.device)
# args.actor_load_path = "path/to/yout/ckpt/file"
if args.actor_load_path is not None:
print("loading actor...")
ckpt = torch.load(args.actor_load_path, map_location=args.device)
score_model.load_state_dict({k:v for k,v in ckpt.items() if "diffusion_behavior" in k}, strict=False)
else:
assert False
# args.critic_load_path = "path/to/yout/ckpt/file"
if args.critic_load_path is not None:
print("loadind critic...")
ckpt = torch.load(args.critic_load_path, map_location=args.device)
score_model.q[0].load_state_dict(ckpt)
else:
assert False
dataset = D4RL_dataset(args)
print("training critic")
train_policy(args, score_model, dataset, start_epoch=0)
print("finished")
run.finish()
if __name__ == "__main__":
args = get_args()
critic(args)
train_critic.py
import os
import d4rl
import gym
import numpy as np
import torch
import tqdm
import wandb
from dataset import D4RL_dataset
from SRPO import SRPO_IQL
from utils import get_args, pallaral_simple_eval_policy
def train_critic(args, score_model, data_loader, start_epoch=0):
n_epochs = 150
tqdm_epoch = tqdm.trange(start_epoch, n_epochs)
# evaluation_inerval = 4
evaluation_inerval = 1
save_interval = 10
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
for _ in range(10000):
data = data_loader.sample(256)
loss2 = score_model.update_iql(data)
avg_loss += 0.0
num_items += 1
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
if (epoch % evaluation_inerval == (evaluation_inerval -1)) or epoch==0:
if (epoch % 5 == 4) or epoch==0:
mean, std = pallaral_simple_eval_policy(score_model.deter_policy.select_actions,args.env,00)
args.run.log({"eval/rew{}".format("deter"): mean}, step=epoch+1)
args.run.log({"loss/v_loss": score_model.q[0].v_loss.detach().cpu().numpy()}, step=epoch+1)
args.run.log({"loss/q_loss": score_model.q[0].q_loss.detach().cpu().numpy()}, step=epoch+1)
args.run.log({"loss/q": score_model.q[0].q.detach().cpu().numpy()}, step=epoch+1)
args.run.log({"loss/v": score_model.q[0].v.detach().cpu().numpy()}, step=epoch+1)
args.run.log({"loss/policy_loss": score_model.policy_loss.detach().cpu().numpy()}, step=epoch+1)
args.run.log({"info/lr": score_model.deter_policy_optimizer.state_dict()['param_groups'][0]['lr']}, step=epoch+1)
if args.save_model and ((epoch % save_interval == (save_interval - 1)) or epoch==0):
torch.save(score_model.q[0].state_dict(), os.path.join("./SRPO_model_factory", str(args.expid), "critic_ckpt{}.pth".format(epoch+1)))
def critic(args):
for dir in ["./SRPO_model_factory"]:
if not os.path.exists(dir):
os.makedirs(dir)
if not os.path.exists(os.path.join("./SRPO_model_factory", str(args.expid))):
os.makedirs(os.path.join("./SRPO_model_factory", str(args.expid)))
run = wandb.init(project="SRPO_model_factory", name=str(args.expid))
wandb.config.update(args)
env = gym.make(args.env)
env.seed(args.seed)
env.action_space.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
args.run = run
score_model= SRPO_IQL(input_dim=state_dim+action_dim, output_dim=action_dim, args=args).to(args.device)
score_model.q[0].to(args.device)
dataset = D4RL_dataset(args)
print("training critic")
train_critic(args, score_model, dataset, start_epoch=0)
print("finished")
run.finish()
if __name__ == "__main__":
args = get_args()
critic(args)
train_behavior.py
import functools
import os
import d4rl
import gym
import numpy as np
import torch
import tqdm
import wandb
from dataset import D4RL_dataset
from SRPO import SRPO_Behavior
from utils import get_args, marginal_prob_std
def train_behavior(args, score_model, data_loader, start_epoch=0):
n_epochs = 200
tqdm_epoch = tqdm.trange(start_epoch, n_epochs)
# evaluation_inerval = 4
evaluation_inerval = 1
save_interval = 20
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
for _ in range(10000):
data = data_loader.sample(2048)
loss2 = score_model.update_behavior(data)
avg_loss += score_model.loss.detach().cpu().numpy()
num_items += 1
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
if (epoch % evaluation_inerval == (evaluation_inerval -1)) or epoch==0:
args.run.log({"loss/diffusion": score_model.loss.detach().cpu().numpy()}, step=epoch+1)
if args.save_model and ((epoch % save_interval == (save_interval - 1)) or epoch==0):
torch.save(score_model.state_dict(), os.path.join("./SRPO_model_factory", str(args.expid), "behavior_ckpt{}.pth".format(epoch+1)))
def behavior(args):
for dir in ["./SRPO_model_factory"]:
if not os.path.exists(dir):
os.makedirs(dir)
if not os.path.exists(os.path.join("./SRPO_model_factory", str(args.expid))):
os.makedirs(os.path.join("./SRPO_model_factory", str(args.expid)))
run = wandb.init(project="SRPO_model_factory", name=str(args.expid))
wandb.config.update(args)
env = gym.make(args.env)
env.seed(args.seed)
env.action_space.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
args.run = run
marginal_prob_std_fn = functools.partial(marginal_prob_std, device=args.device,beta_1=20.0)
args.marginal_prob_std_fn = marginal_prob_std_fn
score_model= SRPO_Behavior(input_dim=state_dim+action_dim, output_dim=action_dim, marginal_prob_std=marginal_prob_std_fn, args=args).to(args.device)
dataset = D4RL_dataset(args)
print("training behavior")
train_behavior(args, score_model, dataset, start_epoch=0)
print("finished")
run.finish()
if __name__ == "__main__":
args = get_args()
behavior(args)
ㅓ