|  | // Copyright 2020 The Go Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style | 
|  | // license that can be found in the LICENSE file. | 
|  |  | 
|  | package os_test | 
|  |  | 
|  | import ( | 
|  | "bytes" | 
|  | "internal/poll" | 
|  | "io" | 
|  | "math/rand" | 
|  | "os" | 
|  | . "os" | 
|  | "path/filepath" | 
|  | "strconv" | 
|  | "syscall" | 
|  | "testing" | 
|  | "time" | 
|  | ) | 
|  |  | 
|  | func TestCopyFileRange(t *testing.T) { | 
|  | sizes := []int{ | 
|  | 1, | 
|  | 42, | 
|  | 1025, | 
|  | syscall.Getpagesize() + 1, | 
|  | 32769, | 
|  | } | 
|  | t.Run("Basic", func(t *testing.T) { | 
|  | for _, size := range sizes { | 
|  | t.Run(strconv.Itoa(size), func(t *testing.T) { | 
|  | testCopyFileRange(t, int64(size), -1) | 
|  | }) | 
|  | } | 
|  | }) | 
|  | t.Run("Limited", func(t *testing.T) { | 
|  | t.Run("OneLess", func(t *testing.T) { | 
|  | for _, size := range sizes { | 
|  | t.Run(strconv.Itoa(size), func(t *testing.T) { | 
|  | testCopyFileRange(t, int64(size), int64(size)-1) | 
|  | }) | 
|  | } | 
|  | }) | 
|  | t.Run("Half", func(t *testing.T) { | 
|  | for _, size := range sizes { | 
|  | t.Run(strconv.Itoa(size), func(t *testing.T) { | 
|  | testCopyFileRange(t, int64(size), int64(size)/2) | 
|  | }) | 
|  | } | 
|  | }) | 
|  | t.Run("More", func(t *testing.T) { | 
|  | for _, size := range sizes { | 
|  | t.Run(strconv.Itoa(size), func(t *testing.T) { | 
|  | testCopyFileRange(t, int64(size), int64(size)+7) | 
|  | }) | 
|  | } | 
|  | }) | 
|  | }) | 
|  | t.Run("DoesntTryInAppendMode", func(t *testing.T) { | 
|  | dst, src, data, hook := newCopyFileRangeTest(t, 42) | 
|  |  | 
|  | dst2, err := OpenFile(dst.Name(), O_RDWR|O_APPEND, 0755) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | defer dst2.Close() | 
|  |  | 
|  | if _, err := io.Copy(dst2, src); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if hook.called { | 
|  | t.Fatal("called poll.CopyFileRange for destination in O_APPEND mode") | 
|  | } | 
|  | mustSeekStart(t, dst2) | 
|  | mustContainData(t, dst2, data) // through traditional means | 
|  | }) | 
|  | t.Run("NotRegular", func(t *testing.T) { | 
|  | t.Run("BothPipes", func(t *testing.T) { | 
|  | hook := hookCopyFileRange(t) | 
|  |  | 
|  | pr1, pw1, err := Pipe() | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | defer pr1.Close() | 
|  | defer pw1.Close() | 
|  |  | 
|  | pr2, pw2, err := Pipe() | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | defer pr2.Close() | 
|  | defer pw2.Close() | 
|  |  | 
|  | // The pipe is empty, and PIPE_BUF is large enough | 
|  | // for this, by (POSIX) definition, so there is no | 
|  | // need for an additional goroutine. | 
|  | data := []byte("hello") | 
|  | if _, err := pw1.Write(data); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | pw1.Close() | 
|  |  | 
|  | n, err := io.Copy(pw2, pr1) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if n != int64(len(data)) { | 
|  | t.Fatalf("transferred %d, want %d", n, len(data)) | 
|  | } | 
|  | if !hook.called { | 
|  | t.Fatalf("should have called poll.CopyFileRange") | 
|  | } | 
|  | pw2.Close() | 
|  | mustContainData(t, pr2, data) | 
|  | }) | 
|  | t.Run("DstPipe", func(t *testing.T) { | 
|  | dst, src, data, hook := newCopyFileRangeTest(t, 255) | 
|  | dst.Close() | 
|  |  | 
|  | pr, pw, err := Pipe() | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | defer pr.Close() | 
|  | defer pw.Close() | 
|  |  | 
|  | n, err := io.Copy(pw, src) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if n != int64(len(data)) { | 
|  | t.Fatalf("transferred %d, want %d", n, len(data)) | 
|  | } | 
|  | if !hook.called { | 
|  | t.Fatalf("should have called poll.CopyFileRange") | 
|  | } | 
|  | pw.Close() | 
|  | mustContainData(t, pr, data) | 
|  | }) | 
|  | t.Run("SrcPipe", func(t *testing.T) { | 
|  | dst, src, data, hook := newCopyFileRangeTest(t, 255) | 
|  | src.Close() | 
|  |  | 
|  | pr, pw, err := Pipe() | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | defer pr.Close() | 
|  | defer pw.Close() | 
|  |  | 
|  | // The pipe is empty, and PIPE_BUF is large enough | 
|  | // for this, by (POSIX) definition, so there is no | 
|  | // need for an additional goroutine. | 
|  | if _, err := pw.Write(data); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | pw.Close() | 
|  |  | 
|  | n, err := io.Copy(dst, pr) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if n != int64(len(data)) { | 
|  | t.Fatalf("transferred %d, want %d", n, len(data)) | 
|  | } | 
|  | if !hook.called { | 
|  | t.Fatalf("should have called poll.CopyFileRange") | 
|  | } | 
|  | mustSeekStart(t, dst) | 
|  | mustContainData(t, dst, data) | 
|  | }) | 
|  | }) | 
|  | t.Run("Nil", func(t *testing.T) { | 
|  | var nilFile *File | 
|  | anyFile, err := os.CreateTemp("", "") | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | defer Remove(anyFile.Name()) | 
|  | defer anyFile.Close() | 
|  |  | 
|  | if _, err := io.Copy(nilFile, nilFile); err != ErrInvalid { | 
|  | t.Errorf("io.Copy(nilFile, nilFile) = %v, want %v", err, ErrInvalid) | 
|  | } | 
|  | if _, err := io.Copy(anyFile, nilFile); err != ErrInvalid { | 
|  | t.Errorf("io.Copy(anyFile, nilFile) = %v, want %v", err, ErrInvalid) | 
|  | } | 
|  | if _, err := io.Copy(nilFile, anyFile); err != ErrInvalid { | 
|  | t.Errorf("io.Copy(nilFile, anyFile) = %v, want %v", err, ErrInvalid) | 
|  | } | 
|  |  | 
|  | if _, err := nilFile.ReadFrom(nilFile); err != ErrInvalid { | 
|  | t.Errorf("nilFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid) | 
|  | } | 
|  | if _, err := anyFile.ReadFrom(nilFile); err != ErrInvalid { | 
|  | t.Errorf("anyFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid) | 
|  | } | 
|  | if _, err := nilFile.ReadFrom(anyFile); err != ErrInvalid { | 
|  | t.Errorf("nilFile.ReadFrom(anyFile) = %v, want %v", err, ErrInvalid) | 
|  | } | 
|  | }) | 
|  | } | 
|  |  | 
|  | func testCopyFileRange(t *testing.T, size int64, limit int64) { | 
|  | dst, src, data, hook := newCopyFileRangeTest(t, size) | 
|  |  | 
|  | // If we have a limit, wrap the reader. | 
|  | var ( | 
|  | realsrc io.Reader | 
|  | lr      *io.LimitedReader | 
|  | ) | 
|  | if limit >= 0 { | 
|  | lr = &io.LimitedReader{N: limit, R: src} | 
|  | realsrc = lr | 
|  | if limit < int64(len(data)) { | 
|  | data = data[:limit] | 
|  | } | 
|  | } else { | 
|  | realsrc = src | 
|  | } | 
|  |  | 
|  | // Now call ReadFrom (through io.Copy), which will hopefully call | 
|  | // poll.CopyFileRange. | 
|  | n, err := io.Copy(dst, realsrc) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  |  | 
|  | // If we didn't have a limit, we should have called poll.CopyFileRange | 
|  | // with the right file descriptor arguments. | 
|  | if limit > 0 && !hook.called { | 
|  | t.Fatal("never called poll.CopyFileRange") | 
|  | } | 
|  | if hook.called && hook.dstfd != int(dst.Fd()) { | 
|  | t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd()) | 
|  | } | 
|  | if hook.called && hook.srcfd != int(src.Fd()) { | 
|  | t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd()) | 
|  | } | 
|  |  | 
|  | // Check that the offsets after the transfer make sense, that the size | 
|  | // of the transfer was reported correctly, and that the destination | 
|  | // file contains exactly the bytes we expect it to contain. | 
|  | dstoff, err := dst.Seek(0, io.SeekCurrent) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | srcoff, err := src.Seek(0, io.SeekCurrent) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if dstoff != srcoff { | 
|  | t.Errorf("offsets differ: dstoff = %d, srcoff = %d", dstoff, srcoff) | 
|  | } | 
|  | if dstoff != int64(len(data)) { | 
|  | t.Errorf("dstoff = %d, want %d", dstoff, len(data)) | 
|  | } | 
|  | if n != int64(len(data)) { | 
|  | t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data)) | 
|  | } | 
|  | mustSeekStart(t, dst) | 
|  | mustContainData(t, dst, data) | 
|  |  | 
|  | // If we had a limit, check that it was updated. | 
|  | if lr != nil { | 
|  | if want := limit - n; lr.N != want { | 
|  | t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want) | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | // newCopyFileRangeTest initializes a new test for copy_file_range. | 
|  | // | 
|  | // It creates source and destination files, and populates the source file | 
|  | // with random data of the specified size. It also hooks package os' call | 
|  | // to poll.CopyFileRange and returns the hook so it can be inspected. | 
|  | func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileRangeHook) { | 
|  | t.Helper() | 
|  |  | 
|  | hook = hookCopyFileRange(t) | 
|  | tmp := t.TempDir() | 
|  |  | 
|  | src, err := Create(filepath.Join(tmp, "src")) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | t.Cleanup(func() { src.Close() }) | 
|  |  | 
|  | dst, err = Create(filepath.Join(tmp, "dst")) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | t.Cleanup(func() { dst.Close() }) | 
|  |  | 
|  | // Populate the source file with data, then rewind it, so it can be | 
|  | // consumed by copy_file_range(2). | 
|  | prng := rand.New(rand.NewSource(time.Now().Unix())) | 
|  | data = make([]byte, size) | 
|  | prng.Read(data) | 
|  | if _, err := src.Write(data); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if _, err := src.Seek(0, io.SeekStart); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  |  | 
|  | return dst, src, data, hook | 
|  | } | 
|  |  | 
|  | // mustContainData ensures that the specified file contains exactly the | 
|  | // specified data. | 
|  | func mustContainData(t *testing.T, f *File, data []byte) { | 
|  | t.Helper() | 
|  |  | 
|  | got := make([]byte, len(data)) | 
|  | if _, err := io.ReadFull(f, got); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if !bytes.Equal(got, data) { | 
|  | t.Fatalf("didn't get the same data back from %s", f.Name()) | 
|  | } | 
|  | if _, err := f.Read(make([]byte, 1)); err != io.EOF { | 
|  | t.Fatalf("not at EOF") | 
|  | } | 
|  | } | 
|  |  | 
|  | func mustSeekStart(t *testing.T, f *File) { | 
|  | if _, err := f.Seek(0, io.SeekStart); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | } | 
|  |  | 
|  | func hookCopyFileRange(t *testing.T) *copyFileRangeHook { | 
|  | h := new(copyFileRangeHook) | 
|  | h.install() | 
|  | t.Cleanup(h.uninstall) | 
|  | return h | 
|  | } | 
|  |  | 
|  | type copyFileRangeHook struct { | 
|  | called bool | 
|  | dstfd  int | 
|  | srcfd  int | 
|  | remain int64 | 
|  |  | 
|  | original func(dst, src *poll.FD, remain int64) (int64, bool, error) | 
|  | } | 
|  |  | 
|  | func (h *copyFileRangeHook) install() { | 
|  | h.original = *PollCopyFileRangeP | 
|  | *PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) { | 
|  | h.called = true | 
|  | h.dstfd = dst.Sysfd | 
|  | h.srcfd = src.Sysfd | 
|  | h.remain = remain | 
|  | return h.original(dst, src, remain) | 
|  | } | 
|  | } | 
|  |  | 
|  | func (h *copyFileRangeHook) uninstall() { | 
|  | *PollCopyFileRangeP = h.original | 
|  | } | 
|  |  | 
|  | // On some kernels copy_file_range fails on files in /proc. | 
|  | func TestProcCopy(t *testing.T) { | 
|  | const cmdlineFile = "/proc/self/cmdline" | 
|  | cmdline, err := os.ReadFile(cmdlineFile) | 
|  | if err != nil { | 
|  | t.Skipf("can't read /proc file: %v", err) | 
|  | } | 
|  | in, err := os.Open(cmdlineFile) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | defer in.Close() | 
|  | outFile := filepath.Join(t.TempDir(), "cmdline") | 
|  | out, err := os.Create(outFile) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if _, err := io.Copy(out, in); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if err := out.Close(); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | copy, err := os.ReadFile(outFile) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if !bytes.Equal(cmdline, copy) { | 
|  | t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline) | 
|  | } | 
|  | } |