def main(learner):

    device = "cuda"

    ## load pretrained model ##
    global model
    model = "loc-1-extensive-extracted_output_x_x_x_x_x_resized_x_x_0820_2220_66.pth"  # set model here
    learner.load_state_dict(
        torch.load(
            '/mnt/CrowdData/dataset/resized/loc-1/loc-1-extensive-extracted_output_x_x_x_x_x_resized_x_x/log/0820_2220/models/'
            + model))
    learner.full_conn = nn.Linear(in_features=640, out_features=4, bias=True)
    learner.variance4pool = 12
    ###########################

    learner = torch.nn.DataParallel(learner, device_ids=[0,
                                                         1])  # make parallel

    train_data, test_data = get_data(learner.module.batch_size)

    global now
    now = datetime.datetime.now()
    if not os.path.exists(dataset4loc_2.dataset_directory +
                          dataset4loc_2.dataset_folder +
                          '/log/{0:%m%d}_{0:%H%M}/models'.format(now)):
        os.makedirs(dataset4loc_2.dataset_directory +
                    dataset4loc_2.dataset_folder +
                    '/log/{0:%m%d}_{0:%H%M}/models'.format(now))
        shutil.copyfile(
            "./train_loc_2.py",
            dataset4loc_2.dataset_directory + dataset4loc_2.dataset_folder +
            '/log/{0:%m%d}_{0:%H%M}/train_loc_2_{0:%m%d}_{0:%H%M}.py'.format(
                now))
        shutil.copyfile(
            "./dataset4loc_2.py",
            dataset4loc_2.dataset_directory + dataset4loc_2.dataset_folder +
            '/log/{0:%m%d}_{0:%H%M}/dataset4loc_2_{0:%m%d}_{0:%H%M}.py'.format(
                now))

    learner = learner.to(device)
    cudnn.benchmark = True

    optimizer = optim.SGD( \
                        learner.parameters(), \
                        lr = learner.module.lr, \
                        momentum = learner.module.momentum, \
                        weight_decay = learner.module.weight_decay, \
                        nesterov = True \
                        )

    loss_mse = nn.MSELoss().cuda()

    milestones = learner.module.lr_step
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=milestones,
                                               gamma=learner.module.lr_decay)

    rsl_keys = ["lr", "epoch", "TrainLoss", "TestLoss", "Time"]
    rsl = []

    print_result(rsl_keys)

    global patience, counter, best_loss, epoch
    patience = 200  #set any adequate number here
    counter = 0
    earlystopper = 0
    best_loss = None

    for epoch in range(learner.module.epochs):

        lr = optimizer.param_groups[0]["lr"]
        train_loss = train(device, optimizer, learner, train_data, loss_mse)

        learner.eval(
        )  # switch to test mode (make model not save the record of calculation)

        with torch.no_grad():
            test_loss = test(device, optimizer, learner, test_data, loss_mse)
            s = 'Test Loss: %.2f' % (test_loss)
            print(s)

            ### early stopping ###
            if best_loss is None:
                best_loss = test_loss
            elif test_loss > best_loss:
                counter += 1
                print("EarlyStopping: %i / %i" % (counter, patience))
                if counter >= patience:
                    print("EarlyStopping: Stop training")
                    earlystopper = 1
            else:
                best_loss = test_loss
                counter = 0
                torch.save(
                    learner.module.state_dict(),
                    dataset4loc_2.dataset_directory +
                    dataset4loc_2.dataset_folder +
                    '/log/{0:%m%d}_{0:%H%M}/models/'.format(now, now) +
                    os.path.basename(dataset4loc_2.dataset_folder) +
                    '_{0:%m%d}_{0:%H%M}_{1}.pth'.format(now, epoch))
            ######################

        time_now = str(datetime.datetime.today())
        rsl.append({
            k: v
            for k, v in zip(rsl_keys,
                            [lr, epoch + 1, train_loss, test_loss, time_now])
        })

        draw_graph.draw_graph_regress(
            learner.module.epochs,
            epoch,
            train_loss,
            test_loss,
            os.path.basename(dataset4loc_2.dataset_folder),
            save_place=dataset4loc_2.dataset_directory +
            dataset4loc_2.dataset_folder +
            '/log/{0:%m%d}_{0:%H%M}/'.format(now),
            ymax=300.0)

        hp_for_record = get_arguments()
        otherparams = []
        otherparams.append(hp_for_record["batch_size"])
        otherparams.append(hp_for_record["lr"])
        otherparams.append(hp_for_record["momentum"])
        otherparams.append(hp_for_record["weight_decay"])
        otherparams.append(hp_for_record["width_coef1"])
        otherparams.append(hp_for_record["width_coef2"])
        otherparams.append(hp_for_record["width_coef3"])
        otherparams.append(hp_for_record["n_blocks1"])
        otherparams.append(hp_for_record["n_blocks2"])
        otherparams.append(hp_for_record["n_blocks3"])
        otherparams.append(hp_for_record["drop_rates1"])
        otherparams.append(hp_for_record["drop_rates2"])
        otherparams.append(hp_for_record["drop_rates3"])
        otherparams.append(hp_for_record["lr_decay"])
        otherparams.append(time_now)

        save_place = dataset4loc_2.dataset_directory + dataset4loc_2.dataset_folder + '/log/{0:%m%d}_{0:%H%M}/'.format(
            now, now)
        #write_gspread.update_gspread(dataset4loc_2.dataset_folder, 'WRN', dataset4loc_2.dataset_directory, now, 'N/A(regress)', train_loss, 'N/A(regress)', test_loss, epoch+1, learner.module.epochs, False, save_place, otherparams)

        print_result(rsl[-1].values())
        scheduler.step()

        if earlystopper == 1:
            #write_gspread.update_gspread(dataset4loc_2.dataset_folder, 'WRN', dataset4loc_2.dataset_directory, now, 'N/A(regress)', train_loss, 'N/A(regress)', test_loss, epoch+1, learner.module.epochs, True, save_place, otherparams)
            break
