class="hljs-ln-code"> class="hljs-ln-line">
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">import tensorflow as tf
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">from tensorflow.keras.models import Model
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">from tensorflow.keras.layers import Dense, Input
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line">from tensorflow.keras.optimizers import Adam
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line">
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"># 假设我们有一个多目标的数据集,包含多个标签
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"># 数据准备部分在实际应用中可能更加复杂,这里简化处理
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line">num_features = 10
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line">num_targets = 3 # 假设有三个不同的预测目标
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line">batch_size = 64
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line">epochs = 10
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line">
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line"># 定义一个简单的DNN模型
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line">def create_model(input_shape, num_targets):
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line"> inputs = Input(shape=(input_shape,))
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"> x = Dense(128, activation='relu')(inputs)
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line"> x = Dense(64, activation='relu')(x)
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line"> outputs = [Dense(1)(x) for _ in range(num_targets)] # 每个目标一个输出
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line"> model = Model(inputs=inputs, outputs=outputs)
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line"> return model
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line">
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line">model = create_model(input_shape=num_features, num_targets=num_targets)
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line">
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line"># 自定义的多目标加权损失函数
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line">def weighted_loss(weights):
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line"> def loss(y_true_list, y_pred_list):
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line"> total_loss = 0
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line"> for weight, y_true, y_pred in zip(weights, y_true_list, y_pred_list):
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="31"> class="hljs-ln-code"> class="hljs-ln-line"> # 这里假设使用均方误差作为损失函数
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="32"> class="hljs-ln-code"> class="hljs-ln-line"> single_loss = tf.reduce_mean(tf.square(y_true - y_pred))
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="33"> class="hljs-ln-code"> class="hljs-ln-line"> total_loss += weight * single_loss
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="34"> class="hljs-ln-code"> class="hljs-ln-line"> return total_loss
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="35"> class="hljs-ln-code"> class="hljs-ln-line"> return loss
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="36"> class="hljs-ln-code"> class="hljs-ln-line">
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="37"> class="hljs-ln-code"> class="hljs-ln-line"># 假设我们给每个目标分配的权重分别是0.5, 0.3, 0.2
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="38"> class="hljs-ln-code"> class="hljs-ln-line">loss_weights = [0.5, 0.3, 0.2]
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="39"> class="hljs-ln-code"> class="hljs-ln-line">model.compile(optimizer=Adam(), loss=weighted_loss(loss_weights))
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="40"> class="hljs-ln-code"> class="hljs-ln-line">
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="41"> class="hljs-ln-code"> class="hljs-ln-line"># 生成一些虚拟数据以进行训练
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="42"> class="hljs-ln-code"> class="hljs-ln-line">import numpy as np
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="43"> class="hljs-ln-code"> class="hljs-ln-line">
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="44"> class="hljs-ln-code"> class="hljs-ln-line">x_train = np.random.rand(1000, num_features) # 随机生成特征
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="45"> class="hljs-ln-code"> class="hljs-ln-line">y_train_list = [np.random.rand(1000, 1) for _ in range(num_targets)] # 每个目标的标签
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="46"> class="hljs-ln-code"> class="hljs-ln-line">
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="47"> class="hljs-ln-code"> class="hljs-ln-line"># 训练模型
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="48"> class="hljs-ln-code"> class="hljs-ln-line">model.fit(x_train, y_train_list, batch_size=batch_size, epochs=epochs)
  • class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    PyTorch版本:

    代码结构

    1. 导入必要的库
    2. 定义多目标 DNN 模型
    3. 定义加权损失函数
    4. 生成虚拟数据并进行训练
    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">import torch
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">import torch.nn as nn
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">import torch.optim as optim
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">from torch.utils.data import DataLoader, TensorDataset
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"># 定义多目标 DNN 模型
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line">class MultiTargetDNN(nn.Module):
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"> def __init__(self, input_dim, hidden_dim, output_dim):
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"> super(MultiTargetDNN, self).__init__()
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"> self.fc1 = nn.Linear(input_dim, hidden_dim)
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line"> self.relu = nn.ReLU()
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line"> self.fc2 = nn.Linear(hidden_dim, output_dim)
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line">
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line"> def forward(self, x):
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line"> x = self.fc1(x)
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line"> x = self.relu(x)
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line"> x = self.fc2(x)
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"> return x
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line">
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line"># 定义加权损失函数
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line">def weighted_loss(weights):
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line"> def loss(y_true_list, y_pred_list):
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line"> total_loss = 0
    24. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line"> for weight, y_true, y_pred in zip(weights, y_true_list, y_pred_list):
    25. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line"> # 使用均方误差作为损失函数
    26. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line"> single_loss = nn.MSELoss()(y_true, y_pred)
    27. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line"> total_loss += weight * single_loss
    28. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line"> return total_loss
    29. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line"> return loss
    30. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line">
    31. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="31"> class="hljs-ln-code"> class="hljs-ln-line"># 生成虚拟数据
    32. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="32"> class="hljs-ln-code"> class="hljs-ln-line">def generate_data(num_samples, input_dim, num_targets):
    33. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="33"> class="hljs-ln-code"> class="hljs-ln-line"> x_train = torch.randn(num_samples, input_dim)
    34. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="34"> class="hljs-ln-code"> class="hljs-ln-line"> y_train_list = [torch.randn(num_samples, 1) for _ in range(num_targets)]
    35. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="35"> class="hljs-ln-code"> class="hljs-ln-line"> return x_train, y_train_list
    36. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="36"> class="hljs-ln-code"> class="hljs-ln-line">
    37. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="37"> class="hljs-ln-code"> class="hljs-ln-line"># 主训练函数
    38. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="38"> class="hljs-ln-code"> class="hljs-ln-line">def train_model(x_train, y_train_list, weights, num_epochs=10, batch_size=8):
    39. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="39"> class="hljs-ln-code"> class="hljs-ln-line"> # 将数据转换为 DataLoader
    40. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="40"> class="hljs-ln-code"> class="hljs-ln-line"> dataset = TensorDataset(*([x_train] + y_train_list))
    41. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="41"> class="hljs-ln-code"> class="hljs-ln-line"> dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    42. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="42"> class="hljs-ln-code"> class="hljs-ln-line">
    43. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="43"> class="hljs-ln-code"> class="hljs-ln-line"> # 初始化模型、损失函数和优化器
    44. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="44"> class="hljs-ln-code"> class="hljs-ln-line"> input_dim = x_train.shape[1]
    45. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="45"> class="hljs-ln-code"> class="hljs-ln-line"> hidden_dim = 64
    46. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="46"> class="hljs-ln-code"> class="hljs-ln-line"> output_dim = len(y_train_list)
    47. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="47"> class="hljs-ln-code"> class="hljs-ln-line"> model = MultiTargetDNN(input_dim, hidden_dim, output_dim)
    48. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="48"> class="hljs-ln-code"> class="hljs-ln-line"> criterion = weighted_loss(weights)
    49. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="49"> class="hljs-ln-code"> class="hljs-ln-line"> optimizer = optim.Adam(model.parameters(), lr=0.001)
    50. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="50"> class="hljs-ln-code"> class="hljs-ln-line">
    51. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="51"> class="hljs-ln-code"> class="hljs-ln-line"> # 训练模型
    52. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="52"> class="hljs-ln-code"> class="hljs-ln-line"> for epoch in range(num_epochs):
    53. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="53"> class="hljs-ln-code"> class="hljs-ln-line"> total_loss = 0
    54. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="54"> class="hljs-ln-code"> class="hljs-ln-line"> for batch in dataloader:
    55. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="55"> class="hljs-ln-code"> class="hljs-ln-line"> x_batch = batch[0]
    56. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="56"> class="hljs-ln-code"> class="hljs-ln-line"> y_true_list = batch[1:]
    57. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="57"> class="hljs-ln-code"> class="hljs-ln-line">
    58. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="58"> class="hljs-ln-code"> class="hljs-ln-line"> # 前向传播
    59. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="59"> class="hljs-ln-code"> class="hljs-ln-line"> optimizer.zero_grad()
    60. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="60"> class="hljs-ln-code"> class="hljs-ln-line"> y_pred_list = model(x_batch).split(1, dim=1)
    61. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="61"> class="hljs-ln-code"> class="hljs-ln-line"> loss_value = criterion(y_true_list, y_pred_list)
    62. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="62"> class="hljs-ln-code"> class="hljs-ln-line">
    63. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="63"> class="hljs-ln-code"> class="hljs-ln-line"> # 反向传播和优化
    64. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="64"> class="hljs-ln-code"> class="hljs-ln-line"> loss_value.backward()
    65. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="65"> class="hljs-ln-code"> class="hljs-ln-line"> optimizer.step()
    66. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="66"> class="hljs-ln-code"> class="hljs-ln-line">
    67. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="67"> class="hljs-ln-code"> class="hljs-ln-line"> total_loss += loss_value.item()
    68. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="68"> class="hljs-ln-code"> class="hljs-ln-line">
    69. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="69"> class="hljs-ln-code"> class="hljs-ln-line"> print(f"Epoch {epoch + 1}: Loss = {total_loss / len(dataloader)}")
    70. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="70"> class="hljs-ln-code"> class="hljs-ln-line">
    71. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="71"> class="hljs-ln-code"> class="hljs-ln-line"># 主函数
    72. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="72"> class="hljs-ln-code"> class="hljs-ln-line">if __name__ == '__main__':
    73. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="73"> class="hljs-ln-code"> class="hljs-ln-line"> # 设置参数
    74. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="74"> class="hljs-ln-code"> class="hljs-ln-line"> num_samples = 64
    75. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="75"> class="hljs-ln-code"> class="hljs-ln-line"> input_dim = 10
    76. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="76"> class="hljs-ln-code"> class="hljs-ln-line"> num_targets = 3
    77. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="77"> class="hljs-ln-code"> class="hljs-ln-line"> weights = [0.5, 0.3, 0.2]
    78. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="78"> class="hljs-ln-code"> class="hljs-ln-line">
    79. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="79"> class="hljs-ln-code"> class="hljs-ln-line"> # 生成虚拟数据
    80. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="80"> class="hljs-ln-code"> class="hljs-ln-line"> x_train, y_train_list = generate_data(num_samples, input_dim, num_targets)
    81. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="81"> class="hljs-ln-code"> class="hljs-ln-line">
    82. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="82"> class="hljs-ln-code"> class="hljs-ln-line"> # 训练模型
    83. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="83"> class="hljs-ln-code"> class="hljs-ln-line"> train_model(x_train, y_train_list, weights)
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    关键点解释:

    1. 多目标 DNN 模型:

      • MultiTargetDNN 类定义了一个简单的两层全连接网络,输入维度为 input_dim,隐藏层维度为 hidden_dim,输出维度为 output_dim(即目标数量)。
    2. 加权损失函数:

      • weighted_loss 函数接受权重列表,并返回一个损失计算函数。该函数通过循环遍历每个目标的预测值和真实值,计算均方误差并根据权重进行加权求和。
    3. 生成虚拟数据:

      • generate_data 函数生成随机输入数据 x_train 和多个目标的真实值 y_train_list
    4. 训练模型:

      • train_model 函数将数据加载到 DataLoader 中,并初始化模型、损失函数和优化器。
      • 在每个 epoch 中,遍历数据批次,进行前向传播、计算损失、反向传播和参数更新。

    三、总结

    本文从技术原理和技术优缺点方面对推荐系统深度学习多目标融合的“样本Loss加权”进行简要讲解,以达到对预期指标的强化,具有模型简单,成本较低的优点,但同时优化周期长、多个指标跷跷板问题也是该方法的缺点,业界针对该方法的缺点进行了一系列的升级,专栏中会逐步讲解,期待您的关注。

    如果您还有时间,欢迎阅读本专栏的其他文章:

    【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)

    【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model)​​​​​​​

    id="blogVoteBox" style="width:400px;margin:auto;margin-top:12px" class="blog-vote-box"> class="vote-box csdn-vote" style="opacity: 1;"> class="pos-box pt0" style="height: 222px; visibility: visible;">
    注:本文转载自blog.csdn.net的$海阔天空$的文章"https://blog.csdn.net/ywyngq/article/details/118158608"。版权归原作者所有,此博客不拥有其著作权,亦不承担相应法律责任。如有侵权,请联系我们删除。
    复制链接

    评论记录:

    未查询到任何数据!