(BDCI-CCF)出租车发票识别

news/2024/4/24 5:51:22/

参考文章:

​​​​​​百度AI攻略:出租车票识别_才能我浪费的博客-CSDN博客

附完整python源码)基于tensorflow、opencv的入门案例_发票识别一:关键区域定位_小白来搬家-CSDN博客_python发票识别

注:感谢一起完成项目的队友们

大赛官网:出租车发票识别 Competitions - DataFountain

一、赛题说明

1.赛题背景

出租车发票在日常财务发票报销中较为常见,由于这类发票样式丰富,区域性特点明显,并且包含大量模糊字迹和错位字迹,因此准确的定位发票文字字段,准确的识别文字和结构化输出显得十分重要。

2.赛题任务
本赛题任务是利用图像处理、机器学习、深度学习等方法训练出租车发票的文字检测,识别模型,并实现识别结果的结构化输出。

3.数据简介

数据来自实际生产生活中的报销出租车发票。

4.数据说明

提供10张以上不同地区的发票数据样例以便参赛者熟悉测试出租车发票样式。提供验证集数据量不少于500张。

数据包含两部分:第一部分是整张发票图片,用于文字检测算法的测试以及最终结构化输出效果测试;第二部分是发票图片中截取的字段截图,用于测试文字识别算法效果(根据切割完的图片做文本识别)。

5.评测标准

文本检测的标准包括准确率(Precision)、召回率(Recall)、F值(F-Measure);文字识别的标准包括按整字段统计识别准确率和按字符统计识别准确率;结构化输出的标准为按整字段统计输出字段识别准确率和按字符统计输出文字识别准确率。

二、解决思路

深度学习框架选择:PyTorch;文字检测算法选择:CTPN框架;文字识别算法选择:CRNN+CTC框架。

具体思路如下:首先CNN提取图像卷积特征,然后LSTM进一步提取图像卷积特征中的序列特征,最后引入CTC解决训练时字符无法对齐的问题。

三、实验过程

1、搭建pytorch框架

1)torch、torchvision、cuda 、python对应版本查询:

NVIDIA英伟达GPU显卡算力一览(包含Tesla和GeForce、TITAN及RTX系列等)_竹风寒的博客-CSDN博客_显卡算力

 

2)准备工作: anaconda、Pycharm下载安装;

(Anaconda详细安装及环境变量配置(图文)_ZoomToday-CSDN博客_anaconda安装教程环境变量)

                      cuda、cudnn下载安装;

                      Torch、torchvision下载;

                      conda与pip的源已切换为清华镜像源;

                   (conda和pip更新为国内源_lazy_boy的博客-CSDN博客)

                      conda内已建立了python版本为3.8的虚拟环境;

                    (conda建立python36虚拟环境_lazy_boy的博客-CSDN博客)

3)安装与测试torch:

在anaconda命令行下打开新建的虚拟环境,使用pip安装轮子,进行导入测试

4)在虚拟环境中安装pytorch

基础pytorch安装gpu版本--保姆级教程_lazy_boy的博客-CSDN博客_pytorch安装gpu

5)在pycharm中创建测试工程:可以得到cuda的版本和cudnn的版本

 

2.文字检测

主要原理:VGG提取特征,BLSTM融入上下文信息,基于RPN完成检测

具体步骤如下:

1)编写dataset.py函数,完成数据预处理。

