From 2c28809b1ea606597be71204cc1949e51ff97637 Mon Sep 17 00:00:00 2001 From: jjxu <428192774@qq.com> Date: Fri, 12 Jan 2024 13:12:19 +0800 Subject: [PATCH] =?UTF-8?q?1.=E6=96=B0=E5=A2=9EOPTION=E5=8F=82=E6=95=B0=20?= =?UTF-8?q?2.=E4=BC=98=E5=8C=96StringEntryType=E5=92=8CBytesEntryType?= =?UTF-8?q?=E7=9A=84body=E6=95=B0=E6=8D=AE=E8=A7=A3=E6=9E=90=E6=96=B9?= =?UTF-8?q?=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.MD | 13 +++-- body.go | 18 +++++++ options.go | 17 ++++++ options_test.go | 49 ++++++++++++++++++ parser.go | 37 ++++++++++--- parser_test.go | 118 ++++++++++++++++++++++++++++++++++++++++++ simpleRequest.go | 103 +++++++++++++++++++++--------------- simpleRequest_test.go | 61 ++++++++++++++++++++-- 8 files changed, 361 insertions(+), 55 deletions(-) create mode 100644 options_test.go create mode 100644 parser_test.go diff --git a/README.MD b/README.MD index 6707653..f4b4c97 100644 --- a/README.MD +++ b/README.MD @@ -101,11 +101,15 @@ r.Headers().ConentType_textPlain() //r.Headers().Set("Content-Type","text/plain; charset=utf-8") ``` -#### 2.4.5 忽略指定请求头 -在默认情况下,`Content-Type`将被自动添加到请求头中,但对于某些场景来说这会导致请求失败。使用下面的方法可以忽略指定的请求头 +#### 2.4.5 忽略指定请求头/禁用默认content-Type +在默认情况下,`Content-Type`将被自动添加到请求头中,值为`application/x-www-form-urlencoded`。但对于某些场景来说这会导致请求失败。使用下面的方法可以忽略指定的请求头 ```go r.Headers().Omit("Content-Type") ``` +或者使用以下方法禁用**默认content-Type**。禁用后当你不主动设置`Content-Type`时,请求头中将不会包含`Content-Type`。 +```go +var r = simpleRequest.NewRequest(simpleRequest.OptionDisableDefaultContentType()) +```` ### 2.5 添加queryParams #### 2.5.1 单个赋值 @@ -311,7 +315,10 @@ if err != nil { ```go requestContext:=r.Request ``` - +为了让用户能够便于分析调试,在进行http请求后r.Request.Body中的数据仍旧可读,但是会丧失部分性能,如果要禁用此功能请使用以下方法。 +```go +var r = simpleRequest.NewRequest(simpleRequest.OptionDisableDefaultContentType()) +``` #### 2.10.2 获取返回的上下文对象 ```go responseContext:=r.Response diff --git a/body.go b/body.go index a750f49..ef0a19e 100644 --- a/body.go +++ b/body.go @@ -8,7 +8,10 @@ package simpleRequest import ( + "bytes" + "io" "mime/multipart" + "strings" ) // EntryMark 请求体条目标记,用于标记输入的body内容格式 @@ -27,6 +30,21 @@ const ( FormFilePathKey EntryMark = "__FORM_FILE_PATH_KEY__" ) +func GetStringEntryTypeBody(bodyEntries map[string]any) io.Reader { + data, ok := bodyEntries[StringEntryType.string()] + if !ok || data == nil { + return nil + } + return strings.NewReader(data.(string)) +} +func GetBytesEntryTypeBody(bodyEntries map[string]any) io.Reader { + data, ok := bodyEntries[BytesEntryType.string()] + if !ok || data == "" { + return nil + } + return bytes.NewReader(data.([]byte)) +} + type BodyConf struct { simpleReq *SimpleRequest } diff --git a/options.go b/options.go index 9bf36ca..688bc6c 100644 --- a/options.go +++ b/options.go @@ -15,3 +15,20 @@ func OptionNewBodyEntryParser(contentType string, parser IBodyEntryParser) OPTIO return r } } + +// OptionDisableDefaultContentType 禁用默认的ContentType +// 当未指定ContentType时,将不会使用默认的ContentType +func OptionDisableDefaultContentType() OPTION { + return func(r *SimpleRequest) *SimpleRequest { + r.disableDefaultContentType = true + return r + } +} + +// OptionDisableCopyRequestBody 禁用复制RequestBody +func OptionDisableCopyRequestBody() OPTION { + return func(r *SimpleRequest) *SimpleRequest { + r.disableCopyRequestBody = true + return r + } +} diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..59fa6b0 --- /dev/null +++ b/options_test.go @@ -0,0 +1,49 @@ +// Package simpleRequest ----------------------------- +// @file : options_test.go +// @author : JJXu +// @contact : wavingbear@163.com +// @time : 2024/1/12 11:27 +// ------------------------------------------- +package simpleRequest + +import ( + "io" + "testing" +) + +func TestOptionDisableDefaultContentType(t *testing.T) { + go httpserver() + var r = NewRequest(OptionDisableDefaultContentType()) + r.Headers() + bodyData := "{'a'=123,'b'=56}" + r.Body().SetString(bodyData) + _, err := r.POST("http://localhost:8989") + if err != nil { + t.Error(err) + return + } + if r.Request.Header.Get(hdrContentTypeKey) != "" { + t.Errorf("query params want '%s' but get '%s'", "", r.Request.Header.Get(hdrContentTypeKey)) + } +} + +func TestOptionOptionDisableCopyRequestBody(t *testing.T) { + go httpserver() + var r = NewRequest(OptionDisableCopyRequestBody()) + r.Headers() + bodyData := "{'a'=123,'b'=56}" + r.Body().SetString(bodyData) + _, err := r.POST("http://localhost:8989") + if err != nil { + t.Error(err) + return + } + body, err := io.ReadAll(r.Request.Body) + if err != nil { + t.Error(err) + return + } + if string(body) != "" { + t.Errorf("query params want '%s' but get '%s'", "", body) + } +} diff --git a/parser.go b/parser.go index cae457d..90caaee 100644 --- a/parser.go +++ b/parser.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "mime/multipart" + "net/url" "os" "path/filepath" "strings" @@ -34,9 +35,9 @@ type JsonParser struct{} func (JsonParser) Unmarshal(bodyType EntryMark, BodyEntry map[string]any) io.Reader { switch bodyType { case StringEntryType: - return strings.NewReader(BodyEntry[StringEntryType.string()].(string)) + return GetStringEntryTypeBody(BodyEntry) case BytesEntryType: - return bytes.NewReader(BodyEntry[BytesEntryType.string()].([]byte)) + return GetBytesEntryTypeBody(BodyEntry) case ModelEntryType: jsonData, err := json.Marshal(BodyEntry[ModelEntryType.string()]) if err == nil { @@ -82,9 +83,9 @@ func (f *FormDataParser) Unmarshal(bodyType EntryMark, BodyEntry map[string]any) } body, f.ContentType = multipartCommonParse(mapper) case StringEntryType: - return strings.NewReader(BodyEntry[StringEntryType.string()].(string)) + return GetStringEntryTypeBody(BodyEntry) case BytesEntryType: - return bytes.NewReader(BodyEntry[BytesEntryType.string()].([]byte)) + return GetBytesEntryTypeBody(BodyEntry) default: body, f.ContentType = multipartCommonParse(BodyEntry) } @@ -161,10 +162,34 @@ func (f XmlParser) Unmarshal(bodyType EntryMark, BodyEntry map[string]any) (body return strings.NewReader("") } case StringEntryType: - return strings.NewReader(BodyEntry[StringEntryType.string()].(string)) + return GetStringEntryTypeBody(BodyEntry) case BytesEntryType: - return bytes.NewReader(BodyEntry[BytesEntryType.string()].([]byte)) + return GetBytesEntryTypeBody(BodyEntry) default: return strings.NewReader("") } } + +type CommonParser struct { +} + +func (f CommonParser) Unmarshal(bodyType EntryMark, BodyEntry map[string]any) (body io.Reader) { + tmpData := url.Values{} + for k, v := range BodyEntry { + switch k { + case StringEntryType.string(): + body = GetStringEntryTypeBody(BodyEntry) + break + case BytesEntryType.string(): + body = GetBytesEntryTypeBody(BodyEntry) + break + default: + if strings.HasPrefix(k, FormFilePathKey.string()) { + k = k[len(FormFilePathKey):] + } + tmpData.Set(k, fmt.Sprintf("%v", v)) + } + } + body = strings.NewReader(tmpData.Encode()) + return +} diff --git a/parser_test.go b/parser_test.go new file mode 100644 index 0000000..879aa87 --- /dev/null +++ b/parser_test.go @@ -0,0 +1,118 @@ +// Package simpleRequest ----------------------------- +// @file : parser_test.go +// @author : JJXu +// @contact : wavingbear@163.com +// @time : 2024/1/11 18:35 +// ------------------------------------------- +package simpleRequest + +import ( + "io" + "reflect" + "testing" +) + +func TestFormDataParser_Unmarshal(t *testing.T) { + type fields struct { + ContentType string + } + type args struct { + bodyType EntryMark + BodyEntry map[string]any + } + tests := []struct { + name string + fields fields + args args + wantBody io.Reader + }{ + { + name: "StringEntryType", + fields: fields{}, + args: args{ + bodyType: StringEntryType, + BodyEntry: map[string]any{StringEntryType.string(): "test"}, + }, + wantBody: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &FormDataParser{ + ContentType: tt.fields.ContentType, + } + if gotBody := f.Unmarshal(tt.args.bodyType, tt.args.BodyEntry); !reflect.DeepEqual(gotBody, tt.wantBody) { + t.Errorf("Unmarshal() = %v, want %v", gotBody, tt.wantBody) + } + }) + } +} + +func TestJsonParser_Unmarshal(t *testing.T) { + type args struct { + bodyType EntryMark + BodyEntry map[string]any + } + tests := []struct { + name string + args args + want io.Reader + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + js := JsonParser{} + if got := js.Unmarshal(tt.args.bodyType, tt.args.BodyEntry); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Unmarshal() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestXmlParser_Unmarshal(t *testing.T) { + type args struct { + bodyType EntryMark + BodyEntry map[string]any + } + tests := []struct { + name string + args args + wantBody io.Reader + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := XmlParser{} + if gotBody := f.Unmarshal(tt.args.bodyType, tt.args.BodyEntry); !reflect.DeepEqual(gotBody, tt.wantBody) { + t.Errorf("Unmarshal() = %v, want %v", gotBody, tt.wantBody) + } + }) + } +} + +func Test_multipartCommonParse(t *testing.T) { + type args struct { + BodyEntry map[string]any + } + tests := []struct { + name string + args args + wantReader io.Reader + wantContentType string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotReader, gotContentType := multipartCommonParse(tt.args.BodyEntry) + if !reflect.DeepEqual(gotReader, tt.wantReader) { + t.Errorf("multipartCommonParse() gotReader = %v, want %v", gotReader, tt.wantReader) + } + if gotContentType != tt.wantContentType { + t.Errorf("multipartCommonParse() gotContentType = %v, want %v", gotContentType, tt.wantContentType) + } + }) + } +} diff --git a/simpleRequest.go b/simpleRequest.go index f369790..6695b22 100644 --- a/simpleRequest.go +++ b/simpleRequest.go @@ -10,6 +10,7 @@ package simpleRequest import ( "bytes" "crypto/tls" + "errors" "fmt" "io" "net/http" @@ -49,14 +50,16 @@ type SimpleRequest struct { omitHeaderKeys []string transport *http.Transport - BodyEntryMark EntryMark - BodyEntries map[string]any - bodyEntryParsers map[string]IBodyEntryParser + BodyEntryMark EntryMark // 用于标记输入的body内容格式 + BodyEntries map[string]any // 输入的body中的内容 + bodyEntryParsers map[string]IBodyEntryParser // 用于将BodyEntries中的内容解析后传入request body中 + disableDefaultContentType bool // 禁用默认的ContentType + disableCopyRequestBody bool // 禁用默认的ContentType,在进行http请求后SimpleRequest.Request.Body中内容会无法读取 timeout time.Duration - Response http.Response //用于获取完整的返回内容。请注意要在请求之后才能获取 - Request http.Request //用于获取完整的请求内容。请注意要在请求之后才能获取 + Response *http.Response //用于获取完整的返回内容。请注意要在请求之后才能获取 + Request *http.Request //用于获取完整的请求内容。请注意要在请求之后才能获取 //cookies map[string]string //data any //cli *http.Client @@ -144,23 +147,17 @@ func (s *SimpleRequest) do(request *http.Request) (body []byte, err error) { //3.1 发送数据 resp, err := client.Do(request) if err != nil { - fmt.Println("【Request Error】:", err.Error()) + err = errors.New("[client.Do Err]:" + err.Error()) return } //v0.0.2更新,将request和response内容返回,便于用户进行分析 JJXu 03-11-2022 + s.Request = request if resp != nil { - s.Response = *resp - } else { - return + s.Response = resp + defer resp.Body.Close() + body, err = io.ReadAll(resp.Body) } - if request != nil { - s.Request = *request - } - - defer resp.Body.Close() - //3.2 获取数据 - body, err = io.ReadAll(resp.Body) return } @@ -178,15 +175,27 @@ func (s *SimpleRequest) GET(urls string) (body []byte, err error) { func (s *SimpleRequest) LaunchTo(urls, method string) (body []byte, err error) { // body s.initBody() - r, err := http.NewRequest(method, urls, s.body) + var r *http.Request + copBody, err := io.ReadAll(s.body) if err != nil { - return nil, err + return + } + if !s.disableCopyRequestBody { + // 使r.body在请求后仍旧可读,便于使用者对请求过程进行分析 + r, err = http.NewRequest(method, urls, io.NopCloser(bytes.NewBuffer(copBody))) + if err != nil { + return nil, err + } + defer func() { + r.Body = io.NopCloser(bytes.NewBuffer(copBody)) + }() + } else { + r, err = http.NewRequest(method, urls, s.body) + if err != nil { + return nil, err + } } //headers - //for k := range s.headers { - // r.Header[k] = append(r.Header[k], s.headers[k]...) - // s.headers.Del(k) - //} r.Header = s.headers for _, k := range s.omitHeaderKeys { r.Header.Del(k) @@ -248,7 +257,7 @@ func (s *SimpleRequest) TRACE(url string) (body []byte, err error) { // ------------------------------------------------------ // -// 这里数据 +// Automatically parses the request body based on the content-type type defined in the request header func (s *SimpleRequest) initBody() { contentTypeData := s.headers.Get(hdrContentTypeKey) if contentTypeData != "" { @@ -273,8 +282,7 @@ func (s *SimpleRequest) initBody() { //application/soap+xml ,application/xml var parser, ok = s.bodyEntryParsers[xmlDataType] if !ok { - data, _ := s.BodyEntries[StringEntryType.string()].(string) - s.body = strings.NewReader(data) + s.body = GetStringEntryTypeBody(s.BodyEntries) return } s.body = parser.Unmarshal(s.BodyEntryMark, s.BodyEntries) @@ -282,8 +290,7 @@ func (s *SimpleRequest) initBody() { case strings.Contains(contentTypeData, "text") || strings.Contains(contentTypeData, javaScriptType): var parser, ok = s.bodyEntryParsers[textPlainType] if !ok { - data, _ := s.BodyEntries[StringEntryType.string()].(string) - s.body = strings.NewReader(data) + s.body = GetStringEntryTypeBody(s.BodyEntries) return } s.body = parser.Unmarshal(s.BodyEntryMark, s.BodyEntries) @@ -292,32 +299,40 @@ func (s *SimpleRequest) initBody() { //default header type is "x-www-form-urlencoded" var parser, ok = s.bodyEntryParsers["form-urlencoded"] if !ok { - tmpData := url.Values{} for k, v := range s.BodyEntries { - tmpData.Set(k, fmt.Sprintf("%v", v)) + if v == nil { + break + } + switch k { + case StringEntryType.string(): + s.body = GetStringEntryTypeBody(s.BodyEntries) + break + case BytesEntryType.string(): + s.body = GetBytesEntryTypeBody(s.BodyEntries) + break + default: + tmpData := url.Values{} + if strings.HasPrefix(k, FormFilePathKey.string()) { + k = k[len(FormFilePathKey):] + } + tmpData.Set(k, fmt.Sprintf("%v", v)) + s.body = strings.NewReader(tmpData.Encode()) + } } - s.body = strings.NewReader(tmpData.Encode()) - s.Headers().ConentType_formUrlencoded() + //s.Headers().ConentType_formUrlencoded() return } s.body = parser.Unmarshal(s.BodyEntryMark, s.BodyEntries) default: - //todo Automatically determine the data type - tmpData := url.Values{} - for k, v := range tmpData { - if strings.HasPrefix(k, FormFilePathKey.string()) { - k = k[len(FormFilePathKey):] - } - tmpData.Set(k, fmt.Sprintf("%v", v)) - } - s.body = strings.NewReader(tmpData.Encode()) + // 自动处理body数据 + s.body = new(CommonParser).Unmarshal(s.BodyEntryMark, s.BodyEntries) } } else { switch s.BodyEntryMark { case BytesEntryType: - s.body = bytes.NewReader(s.BodyEntries[BytesEntryType.string()].([]byte)) + s.body = GetBytesEntryTypeBody(s.BodyEntries) case StringEntryType: - s.body = strings.NewReader(s.BodyEntries[BytesEntryType.string()].(string)) + s.body = GetStringEntryTypeBody(s.BodyEntries) default: var parser, ok = s.bodyEntryParsers["form-urlencoded"] if !ok { @@ -329,7 +344,9 @@ func (s *SimpleRequest) initBody() { tmpData.Set(k, fmt.Sprintf("%v", v)) } s.body = strings.NewReader(tmpData.Encode()) - s.Headers().ConentType_formUrlencoded() + if !s.disableDefaultContentType { + s.Headers().ConentType_formUrlencoded() + } return } s.body = parser.Unmarshal(s.BodyEntryMark, s.BodyEntries) diff --git a/simpleRequest_test.go b/simpleRequest_test.go index be60c56..b1fb60c 100644 --- a/simpleRequest_test.go +++ b/simpleRequest_test.go @@ -8,7 +8,6 @@ package simpleRequest import ( "encoding/json" - "fmt" "io" "net/http" "testing" @@ -35,7 +34,7 @@ func httpserver() { io.WriteString(w, "false") } }) - fmt.Println("http服务启动了") + //fmt.Println("http服务启动了") http.ListenAndServe(":8989", nil) } @@ -98,8 +97,64 @@ func TestQueryUrl2(t *testing.T) { t.Error(err.Error()) } else { if r.Request.URL.RawQuery != "a=123&b=456&c=3" { - t.Errorf("query params wangt '%s' but get '%s'", "a=123&b=456&c=3", r.Request.URL.RawQuery) + t.Errorf("query params want '%s' but get '%s'", "a=123&b=456&c=3", r.Request.URL.RawQuery) } } } + +// 请求后,r.Request.Body中的内容仍旧可读 +func TestQueryUseStringBody(t *testing.T) { + go httpserver() + var r = NewRequest() + r.Headers().ConentType_json() + bodyData := "{'a'=123,'b'=56}" + r.Body().SetString(bodyData) + _, err := r.POST("http://localhost:8989") + if err != nil { + t.Error(err) + return + } + body, err := io.ReadAll(r.Request.Body) + if err != nil { + t.Error(err) + return + } + if string(body) != bodyData { + t.Errorf("request body want '%s' but get '%s'", bodyData, body) + } +} + +func TestEmptyBody(t *testing.T) { + go httpserver() + var r = NewRequest() + r.Headers().ConentType_json() + _, err := r.POST("http://localhost:8989") + if err != nil { + t.Error(err) + return + } + body, err := io.ReadAll(r.Request.Body) + if err != nil { + t.Error(err) + return + } + if string(body) != "{}" { + t.Errorf("request body want '%s' but get '%s'", "{}", body) + } + + r.Headers().ConentType_formUrlencoded() + _, err = r.POST("http://localhost:8989") + if err != nil { + t.Error(err) + return + } + body, err = io.ReadAll(r.Request.Body) + if err != nil { + t.Error(err) + return + } + if string(body) != "" { + t.Errorf("request body want '%s' but get '%s'", "", body) + } +}