import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import mrcfile
import numpy as np
# データの準備
class PeptideDataset(Dataset):
def __init__(self, pdb_files, map_size):
self.pdb_files = pdb_files
self.map_size = map_size
def __len__(self):
return len(self.pdb_files)
def __getitem__(self, idx):
pdb_file = self.pdb_files[idx]
with mrcfile.open(pdb_file, mode='r', permissive=True) as mrc:
density_map = np.array(mrc.data, dtype=np.float32)
density_map = np.expand_dims(density_map, axis=0) # チャンネル次元を追加
density_map = torch.from_numpy(density_map)
# 重心位置の計算
coords = np.loadtxt(pdb_file, usecols=(6, 7, 8)) # Cα原子の座標を読み込み
centroid = np.mean(coords, axis=0)
centroid = torch.from_numpy(centroid.astype(np.float32))
return density_map, centroid
# U-netモデルの構築
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
# エンコーダー部分
self.enc_conv1 = self.conv_block(in_channels, 64)
self.enc_conv2 = self.conv_block(64, 128)
self.enc_conv3 = self.conv_block(128, 256)
self.enc_conv4 = self.conv_block(256, 512)
# デコーダー部分
self.dec_conv1 = self.conv_block(512, 256)
self.dec_conv2 = self.conv_block(256, 128)
self.dec_conv3 = self.conv_block(128, 64)
self.dec_conv4 = self.conv_block(64, out_channels)
self.upconv1 = nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2)
self.upconv2 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)
self.upconv3 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
# エンコーダー部分
enc1 = self.enc_conv1(x)
enc2 = self.enc_conv2(self.pool(enc1))
enc3 = self.enc_conv3(self.pool(enc2))
enc4 = self.enc_conv4(self.pool(enc3))
# デコーダー部分
dec1 = self.upconv1(enc4)
dec1 = torch.cat((dec1, enc3), dim=1)
dec1 = self.dec_conv1(dec1)
dec2 = self.upconv2(dec1)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.dec_conv2(dec2)
dec3 = self.upconv3(dec2)
dec3 = torch.cat((dec3, enc1), dim=1)
dec3 = self.dec_conv3(dec3)
out = self.dec_conv4(dec3)
return out
dd
import os
import glob
from sklearn.model_selection import train_test_split
# トレーニングデータの準備
pdb_files = glob.glob('path/to/pdb/files/*.pdb') # PDBファイルのパスを指定
train_files, val_files = train_test_split(pdb_files, test_size=0.2, random_state=42)
train_dataset = PeptideDataset(train_files, map_size=(64, 64, 64))
val_dataset = PeptideDataset(val_files, map_size=(64, 64, 64))
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
# モデルの初期化
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=1, out_channels=3).to(device)
# 損失関数とオプティマイザの定義
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# モデルのトレーニング
num_epochs = 100
best_loss = float('inf')
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
for density_maps, centroids in train_loader:
density_maps = density_maps.to(device)
centroids = centroids.to(device)
optimizer.zero_grad()
outputs = model(density_maps)
loss = criterion(outputs.squeeze(), centroids)
loss.backward()
optimizer.step()
train_loss += loss.item() * density_maps.size(0)
train_loss = train_loss / len(train_dataset)
model.eval()
val_loss = 0.0
with torch.no_grad():
for density_maps, centroids in val_loader:
density_maps = density_maps.to(device)
centroids = centroids.to(device)
outputs = model(density_maps)
loss = criterion(outputs.squeeze(), centroids)
val_loss += loss.item() * density_maps.size(0)
val_loss = val_loss / len(val_dataset)
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
if val_loss < best_loss:
best_loss = val_loss
torch.save(model.state_dict(), 'best_model.pth')
print('Training completed. Best model saved.')
import mrcfile
# 推論する密度マップのファイルパスを指定
test_map_file = 'path/to/test/map/file.mrc'
# 推論用のデータセットを作成
test_dataset = PeptideDataset([test_map_file], map_size=(64, 64, 64))
test_loader = DataLoader(test_dataset, batch_size=1)
# 最良のモデルをロード
best_model = UNet(in_channels=1, out_channels=3).to(device)
best_model.load_state_dict(torch.load('best_model.pth'))
best_model.eval()
# 推論の実行
with torch.no_grad():
for density_map, _ in test_loader:
density_map = density_map.to(device)
output = best_model(density_map)
predicted_centroid = output.squeeze().cpu().numpy()
print('Predicted Centroid:', predicted_centroid)
# 推論結果を保存(オプション)
with mrcfile.new('predicted_centroid.mrc', overwrite=True) as mrc:
mrc.set_data(predicted_centroid.astype(np.float32))
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
# ... (中略) ...
self.dec_conv4 = self.conv_block(64, out_channels)
# ... (中略) ...
def forward(self, x):
# ... (中略) ...
out = self.dec_conv4(dec3)
out = out.view(-1, 11, 3) # 出力を (バッチサイズ, 11, 3) の形状に変更
return out
criterion = nn.MSELoss(reduction='sum')
for epoch in range(num_epochs):
# ... (中略) ...
for density_maps, ca_coords in train_loader:
density_maps = density_maps.to(device)
ca_coords = ca_coords.to(device)
optimizer.zero_grad()
outputs = model(density_maps)
loss = criterion(outputs, ca_coords) / outputs.size(0) # バッチサイズで割って平均化
# ... (中略) ...
# ... (中略) ...
with torch.no_grad():
for density_maps, ca_coords in val_loader:
density_maps = density_maps.to(device)
ca_coords = ca_coords.to(device)
outputs = model(density_maps)
loss = criterion(outputs, ca_coords) / outputs.size(0) # バッチサイズで割って平均化
# ... (中略) ...
with torch.no_grad():
for density_map, _ in test_loader:
density_map = density_map.to(device)
output = best_model(density_map)
predicted_coords = output.squeeze().cpu().numpy()
print('Predicted Cα Coordinates:')
print(predicted_coords)
# 11個の原子を通る平面の計算
centroid = np.mean(predicted_coords, axis=0)
centered_coords = predicted_coords - centroid
_, _, vh = np.linalg.svd(centered_coords)
normal_vector = vh[2, :]
print('Plane Normal Vector:', normal_vector)
コメント