import os
import xml.etree.ElementTree as ET
import numpy as np
import cv2
from torch.utils.data import Dataset
import torch
from config import IMAGE_MEAN
from ctpn_utils import cal_rpndef readxml(path):gtboxes = []imgfile = ''xml = ET.parse(path)for elem in xml.iter():if 'filename' in elem.tag:imgfile = elem.textif 'object' in elem.tag:for attr in list(elem):if 'bndbox' in attr.tag:xmin = int(round(float(attr.find('xmin').text)))ymin = int(round(float(attr.find('ymin').text)))xmax = int(round(float(attr.find('xmax').text)))ymax = int(round(float(attr.find('ymax').text)))gtboxes.append((xmin, ymin, xmax, ymax))return np.array(gtboxes), imgfile# for ctpn text detection
class VOCDataset(Dataset):def __init__(self,datadir,labelsdir):''':param txtfile: image name list text file:param datadir: image's directory:param labelsdir: annotations' directory'''if not os.path.isdir(datadir):raise Exception('[ERROR] {} is not a directory'.format(datadir))if not os.path.isdir(labelsdir):raise Exception('[ERROR] {} is not a directory'.format(labelsdir))self.datadir = datadirself.img_names = os.listdir(self.datadir)self.labelsdir = labelsdirdef __len__(self):return len(self.img_names)def __getitem__(self, idx):img_name = self.img_names[idx]img_path = os.path.join(self.datadir, img_name)print(img_path)xml_path = os.path.join(self.labelsdir, img_name.replace('.jpg', '.xml'))gtbox, _ = readxml(xml_path)img = cv2.imread(img_path)h, w, c = img.shape# clip imageif np.random.randint(2) == 1:img = img[:, ::-1, :]newx1 = w - gtbox[:, 2] - 1newx2 = w - gtbox[:, 0] - 1gtbox[:, 0] = newx1gtbox[:, 2] = newx2[cls, regr], _ = cal_rpn((h, w), (int(h / 16), int(w / 16)), 16, gtbox)m_img = img - IMAGE_MEANregr = np.hstack([cls.reshape(cls.shape[0], 1), regr])cls = np.expand_dims(cls, axis=0)# transform to torch tensorm_img = torch.from_numpy(m_img.transpose([2, 0, 1])).float()cls = torch.from_numpy(cls).float()regr = torch.from_numpy(regr).float()return m_img, cls, regrclass ICDARDataset(Dataset):def __init__(self,datadir,labelsdir):''':param txtfile: image name list text file:param datadir: image's directory:param labelsdir: annotations' directory'''if not os.path.isdir(datadir):raise Exception('[ERROR] {} is not a directory'.format(datadir))if not os.path.isdir(labelsdir):raise Exception('[ERROR] {} is not a directory'.format(labelsdir))self.datadir = datadirself.img_names = os.listdir(self.datadir)self.labelsdir = labelsdirdef __len__(self):return len(self.img_names)def box_transfer(self,coor_lists,rescale_fac = 1.0):gtboxes = []for coor_list in coor_lists:coors_x = [int(coor_list[2*i]) for i in range(4)]coors_y = [int(coor_list[2*i+1]) for i in range(4)]xmin = min(coors_x)xmax = max(coors_x)ymin = min(coors_y)ymax = max(coors_y)if rescale_fac>1.0:xmin = int(xmin / rescale_fac)xmax = int(xmax / rescale_fac)ymin = int(ymin / rescale_fac)ymax = int(ymax / rescale_fac)gtboxes.append((xmin, ymin, xmax, ymax))return np.array(gtboxes)def box_transfer_v2(self,coor_lists,rescale_fac = 1.0):gtboxes = []for coor_list in coor_lists:coors_x = [int(coor_list[2 * i]) for i in range(4)]coors_y = [int(coor_list[2 * i + 1]) for i in range(4)]xmin = min(coors_x)xmax = max(coors_x)ymin = min(coors_y)ymax = max(coors_y)if rescale_fac > 1.0:xmin = int(xmin / rescale_fac)xmax = int(xmax / rescale_fac)ymin = int(ymin / rescale_fac)ymax = int(ymax / rescale_fac)prev = xminfor i in range(xmin // 16 + 1, xmax // 16 + 1):next = 16*i-0.5gtboxes.append((prev, ymin, next, ymax))prev = nextgtboxes.append((prev, ymin, xmax, ymax))return np.array(gtboxes)def parse_gtfile(self,gt_path,rescale_fac = 1.0):coor_lists = list()with open(gt_path) as f:content = f.readlines()for line in content:coor_list = line.split(',')[:8]if len(coor_list)==8:coor_lists.append(coor_list)return self.box_transfer_v2(coor_lists,rescale_fac)def draw_boxes(self,img,cls,base_anchors,gt_box):for i in range(len(cls)):if cls[i]==1:pt1 = (int(base_anchors[i][0]),int(base_anchors[i][1]))pt2 = (int(base_anchors[i][2]),int(base_anchors[i][3]))img = cv2.rectangle(img,pt1,pt2,(200,100,100))for i in range(gt_box.shape[0]):pt1 = (int(gt_box[i][0]),int(gt_box[i][1]))pt2 = (int(gt_box[i][2]),int(gt_box[i][3]))img = cv2.rectangle(img, pt1, pt2, (100, 200, 100))return imgdef __getitem__(self, idx):img_name = self.img_names[idx]img_path = os.path.join(self.datadir, img_name)# print(img_path)img = cv2.imread(img_path)#####for read error, use default image#####if img is None:print(img_path)with open('error_imgs.txt','a') as f:f.write('{}\n'.format(img_path))img_name = 'img_2647.jpg'img_path = os.path.join(self.datadir, img_name)img = cv2.imread(img_path)#####for read error, use default image#####h, w, c = img.shaperescale_fac = max(h, w) / 1600if rescale_fac>1.0:h = int(h/rescale_fac)w = int(w/rescale_fac)img = cv2.resize(img,(w,h))gt_path = os.path.join(self.labelsdir, 'gt_'+img_name.split('.')[0]+'.txt')gtbox = self.parse_gtfile(gt_path,rescale_fac)# clip imageif np.random.randint(2) == 1:img = img[:, ::-1, :]newx1 = w - gtbox[:, 2] - 1newx2 = w - gtbox[:, 0] - 1gtbox[:, 0] = newx1gtbox[:, 2] = newx2[cls, regr], base_anchors = cal_rpn((h, w), (int(h / 16), int(w / 16)), 16, gtbox)# debug_img = self.draw_boxes(img.copy(),cls,base_anchors,gtbox)# cv2.imwrite('debug/{}'.format(img_name),debug_img)m_img = img - IMAGE_MEANregr = np.hstack([cls.reshape(cls.shape[0], 1), regr])cls = np.expand_dims(cls, axis=0)# transform to torch tensorm_img = torch.from_numpy(m_img.transpose([2, 0, 1])).float()cls = torch.from_numpy(cls).float()regr = torch.from_numpy(regr).float()return m_img, cls, regrif __name__ == '__main__':xmin = 15xmax = 95for i in range(xmin//16+1,xmax//16+1):print(16*i-0.5)