def main(learner):

    device = "cuda"
    learner = torch.nn.DataParallel(learner, device_ids=[0,
                                                         1])  # make parallel

    train_data, test_data = get_data(learner.module.batch_size)

    global now
    now = datetime.datetime.now()
    if not os.path.exists(dataset4regress.dataset_directory +
                          dataset4regress.dataset_folder +
                          '/log/{0:%m%d}_{0:%H%M}'.format(now, now)):
        os.makedirs(dataset4regress.dataset_directory +
                    dataset4regress.dataset_folder +
                    '/log/{0:%m%d}_{0:%H%M}'.format(now, now))

    learner = learner.to(device)
    cudnn.benchmark = True

    optimizer = optim.SGD( \
                        learner.parameters(), \
                        lr = learner.module.lr, \
                        momentum = learner.module.momentum, \
                        weight_decay = learner.module.weight_decay, \
                        nesterov = True \
                        )

    #this is not RMSE
    loss_mse = nn.MSELoss().cuda()

    milestones = learner.module.lr_step
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=milestones,
                                               gamma=learner.module.lr_decay)

    rsl_keys = ["lr", "epoch", "TrainLoss", "TestLoss", "Time"]
    rsl = []

    print_result(rsl_keys)

    global patience, counter, best_loss, epoch
    patience = 20  #set any adequate number here
    counter = 0
    earlystopper = 0
    best_loss = None

    for epoch in range(learner.module.epochs):

        lr = optimizer.param_groups[0]["lr"]
        train_loss = train(device, optimizer, learner, train_data, loss_mse)

        learner.eval(
        )  # switch to test mode (make model not save the record of calculation)

        with torch.no_grad():
            test_loss = test(device, optimizer, learner, test_data, loss_mse)
            s = 'Test Loss: %.2f' % (test_loss)
            print(s)

            ### early stopping ###
            if best_loss is None:
                best_loss = test_loss
            elif test_loss > best_loss:
                counter += 1
                print("EarlyStopping: %i / %i" % (counter, patience))
                if counter >= patience:
                    print("EarlyStopping: Stop training")
                    earlystopper = 1
            else:
                best_loss = test_loss
                counter = 0
            ######################

        time_now = str(datetime.datetime.today())
        rsl.append({
            k: v
            for k, v in zip(rsl_keys,
                            [lr, epoch + 1, train_loss, test_loss, time_now])
        })

        draw_graph.draw_graph_regress(
            learner.module.epochs,
            epoch,
            train_loss,
            test_loss,
            os.path.basename(dataset4regress.dataset_folder),
            save_place=dataset4regress.dataset_directory +
            dataset4regress.dataset_folder +
            '/log/{0:%m%d}_{0:%H%M}/'.format(now, now))

        hp_for_record = get_arguments()
        otherparams = []
        otherparams.append(hp_for_record["batch_size"])
        otherparams.append(hp_for_record["lr"])
        otherparams.append(hp_for_record["momentum"])
        otherparams.append(hp_for_record["weight_decay"])
        otherparams.append(hp_for_record["width_coef1"])
        otherparams.append(hp_for_record["width_coef2"])
        otherparams.append(hp_for_record["width_coef3"])
        otherparams.append(hp_for_record["n_blocks1"])
        otherparams.append(hp_for_record["n_blocks2"])
        otherparams.append(hp_for_record["n_blocks3"])
        otherparams.append(hp_for_record["drop_rates1"])
        otherparams.append(hp_for_record["drop_rates2"])
        otherparams.append(hp_for_record["drop_rates3"])
        otherparams.append(hp_for_record["lr_decay"])
        otherparams.append(time_now)

        save_place = dataset4regress.dataset_directory + dataset4regress.dataset_folder + '/log/{0:%m%d}_{0:%H%M}/'.format(
            now, now)
        #write_gspread.update_gspread(dataset4regress.dataset_folder, 'WRN', dataset4regress.dataset_directory, now, 'N/A(regress)', train_loss, 'N/A(regress)', test_loss, epoch+1, learner.module.epochs, False, save_place, otherparams)

        print_result(rsl[-1].values())
        scheduler.step()

        torch.save(
            learner.module.state_dict(), dataset4regress.dataset_directory +
            dataset4regress.dataset_folder +
            '/log/{0:%m%d}_{0:%H%M}/'.format(now, now) +
            os.path.basename(dataset4regress.dataset_folder) +
            '_{0:%m%d}_{0:%H%M}.pth'.format(now, now))

        if earlystopper == 1:
            #write_gspread.update_gspread(dataset4regress.dataset_folder, 'WRN', dataset4regress.dataset_directory, now, 'N/A(regress)', train_loss, 'N/A(regress)', test_loss, epoch+1, learner.module.epochs, True, save_place, otherparams)
            break