构建数据读取器

至此我们已经分别处理了用户、电影和评分数据,接下来我们要利用这些处理好的数据,构建一个数据读取器,方便在训练神经网络时直接调用。

首先,构造一个函数,把读取并处理后的数据整合到一起,即在rating数据中补齐用户和电影的所有特征字段。

  1. def get_dataset(usr_info, rating_info, movie_info):
  2. trainset = []
  3. # 按照评分数据的key值索引数据
  4. for usr_id in rating_info.keys():
  5. usr_ratings = rating_info[usr_id]
  6. for movie_id in usr_ratings:
  7. trainset.append({'usr_info': usr_info[usr_id],
  8. 'mov_info': movie_info[movie_id],
  9. 'scores': usr_ratings[movie_id]})
  10. return trainset
  11. dataset = get_dataset(usr_info, rating_info, movie_info)
  12. print("数据集总数据数:", len(dataset))
  1. 数据集总数据数: 1000209

接下来构建数据读取器函数load_data(),先看一下整体结构:

  1. import random
  2. def load_data(dataset=None, mode='train'):
  3. """定义一些超参数等等"""
  4. # 定义数据迭代加载器
  5. def data_generator():
  6. """ 定义数据的处理过程"""
  7. data = None
  8. yield data
  9. # 返回数据迭代加载器
  10. return data_generator

我们来看一下完整的数据读取器函数实现,核心是将多个样本数据合并到一个列表(batch),当该列表达到batchsize后,以yield的方式返回(Python数据迭代器)。

在进行批次数据拼合的同时,完成数据格式和数据尺寸的转换:

  • 由于飞桨框架的网络接入层要求将数据先转换成np.array的类型,再转换成框架内置变量variable的类型。所以在数据返回前,需将所有数据均转换成np.array的类型,方便后续处理。
  • 每个特征字段的尺寸也需要根据网络输入层的设计进行调整。根据之前的分析,用户和电影的所有原始特征可以分为四类,ID类(用户ID,电影ID,性别,年龄,职业)、列表类(电影类别)、文本类(电影名称)和图像类(电影海报)。因为每种特征后续接入的网络层方案不同,所以要求他们的数据尺寸也不同。这里我们先初步的了解即可,待后续阅读了模型设计章节后,将对输入输出尺寸有更好的理解。