2)导入ctpn网络架构,初始化架构

class RPN_REGR_Loss(nn.Module):def __init__(self, device, sigma=9.0):super(RPN_REGR_Loss, self).__init__()self.sigma = sigmaself.device = devicedef forward(self, input, target):'''smooth L1 loss:param input:y_preds:param target: y_true:return:'''try:cls = target[0, :, 0]regr = target[0, :, 1:3]# apply regression to positive sampleregr_keep = (cls == 1).nonzero()[:, 0]regr_true = regr[regr_keep]regr_pred = input[0][regr_keep]diff = torch.abs(regr_true - regr_pred)less_one = (diff<1.0/self.sigma).float()loss = less_one * 0.5 * diff ** 2 * self.sigma + torch.abs(1- less_one) * (diff - 0.5/self.sigma)loss = torch.sum(loss, 1)loss = torch.mean(loss) if loss.numel() > 0 else torch.tensor(0.0)except Exception as e:print('RPN_REGR_Loss Exception:', e)# print(input, target)loss = torch.tensor(0.0)return loss.to(self.device)class RPN_CLS_Loss(nn.Module):def __init__(self,device):super(RPN_CLS_Loss, self).__init__()self.device = deviceself.L_cls = nn.CrossEntropyLoss(reduction='none')# self.L_regr = nn.SmoothL1Loss()# self.L_refi = nn.SmoothL1Loss()self.pos_neg_ratio = 3def forward(self, input, target):if config.OHEM:cls_gt = target[0][0]num_pos = 0loss_pos_sum = 0# print(len((cls_gt == 0).nonzero()),len((cls_gt == 1).nonzero()))if len((cls_gt == 1).nonzero())!=0:       # avoid num of pos sample is 0cls_pos = (cls_gt == 1).nonzero()[:, 0]gt_pos = cls_gt[cls_pos].long()cls_pred_pos = input[0][cls_pos]# print(cls_pred_pos.shape)loss_pos = self.L_cls(cls_pred_pos.view(-1, 2), gt_pos.view(-1))loss_pos_sum = loss_pos.sum()num_pos = len(loss_pos)cls_neg = (cls_gt == 0).nonzero()[:, 0]gt_neg = cls_gt[cls_neg].long()cls_pred_neg = input[0][cls_neg]loss_neg = self.L_cls(cls_pred_neg.view(-1, 2), gt_neg.view(-1))loss_neg_topK, _ = torch.topk(loss_neg, min(len(loss_neg), config.RPN_TOTAL_NUM-num_pos))loss_cls = loss_pos_sum+loss_neg_topK.sum()loss_cls = loss_cls/config.RPN_TOTAL_NUMreturn loss_cls.to(self.device)else:y_true = target[0][0]cls_keep = (y_true != -1).nonzero()[:, 0]cls_true = y_true[cls_keep].long()cls_pred = input[0][cls_keep]loss = F.nll_loss(F.log_softmax(cls_pred, dim=-1),cls_true)  # original is sparse_softmax_cross_entropy_with_logits# loss = nn.BCEWithLogitsLoss()(cls_pred[:,0], cls_true.float())  # 18-12-8loss = torch.clamp(torch.mean(loss), 0, 10) if loss.numel() > 0 else torch.tensor(0.0)return loss.to(self.device)class basic_conv(nn.Module):def __init__(self,in_planes,out_planes,kernel_size,stride=1,padding=0,dilation=1,groups=1,relu=True,bn=True,bias=True):super(basic_conv, self).__init__()self.out_channels = out_planesself.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else Noneself.relu = nn.ReLU(inplace=True) if relu else Nonedef forward(self, x):x = self.conv(x)if self.bn is not None:x = self.bn(x)if self.relu is not None:x = self.relu(x)return xclass CTPN_Model(nn.Module):def __init__(self):super().__init__()base_model = models.vgg16(pretrained=False)layers = list(base_model.features)[:-1]self.base_layers = nn.Sequential(*layers)  # block5_conv3 outputself.rpn = basic_conv(512, 512, 3, 1, 1, bn=False)self.brnn = nn.GRU(512,128, bidirectional=True, batch_first=True)self.lstm_fc = basic_conv(256, 512, 1, 1, relu=True, bn=False)self.rpn_class = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False)self.rpn_regress = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False)def forward(self, x):x = self.base_layers(x)# rpnx = self.rpn(x)    #[b, c, h, w]x1 = x.permute(0,2,3,1).contiguous()  # channels last   [b, h, w, c]b = x1.size()  # b, h, w, cx1 = x1.view(b[0]*b[1], b[2], b[3])x2, _ = self.brnn(x1)xsz = x.size()x3 = x2.view(xsz[0], xsz[2], xsz[3], 256)  # torch.Size([4, 20, 20, 256])x3 = x3.permute(0,3,1,2).contiguous()  # channels first [b, c, h, w]x3 = self.lstm_fc(x3)x = x3cls = self.rpn_class(x)regr = self.rpn_regress(x)cls = cls.permute(0,2,3,1).contiguous()regr = regr.permute(0,2,3,1).contiguous()cls = cls.view(cls.size(0), cls.size(1)*cls.size(2)*10, 2)regr = regr.view(regr.size(0), regr.size(1)*regr.size(2)*10, 2)return cls, regr

