def mp_data(args, idx, queue_data):
    args.init_before_training(if_main=False)

    data_path = args.data_path
    img_shape = args.img_shape

    if_one_hot = args.if_one_hot

    batch_sizes = args.batch_sizes
    train_epochs = args.train_epochs
    del args

    device = torch.device(f"cuda:{idx}" if torch.cuda.is_available() else 'cpu')

    '''load data'''
    train_imgs, train_labs, eval__imgs, eval__labs = load_data_ary(data_path, img_shape, if_one_hot)
    train_imgs = torch.as_tensor(train_imgs, dtype=torch.float32, device=device)
    # eval__imgs = torch.as_tensor(eval__imgs, dtype=torch.float32, device=device)

    label_data_type = torch.float32 if if_one_hot else torch.long
    train_labs = torch.as_tensor(train_labs, dtype=label_data_type, device=device)
    # eval__labs = torch.as_tensor(eval__labs, dtype=label_data_type, device=device)
    del label_data_type

    train_len = train_imgs.shape[0]
    # eval__len = eval__imgs.shape[0]
    # eval_size = min(2 ** 12, eval__len)
    del eval__imgs, eval__labs

    queue_data.put(train_len)
    # train_len = queue_data.get()

    '''training loop'''
    learning_rates = [1e-3, ] * len(train_epochs) + [1e-4, ] * 2 + [1e-5, ] * 2
    train_epochs.extend(train_epochs[-1:] * (len(learning_rates) - len(train_epochs)))
    batch_sizes.extend(batch_sizes[-1:] * (len(learning_rates) - len(batch_sizes)))

    for train_epoch, batch_size, learning_rate in zip(train_epochs, batch_sizes, learning_rates):
        for epoch in range(train_epoch):
            train_time = int(train_len / batch_size)
            for i in range(train_time):
                ids = rd.randint(train_len, size=batch_size)
                inp = train_imgs[ids]
                lab = train_labs[ids]

                queue_data.put((inp, lab))
                # inp, lab = queue_data.get()

    '''safe exit'''
    while queue_data.qsize() > 0:
        time.sleep(2)
def train_and_evaluate(args):
    net_class = args.net_class
    data_path = args.data_path
    img_shape = args.img_shape

    if_amp = args.if_amp
    if_one_hot = args.if_one_hot

    mid_dim = args.mid_dim
    mod_dir = args.mod_dir
    gpu_id = args.gpu_id
    train_epochs = args.train_epochs
    batch_sizes = args.batch_sizes
    show_gap = args.show_gap
    eval_gap = args.eval_gap
    del args

    whether_remove_history(mod_dir, remove=True)
    '''init env'''
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)  # choose GPU:0
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    np.random.seed(1943 + int(time.time()))
    torch.manual_seed(1943 + rd.randint(0, int(time.time())))
    '''load data'''
    train_imgs, train_labs, eval__imgs, eval__labs = load_data_ary(
        data_path, img_shape, if_one_hot)
    train_imgs = torch.as_tensor(train_imgs,
                                 dtype=torch.float32,
                                 device=device)
    eval__imgs = torch.as_tensor(eval__imgs,
                                 dtype=torch.float32,
                                 device=device)

    label_data_type = torch.float32 if if_one_hot else torch.long
    train_labs = torch.as_tensor(train_labs,
                                 dtype=label_data_type,
                                 device=device)
    eval__labs = torch.as_tensor(eval__labs,
                                 dtype=label_data_type,
                                 device=device)
    del label_data_type
    train_len = train_imgs.shape[0]
    eval__len = eval__imgs.shape[0]
    eval_size = min(2**12, eval__len)
    '''train model'''
    model = net_class(mid_dim, img_shape).to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-3,
                                 weight_decay=1e-5)
    from torch.nn import functional
    criterion = torch.nn.SmoothL1Loss() if if_one_hot else functional.nll_loss
    amp_scale = torch.cuda.amp.GradScaler()
    '''evaluator'''
    evaluator = Evaluator(eval__imgs, eval__labs, eval_size, eval_gap,
                          show_gap, criterion)
    save_path = f'{mod_dir}/net.pth'
    '''if_amp'''
    def gradient_decent_original():
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), max_norm=3.0)
        optimizer.step()

    def gradient_decent_amp():  # automatic mixed precision
        optimizer.zero_grad()
        amp_scale.scale(loss).backward()  # loss.backward()
        amp_scale.unscale_(optimizer)  # amp, clip_grad_norm_
        clip_grad_norm_(model.parameters(),
                        max_norm=3.0)  # amp, clip_grad_norm_
        amp_scale.step(optimizer)  # optimizer.step()
        amp_scale.update()  # optimizer.step()

    gradient_decent = gradient_decent_amp if if_amp else gradient_decent_original

    print("Train Loop:")
    learning_rates = [
        1e-3,
    ] * len(train_epochs) + [
        1e-4,
    ] * 2 + [
        1e-5,
    ] * 2
    train_epochs.extend(train_epochs[-1:] *
                        (len(learning_rates) - len(train_epochs)))
    batch_sizes.extend(batch_sizes[-1:] *
                       (len(learning_rates) - len(batch_sizes)))

    for train_epoch, batch_size, learning_rate in zip(train_epochs,
                                                      batch_sizes,
                                                      learning_rates):
        optimizer.param_groups[0]['lr'] = learning_rate
        for epoch in range(train_epoch):
            loss_sum = 0
            model.train()
            '''train_it'''
            train_time = int(train_len / batch_size)
            for i in range(train_time):
                ids = rd.randint(train_len, size=batch_size)
                inp = train_imgs[ids]
                lab = train_labs[ids]

                out = model(inp)

                loss = criterion(torch.softmax(out, dim=1), lab)
                gradient_decent()
                loss_sum += loss.item()
            loss_avg = loss_sum / train_time

            evaluator.evaluate(model, batch_size, train_epoch, epoch, loss_avg)
    evaluator.final_print()

    torch.save(model.state_dict(), save_path)
    file_size = os.path.getsize(save_path) / (2**20)  # Byte --> KB --> MB
    print(f"\nSave: {mod_dir} | {file_size:.2f} MB")
