Class: MultiCompress::Inflater

Inherits:
Object
  • Object
show all
Includes:
InflaterDefaults
Defined in:
ext/multi_compress/multi_compress.c

Instance Method Summary collapse

Constructor Details

#initialize(*args) ⇒ Object



2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
# File 'ext/multi_compress/multi_compress.c', line 2629

static VALUE inflater_initialize(int argc, VALUE *argv, VALUE self) {
    VALUE opts;
    rb_scan_args(argc, argv, "0:", &opts);
    reject_algorithm_keyword(opts);

    inflater_t *inf;
    TypedData_Get_Struct(self, inflater_t, &inflater_type, inf);

    VALUE algo_sym = Qnil, dict_val = Qnil;
    limits_config_t limits;
    parse_limits_from_opts(opts, &limits);
    if (!NIL_P(opts)) {
        algo_sym = opt_get(opts, sym_cache.algo);
        dict_val = opt_get(opts, sym_cache.dictionary);
    }

    inf->algo = NIL_P(algo_sym) ? ALGO_ZSTD : sym_to_algo(algo_sym);
    inf->closed = 0;
    inf->finished = 0;
    inf->max_output_size = limits.max_output_size;
    inf->total_output = 0;
    inf->total_input = 0;
    inf->max_ratio_enabled = limits.max_ratio_enabled;
    inf->max_ratio = limits.max_ratio;

    dictionary_t *dict = NULL;
    if (!NIL_P(dict_val)) {
        if (inf->algo == ALGO_LZ4) {
            rb_raise(eUnsupportedError, "LZ4 does not support dictionaries");
        }
        dict = opt_dictionary(dict_val);
        dictionary_ivar_set(self, dict_val);
    }

    switch (inf->algo) {
    case ALGO_ZSTD:
        inf->ctx.zstd = ZSTD_createDStream();
        if (!inf->ctx.zstd)
            rb_raise(eMemError, "zstd: failed to create dstream");
        if (dict) {
            ZSTD_DCtx_reset(inf->ctx.zstd, ZSTD_reset_session_only);
            size_t r = ZSTD_DCtx_loadDictionary(inf->ctx.zstd, dict->data, dict->size);
            if (ZSTD_isError(r))
                rb_raise(eError, "zstd dict load: %s", ZSTD_getErrorName(r));
        } else {
            ZSTD_initDStream(inf->ctx.zstd);
        }
        break;
    case ALGO_BROTLI:
        inf->ctx.brotli = BrotliDecoderCreateInstance(NULL, NULL, NULL);
        if (!inf->ctx.brotli)
            rb_raise(eMemError, "brotli: failed to create decoder");
        if (dict) {
            BrotliDecoderAttachDictionary(inf->ctx.brotli, BROTLI_SHARED_DICTIONARY_RAW, dict->size,
                                          dict->data);
        }
        break;
    case ALGO_LZ4:
        inf->lz4_buf.cap = 16 * 1024;
        inf->lz4_buf.buf = ALLOC_N(char, inf->lz4_buf.cap);
        inf->lz4_buf.len = 0;
        inf->lz4_buf.offset = 0;
        break;
    }

    return self;
}

Instance Method Details

#closeObject



3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
# File 'ext/multi_compress/multi_compress.c', line 3023

static VALUE inflater_close(VALUE self) {
    inflater_t *inf;
    TypedData_Get_Struct(self, inflater_t, &inflater_type, inf);
    if (inf->closed)
        return Qnil;

    switch (inf->algo) {
    case ALGO_ZSTD:
        if (inf->ctx.zstd) {
            ZSTD_freeDStream(inf->ctx.zstd);
            inf->ctx.zstd = NULL;
        }
        break;
    case ALGO_BROTLI:
        if (inf->ctx.brotli) {
            BrotliDecoderDestroyInstance(inf->ctx.brotli);
            inf->ctx.brotli = NULL;
        }
        break;
    case ALGO_LZ4:
        break;
    }
    inf->closed = 1;
    return Qnil;
}

#closed?Boolean

Returns:

  • (Boolean)


3049
3050
3051
3052
3053
# File 'ext/multi_compress/multi_compress.c', line 3049

static VALUE inflater_closed_p(VALUE self) {
    inflater_t *inf;
    TypedData_Get_Struct(self, inflater_t, &inflater_type, inf);
    return inf->closed ? Qtrue : Qfalse;
}

#finishObject



2969
2970
2971
2972
2973
2974
2975
2976
# File 'ext/multi_compress/multi_compress.c', line 2969