3)生成文字识别候选框

4)进行模型训练

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import cv2
import numpy as npimport torch
import torch.nn.functional as F
from ctpn_model import CTPN_Model
from ctpn_utils import gen_anchor, bbox_transfor_inv, clip_box, filter_bbox,nms, TextProposalConnectorOriented
from ctpn_utils import resize
import configprob_thresh = 0.5
width = 960
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
weights = os.path.join(config.checkpoints_dir, 'v3_ctpn_ep30_0.3699_0.0929_0.4628.pth')#'ctpn_ep17_0.0544_0.1125_0.1669.pth')model = CTPN_Model()
model.load_state_dict(torch.load(weights, map_location=device)['model_state_dict'])
model.to(device)
model.eval()def dis(image):cv2.imshow('image', image)cv2.waitKey(0)cv2.destroyAllWindows()def get_det_boxes(image,display = True):image = resize(image, height=720)image_c = image.copy()h, w = image.shape[:2]image = image.astype(np.float32) - config.IMAGE_MEANimage = torch.from_numpy(image.transpose(2, 0, 1)).unsqueeze(0).float()with torch.no_grad():image = image.to(device)cls, regr = model(image)cls_prob = F.softmax(cls, dim=-1).cpu().numpy()regr = regr.cpu().numpy()anchor = gen_anchor((int(h / 16), int(w / 16)), 16)bbox = bbox_transfor_inv(anchor, regr)bbox = clip_box(bbox, [h, w])# print(bbox.shape)fg = np.where(cls_prob[0, :, 1] > prob_thresh)[0]# print(np.max(cls_prob[0, :, 1]))select_anchor = bbox[fg, :]select_score = cls_prob[0, fg, 1]select_anchor = select_anchor.astype(np.int32)# print(select_anchor.shape)keep_index = filter_bbox(select_anchor, 16)# nmsselect_anchor = select_anchor[keep_index]select_score = select_score[keep_index]select_score = np.reshape(select_score, (select_score.shape[0], 1))nmsbox = np.hstack((select_anchor, select_score))keep = nms(nmsbox, 0.3)# print(keep)select_anchor = select_anchor[keep]select_score = select_score[keep]# text line-textConn = TextProposalConnectorOriented()text = textConn.get_text_lines(select_anchor, select_score, [h, w])print(text)if display:for i in text:s = str(round(i[-1] * 100, 2)) + '%'i = [int(j) for j in i]cv2.line(image_c, (i[0], i[1]), (i[2], i[3]), (0, 0, 255), 2)cv2.line(image_c, (i[0], i[1]), (i[4], i[5]), (0, 0, 255), 2)cv2.line(image_c, (i[6], i[7]), (i[2], i[3]), (0, 0, 255), 2)cv2.line(image_c, (i[4], i[5]), (i[6], i[7]), (0, 0, 255), 2)cv2.putText(image_c, s, (i[0]+13, i[1]+13),cv2.FONT_HERSHEY_SIMPLEX,1,(255,0,0),2,cv2.LINE_AA)return text,image_cif __name__ == '__main__':img_path = 'images/t1.png'image = cv2.imread(img_path)text,image = get_det_boxes(image)cv2.imwrite('results/t.jpg',image)# dis(image)