数据尺寸的说明:

  • ID类(用户ID,电影ID,性别,年龄,职业)处理成(256,1)的尺寸,以便后续接入Embedding层。第一个维度256是batchsize,第二个维度是1,因为Embedding层要求输入数据的最后一维为1。
  • 列表类(电影类别)处理成(256,6,1)的尺寸,6是电影最多的类比个数,以便后续接入全连接层。
  • 文本类(电影名称)处理成(256,1,15,1)的尺寸,15是电影名称的最大单词数,以便接入2D卷积层。2D卷积层要求输入数据为四维,对应图像数据是【批次大小,通道数、图像的长、图像的宽】,其中RGB的彩色图像是3通道,灰度图像是单通道。
  • 图像类(电影海报)处理成(256,3,64,64)的尺寸, 以便接入2D卷积层。图像的原始尺寸是180270彩色图像,使用resize函数压缩成6464的尺寸,减少网络计算。
  1. import random
  2. use_poster = False
  3. def load_data(dataset=None, mode='train'):
  4. # 定义数据迭代Batch大小
  5. BATCHSIZE = 256
  6. data_length = len(dataset)
  7. index_list = list(range(data_length))
  8. # 定义数据迭代加载器
  9. def data_generator():
  10. # 训练模式下,打乱训练数据
  11. if mode == 'train':
  12. random.shuffle(index_list)
  13. # 声明每个特征的列表
  14. usr_id_list,usr_gender_list,usr_age_list,usr_job_list = [], [], [], []
  15. mov_id_list,mov_tit_list,mov_cat_list,mov_poster_list = [], [], [], []
  16. score_list = []
  17. # 索引遍历输入数据集
  18. for idx, i in enumerate(index_list):
  19. # 获得特征数据保存到对应特征列表中
  20. usr_id_list.append(dataset[i]['usr_info']['usr_id'])
  21. usr_gender_list.append(dataset[i]['usr_info']['gender'])
  22. usr_age_list.append(dataset[i]['usr_info']['age'])
  23. usr_job_list.append(dataset[i]['usr_info']['job'])
  24. mov_id_list.append(dataset[i]['mov_info']['mov_id'])
  25. mov_tit_list.append(dataset[i]['mov_info']['title'])
  26. mov_cat_list.append(dataset[i]['mov_info']['category'])
  27. mov_id = dataset[i]['mov_info']['mov_id']
  28. if use_poster:
  29. # 不使用图像特征时,不读取图像数据,加快数据读取速度
  30. poster = Image.open(poster_path+'mov_id{}.jpg'.format(str(mov_id)))
  31. poster = poster.resize([64, 64])
  32. if len(poster.size) <= 2:
  33. poster = poster.convert("RGB")
  34. mov_poster_list.append(np.array(poster))
  35. score_list.append(int(dataset[i]['scores']))
  36. # 如果读取的数据量达到当前的batch大小,就返回当前批次
  37. if len(usr_id_list)==BATCHSIZE:
  38. # 转换列表数据为数组形式,reshape到固定形状
  39. usr_id_arr = np.array(usr_id_list)
  40. usr_gender_arr = np.array(usr_gender_list)
  41. usr_age_arr = np.array(usr_age_list)
  42. usr_job_arr = np.array(usr_job_list)
  43. mov_id_arr = np.array(mov_id_list)
  44. mov_cat_arr = np.reshape(np.array(mov_cat_list), [BATCHSIZE, 6]).astype(np.int64)
  45. mov_tit_arr = np.reshape(np.array(mov_tit_list), [BATCHSIZE, 1, 15]).astype(np.int64)
  46. if use_poster:
  47. mov_poster_arr = np.reshape(np.array(mov_poster_list)/127.5 - 1, [BATCHSIZE, 3, 64, 64]).astype(np.float32)
  48. else:
  49. mov_poster_arr = np.array([0.])
  50. scores_arr = np.reshape(np.array(score_list), [-1, 1]).astype(np.float32)
  51. # 返回当前批次数据
  52. yield [usr_id_arr, usr_gender_arr, usr_age_arr, usr_job_arr], \
  53. [mov_id_arr, mov_cat_arr, mov_tit_arr, mov_poster_arr], scores_arr
  54. # 清空数据
  55. usr_id_list, usr_gender_list, usr_age_list, usr_job_list = [], [], [], []
  56. mov_id_list, mov_tit_list, mov_cat_list, score_list = [], [], [], []
  57. mov_poster_list = []
  58. return data_generator

load_data()函数通过输入的数据集,处理数据并返回一个数据迭代器。

我们将数据集按照8:2的比例划分训练集和验证集,可以分别得到训练数据迭代器和验证数据迭代器。

  1. dataset = get_dataset(usr_info, rating_info, movie_info)
  2. print("数据集总数量:", len(dataset))
  3. trainset = dataset[:int(0.8*len(dataset))]
  4. train_loader = load_data(trainset, mode="train")
  5. print("训练集数量:", len(trainset))
  6. validset = dataset[int(0.8*len(dataset)):]
  7. valid_loader = load_data(validset, mode='valid')
  8. print("验证集数量:", len(validset))
  1. 数据集总数量: 1000209
  2. 训练集数量: 800167
  3. 验证集数量: 200042

数据迭代器的使用方式如下:

  1. for idx, data in enumerate(train_loader()):
  2. usr_data, mov_data, score = data
  3. usr_id_arr, usr_gender_arr, usr_age_arr, usr_job_arr = usr_data
  4. mov_id_arr, mov_cat_arr, mov_tit_arr, mov_poster_arr = mov_data
  5. print("用户ID数据尺寸", usr_id_arr.shape)
  6. print("电影ID数据尺寸", mov_id_arr.shape, ", 电影类别genres数据的尺寸", mov_cat_arr.shape, ", 电影名字title的尺寸", mov_tit_arr.shape)
  7. break
  1. 用户ID数据尺寸 (256, 1)
  2. 电影ID数据尺寸 (256, 1) , 电影类别genres数据的尺寸 (256, 6, 1) , 电影名字title的尺寸 (256, 1, 15, 1)