import vapoursynth as vs
from vapoursynth import core

def FixFrameDurations(out, src, num , den):
    src0 = core.std.Interleave([src] * num)
    src1 = src.std.Trim(first=1)
    src1 = core.std.Interleave([src1] * num)
    if den>1:
        src0 = src0.std.SelectEvery(cycle=den, offsets=0)
        src1 = src1.std.SelectEvery(cycle=den, offsets=0)

    def set_duration(n, f):
        src_idx = 0
        #if '_DoVi' in f[1].props and '_DoVi' in f[2].props:
        #      if f[1].props['_DoVi'] != f[2].props['_DoVi']:
        #            src_idx = 1
        fout = f[src_idx].copy()
        if not '_DurationNum' in f[1].props or not '_DurationNum' in f[2].props:
              return fout
        dur_rel = den / num
        tm = n * dur_rel - int(n * dur_rel)
        dur_src0 = f[1].props['_DurationNum'] / f[1].props['_DurationDen'] / dur_rel
        dur_src1 = f[2].props['_DurationNum'] / f[2].props['_DurationDen'] / dur_rel
        dur = dur_src0 * dur_rel if tm + dur_rel <= 1.0 else dur_src0 * (1.0 - tm) + dur_src1 * (tm + dur_rel - 1.0)
        fout.props['_DurationNum'] = int(dur*10000000)
        fout.props['_DurationDen'] = 10000000
        return fout
    res = core.std.ModifyFrame(clip=out, clips=[out,src0,src1], selector=set_duration)
    return res

import re
from vsmlrt import RIFE

def RIFE_imp(clip: vs.VideoNode,multi,model,backend) -> vs.VideoNode:
    p = re.split('\\.|_',model)
    model_num = int(p[0]) * (10 if len(p[1])==1 else 100) + int(p[1])
    if model.find('lite') > 0:
        model_num = model_num*10 + 1
    if model_num < 30:
        model_num = 46
    ensemble = model.find('ensemble') > 0
    implementation = 2 if model.endswith('_v2') else 1

    backend.force_fp16 = True
    backend.output_format = 1
    backend.use_cuda_graph = True
    backend.workspace = None
    if implementation == 2:
        backend.force_fp16 = False
        backend.fp16 = True

    return RIFE(clip,multi,1.0,None,None,None,model_num,backend,ensemble,implementation)
