import os
import socket
from time import sleep, time
import sys
import tempfile
from urllib.parse import urlparse

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.distributed.optim import ZeroRedundancyOptimizer
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

#from torch.profiler import profile, ProfilerActivity, record_function

# From UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication
# available but not enabled.
#torch.set_float32_matmul_precision('high')

# Allow triton to choose precompiled kernels from aotriton?
#torch._inductor.config.max_autotune = True

# Try to make more consistent for reproducibility
init_seed = 19311

torch.manual_seed(init_seed)
torch.cuda.manual_seed(init_seed)
torch.cuda.manual_seed_all(init_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


class ToyModel(nn.Module):
    def __init__(self, apply_regional_compilation=False):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        if apply_regional_compilation:
            self.relu = torch.compile(nn.ReLU())
        else:
            self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank):
    local_hostname = socket.gethostname()
    if rank == 0:
        print("Rank 0 on node {}".format(local_hostname))
    # compute our local rank on the node and select a corresponding gpu,
    # this assumes we started exactly one rank per gpu on the node
    ngpus_per_node = torch.cuda.device_count()
    print(f'ngpus per node {ngpus_per_node}')
    local_rank = rank % ngpus_per_node
    print(f'Rank: {rank} on host {local_hostname} has local_rank: {local_rank}')
    #activities = [ProfilerActivity.CPU]
    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)
        device = torch.device('cuda')
        #activities += [ProfilerActivity.CUDA]
        print(f'Rank: {rank} Torch on host {local_hostname} has cuda device {torch.cuda.current_device()}')
    else:
        device = torch.device('cpu')

    sort_by_keyword = str(device) + "_time_total"

    #model = ToyModel(apply_regional_compilation=True).to(device)
    model = torch.compile(ToyModel().to(device), dynamic=True)
    #model = ToyModel().to(device)
    ddp_model = DDP(model, device_ids=[device])
    #my_auto_wrap_policy = size_based_auto_wrap_policy
    #ddp_model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy)
    #ddp_model.to_device(device_ids=[device])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    batch_size = 32

    n_data_per_rank = batch_size*600

    image_size = 256

    nc = 1

    ndf = 256

    num_epochs = 10

    lr = 0.0002

    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()
            self.main = nn.Sequential(
                # input is (nc) x 256 x 256
                nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (ndf) x 128 x 128
                nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 2),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (ndf*2) x 64 x 64
                nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 4),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (ndf*4) x 32 x 32
                nn.Conv2d(ndf * 4, ndf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 4),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (ndf*4) x 16 x 16
                nn.Conv2d(ndf * 4, ndf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 4),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (ndf*4) x 8 x 8
                nn.Conv2d(ndf * 4, ndf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 4),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (ndf*4) x 4 x 4
                nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
                nn.Sigmoid()
            )

        def forward(self, input):
            return self.main(input)

    X = torch.rand((n_data_per_rank, nc, image_size, image_size))
    Y = torch.randint(low=0, high=2, size=(n_data_per_rank, 1, 1, 1)).float()

    model = Discriminator()
    model.to(device)  # move to GPU
    model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
    model.to(device)

    criterion = nn.BCELoss()

    optimizer = ZeroRedundancyOptimizer(
            model.parameters(),
            optimizer_class=torch.optim.Adam,
            lr=lr
            )

    train_dataset = TensorDataset(X, Y)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
        )

    if rank == 0:
        print("Starting Training Loop...")

    #with profile(activities=activities, profile_memory=True, record_shapes=True) as prof:
    #with profile(activities=[]) as prof:
    for epoch in range(num_epochs):
        t0 = time()
        #with record_function(f'training_iteration_{epoch:03d}'):

        for batch_idx, (data, target) in enumerate(train_loader):

            optimizer.zero_grad()
            data = data.to(device)
            target = target.to(device)
            outputs = ddp_model(torch.randn(20, 10))
            labels = torch.randn(20, 5).to(device)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
        t1 = time()
        #print(f'Rank {rank} Epoch {epoch:03d} Loss {loss:0.4f} Time {t1-t0:0.4f}')
        if rank == 0:
            print(f'Epoch {epoch:03d} Loss {loss:0.4f} Time {t1-t0:0.4f}')

    #print(f'Rank: {rank}\n' + prof.key_averages().table(sort_by=sort_by_keyword, row_limit=30))

    #print(f'Rank: {rank}\n' + prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=30))

    #prof.export_chrome_trace("trace.json")


def spmd_main():
    # equivalent to MPI init.
    torch.distributed.init_process_group(
        "nccl",
        init_method="env://"
    )
    # lookup number of ranks in the job, and our rank
    size = torch.distributed.get_world_size()
    rank = torch.distributed.get_rank()
    demo_basic(rank)

    # Tear down the process group
    dist.barrier()
    dist.destroy_process_group()


if __name__ == "__main__":
    # The main entry point is called directly without using subprocess
    spmd_main()
