summaryrefslogtreecommitdiff
path: root/zen/stream_buffer.h
blob: 8b8cd0d7d8539f1145096cb00f3e3b914c2b2bff (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
// *****************************************************************************
// * This file is part of the FreeFileSync project. It is distributed under    *
// * GNU General Public License: https://www.gnu.org/licenses/gpl-3.0          *
// * Copyright (C) Zenju (zenju AT freefilesync DOT org) - All Rights Reserved *
// *****************************************************************************

#ifndef STREAM_BUFFER_H_08492572089560298
#define STREAM_BUFFER_H_08492572089560298

#include <condition_variable>
#include "ring_buffer.h"
#include "string_tools.h"


namespace zen
{
/*   implement streaming API on top of libcurl's icky callback-based design
        => support copying arbitrarily-large files: https://freefilesync.org/forum/viewtopic.php?t=4471
        => maximum performance through async processing (prefetching + output buffer!)
        => cost per worker thread creation ~ 1/20 ms                                         */
class AsyncStreamBuffer
{
public:
    explicit AsyncStreamBuffer(size_t bufferSize) : bufSize_(bufferSize) { ringBuf_.reserve(bufferSize); }

    //context of input thread, blocking
    //return "bytesToRead" bytes unless end of stream!
    size_t read(void* buffer, size_t bytesToRead) //throw <write error>
    {
        if (bytesToRead == 0) //"read() with a count of 0 returns zero" => indistinguishable from end of file! => check!
            throw std::logic_error("Contract violation! " + std::string(__FILE__) + ':' + numberTo<std::string>(__LINE__));

        auto       it    = static_cast<std::byte*>(buffer);
        const auto itEnd = it + bytesToRead;

        for (std::unique_lock dummy(lockStream_); it != itEnd;)
        {
            assert(!errorRead_);
            conditionBytesWritten_.wait(dummy, [this] { return errorWrite_ || !ringBuf_.empty() || eof_; });

            if (errorWrite_)
                std::rethrow_exception(errorWrite_); //throw <write error>

            const size_t junkSize = std::min(static_cast<size_t>(itEnd - it), ringBuf_.size());
            ringBuf_.extract_front(it, it + junkSize);
            it += junkSize;

            conditionBytesRead_.notify_all();

            if (eof_) //end of file
                break;
        }

        const size_t bytesRead = it - static_cast<std::byte*>(buffer);
        totalBytesRead_ += bytesRead;
        return bytesRead;
    }

    //context of output thread, blocking
    void write(const void* buffer, size_t bytesToWrite) //throw <read error>
    {
        totalBytesWritten_ += bytesToWrite; //bytes already processed as far as raw FTP access is concerned

        auto       it    = static_cast<const std::byte*>(buffer);
        const auto itEnd = it + bytesToWrite;

        for (std::unique_lock dummy(lockStream_); it != itEnd;)
        {
            assert(!eof_ && !errorWrite_);
            /*  => can't use InterruptibleThread's interruptibleWait() :(
                -> AsyncStreamBuffer is used for input and output streaming
                => both AsyncStreamBuffer::write()/read() would have to implement interruptibleWait()
                => one of these usually called from main thread
                => but interruptibleWait() cannot be called from main thread!          */
            conditionBytesRead_.wait(dummy, [this] { return errorRead_ || ringBuf_.size() < bufSize_; });

            if (errorRead_)
                std::rethrow_exception(errorRead_); //throw <read error>

            const size_t junkSize = std::min(static_cast<size_t>(itEnd - it), bufSize_ - ringBuf_.size());
            ringBuf_.insert_back(it, it + junkSize);
            it += junkSize;

            conditionBytesWritten_.notify_all();
        }
    }

    //context of output thread
    void closeStream()
    {
        {
            std::lock_guard dummy(lockStream_);
            assert(!eof_ && !errorWrite_);
            eof_ = true;
        }
        conditionBytesWritten_.notify_all();
    }

    //context of input thread
    void setReadError(const std::exception_ptr& error)
    {
        {
            std::lock_guard dummy(lockStream_);
            assert(!errorRead_);
            if (!errorRead_)
                errorRead_ = error;
        }
        conditionBytesRead_.notify_all();
    }

    //context of output thread
    void setWriteError(const std::exception_ptr& error)
    {
        {
            std::lock_guard dummy(lockStream_);
            assert(!errorWrite_);
            if (!errorWrite_)
                errorWrite_ = error;
        }
        conditionBytesWritten_.notify_all();
    }

    //context of *output* thread
    void checkReadErrors() //throw <read error>
    {
        std::lock_guard dummy(lockStream_);
        if (errorRead_)
            std::rethrow_exception(errorRead_); //throw <read error>
    }

#if 0 //function not needed: when EOF is reached (without errors), reading is done => no further error can occur!
    void checkWriteErrors() //throw <write error>
    {
        std::lock_guard dummy(lockStream_);
        if (errorWrite_)
            std::rethrow_exception(errorWrite_); //throw <write error>
    }
#endif

    uint64_t getTotalBytesWritten() const { return totalBytesWritten_; }
    uint64_t getTotalBytesRead   () const { return totalBytesRead_; }

private:
    AsyncStreamBuffer           (const AsyncStreamBuffer&) = delete;
    AsyncStreamBuffer& operator=(const AsyncStreamBuffer&) = delete;

    const size_t bufSize_;
    std::mutex lockStream_;
    RingBuffer<std::byte> ringBuf_; //prefetch/output buffer
    bool eof_ = false;
    std::exception_ptr errorWrite_;
    std::exception_ptr errorRead_;
    std::condition_variable conditionBytesWritten_;
    std::condition_variable conditionBytesRead_;

    std::atomic<uint64_t> totalBytesWritten_{0}; //std:atomic is uninitialized by default!
    std::atomic<uint64_t> totalBytesRead_   {0}; //
};
}

#endif //STREAM_BUFFER_H_08492572089560298
bgstack15