3.文字识别

主要用到CRNN算法,主要由CNN、RNN、CTC三大部分架构组成,分别对应卷积层、循环层和转录层。首先通过CNN将图片的特征提取出来后采用RNN对序列进行预测,最后通过一个CTC的翻译层得到最终结果。

CNN采取了经典的VGG16,RNN部分使用了双向LSTM,注意Pytorch里的LSTM单元接受的输入都必须是3维的张量(Tensors),每一维代表的意思不同。

CRNN部分代码:

import torch.nn as nn
from collections import OrderedDictclass BidirectionalLSTM(nn.Module):def __init__(self, nIn, nHidden, nOut):super(BidirectionalLSTM, self).__init__()self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)self.embedding = nn.Linear(nHidden * 2, nOut)def forward(self, input):recurrent, _ = self.rnn(input)T, b, h = recurrent.size()t_rec = recurrent.view(T * b, h)output = self.embedding(t_rec)  # [T * b, nOut]output = output.view(T, b, -1)return outputclass CRNN(nn.Module):def __init__(self, imgH, nc, nclass, nh, leakyRelu=False):super(CRNN, self).__init__()assert imgH % 16 == 0, 'imgH has to be a multiple of 16'# 1x32x128self.conv1 = nn.Conv2d(nc, 64, 3, 1, 1)self.relu1 = nn.ReLU(True)self.pool1 = nn.MaxPool2d(2, 2)# 64x16x64self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)self.relu2 = nn.ReLU(True)self.pool2 = nn.MaxPool2d(2, 2)# 128x8x32self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1)self.bn3 = nn.BatchNorm2d(256)self.relu3_1 = nn.ReLU(True)self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1)self.relu3_2 = nn.ReLU(True)self.pool3 = nn.MaxPool2d((2, 2), (2, 1), (0, 1))# 256x4x16self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1)self.bn4 = nn.BatchNorm2d(512)self.relu4_1 = nn.ReLU(True)self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1)self.relu4_2 = nn.ReLU(True)self.pool4 = nn.MaxPool2d((2, 2), (2, 1), (0, 1))# 512x2x16self.conv5 = nn.Conv2d(512, 512, 2, 1, 0)self.bn5 = nn.BatchNorm2d(512)self.relu5 = nn.ReLU(True)# 512x1x16self.rnn = nn.Sequential(BidirectionalLSTM(512, nh, nh),BidirectionalLSTM(nh, nh, nclass))def forward(self, input):# conv featuresx = self.pool1(self.relu1(self.conv1(input)))x = self.pool2(self.relu2(self.conv2(x)))x = self.pool3(self.relu3_2(self.conv3_2(self.relu3_1(self.bn3(self.conv3_1(x))))))x = self.pool4(self.relu4_2(self.conv4_2(self.relu4_1(self.bn4(self.conv4_1(x))))))conv = self.relu5(self.bn5(self.conv5(x)))# print(conv.size())b, c, h, w = conv.size()assert h == 1, "the height of conv must be 1"conv = conv.squeeze(2)conv = conv.permute(2, 0, 1)  # [w, b, c]# rnn featuresoutput = self.rnn(conv)return outputclass CRNN_v2(nn.Module):def __init__(self, imgH, nc, nclass, nh, leakyRelu=False):super(CRNN_v2, self).__init__()assert imgH % 16 == 0, 'imgH has to be a multiple of 16'# 1x32x128self.conv1_1 = nn.Conv2d(nc, 32, 3, 1, 1)self.bn1_1 = nn.BatchNorm2d(32)self.relu1_1 = nn.ReLU(True)self.conv1_2 = nn.Conv2d(32, 64, 3, 1, 1)self.bn1_2 = nn.BatchNorm2d(64)self.relu1_2 = nn.ReLU(True)self.pool1 = nn.MaxPool2d(2, 2)# 64x16x64self.conv2_1 = nn.Conv2d(64, 64, 3, 1, 1)self.bn2_1 = nn.BatchNorm2d(64)self.relu2_1 = nn.ReLU(True)self.conv2_2 = nn.Conv2d(64, 128, 3, 1, 1)self.bn2_2 = nn.BatchNorm2d(128)self.relu2_2 = nn.ReLU(True)self.pool2 = nn.MaxPool2d(2, 2)# 128x8x32self.conv3_1 = nn.Conv2d(128, 96, 3, 1, 1)self.bn3_1 = nn.BatchNorm2d(96)self.relu3_1 = nn.ReLU(True)self.conv3_2 = nn.Conv2d(96, 192, 3, 1, 1)self.bn3_2 = nn.BatchNorm2d(192)self.relu3_2 = nn.ReLU(True)self.pool3 = nn.MaxPool2d((2, 2), (2, 1), (0, 1))# 192x4x32self.conv4_1 = nn.Conv2d(192, 128, 3, 1, 1)self.bn4_1 = nn.BatchNorm2d(128)self.relu4_1 = nn.ReLU(True)self.conv4_2 = nn.Conv2d(128, 256, 3, 1, 1)self.bn4_2 = nn.BatchNorm2d(256)self.relu4_2 = nn.ReLU(True)self.pool4 = nn.MaxPool2d((2, 2), (2, 1), (0, 1))# 256x2x32self.bn5 = nn.BatchNorm2d(256)# 256x2x32self.rnn = nn.Sequential(BidirectionalLSTM(512, nh, nh),BidirectionalLSTM(nh, nh, nclass))def forward(self, input):# conv featuresx = self.pool1(self.relu1_2(self.bn1_2(self.conv1_2(self.relu1_1(self.bn1_1(self.conv1_1(input)))))))x = self.pool2(self.relu2_2(self.bn2_2(self.conv2_2(self.relu2_1(self.bn2_1(self.conv2_1(x)))))))x = self.pool3(self.relu3_2(self.bn3_2(self.conv3_2(self.relu3_1(self.bn3_1(self.conv3_1(x)))))))x = self.pool4(self.relu4_2(self.bn4_2(self.conv4_2(self.relu4_1(self.bn4_1(self.conv4_1(x)))))))conv = self.bn5(x)# print(conv.size())b, c, h, w = conv.size()assert h == 2, "the height of conv must be 2"conv = conv.reshape([b,c*h,w])conv = conv.permute(2, 0, 1)  # [w, b, c]# rnn featuresoutput = self.rnn(conv)return outputdef conv3x3(nIn, nOut, stride=1):# "3x3 convolution with padding"return nn.Conv2d( nIn, nOut, kernel_size=3, stride=stride, padding=1, bias=False )class basic_res_block(nn.Module):def __init__(self, nIn, nOut, stride=1, downsample=None):super( basic_res_block, self ).__init__()m = OrderedDict()m['conv1'] = conv3x3( nIn, nOut, stride )m['bn1'] = nn.BatchNorm2d( nOut )m['relu1'] = nn.ReLU( inplace=True )m['conv2'] = conv3x3( nOut, nOut )m['bn2'] = nn.BatchNorm2d( nOut )self.group1 = nn.Sequential( m )self.relu = nn.Sequential( nn.ReLU( inplace=True ) )self.downsample = downsampledef forward(self, x):if self.downsample is not None:residual = self.downsample( x )else:residual = xout = self.group1( x ) + residualout = self.relu( out )return outclass CRNN_res(nn.Module):def __init__(self, imgH, nc, nclass, nh):super(CRNN_res, self).__init__()assert imgH % 16 == 0, 'imgH has to be a multiple of 16'self.conv1 = nn.Conv2d(nc, 64, 3, 1, 1)self.relu1 = nn.ReLU(True)self.res1 = basic_res_block(64, 64)# 1x32x128down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(128))self.res2_1 = basic_res_block( 64, 128, 2, down1 )self.res2_2 = basic_res_block(128,128)# 64x16x64down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(256))self.res3_1 = basic_res_block(128, 256, 2, down2)self.res3_2 = basic_res_block(256, 256)self.res3_3 = basic_res_block(256, 256)# 128x8x32down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=1, stride=(2, 1), bias=False),nn.BatchNorm2d(512))self.res4_1 = basic_res_block(256, 512, (2, 1), down3)self.res4_2 = basic_res_block(512, 512)self.res4_3 = basic_res_block(512, 512)# 256x4x16self.pool = nn.AvgPool2d((2, 2), (2, 1), (0, 1))# 512x2x16self.conv5 = nn.Conv2d(512, 512, 2, 1, 0)self.bn5 = nn.BatchNorm2d(512)self.relu5 = nn.ReLU(True)# 512x1x16self.rnn = nn.Sequential(BidirectionalLSTM(512, nh, nh),BidirectionalLSTM(nh, nh, nclass))def forward(self, input):# conv featuresx = self.res1(self.relu1(self.conv1(input)))x = self.res2_2(self.res2_1(x))x = self.res3_3(self.res3_2(self.res3_1(x)))x = self.res4_3(self.res4_2(self.res4_1(x)))x = self.pool(x)conv = self.relu5(self.bn5(self.conv5(x)))# print(conv.size())b, c, h, w = conv.size()assert h == 1, "the height of conv must be 1"conv = conv.squeeze(2)conv = conv.permute(2, 0, 1)  # [w, b, c]# rnn featuresoutput = self.rnn(conv)return outputif __name__ == '__main__':pass