static VALUE inflater_finish(VALUE self) {
    inflater_t *inf;
    TypedData_Get_Struct(self, inflater_t, &inflater_type, inf);
    if (inf->closed)
        rb_raise(eStreamError, "stream is closed");
    inf->finished = 1;
    return rb_binary_str_new("", 0);
}

#resetObject



2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
# File 'ext/multi_compress/multi_compress.c', line 2978

static VALUE inflater_reset(VALUE self) {
    inflater_t *inf;
    TypedData_Get_Struct(self, inflater_t, &inflater_type, inf);

    VALUE dict_val = dictionary_ivar_get(self);
    dictionary_t *dict = NULL;
    if (!NIL_P(dict_val)) {
        TypedData_Get_Struct(dict_val, dictionary_t, &dictionary_type, dict);
    }

    switch (inf->algo) {
    case ALGO_ZSTD:
        if (inf->ctx.zstd) {
            ZSTD_DCtx_reset(inf->ctx.zstd, ZSTD_reset_session_only);
            if (dict) {
                size_t r = ZSTD_DCtx_loadDictionary(inf->ctx.zstd, dict->data, dict->size);
                if (ZSTD_isError(r))
                    rb_raise(eError, "zstd dict reload on reset: %s", ZSTD_getErrorName(r));
            }
        }
        break;
    case ALGO_BROTLI:
        if (inf->ctx.brotli) {
            BrotliDecoderDestroyInstance(inf->ctx.brotli);
            inf->ctx.brotli = BrotliDecoderCreateInstance(NULL, NULL, NULL);
            if (!inf->ctx.brotli)
                rb_raise(eMemError, "brotli: failed to recreate decoder");
            if (dict) {
                BrotliDecoderAttachDictionary(inf->ctx.brotli, BROTLI_SHARED_DICTIONARY_RAW,
                                              dict->size, dict->data);
            }
        }
        break;
    case ALGO_LZ4:
        inf->lz4_buf.len = 0;
        inf->lz4_buf.offset = 0;
        break;
    }
    inf->closed = 0;
    inf->finished = 0;
    inf->total_output = 0;
    inf->total_input = 0;
    return self;
}

#write(chunk) ⇒ Object



2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
# File 'ext/multi_compress/multi_compress.c', line 2697