def mp_evaluate(args, pipe_eva1):
    args.init_before_training(if_main=True)

    net_class = args.net_class
    data_path = args.data_path
    img_shape = args.img_shape

    if_one_hot = args.if_one_hot
    num_worker = args.num_worker

    mid_dim = args.mid_dim
    mod_dir = args.mod_dir
    train_epochs = args.train_epochs
    batch_sizes = args.batch_sizes
    show_gap = args.show_gap
    eval_gap = args.eval_gap
    del args

    device = torch.device('cpu')
    '''load data'''
    train_imgs, train_labs, eval__imgs, eval__labs = load_data_ary(
        data_path, img_shape, if_one_hot)
    # train_imgs = torch.as_tensor(train_imgs, dtype=torch.float32, device=device)
    eval__imgs = torch.as_tensor(eval__imgs,
                                 dtype=torch.float32,
                                 device=device)

    label_data_type = torch.float32 if if_one_hot else torch.long
    # train_labs = torch.as_tensor(train_labs, dtype=label_data_type, device=device)
    eval__labs = torch.as_tensor(eval__labs,
                                 dtype=label_data_type,
                                 device=device)
    del label_data_type

    # train_len = train_imgs.shape[0]
    eval__len = eval__imgs.shape[0]
    eval_size = min(2**12, eval__len)
    del train_imgs, train_labs
    '''train model'''
    model = net_class(mid_dim, img_shape).to(device)
    model_cpu = model.to(torch.device("cpu"))  # for pipe1_eva
    [
        setattr(param, 'requires_grad', False)
        for param in model_cpu.parameters()
    ]
    del model

    from torch.nn import functional
    criterion = torch.nn.SmoothL1Loss() if if_one_hot else functional.nll_loss
    '''init evaluate'''
    evaluator = Evaluator(eval__imgs, eval__labs, eval_size, eval_gap,
                          show_gap, criterion)
    save_path = f'{mod_dir}/net.pth'

    learning_rates = [
        1e-3,
    ] * len(train_epochs) + [
        1e-4,
    ] * 2 + [
        1e-5,
    ] * 2
    train_epochs.extend(train_epochs[-1:] *
                        (len(learning_rates) - len(train_epochs)))
    batch_sizes.extend(batch_sizes[-1:] *
                       (len(learning_rates) - len(batch_sizes)))

    # pipe_eva2.send((idx, model_dict, batch_size, train_epoch, epoch, loss_avg))
    # pipe_eva2.send('break')
    pipe_receive = pipe_eva1.recv()

    print("Train Loop:")
    with torch.no_grad():
        while True:
            # pipe_eva2.send((idx, model_dict, batch_size, train_epoch, epoch, loss_avg))
            # pipe_eva2.send('break')
            while pipe_eva1.poll():
                pipe_receive = pipe_eva1.recv()
                if pipe_receive == 'break':
                    break
            if pipe_receive == 'break':
                break

            idx, model_dict, batch_size, train_epoch, epoch, loss_avg = pipe_receive
            model_cpu.load_state_dict(model_dict)

            evaluator.evaluate(model_cpu, batch_size, train_epoch, epoch,
                               loss_avg)
    evaluator.final_print()

    torch.save(model_cpu.state_dict(), save_path)
    file_size = os.path.getsize(save_path) / (2**20)  # Byte --> KB --> MB
    print(f"\nSave: {mod_dir} | {file_size:.2f} MB")

    for _ in range(num_worker):
        pipe_eva1.send('stop')
        # receive_signal = pipe_eva2.recv()
    time.sleep(2)