具体完整代码请参考:https://github.com/breezedeus/cnstd


http://www.ppmy.cn/news/112818.html

相关文章

2021-05-31

Linux DRM那些事-libdrm调试准备 本文主要介绍libdrm的代码下载、编译和调试前的准备工作。 一、代码下载 libdrm下载网址&#xff1a;https://dri.freedesktop.org/libdrm/libdrm-2.4.89.tar.bz2 解压命令&#xff1a;tar -jxvf libdrm-2.4.89.tar.bz2 在代码解压后&…

03

术语表 术语 定义 敏感数据 敏感数据的具体范围取决于产品具体的应用场景&#xff0c;产品应根据风险进行分析和判断。典型的敏感数据包括口令、银行帐号、大批量个人数据、用户通信内容和密钥等。一类如果丢失或者泄漏&#xff0c;会对数据的所有者造成负面影响的数据。 本…

20070523

休息了几天&#xff08;其实也不算休息&#xff0c;基本上按时下班走人&#xff09; 今天把活作了 估计会晚走 昨晚玩到快2点 。。。。 疯了。。。 拟或是冲动了&#xff1f;

EC 三

8>.debug message ECdebug的方式很多,类似于RW的方式,通过外接的工具,debugger和debuggee,把当前 Registers,memory里的值都读出来,拿来和FW或者DS里面对比,看是不是正确,也可以抛一些断点,然后回看registers里有没有改变.呵,这是一个细致又烦锁的活. 就我所知ITE85X…

23种设计模式之创建型模式(单例、工厂方法、抽象工厂、生成器和原型模式)

概述 设计模式是针对某一类问题的最优解决方案&#xff0c;是从许多优秀的软件系统中总结出的。 Java中设计模式&#xff08;java design patterns&#xff09;通常有23种。 模式可以分成3类&#xff1a;创建型、行为型和结构型。 创建型模式 创建型模式涉及对象的实例化&#…

PL/SQL之索引和分区

一、索引 --index 数据库中的索引和目录的概念类似&#xff0c;如果某个列出现在查询的条件中&#xff0c;而该列的数据是无序的&#xff0c;那么查询时只能一行一行去扫描。 创建索引就是对某些特定列中的数据进行排序&#xff0c;生成独立的索引表&#xff0c; 当在某个…

ASCII Unicode UTF-8等等编码介绍

目录 背景 Unicode UTF-8 ISO-8859-1 GB2312和GBK ANSI UTF-16LE 和UTF-16BE UTF-16 LE 和BE是什么 如何处理字节序问题 "带有BOM的UTF-8"又是什么&#xff1f; 背景 由于计算机是美国人发明的&#xff0c;因此最早只有127个字母被编码到计算机中&#x…

【后缀数组/SAM+边分树合并】LGP5115 Check,Check,Check one two!

【题目】 原题地址 给定一个字符串 S S S&#xff0c;求 ∑ 1 ≤ i < j ≤ n l c p ( i , j ) l c s ( i , j ) [ l c p ( i , j ) ≤ k 1 ] [ l c s ( i , j ) ≤ k 2 ] \sum_{1\leq i<j\leq n}lcp(i,j)lcs(i,j)[lcp(i,j)\leq k_1][lcs(i,j)\leq k_2] 1≤i<j≤n∑​l…

hdu5115-Dire Wolf【区间dp】

