#!/usr/bin/env python3
import os
from sys import argv
from ofunctions import bisection
from typing import List, Tuple, Union, Iterable, Optional
from ipaddress import IPv6Address, AddressValueError
from command_runner import command_runner
from requests import get
import socket
import warnings
import logging

def ping(targets: Union[Iterable[str], str] = None, mtu: int = 1200, retries: int = 2,
         timeout: float = 4, interval: float = 1, ip_type: int = None, do_not_fragment: bool = False,
         all_targets_must_succeed: bool = False) -> bool:
    """
    Tests if ICMP ping works
    IF all targets_must_succeed is False, at least one good result gives a positive result
    targets: can be a list of targets, or a single targets
    timeout: is in seconds
    interval: is in seconds seconds, linux only
    """

    icmp_overhead = 8 + 20
    mtu_encapsulated = mtu - icmp_overhead

    # Let's have a maximum process timeout for subprocess of 5 seconds extra ontop of the ping timeout
    # timeout is in seconds (int)
    command_timeout = int(timeout + 5)
    # windows uses timeout in milliseconds
    windows_timeout = timeout * 1000

    if mtu_encapsulated < 0:
        raise ValueError('MTU cannot be lower than {}.'.format(icmp_overhead))

    if targets is None:
        # Cloudflare, Google and OpenDNS dns servers
        targets = ['1.1.1.1', '8.8.8.8', '208.67.222.222']

    def _try_server(target):
        nonlocal retries

        if os.name == 'nt':
            # -4/-6: IPType
            # -n ...: number of packets to send
            # -f: do not fragment
            # -l ...: packet size to send
            # -w ...: timeout (ms)
            command = 'ping -n 1 -l {} -w {}'.format(mtu_encapsulated, windows_timeout)

            # IPv6 does not allow to set fragmentation
            if do_not_fragment and ip_type != 6:
                command += ' -f'
            encoding = 'cp437'
        else:
            # -4/-6: IPType
            # -c ...: number of packets to send
            # -M do: do not fragment
            # -s ...: packet size to send
            # -i ...: interval (s), only root can set less than .2 seconds
            # -W ...: timeous (s)
            command = 'ping -c 1 -s {} -W {} -i {}'.format(mtu_encapsulated, timeout, interval)

            # IPv6 does not allow to set fragmentation
            if do_not_fragment and ip_type != 6:
                command += ' -M do'
            encoding = 'utf-8'

        # Add ip_type if specified
        if ip_type:
            command += ' -{}'.format(ip_type)
        command += ' {}'.format(target)

        result = False
        while retries > 0 and not result:
            exit_code, output = command_runner(command, timeout=command_timeout, encoding=encoding)
            if exit_code == 0:
                return True
            retries -= 1
        return False

    if all_targets_must_succeed:
        all_ping_results = True
    else:
        all_ping_results = False

    # Handle the case when a user gives a single target instead of a list
    for target in targets if isinstance(targets, list) else [targets]:
        if _try_server(target):
            if not all_targets_must_succeed:
                all_ping_results = True
                break
        else:
            if all_targets_must_succeed:
                all_ping_results = False
                break

    return all_ping_results

def probe_mtu(target: str, method: str = 'ICMP', min: int = 1100, max: int = 9000):
    """
    Detects MTU to target
    Probing can take up to 15-20 seconds
    MTU 65536 bytes is maxiumal value
    Standard values are
      1500 for ethernet over WAN
      1492 for ethernet over ADSL
      9000 for ethernet over LAN with jumbo frames
      13xx for ethernet over 3G/4G
    """

    if method == 'ICMP':
        # Let's always keep 2 retries just to make sure we don't get false positives
        # timeout = 4, interval = 1, ip_type is detected
        ip_type = 4
        try:
            IPv6Address(target)
            ip_type = 6
        except AddressValueError:
            # Let's assume it's IPv4:
            pass

        ping_args = [(target, mtu, 2, 4, 1, ip_type, True) for mtu in range(min, max + 1)]

        # Bisect will return argument, list, let's just return the MTU
        return bisection.bisect(ping, ping_args, allow_all_expected=True)[1]
    else:
        raise ValueError("Method {} not implemented yet".format(method))

print('getting mtu for: ' + argv[1])
print('please hold this will take a while \r')
print(probe_mtu(argv[1]))