BatchedSeq2SeqExampleBasedontheseq2seq-translation-batched.ipynbfrompractical-pytorch,butmoreextrafeatures.
ThisexamplerunsgrammaticalerrorcorrectiontaskwherethesourcesequenceisagrammaticallyerroneuousEnglishsentenceandthetargetsequenceisangrammaticallycorrectEnglishsentence.Thecorpusandevaluationscriptcanbedownloadat:https://github.com/keisks/jfleg.
ExtrafeaturesCleanercodebaseVerydetailedcommentsforlearnersImplementPytorchnativedatasetanddataloaderforbatchingCorrectlyhandlethehiddenstatefrombidirectionalencoderandpasttothedecoderasinitialhiddenstate.Fullybatchedattentionmechanismcomputation(onlyimplementgeneralattentionbutit'ssufficient).Note:Theoriginalcodestillusesfor-looptocompute,whichisveryslow.SupportLSTMinsteadofonlyGRUSharedembeddings(encoder'sinputembeddinganddecoder'sinputembedding)PretrainedGloveembeddingFixedembeddingTieembeddings(decoder'sinputembeddinganddecoder'soutputembedding)TensorboardvisualizationLoadandsavecheckpointReplaceunknownwordsbyselectingthesourcetokenwiththehighestattentionscore.(Translation)ConsComparingtothestate-of-the-artseq2seqlibrary,OpenNMT-py,therearesomestuffsthataren'toptimizedinthiscodebase:
UseCuDNNwhenpossible(alwaysonencoder,ondecoderwheninput_feed=0)Alwaysavoidindexing/loopsandusetorchprimitives.Whenpossible,batchsoftmaxoperationsacrosstime.(thisisthesecondcomplicatedpartofthecode)Batchinferenceandbeamsearchfortranslation(thisisthemostcomplicatedpartofthecode)HowtospeedupRNNtraining?SeveralwaystospeedupRNNtraining:
BatchingStaticpaddingDynamicpaddingBucketingTruncatedBPTTSee"SequenceModelsandtheRNNAPI(TensorFlowDevSummit2017)"forunderstandingthosetechniques.
YoucanusetorchtextorOpenNMT'sdataiteratorforspeedingupthetraining.Itcanbe7xfaster!(ex:7hoursforanepoch->1hour!)
AcknowledgementThankstotheauthorofOpenNMT-py@srushforansweringthequestionsforme!Seehttps://github.com/OpenNMT/OpenNMT-py/issues/552
评论