正题 题目链接:http://acm.hdu.edu.cn/showproblem.php?pid5115 题目大意 有 n n n只狼&#xff0c;击败第 i i i只狼会扣 a i a_i ai​加上于其相邻的狼的 b l b r b_lb_r bl​br​点 h p hp hp。注意该狼被击败后会使原来于其相邻的狼变的相邻。 解题思路 显然区间 d p …

HDU 5115 Dire Wolf 区间dp

Dire Wolf Time Limit: 1 Sec Memory Limit: 256 MB 题目连接 http://acm.hdu.edu.cn/showproblem.php?pid5115 Description Dire wolves, also known as Dark wolves, are extraordinarily large and powerful wolves. Many, if not all, Dire Wolves appear to originate …

HDU - 5115 经典区间dp

题意&#xff1a;给定n个狼的攻击值ai和附加攻击值bi&#xff0c;每杀死一匹狼i&#xff0c;受到的伤害等于i的攻击值和与i相邻的狼的附加攻击值。求杀死所有的狼受到的伤害的最小值。 dp[i][j]&#xff1a;杀死区间i~j的狼受到伤害的最小值。 初始化&#xff1a; a[0]a[n1]…

洛谷P5115 : SAM + 边分治 + 虚树dp

题意 给出串 S S S&#xff0c; K 1 , K 2 K1,K2 K1,K2&#xff0c;求 ∑ 1 ≤ i < j ≤ n L C P ( i , j ) ⋅ L C S ( i , j ) ⋅ [ L C P ( i , j ) ≤ K 1 ] ⋅ [ L C S ( i , j ) ≤ K 2 ] \sum_{1 \le i < j \le n}{LCP(i,j) \cdot LCS(i,j) \cdot [LCP(i,j) \le…

HDU5115(区间dp)详解

题目大意&#xff1a;你是一个战士现在面对&#xff0c;一群狼&#xff0c;每只狼都有一定的主动攻击力和附带攻击力。你杀死一只狼。你会受到这只狼的&#xff08;主动攻击力旁边两只狼的附带攻击力&#xff09;这么多伤害~现在问你如何选择杀狼的顺序使的杀完所有狼时&#x…

5115. 删除回文子数组

给你一个整数数组 arr&#xff0c;每一次操作你都可以选择并删除它的一个 回文 子数组 arr[i], arr[i1], ..., arr[j]&#xff08; i < j&#xff09;。 注意&#xff0c;每当你删除掉一个子数组&#xff0c;右侧元素都会自行向前移动填补空位。 请你计算并返回从数组中删…

【题解】hdu5115 区间DP

题目链接 dp[i][j]表示从i杀到j所受的最小伤害 dp[i][j]min(dp[i][j],dp[i][k-1]dp[k1][j]attack[k]extre[i-1]extre[j1]) 吓到了贼NM难想 //巨难想 #include<cstdio> #include<algorithm> #define INF 0x3f3f3f3f using namespace std; int dp[210][210];//d…

【洛谷P5115】Check,Check,Check one two!(后缀数组)(并查集)

传送门 题解&#xff1a; 前前后后花了几个月的时间总算是把shadowice的比赛写到只剩一道题了&#xff0c;那道题是个很水的莫队不想写了。 然而这道题标算给了个很扯的后缀自动机上边分树合并。。。TM什么毒瘤玩意 考虑选择两个极长重复子串来计算答案&#xff0c;其实就是…

HDU 5115 Dire Wolf

传送门&#xff1a;http://acm.hdu.edu.cn/showproblem.php?pid5115 森林狼嘛&#xff0c;1费11但是可以给相邻的随从加buff&#xff0c;hhh 现在有一列森林狼&#xff0c;你需要A死他们&#xff0c;当然你也会受到伤害每一只森林狼有自己的攻击力a[i]&#xff0c;以及对相邻…

刷基础题-hdu5115

http://acm.hdu.edu.cn/showproblem.php?pid5115 2019.5.30 区间dp&#xff0c;O(N^3)能过&#xff0c;攻击一匹狼的花费是这匹狼的攻击力和相邻狼的额外攻击力 #include <stdio.h> #include <algorithm> #include <string.h> #define mem(x,y) memset(x,y…

hdu5115(区间dp)

链接&#xff1a;点击打开链接 题意&#xff1a;有一排狼&#xff0c;每只狼有一个伤害A&#xff0c;还有一个伤害B。杀死一只狼的时候&#xff0c;会受到这只狼的伤害A和这只狼两边的狼的伤害B的和。如果某位置的狼被杀&#xff0c;那么杀它左边的狼时就会收到来自右边狼的B&…