class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">import tensorflow as tf class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">from tensorflow.keras.models import Model class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">from tensorflow.keras.layers import Dense, Input class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line">from tensorflow.keras.optimizers import Adam class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line">num_features = 10 class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line">num_targets = 3 class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line">batch_size = 64 class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line">epochs = 10 class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line">def create_model(input_shape, num_targets): class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line"> inputs = Input(shape=(input_shape,)) class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"> x = Dense(128, activation='relu')(inputs) class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line"> x = Dense(64, activation='relu')(x) class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line"> outputs = [Dense(1)(x) for _ in range(num_targets)] class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line"> model = Model(inputs=inputs, outputs=outputs) class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line"> return model class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line">model = create_model(input_shape=num_features, num_targets=num_targets) class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line">def weighted_loss(weights): class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line"> def loss(y_true_list, y_pred_list): class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line"> total_loss = 0 class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line"> for weight, y_true, y_pred in zip(weights, y_true_list, y_pred_list): class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="31"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="32"> class="hljs-ln-code"> class="hljs-ln-line"> single_loss = tf.reduce_mean(tf.square(y_true - y_pred)) class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="33"> class="hljs-ln-code"> class="hljs-ln-line"> total_loss += weight * single_loss class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="34"> class="hljs-ln-code"> class="hljs-ln-line"> return total_loss class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="35"> class="hljs-ln-code"> class="hljs-ln-line"> return loss class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="36"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="37"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="38"> class="hljs-ln-code"> class="hljs-ln-line">loss_weights = [0.5, 0.3, 0.2] class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="39"> class="hljs-ln-code"> class="hljs-ln-line">model.compile(optimizer=Adam(), loss=weighted_loss(loss_weights)) class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="40"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="41"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="42"> class="hljs-ln-code"> class="hljs-ln-line">import numpy as np class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="43"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="44"> class="hljs-ln-code"> class="hljs-ln-line">x_train = np.random.rand(1000, num_features) class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="45"> class="hljs-ln-code"> class="hljs-ln-line">y_train_list = [np.random.rand(1000, 1) for _ in range(num_targets)] class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="46"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="47"> class="hljs-ln-code"> class="hljs-ln-line"> class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="48"> class="hljs-ln-code"> class="hljs-ln-line">model.fit(x_train, y_train_list, batch_size=batch_size, epochs=epochs) class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">
PyTorch版本:
代码结构
- 导入必要的库。
- 定义多目标 DNN 模型。
- 定义加权损失函数。
- 生成虚拟数据并进行训练。
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">import torch
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">import torch.nn as nn
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">import torch.optim as optim
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">from torch.utils.data import DataLoader, TensorDataset
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line">class MultiTargetDNN(nn.Module):
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"> def __init__(self, input_dim, hidden_dim, output_dim):
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"> super(MultiTargetDNN, self).__init__()
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"> self.fc1 = nn.Linear(input_dim, hidden_dim)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line"> self.relu = nn.ReLU()
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line"> self.fc2 = nn.Linear(hidden_dim, output_dim)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line"> def forward(self, x):
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line"> x = self.fc1(x)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line"> x = self.relu(x)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line"> x = self.fc2(x)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"> return x
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line">def weighted_loss(weights):
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line"> def loss(y_true_list, y_pred_list):
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line"> total_loss = 0
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line"> for weight, y_true, y_pred in zip(weights, y_true_list, y_pred_list):
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line"> single_loss = nn.MSELoss()(y_true, y_pred)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line"> total_loss += weight * single_loss
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line"> return total_loss
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line"> return loss
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="31"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="32"> class="hljs-ln-code"> class="hljs-ln-line">def generate_data(num_samples, input_dim, num_targets):
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="33"> class="hljs-ln-code"> class="hljs-ln-line"> x_train = torch.randn(num_samples, input_dim)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="34"> class="hljs-ln-code"> class="hljs-ln-line"> y_train_list = [torch.randn(num_samples, 1) for _ in range(num_targets)]
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="35"> class="hljs-ln-code"> class="hljs-ln-line"> return x_train, y_train_list
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="36"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="37"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="38"> class="hljs-ln-code"> class="hljs-ln-line">def train_model(x_train, y_train_list, weights, num_epochs=10, batch_size=8):
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="39"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="40"> class="hljs-ln-code"> class="hljs-ln-line"> dataset = TensorDataset(*([x_train] + y_train_list))
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="41"> class="hljs-ln-code"> class="hljs-ln-line"> dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="42"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="43"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="44"> class="hljs-ln-code"> class="hljs-ln-line"> input_dim = x_train.shape[1]
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="45"> class="hljs-ln-code"> class="hljs-ln-line"> hidden_dim = 64
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="46"> class="hljs-ln-code"> class="hljs-ln-line"> output_dim = len(y_train_list)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="47"> class="hljs-ln-code"> class="hljs-ln-line"> model = MultiTargetDNN(input_dim, hidden_dim, output_dim)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="48"> class="hljs-ln-code"> class="hljs-ln-line"> criterion = weighted_loss(weights)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="49"> class="hljs-ln-code"> class="hljs-ln-line"> optimizer = optim.Adam(model.parameters(), lr=0.001)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="50"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="51"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="52"> class="hljs-ln-code"> class="hljs-ln-line"> for epoch in range(num_epochs):
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="53"> class="hljs-ln-code"> class="hljs-ln-line"> total_loss = 0
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="54"> class="hljs-ln-code"> class="hljs-ln-line"> for batch in dataloader:
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="55"> class="hljs-ln-code"> class="hljs-ln-line"> x_batch = batch[0]
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="56"> class="hljs-ln-code"> class="hljs-ln-line"> y_true_list = batch[1:]
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="57"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="58"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="59"> class="hljs-ln-code"> class="hljs-ln-line"> optimizer.zero_grad()
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="60"> class="hljs-ln-code"> class="hljs-ln-line"> y_pred_list = model(x_batch).split(1, dim=1)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="61"> class="hljs-ln-code"> class="hljs-ln-line"> loss_value = criterion(y_true_list, y_pred_list)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="62"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="63"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="64"> class="hljs-ln-code"> class="hljs-ln-line"> loss_value.backward()
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="65"> class="hljs-ln-code"> class="hljs-ln-line"> optimizer.step()
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="66"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="67"> class="hljs-ln-code"> class="hljs-ln-line"> total_loss += loss_value.item()
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="68"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="69"> class="hljs-ln-code"> class="hljs-ln-line"> print(f"Epoch {epoch + 1}: Loss = {total_loss / len(dataloader)}")
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="70"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="71"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="72"> class="hljs-ln-code"> class="hljs-ln-line">if __name__ == '__main__':
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="73"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="74"> class="hljs-ln-code"> class="hljs-ln-line"> num_samples = 64
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="75"> class="hljs-ln-code"> class="hljs-ln-line"> input_dim = 10
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="76"> class="hljs-ln-code"> class="hljs-ln-line"> num_targets = 3
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="77"> class="hljs-ln-code"> class="hljs-ln-line"> weights = [0.5, 0.3, 0.2]
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="78"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="79"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="80"> class="hljs-ln-code"> class="hljs-ln-line"> x_train, y_train_list = generate_data(num_samples, input_dim, num_targets)
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="81"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="82"> class="hljs-ln-code"> class="hljs-ln-line">
- class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="83"> class="hljs-ln-code"> class="hljs-ln-line"> train_model(x_train, y_train_list, weights)
class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">
关键点解释:
-
多目标 DNN 模型:
MultiTargetDNN
类定义了一个简单的两层全连接网络,输入维度为 input_dim
,隐藏层维度为 hidden_dim
,输出维度为 output_dim
(即目标数量)。
-
加权损失函数:
weighted_loss
函数接受权重列表,并返回一个损失计算函数。该函数通过循环遍历每个目标的预测值和真实值,计算均方误差并根据权重进行加权求和。
-
生成虚拟数据:
generate_data
函数生成随机输入数据 x_train
和多个目标的真实值 y_train_list
。
-
训练模型:
train_model
函数将数据加载到 DataLoader
中,并初始化模型、损失函数和优化器。- 在每个 epoch 中,遍历数据批次,进行前向传播、计算损失、反向传播和参数更新。

三、总结
本文从技术原理和技术优缺点方面对推荐系统深度学习多目标融合的“样本Loss加权”进行简要讲解,以达到对预期指标的强化,具有模型简单,成本较低的优点,但同时优化周期长、多个指标跷跷板问题也是该方法的缺点,业界针对该方法的缺点进行了一系列的升级,专栏中会逐步讲解,期待您的关注。
如果您还有时间,欢迎阅读本专栏的其他文章:
【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)
【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model)
id="blogVoteBox" style="width:400px;margin:auto;margin-top:12px" class="blog-vote-box"> class="vote-box csdn-vote" style="opacity: 1;">
class="pos-box pt0" style="height: 222px; visibility: visible;">
评论记录:
回复评论: