数据处理完整代码

到这里,我们已完成了ml-1m数据读取和处理,接下来,我们将数据处理的代码封装到一个Python类中,完整实现如下:

  1. import random
  2. import numpy as np
  3. from PIL import Image
  4. class MovieLen(object):
  5. def __init__(self, use_poster):
  6. self.use_poster = use_poster
  7. # 声明每个数据文件的路径
  8. usr_info_path = "./work/ml-1m/users.dat"
  9. if use_poster:
  10. rating_path = "./work/ml-1m/new_rating.txt"
  11. else:
  12. rating_path = "./work/ml-1m/ratings.dat"
  13. movie_info_path = "./work/ml-1m/movies.dat"
  14. self.poster_path = "./work/ml-1m/posters/"
  15. # 得到电影数据
  16. self.movie_info, self.movie_cat, self.movie_title = self.get_movie_info(movie_info_path)
  17. # 记录电影的最大ID
  18. self.max_mov_cat = np.max([self.movie_cat[k] for k in self.movie_cat])
  19. self.max_mov_tit = np.max([self.movie_title[k] for k in self.movie_title])
  20. self.max_mov_id = np.max(list(map(int, self.movie_info.keys())))
  21. # 记录用户数据的最大ID
  22. self.max_usr_id = 0
  23. self.max_usr_age = 0
  24. self.max_usr_job = 0
  25. # 得到用户数据
  26. self.usr_info = self.get_usr_info(usr_info_path)
  27. # 得到评分数据
  28. self.rating_info = self.get_rating_info(rating_path)
  29. # 构建数据集
  30. self.dataset = self.get_dataset(usr_info=self.usr_info,
  31. rating_info=self.rating_info,
  32. movie_info=self.movie_info)
  33. # 划分数据集,获得数据加载器
  34. self.train_dataset = self.dataset[:int(len(self.dataset)*0.9)]
  35. self.valid_dataset = self.dataset[int(len(self.dataset)*0.9):]
  36. print("##Total dataset instances: ", len(self.dataset))
  37. print("##MovieLens dataset information: \nusr num: {}\n"
  38. "movies num: {}".format(len(self.usr_info),len(self.movie_info)))
  39. # 得到电影数据
  40. def get_movie_info(self, path):
  41. # 打开文件,编码方式选择ISO-8859-1,读取所有数据到data中
  42. with open(path, 'r', encoding="ISO-8859-1") as f:
  43. data = f.readlines()
  44. # 建立三个字典,分别用户存放电影所有信息,电影的名字信息、类别信息
  45. movie_info, movie_titles, movie_cat = {}, {}, {}
  46. # 对电影名字、类别中不同的单词计数
  47. t_count, c_count = 1, 1
  48. count_tit = {}
  49. # 按行读取数据并处理
  50. for item in data:
  51. item = item.strip().split("::")
  52. v_id = item[0]
  53. v_title = item[1][:-7]
  54. cats = item[2].split('|')
  55. v_year = item[1][-5:-1]
  56. titles = v_title.split()
  57. # 统计电影名字的单词,并给每个单词一个序号,放在movie_titles中
  58. for t in titles:
  59. if t not in movie_titles:
  60. movie_titles[t] = t_count
  61. t_count += 1
  62. # 统计电影类别单词,并给每个单词一个序号,放在movie_cat中
  63. for cat in cats:
  64. if cat not in movie_cat:
  65. movie_cat[cat] = c_count
  66. c_count += 1
  67. # 补0使电影名称对应的列表长度为15
  68. v_tit = [movie_titles[k] for k in titles]
  69. while len(v_tit)<15:
  70. v_tit.append(0)
  71. # 补0使电影种类对应的列表长度为6
  72. v_cat = [movie_cat[k] for k in cats]
  73. while len(v_cat)<6:
  74. v_cat.append(0)
  75. # 保存电影数据到movie_info中
  76. movie_info[v_id] = {'mov_id': int(v_id),
  77. 'title': v_tit,
  78. 'category': v_cat,
  79. 'years': int(v_year)}
  80. return movie_info, movie_cat, movie_titles
  81. def get_usr_info(self, path):
  82. # 性别转换函数,M-0, F-1
  83. def gender2num(gender):
  84. return 1 if gender == 'F' else 0
  85. # 打开文件,读取所有行到data中
  86. with open(path, 'r') as f:
  87. data = f.readlines()
  88. # 建立用户信息的字典
  89. use_info = {}
  90. max_usr_id = 0
  91. #按行索引数据
  92. for item in data:
  93. # 去除每一行中和数据无关的部分
  94. item = item.strip().split("::")
  95. usr_id = item[0]
  96. # 将字符数据转成数字并保存在字典中
  97. use_info[usr_id] = {'usr_id': int(usr_id),
  98. 'gender': gender2num(item[1]),
  99. 'age': int(item[2]),
  100. 'job': int(item[3])}
  101. self.max_usr_id = max(self.max_usr_id, int(usr_id))
  102. self.max_usr_age = max(self.max_usr_age, int(item[2]))
  103. self.max_usr_job = max(self.max_usr_job, int(item[3]))
  104. return use_info
  105. # 得到评分数据
  106. def get_rating_info(self, path):
  107. # 读取文件里的数据
  108. with open(path, 'r') as f:
  109. data = f.readlines()
  110. # 将数据保存在字典中并返回
  111. rating_info = {}
  112. for item in data:
  113. item = item.strip().split("::")
  114. usr_id,movie_id,score = item[0],item[1],item[2]
  115. if usr_id not in rating_info.keys():
  116. rating_info[usr_id] = {movie_id:float(score)}
  117. else:
  118. rating_info[usr_id][movie_id] = float(score)
  119. return rating_info
  120. # 构建数据集
  121. def get_dataset(self, usr_info, rating_info, movie_info):
  122. trainset = []
  123. for usr_id in rating_info.keys():
  124. usr_ratings = rating_info[usr_id]
  125. for movie_id in usr_ratings:
  126. trainset.append({'usr_info': usr_info[usr_id],
  127. 'mov_info': movie_info[movie_id],
  128. 'scores': usr_ratings[movie_id]})
  129. return trainset
  130. def load_data(self, dataset=None, mode='train'):
  131. use_poster = False
  132. # 定义数据迭代Batch大小
  133. BATCHSIZE = 256
  134. data_length = len(dataset)
  135. index_list = list(range(data_length))
  136. # 定义数据迭代加载器
  137. def data_generator():
  138. # 训练模式下,打乱训练数据
  139. if mode == 'train':
  140. random.shuffle(index_list)
  141. # 声明每个特征的列表
  142. usr_id_list,usr_gender_list,usr_age_list,usr_job_list = [], [], [], []
  143. mov_id_list,mov_tit_list,mov_cat_list,mov_poster_list = [], [], [], []
  144. score_list = []
  145. # 索引遍历输入数据集
  146. for idx, i in enumerate(index_list):
  147. # 获得特征数据保存到对应特征列表中
  148. usr_id_list.append(dataset[i]['usr_info']['usr_id'])
  149. usr_gender_list.append(dataset[i]['usr_info']['gender'])
  150. usr_age_list.append(dataset[i]['usr_info']['age'])
  151. usr_job_list.append(dataset[i]['usr_info']['job'])
  152. mov_id_list.append(dataset[i]['mov_info']['mov_id'])
  153. mov_tit_list.append(dataset[i]['mov_info']['title'])
  154. mov_cat_list.append(dataset[i]['mov_info']['category'])
  155. mov_id = dataset[i]['mov_info']['mov_id']
  156. if use_poster:
  157. # 不使用图像特征时,不读取图像数据,加快数据读取速度
  158. poster = Image.open(self.poster_path+'mov_id{}.jpg'.format(str(mov_id[0])))
  159. poster = poster.resize([64, 64])
  160. if len(poster.size) <= 2:
  161. poster = poster.convert("RGB")
  162. mov_poster_list.append(np.array(poster))
  163. score_list.append(int(dataset[i]['scores']))
  164. # 如果读取的数据量达到当前的batch大小,就返回当前批次
  165. if len(usr_id_list)==BATCHSIZE:
  166. # 转换列表数据为数组形式,reshape到固定形状
  167. usr_id_arr = np.array(usr_id_list)
  168. usr_gender_arr = np.array(usr_gender_list)
  169. usr_age_arr = np.array(usr_age_list)
  170. usr_job_arr = np.array(usr_job_list)
  171. mov_id_arr = np.array(mov_id_list)
  172. mov_cat_arr = np.reshape(np.array(mov_cat_list), [BATCHSIZE, 6]).astype(np.int64)
  173. mov_tit_arr = np.reshape(np.array(mov_tit_list), [BATCHSIZE, 1, 15]).astype(np.int64)
  174. if use_poster:
  175. mov_poster_arr = np.reshape(np.array(mov_poster_list)/127.5 - 1, [BATCHSIZE, 3, 64, 64]).astype(np.float32)
  176. else:
  177. mov_poster_arr = np.array([0.])
  178. scores_arr = np.reshape(np.array(score_list), [-1, 1]).astype(np.float32)
  179. # 放回当前批次数据
  180. yield [usr_id_arr, usr_gender_arr, usr_age_arr, usr_job_arr], \
  181. [mov_id_arr, mov_cat_arr, mov_tit_arr, mov_poster_arr], scores_arr
  182. # 清空数据
  183. usr_id_list, usr_gender_list, usr_age_list, usr_job_list = [], [], [], []
  184. mov_id_list, mov_tit_list, mov_cat_list, score_list = [], [], [], []
  185. mov_poster_list = []
  186. return data_generator
  187. # 声明数据读取类
  188. dataset = MovieLen(False)
  189. # 定义数据读取器
  190. train_loader = dataset.load_data(dataset=dataset.train_dataset, mode='train')
  191. # 迭代的读取数据, Batchsize = 256
  192. for idx, data in enumerate(train_loader()):
  193. usr, mov, score = data
  194. print("打印用户ID,性别,年龄,职业数据的维度:")
  195. for v in usr:
  196. print(v.shape)
  197. print("打印电影ID,名字,类别数据的维度:")
  198. for v in mov:
  199. print(v.shape)
  200. break
  1. ##Total dataset instances: 1000209
  2. ##MovieLens dataset information:
  3. usr num: 6040
  4. movies num: 3883
  5. 打印用户ID,性别,年龄,职业数据的维度:
  6. (256,)
  7. (256,)
  8. (256,)
  9. (256,)
  10. 打印电影ID,名字,类别数据的维度:
  11. (256,)
  12. (256, 6)
  13. (256, 1, 15)
  14. (1,)