专注于 JetBrains IDEA 全家桶,永久激活,教程
持续更新 PyCharm,IDEA,WebStorm,PhpStorm,DataGrip,RubyMine,CLion,AppCode 永久激活教程

RNN模型使用的names2countries数据集处理

数据格式

数据使用csv保存,并压缩为gz格式的文件,X为第一列,表示名字,Y为第二列表示对应的国家

数据处理的目的是根据名字预测对应的国家

64_1.png64_2.png

名字的处理

首先对名字进行分割,然后使用ASCII码进行编码。

64_3.png

接下来进行padding操作,由于名字的长度不尽相同,因此需要在尾部添加0进行补充,使得可以送入RNN时seqLen是相同的,随后根据名字长度进行又长到短进行排序操作。

64_4.png

国家名的处理

根据索引编号直接制作成字典格式即可。

64_5.png


import torch from torch.utils.data import DataLoader, Dataset import numpy as np import gzip import csv import time import matplotlib.pyplot as plt import math HIDDEN_SIZE = 100 BATCH_SIZE = 256 N_LAYER = 2 N_EPOCHS = 100 N_CHARS = 128 USE_GPU = True # 数据存储到GPU上 def create_tensor(tensor): if USE_GPU: device = torch.device("cuda:0") tensor = tensor.to(device) return tensor """ X: name type: str y: country number """ class NameDataset(Dataset): def __init__(self, is_train_set=True): filename = 'data/names_train.csv.gz' if is_train_set else 'data/names_test.csv.gz' with gzip.open(filename, 'rt') as f: reader = csv.reader(f) rows = list(reader) self.names = [row[0] for row in rows] self.len = len(self.names) self.countries = [row[1] for row in rows] self.country_list = list(sorted(set(self.countries))) self.country_dict = self.getCountryDict() self.country_num = len(self.country_list) def __getitem__(self, item): # 返回名字为字符串(str), 国家为索引值(int) return self.names[item], self.country_dict[self.countries[item]] def __len__(self): return self.len def getCountryDict(self): country_dict = dict() # 遍历国家列表,从0开始制作字典 for idx, country_name in enumerate(self.country_list, 0): country_dict[country_name] = idx return country_dict def idx2country(self, index): return self.country_list[index] def getCountriesNum(self): return self.country_num trainset = NameDataset(is_train_set=True) trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) testset = NameDataset(is_train_set=False) testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False) N_COUNTRY = trainset.getCountriesNum() def name2list(name): # 将名字转换成ASSIC的列表 arr = [ord(c) for c in name] return arr, len(arr) # names为字符串构成的列表 # countries为数字构成的列表 # 返回的namesg格式为batch*inputSize def make_tensors(names, countries): # 将名字split然后使用ASCII码编号 sequences_and_lengths = [name2list(name) for name in names] name_sequences = [sl[0] for sl in sequences_and_lengths] seq_lengths = torch.LongTensor([s1[1] for s1 in sequences_and_lengths]) countries = countries.long() # make padding, append 0 to name_sequences seq_tensor = torch.zeros(len(name_sequences), seq_lengths.max()).long() for idx, (seq, seq_len) in enumerate(zip(name_sequences, seq_lengths), 0): seq_tensor[idx, :seq_len] = torch.LongTensor(seq) # sort by length to use pack_padded_sequence seq_lengths, perm_idx = seq_lengths.sort(dim=0, descending=True) seq_tensor = seq_tensor[perm_idx] countries = countries[perm_idx] return create_tensor(seq_tensor), create_tensor(seq_lengths), create_tensor(countries)

文章永久链接:https://tech.souyunku.com/23483

未经允许不得转载:搜云库技术团队 » RNN模型使用的names2countries数据集处理

JetBrains 全家桶,激活、破解、教程

提供 JetBrains 全家桶激活码、注册码、破解补丁下载及详细激活教程,支持 IntelliJ IDEA、PyCharm、WebStorm 等工具的永久激活。无论是破解教程,还是最新激活码,均可免费获得,帮助开发者解决常见激活问题,确保轻松破解并快速使用 JetBrains 软件。获取免费的破解补丁和激活码,快速解决激活难题,全面覆盖 2024/2025 版本!

联系我们联系我们