static VALUE inflater_write(VALUE self, VALUE chunk) {
    inflater_t *inf;
    TypedData_Get_Struct(self, inflater_t, &inflater_type, inf);
    if (inf->closed)
        rb_raise(eStreamError, "stream is closed");
    StringValue(chunk);

    const char *src = RSTRING_PTR(chunk);
    size_t slen = RSTRING_LEN(chunk);
    const algo_policy_t *policy = algo_policy(inf->algo);
    if (slen == 0)
        return rb_binary_str_new("", 0);

    size_t input_accounted_before = inf->total_input;

    switch (inf->algo) {
    case ALGO_ZSTD: {
        ZSTD_inBuffer input = {src, slen, 0};
        size_t out_cap = ZSTD_DStreamOutSize();
        size_t result_cap = out_cap > slen * 2 ? out_cap : slen * 2;
        size_t remaining_total_budget =
            inf->max_output_size > inf->total_output ? inf->max_output_size - inf->total_output : 0;
        if (remaining_total_budget == 0)
            rb_raise(eDataError, "decompressed output exceeds limit (%zu bytes)",
                     inf->max_output_size);
        if (result_cap > remaining_total_budget)
            result_cap = remaining_total_budget;
        VALUE result = rb_binary_str_buf_reserve(result_cap);
        size_t result_len = 0;
        VALUE scheduler = current_fiber_scheduler();

        while (input.pos < input.size) {
            size_t remaining_budget = inf->max_output_size - inf->total_output - result_len;
            if (remaining_budget == 0)
                rb_raise(eDataError, "decompressed output exceeds limit (%zu bytes)",
                         inf->max_output_size);

            if (result_len + out_cap > result_cap) {
                size_t next_cap = result_cap * 2;
                if (next_cap > inf->max_output_size - inf->total_output)
                    next_cap = inf->max_output_size - inf->total_output;
                result_cap = next_cap;
                rb_str_resize(result, result_cap);
            }

            size_t current_out_cap = out_cap;
            if (current_out_cap > remaining_budget)
                current_out_cap = remaining_budget;

            ZSTD_outBuffer output = {RSTRING_PTR(result) + result_len, current_out_cap, 0};
            size_t ret;

            if (select_fiber_or_direct_mode(scheduler, input.size - input.pos,
                                            policy->fiber_stream_threshold) == WORK_EXEC_FIBER) {
                zstd_decompress_stream_chunk_fiber_t args = {
                    .dstream = inf->ctx.zstd,
                    .output = output,
                    .input = input,
                    .result = 0,
                };
                RUN_VIA_FIBER_WORKER(zstd_decompress_stream_chunk_fiber_nogvl, args);
                output.pos = args.output.pos;
                input.pos = args.input.pos;
                ret = args.result;
            } else {
                ret = ZSTD_decompressStream(inf->ctx.zstd, &output, &input);
            }

            if (ZSTD_isError(ret))
                rb_raise(eDataError, "zstd decompress stream: %s", ZSTD_getErrorName(ret));
            result_len = checked_add_size(result_len, output.pos,
                                          "decompressed output exceeds representable size");
            size_t total_output = checked_add_size(
                inf->total_output, result_len, "decompressed output exceeds representable size");
            size_t total_input = checked_add_size(input_accounted_before, input.pos,
                                                  "compressed input exceeds representable size");
            enforce_output_and_ratio_limits(total_output, total_input, inf->max_output_size,
                                            inf->max_ratio_enabled, inf->max_ratio);
            if (ret == 0)
                break;
        }
        inf->total_input = checked_add_size(input_accounted_before, input.pos,
                                            "compressed input exceeds representable size");
        inf->total_output = checked_add_size(inf->total_output, result_len,
                                             "decompressed output exceeds representable size");
        rb_str_set_len(result, result_len);
        RB_GC_GUARD(chunk);
        return result;
    }
    case ALGO_BROTLI: {
        size_t available_in = slen;
        const uint8_t *next_in = (const uint8_t *)src;
        size_t remaining_total_budget =
            inf->max_output_size > inf->total_output ? inf->max_output_size - inf->total_output : 0;
        if (remaining_total_budget == 0)
            rb_raise(eDataError, "decompressed output exceeds limit (%zu bytes)",
                     inf->max_output_size);
        size_t result_cap = slen * 2;
        if (result_cap < 1024)
            result_cap = 1024;
        if (result_cap > remaining_total_budget)
            result_cap = remaining_total_budget;
        VALUE result = rb_binary_str_buf_reserve(result_cap);
        size_t result_len = 0;
        VALUE scheduler = current_fiber_scheduler();

        while (available_in > 0 || BrotliDecoderHasMoreOutput(inf->ctx.brotli)) {
            size_t available_out = 0;
            uint8_t *next_out = NULL;
            BrotliDecoderResult res;

            if (select_fiber_or_direct_mode(scheduler, available_in,
                                            policy->fiber_stream_threshold) == WORK_EXEC_FIBER) {
                brotli_decompress_stream_fiber_t sargs = {
                    .dec = inf->ctx.brotli,
                    .available_in = available_in,
                    .next_in = next_in,
                    .available_out = available_out,
                    .next_out = next_out,
                    .result = BROTLI_DECODER_RESULT_ERROR,
                };
                RUN_VIA_FIBER_WORKER(brotli_decompress_stream_fiber_nogvl, sargs);
                available_in = sargs.available_in;
                next_in = sargs.next_in;
                available_out = sargs.available_out;
                next_out = sargs.next_out;
                res = sargs.result;
            } else {
                res = BrotliDecoderDecompressStream(inf->ctx.brotli, &available_in, &next_in,
                                                    &available_out, &next_out, NULL);
            }
            if (res == BROTLI_DECODER_RESULT_ERROR)
                rb_raise(eDataError, "brotli decompress stream: %s",
                         BrotliDecoderErrorString(BrotliDecoderGetErrorCode(inf->ctx.brotli)));
            const uint8_t *out_data;
            size_t out_size = 0;
            out_data = BrotliDecoderTakeOutput(inf->ctx.brotli, &out_size);
            if (out_size > 0) {
                size_t total_output = checked_add_size(
                    inf->total_output,
                    checked_add_size(result_len, out_size,
                                     "decompressed output exceeds representable size"),
                    "decompressed output exceeds representable size");
                size_t total_input =
                    checked_add_size(input_accounted_before, slen - available_in,
                                     "compressed input exceeds representable size");
                enforce_output_and_ratio_limits(total_output, total_input, inf->max_output_size,
                                                inf->max_ratio_enabled, inf->max_ratio);

                if (result_len + out_size > result_cap) {
                    result_cap = result_len + out_size;
                    rb_str_resize(result, result_cap);
                }

                memcpy(RSTRING_PTR(result) + result_len, out_data, out_size);
                result_len += out_size;
            }
            if (res == BROTLI_DECODER_RESULT_SUCCESS)
                break;
            if (res == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT && available_in == 0)
                break;
        }
        inf->total_input = checked_add_size(input_accounted_before, slen - available_in,
                                            "compressed input exceeds representable size");
        inf->total_output = checked_add_size(inf->total_output, result_len,
                                             "decompressed output exceeds representable size");
        rb_str_set_len(result, result_len);
        RB_GC_GUARD(chunk);
        return result;
    }
    case ALGO_LZ4: {
        size_t data_len = inf->lz4_buf.len - inf->lz4_buf.offset;
        size_t needed = data_len + slen;
        // TODO(v0.4): optional standard LZ4 frame format support via lz4frame.h

        if (inf->lz4_buf.offset > 0 && needed > inf->lz4_buf.cap) {
            if (data_len > 0)
                memmove(inf->lz4_buf.buf, inf->lz4_buf.buf + inf->lz4_buf.offset, data_len);
            inf->lz4_buf.offset = 0;
            inf->lz4_buf.len = data_len;
        } else if (inf->lz4_buf.offset > inf->lz4_buf.cap / 2) {
            if (data_len > 0)
                memmove(inf->lz4_buf.buf, inf->lz4_buf.buf + inf->lz4_buf.offset, data_len);
            inf->lz4_buf.offset = 0;
            inf->lz4_buf.len = data_len;
        }

        needed = inf->lz4_buf.len + slen;
        if (needed > inf->lz4_buf.cap) {
            inf->lz4_buf.cap = needed * 2;
            REALLOC_N(inf->lz4_buf.buf, char, inf->lz4_buf.cap);
        }
        memcpy(inf->lz4_buf.buf + inf->lz4_buf.len, src, slen);
        inf->lz4_buf.len += slen;

        size_t remaining_total_budget =
            inf->max_output_size > inf->total_output ? inf->max_output_size - inf->total_output : 0;
        if (remaining_total_budget == 0)
            rb_raise(eDataError, "decompressed output exceeds limit (%zu bytes)",
                     inf->max_output_size);
        size_t result_cap = slen * 4;
        if (result_cap < 256)
            result_cap = 256;
        if (result_cap > remaining_total_budget)
            result_cap = remaining_total_budget;
        VALUE result = rb_binary_str_buf_new(result_cap);
        size_t result_len = 0;
        int use_fiber = has_fiber_scheduler();
        size_t fiber_counter = 0;

        size_t pos = inf->lz4_buf.offset;
        while (pos + 4 <= inf->lz4_buf.len) {
            const uint8_t *p = (const uint8_t *)(inf->lz4_buf.buf + pos);
            uint32_t orig_size = read_le_u32(p);
            if (orig_size == 0) {
                inf->finished = 1;
                pos += 4;
                break;
            }
            if (pos + 8 > inf->lz4_buf.len)
                break;
            uint32_t comp_size = read_le_u32(p + 4);
            if (pos + 8 + comp_size > inf->lz4_buf.len)
                break;
            if (orig_size > 64 * 1024 * 1024)
                rb_raise(eDataError, "lz4 stream: block too large (%u)", orig_size);

            size_t total_output =
                checked_add_size(inf->total_output,
                                 checked_add_size(result_len, orig_size,
                                                  "decompressed output exceeds representable size"),
                                 "decompressed output exceeds representable size");
            size_t total_input = checked_add_size(
                input_accounted_before, (pos + 8 + (size_t)comp_size) - inf->lz4_buf.offset,
                "compressed input exceeds representable size");
            enforce_output_and_ratio_limits(total_output, total_input, inf->max_output_size,
                                            inf->max_ratio_enabled, inf->max_ratio);

            if (result_len + orig_size > result_cap) {
                result_cap = result_len + orig_size;
                rb_str_resize(result, result_cap);
            }

            int dsize =
                LZ4_decompress_safe(inf->lz4_buf.buf + pos + 8, RSTRING_PTR(result) + result_len,
                                    (int)comp_size, (int)orig_size);
            if (dsize < 0)
                rb_raise(eDataError, "lz4 stream decompress block failed");

            result_len += dsize;
            pos += 8 + comp_size;
            if (use_fiber) {
                int did_yield = 0;
                fiber_counter = fiber_maybe_yield(fiber_counter, (size_t)dsize,
                                                  policy->fiber_yield_chunk, &did_yield);
                (void)did_yield;
            }
        }

        inf->total_input = checked_add_size(input_accounted_before, pos - inf->lz4_buf.offset,
                                            "compressed input exceeds representable size");
        inf->lz4_buf.offset = pos;
        inf->total_output = checked_add_size(inf->total_output, result_len,
                                             "decompressed output exceeds representable size");
        rb_str_set_len(result, result_len);
        RB_GC_GUARD(chunk);
        return result;
    }
    }
    return rb_binary_str_new("", 0);
}