def trainNet(model, train_loader, val_loader, device, mask=None):

    print("=" * 30)

    # define the optimizer & learning rate
    # optim = torch.optim.SGD(
    #     model.parameters(), lr=config["learning_rate"], weight_decay=0.0001, momentum=0.9, nesterov=True
    # )
    optim = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"])

    scheduler = StepLR(optim, step_size=config["lr_step_size"], gamma=config["lr_gamma"])

    log_dir = "../runs/" + networkName + str(int(datetime.now().timestamp()))
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    writer = Visualizer(log_dir)

    # Time for printing
    training_start_time = time.time()
    globaliter = 0
    scheduler_count = 0
    scaler = GradScaler()

    # initialize the early_stopping object
    early_stopping = EarlyStopping(log_dir, patience=config["patience"], verbose=True)

    # Loop for n_epochs
    for epoch in range(config["num_epochs"]):
        writer.write_lr(optim, globaliter)

        # train for one epoch
        globaliter = train(model, train_loader, optim, device, writer, epoch, globaliter, scaler, mask)

        # At the end of the epoch, do a pass on the validation set
        val_loss = validate(model, val_loader, device, writer, epoch, mask)

        # early_stopping needs the validation loss to check if it has decresed,
        # and if it has, it will make a checkpoint of the current model
        early_stopping(val_loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            if scheduler_count == 2:
                break
            model.load_state_dict(torch.load(log_dir + "/checkpoint.pt"))
            early_stopping.early_stop = False
            early_stopping.counter = 0
            scheduler.step()
            scheduler_count += 1
        if config["debug"] == True:
            break
    print("Training finished, took {:.2f}s".format(time.time() - training_start_time))

    model.load_state_dict(torch.load(log_dir + "/checkpoint.pt"))

    # remember to close the writer
    writer.close()
    def train(self, num_epochs, training_path, patience, metric_kind="Mcc"):
        #Inits
        self.writer = SummaryWriter(training_path)
        early_stopping = EarlyStopping(patience=patience,
                                       metric_kind=metric_kind)
        #Training /Eval Loop
        for epoch in tqdm(range(num_epochs)):
            #Init running metrics(Train/Test)
            metrics_train = self.init_metrics()
            metrics_eval = self.init_metrics()

            for i, batch in enumerate(self.train_loader):
                metrics_train = self.train_batch(batch, metrics_train)
                metrics_eval = self.evaluate_set(metrics_eval)
                self.step_lr_schedulers()

            self.log_metrics(metrics_train, metrics_eval, epoch, i)

            validation_info = metrics_train["total_loss"] / i
            #Check if we are overfitting
            early_stopping(metrics_eval[metric_kind] / i, self.models,
                           training_path)

            if early_stopping.early_stop:
                print("Early stopping")
                break

        self.writer.close()
Exemple #3
0
def main():
    is_training = True
    model = MODEL_DISPATCHER[BASE_MODEL](is_training)
    model.to(DEVICE)
    EarlyStoppingObject = EarlyStopping()

    Training_Dataset = BengaliAiDataset(
        folds = TRAINING_FOLDS, \
        img_height= IMG_HEIGHT, \
        img_width= IMG_WIDTH, \
        mean = MODEL_MEAN,\
        std = MODEL_STD)

    Train_DataLoader = torch.utils.data.DataLoader(dataset=Training_Dataset,
                                                   batch_size=TRAIN_BATCH_SIZE,
                                                   shuffle=True,
                                                   num_workers=4)

    Validation_Dataset = BengaliAiDataset(
        folds = VALIDATION_FOLDS, \
        img_height= IMG_HEIGHT, \
        img_width= IMG_WIDTH, \
        mean = MODEL_MEAN,\
        std = MODEL_STD)

    Validation_DataLoader = torch.utils.data.DataLoader(
        dataset=Validation_Dataset,
        batch_size=TEST_BATCH_SIZE,
        shuffle=False,
        num_workers=4)

    optimiser = torch.optim.Adam(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, mode = "min", \
                                                            patience = 5, factor = 0.3, \
                                                                verbose = True)

    for epoch in enumerate(EPOCHS):
        train(Training_Dataset, Train_DataLoader, model, optimiser)
        validationScore = evaluate(Validation_Dataset, Validation_DataLoader,
                                   model, optimiser)
        scheduler.step(validationScore)
        print(f"EPOCH : {epoch} VALIDATION SCORE : {validationScore}")
        # torch.save(model.state_dict(), f"../input/output_models/{BASE_MODEL}_fold{VALIDATION_FOLDS[0]}.bin")
        EarlyStoppingObject(
            validationScore, model,
            f"../input/output_models/{BASE_MODEL}_fold{VALIDATION_FOLDS[0]}.bin"
        )
Exemple #4
0
def trainModel(train_loader,
               validation_loader,
               params,
               model,
               savedModelName,
               test_loader=None,
               device=None,
               detailed_reporting=False):
    n_epochs = 500
    patience = 5
    learning_rate = params["learning_rate"]
    modelType = params["modelType"]
    lambda_coarse = params["lambda"]
    lambda_fine = 1
    weight_decay = 0.0001

    df = pd.DataFrame()

    if device is None:
        print("training model on CPU!")
    if not os.path.exists(savedModelName):
        os.makedirs(savedModelName)
    saved_models_per_iteration = os.path.join(
        savedModelName, saved_models_per_iteration_folder)
    if not os.path.exists(saved_models_per_iteration):
        os.makedirs(saved_models_per_iteration)

    # Adaptive smoothing
    adaptive_smoothing_enabled = params["adaptive_smoothing"]
    adaptive_lambda = None if not adaptive_smoothing_enabled else params[
        "adaptive_lambda"]
    adaptive_alpha = None if not adaptive_smoothing_enabled else params[
        "adaptive_alpha"]
    if adaptive_smoothing_enabled and detailed_reporting:
        df_adaptive_smoothing = pd.DataFrame()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    if device is not None:
        criterion = criterion.cuda()

    # early stopping
    early_stopping = EarlyStopping(path=savedModelName, patience=patience)

    print("Training started...")
    start = time.time()

    with tqdm(total=n_epochs, desc="iteration") as bar:
        epochs = 0
        for epoch in range(n_epochs):

            model.train()
            for i, batch in enumerate(train_loader):
                absolute_batch = i + epoch * len(train_loader)

                if device is not None:
                    batch["image"] = batch["image"].cuda()
                    batch["fine"] = batch["fine"].cuda()
                    batch["coarse"] = batch["coarse"].cuda()

                optimizer.zero_grad()
                with torch.set_grad_enabled(True):
                    z = applyModel(batch["image"], model)

                    loss_coarse = 0
                    if z["coarse"] is not None:
                        loss_coarse = criterion(z["coarse"], batch["coarse"])
                    loss_fine = criterion(z["fine"], batch["fine"])

                    adaptive_info = {
                        'batch':
                        absolute_batch,
                        'epoch':
                        epoch,
                        'loss_fine':
                        loss_fine.item()
                        if torch.is_tensor(loss_fine) else loss_fine,
                        'loss_coarse':
                        loss_coarse.item()
                        if torch.is_tensor(loss_coarse) else loss_coarse,
                        'lambda_fine':
                        lambda_fine,
                        'lambda_coarse':
                        lambda_coarse,
                    }
                    if adaptive_smoothing_enabled and z["coarse"] is not None:
                        # lambda_fine, lambda_coarse = get_total_adaptive_loss(adaptive_alpha, adaptive_lambda, loss_fine, loss_coarse)
                        lambda_fine, lambda_coarse = get_total_adaptive_loss(
                            adaptive_alpha, adaptive_lambda,
                            loss_fine / len(train_loader.dataset.csv_processor.
                                            getFineList()),
                            loss_coarse / len(train_loader.dataset.
                                              csv_processor.getCoarseList()))
                        adaptive_info['lambda_fine'] = lambda_fine
                        adaptive_info['lambda_coarse'] = lambda_coarse
                        if detailed_reporting:
                            df_adaptive_smoothing = df_adaptive_smoothing.append(
                                pd.DataFrame(adaptive_info, index=[0]),
                                ignore_index=True)

                    loss = lambda_fine * loss_fine + lambda_coarse * loss_coarse
                    loss.backward()

                    optimizer.step()

            model.eval()

            getCoarse = (modelType != "BB")

            # Get statistics
            predlist_val, lbllist_val = getLoaderPredictionProbabilities(
                validation_loader, model, params, device=device)
            validation_loss = getCrossEntropy(predlist_val, lbllist_val)
            predlist_val, lbllist_val = getPredictions(predlist_val,
                                                       lbllist_val)
            validation_fine_f1 = get_f1(predlist_val,
                                        lbllist_val,
                                        device=device)

            if getCoarse:
                predlist_val, lbllist_val = getLoaderPredictionProbabilities(
                    validation_loader, model, params, 'coarse', device=device)
                validation_coarse_loss = getCrossEntropy(
                    predlist_val, lbllist_val)
                predlist_val, lbllist_val = getPredictions(
                    predlist_val, lbllist_val)
                validation_coarse_f1 = get_f1(predlist_val,
                                              lbllist_val,
                                              device=device)

            predlist_train, lbllist_train = getLoaderPredictionProbabilities(
                train_loader, model, params, device=device)
            training_loss = getCrossEntropy(predlist_train, lbllist_train)
            predlist_train, lbllist_train = getPredictions(
                predlist_train, lbllist_train)
            train_fine_f1 = get_f1(predlist_train,
                                   lbllist_train,
                                   device=device)

            if detailed_reporting:
                if getCoarse:
                    predlist_train, lbllist_train = getLoaderPredictionProbabilities(
                        train_loader, model, params, 'coarse', device=device)
                    training_coarse_loss = getCrossEntropy(
                        predlist_train, lbllist_train)
                    predlist_train, lbllist_train = getPredictions(
                        predlist_train, lbllist_train)
                    training_coarse_f1 = get_f1(predlist_train,
                                                lbllist_train,
                                                device=device)

            if test_loader and detailed_reporting:
                predlist_test, lbllist_test = getLoaderPredictionProbabilities(
                    test_loader, model, params, device=device)
                predlist_test, lbllist_test = getPredictions(
                    predlist_test, lbllist_test)
                test_fine_f1 = get_f1(predlist_test,
                                      lbllist_test,
                                      device=device)

                predlist_test, lbllist_test = getLoaderPredictionProbabilities(
                    test_loader, model, params, 'coarse', device=device)
                predlist_test, lbllist_test = getPredictions(
                    predlist_test, lbllist_test)
                test_coarse_f1 = get_f1(predlist_test,
                                        lbllist_test,
                                        device=device)

            row_information = {
                'epoch':
                epoch,
                'validation_fine_f1':
                validation_fine_f1,
                'training_fine_f1':
                train_fine_f1,
                'test_fine_f1':
                test_fine_f1 if test_loader and detailed_reporting else None,
                'validation_loss':
                validation_loss,
                'training_loss':
                training_loss if detailed_reporting else None,
                'training_coarse_loss':
                training_coarse_loss
                if getCoarse and detailed_reporting else None,
                'validation_coarse_loss':
                validation_coarse_loss
                if getCoarse and detailed_reporting else None,
                'training_coarse_f1':
                training_coarse_f1
                if getCoarse and detailed_reporting else None,
                'validation_coarse_f1':
                validation_coarse_f1
                if getCoarse and detailed_reporting else None,
                'test_coarse_f1':
                test_coarse_f1 if test_loader and detailed_reporting else None,
            }

            df = df.append(pd.DataFrame(row_information, index=[0]),
                           ignore_index=True)

            # Update the bar
            bar.set_postfix(val=row_information["validation_fine_f1"],
                            train=row_information["training_fine_f1"],
                            val_loss=row_information["validation_loss"],
                            min_val_loss=early_stopping.val_loss_min)
            bar.update()

            # Save model
            if detailed_reporting and (
                    epochs % saved_models_per_iteration_frequency == 0):
                model_name_path = os.path.join(
                    savedModelName, saved_models_per_iteration_folder,
                    saved_models_per_iteration_name).format(epochs)
                try:
                    torch.save(model.state_dict(), model_name_path)
                except:
                    print("model", model_name_path, "could not be saved!")
                    pass

            # early stopping
            early_stopping(1 / row_information['validation_fine_f1'], epoch,
                           model)

            epochs = epochs + 1
            if early_stopping.early_stop:
                print("Early stopping")
                print("total number of epochs: ", epoch)

                # save the final model if it has not been saved already
                if detailed_reporting:
                    torch.save(
                        model.state_dict(),
                        os.path.join(savedModelName,
                                     saved_models_per_iteration_folder,
                                     modelFinalCheckpoint))

                break

        # Register time
        end = time.time()
        time_elapsed = end - start

        # load the last checkpoint with the best model
        model.load_state_dict(early_stopping.getBestModel())

        # save information
        if savedModelName is not None:
            # save model
            torch.save(model.state_dict(),
                       os.path.join(savedModelName, modelFinalCheckpoint))
            # save results
            df.to_csv(os.path.join(savedModelName, statsFileName))

            if adaptive_smoothing_enabled and detailed_reporting:
                df_adaptive_smoothing.to_csv(
                    os.path.join(savedModelName, adaptiveSmoothingFileName))

            with open(os.path.join(savedModelName, timeFileName),
                      'w',
                      newline='') as myfile:
                wr = csv.writer(myfile)
                wr.writerow([time_elapsed])
            with open(os.path.join(savedModelName, epochsFileName),
                      'w',
                      newline='') as myfile:
                wr = csv.writer(myfile)
                wr.writerow([epochs])
            # save params
            j = json.dumps(params)
            f = open(os.path.join(savedModelName, paramsFileName), "w")
            f.write(j)
            f.close()

    return df, epochs, time_elapsed
Exemple #5
0
def train(exp=None):
    """
    main function to run the training
    """
    encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
    decoder = Decoder(decoder_params[0], decoder_params[1]).cuda()
    net = ED(encoder, decoder)
    run_dir = "./runs/" + TIMESTAMP
    if not os.path.isdir(run_dir):
        os.makedirs(run_dir)
    # tb = SummaryWriter(run_dir)
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=20, verbose=True)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    if os.path.exists(os.path.join(save_dir, "checkpoint.pth.tar")):
        # load existing model
        print("==> loading existing model")
        model_info = torch.load(os.path.join(save_dir, "checkpoin.pth.tar"))
        net.load_state_dict(model_info["state_dict"])
        optimizer = torch.optim.Adam(net.parameters())
        optimizer.load_state_dict(model_info["optimizer"])
        cur_epoch = model_info["epoch"] + 1
    else:
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        cur_epoch = 0
    lossfunction = nn.MSELoss().cuda()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=4,
                                                      verbose=True)

    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    # mini_val_loss = np.inf
    for epoch in range(cur_epoch, args.epochs + 1):
        if exp is not None:
            exp.log_metric("epoch", epoch)
        ###################
        # train the model #
        ###################
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            inputs = inputVar.to(device)  # B,S,C,H,W
            label = targetVar.to(device)  # B,S,C,H,W
            optimizer.zero_grad()
            net.train()
            pred = net(inputs)  # B,S,C,H,W
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / args.batch_size
            train_losses.append(loss_aver)
            loss.backward()
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0)
            optimizer.step()
            t.set_postfix({
                "trainloss": "{:.6f}".format(loss_aver),
                "epoch": "{:02d}".format(epoch),
            })
        # tb.add_scalar('TrainLoss', loss_aver, epoch)
        ######################
        # validate the model #
        ######################
        with torch.no_grad():
            net.eval()
            t = tqdm(validLoader, leave=False, total=len(validLoader))
            for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
                if i == 3000:
                    break
                inputs = inputVar.to(device)
                label = targetVar.to(device)
                pred = net(inputs)
                loss = lossfunction(pred, label)
                loss_aver = loss.item() / args.batch_size
                # record validation loss
                valid_losses.append(loss_aver)
                # print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
                t.set_postfix({
                    "validloss": "{:.6f}".format(loss_aver),
                    "epoch": "{:02d}".format(epoch),
                })
        # tb.add_scalar('ValidLoss', loss_aver, epoch)
        torch.cuda.empty_cache()
        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = len(str(args.epochs))

        print_msg = (f"[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] " +
                     f"train_loss: {train_loss:.6f} " +
                     f"valid_loss: {valid_loss:.6f}")

        # print(print_msg)
        # clear lists to track next epoch
        if exp is not None:
            exp.log_metric("TrainLoss", train_loss)
            exp.log_metric("ValidLoss", valid_loss)
        train_losses = []
        valid_losses = []
        pla_lr_scheduler.step(valid_loss)  # lr_scheduler
        model_dict = {
            "epoch": epoch,
            "state_dict": net.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        if epoch % args.save_every == 0 and epoch != 0:
            torch.save(
                model_dict,
                save_dir + "/" + "checkpoint_{}_{:.6f}.pth.tar".format(
                    epoch, valid_loss.item()),
            )
        early_stopping(valid_loss.item(), model_dict, epoch, save_dir)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    with open("avg_train_losses.txt", "wt") as f:
        for i in avg_train_losses:
            print(i, file=f)

    with open("avg_valid_losses.txt", "wt") as f:
        for i in avg_valid_losses:
            print(i, file=f)
def train(X_train, y_train, X_dev, y_dev, model, embedding, args):
    if args.cuda:
        model.cuda()
    model.cuda()
    if args.Adam is True:
        print("Adam Training......")
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.init_weight_decay)
    elif args.SGD is True:
        print("SGD Training.......")
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.init_weight_decay,
                                    momentum=args.momentum_value)
    elif args.Adadelta is True:
        print("Adadelta Training.......")
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         lr=args.lr,
                                         weight_decay=args.init_weight_decay)

    criterion = nn.NLLLoss()

    steps = 0
    epoch_step = 0
    model_count = 0
    loss_full = []
    best_accuracy = Best_Result()
    model.train()
    print("###training model.... :", model)
    early_stopping = EarlyStopping(patience=args.patience, verbose=True)
    for epoch in range(1, args.epochs + 1):
        steps = 0
        print("\n## The {} Epoch, All {} Epochs ! ##".format(
            epoch, args.epochs))
        loss_epoch = []

        g = gen_minibatch(X_train, y_train, args.batch_size, shuffle=True)
        for tokens, labels in g:
            tokens = embedding(tokens.long())

            optimizer.zero_grad()

            logit = model(tokens.cuda())
            loss = criterion(logit, labels)
            loss.backward()
            #if args.init_clip_max_norm is not None:
            #    utils.clip_grad_norm_(model.parameters(), max_norm=args.init_clip_max_norm)
            optimizer.step()

            loss_full.append(loss.item())
            loss_epoch.append(loss.item())

        torch.cuda.empty_cache()
        print('Average training loss at this epoch..minibatch ',
              np.mean(loss_epoch))
        model.eval()
        val_loss = []
        g = gen_minibatch(X_dev, y_dev, 4, shuffle=False)
        for tokens, labels in g:
            tokens = embedding(tokens.long())

            optimizer.zero_grad()

            logit = model(tokens.cuda())
            loss = criterion(logit, labels)
            #print(loss)
            val_loss.append(loss.data.cpu())

    # print("val_loss#:", val_loss)
        vlos = np.mean(val_loss)

        print('dev Loss at ', epoch, ' is ', vlos)

        torch.cuda.empty_cache()
        early_stopping(vlos, model)
        if early_stopping.early_stop:
            print("Early stopping")
            model_count += 1
            break

        model.train()

    return model_count
def train():
    '''
    main function to run the training
    '''
    # 实例化Encoder和Decoder
    encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
    decoder = Decoder(decoder_params[0], decoder_params[1]).cuda()

    # 实例化ED
    net = ED(encoder, decoder)
    # 运行目录
    run_dir = './runs/' + TIMESTAMP
    # 如果 run_dir 不存在则创建目录
    if not os.path.isdir(run_dir):
        os.makedirs(run_dir)
    tb = SummaryWriter(run_dir)
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=20, verbose=True)

    # 判断CUDA是否可用
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # 如果GPU大于一块,使用并行计算
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    # 判断checkpoint.pth.tar文件是否存在
    if os.path.exists(os.path.join(save_dir, 'checkpoint.pth.tar')):
        # 加载保存的模型
        print('==> loading existing model')
        model_info = torch.load(os.path.join(save_dir, 'checkpoin.pth.tar'))
        net.load_state_dict(model_info['state_dict'])
        optimizer = torch.optim.Adam(net.parameters())
        optimizer.load_state_dict(model_info['optimizer'])
        cur_epoch = model_info['epoch'] + 1
    else:
        # 如果checkpoint.pth.tar不存在则判断save_dir是否存在,不存在则创建
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        # 将当前epoch初始化为0
        cur_epoch = 0
    # 损失函数使用MSELoss
    lossfunction = nn.MSELoss().cuda()
    # 优化器使用Adam
    optimizer = optim.Adam(net.parameters(), lr=args.lr)

    pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=4,
                                                      verbose=True)

    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    # mini_val_loss = np.inf
    for epoch in range(cur_epoch, args.epochs + 1):
        ###################
        # train the model #
        ###################
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            inputs = inputVar.to(device)  # B,S,C,H,W
            label = targetVar.to(device)  # B,S,C,H,W
            optimizer.zero_grad()
            net.train()
            pred = net(inputs)  # B,S,C,H,W
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / args.batch_size
            train_losses.append(loss_aver)
            loss.backward()
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0)
            optimizer.step()
            t.set_postfix({
                'trainloss': '{:.6f}'.format(loss_aver),
                'epoch': '{:02d}'.format(epoch)
            })
        tb.add_scalar('TrainLoss', loss_aver, epoch)
        ######################
        # validate the model #
        ######################
        with torch.no_grad():
            net.eval()
            t = tqdm(validLoader, leave=False, total=len(validLoader))
            for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
                if i == 3000:
                    break
                inputs = inputVar.to(device)
                label = targetVar.to(device)
                pred = net(inputs)
                loss = lossfunction(pred, label)
                loss_aver = loss.item() / args.batch_size
                # record validation loss
                valid_losses.append(loss_aver)
                #print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
                t.set_postfix({
                    'validloss': '{:.6f}'.format(loss_aver),
                    'epoch': '{:02d}'.format(epoch)
                })

        tb.add_scalar('ValidLoss', loss_aver, epoch)
        torch.cuda.empty_cache()
        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = len(str(args.epochs))

        print_msg = (f'[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.6f} ' +
                     f'valid_loss: {valid_loss:.6f}')

        print(print_msg)
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        pla_lr_scheduler.step(valid_loss)  # lr_scheduler
        model_dict = {
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        early_stopping(valid_loss.item(), model_dict, epoch, save_dir)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    with open("avg_train_losses.txt", 'wt') as f:
        for i in avg_train_losses:
            print(i, file=f)

    with open("avg_valid_losses.txt", 'wt') as f:
        for i in avg_valid_losses:
            print(i, file=f)
def train_one_folder(opt, folder):
    # Use specific GPU
    device = torch.device(opt.gpu_num)

    opt.folder = folder

    # Dataloaders
    train_dataset_file_path = os.path.join('../dataset', opt.source_domain,
                                           str(opt.folder), 'train.csv')
    train_loader = get_dataloader(train_dataset_file_path, 'train', opt)

    test_dataset_file_path = os.path.join('../dataset', opt.source_domain,
                                          str(opt.folder), 'test.csv')
    test_loader = get_dataloader(test_dataset_file_path, 'test', opt)

    # Model, optimizer and loss function
    emotion_recognizer = models.Model(opt)
    models.init_weights(emotion_recognizer)
    for param in emotion_recognizer.parameters():
        param.requires_grad = True
    emotion_recognizer.to(device)

    optimizer = torch.optim.Adam(emotion_recognizer.parameters(),
                                 lr=opt.learning_rate)
    lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                             patience=1)

    criterion = torch.nn.CrossEntropyLoss()

    best_acc = 0.
    best_uar = 0.
    es = EarlyStopping(patience=opt.patience)

    # Train and validate
    for epoch in range(opt.epochs_num):
        if opt.verbose:
            print('epoch: {}/{}'.format(epoch + 1, opt.epochs_num))

        train_loss, train_acc = train(train_loader, emotion_recognizer,
                                      optimizer, criterion, device, opt)
        test_loss, test_acc, test_uar = test(test_loader, emotion_recognizer,
                                             criterion, device, opt)

        if opt.verbose:
            print('train_loss: {0:.5f}'.format(train_loss),
                  'train_acc: {0:.3f}'.format(train_acc),
                  'test_loss: {0:.5f}'.format(test_loss),
                  'test_acc: {0:.3f}'.format(test_acc),
                  'test_uar: {0:.3f}'.format(test_uar))

        lr_schedule.step(test_loss)

        os.makedirs(os.path.join(opt.logger_path, opt.source_domain),
                    exist_ok=True)

        model_file_name = os.path.join(opt.logger_path, opt.source_domain,
                                       'checkpoint.pth.tar')
        state = {
            'epoch': epoch + 1,
            'emotion_recognizer': emotion_recognizer.state_dict(),
            'opt': opt
        }
        torch.save(state, model_file_name)

        if test_acc > best_acc:
            model_file_name = os.path.join(opt.logger_path, opt.source_domain,
                                           'model.pth.tar')
            torch.save(state, model_file_name)

            best_acc = test_acc

        if test_uar > best_uar:
            best_uar = test_uar

        if es.step(test_loss):
            break

    return best_acc, best_uar
Exemple #9
0
def train(exp=None):
    """
    main function to run the training
    """
    encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
    decoder = Decoder(decoder_params[0], decoder_params[1]).cuda()
    net = ED(encoder, decoder)
    run_dir = "./runs/" + TIMESTAMP
    if not os.path.isdir(run_dir):
        os.makedirs(run_dir)
    # tb = SummaryWriter(run_dir)
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=20, verbose=True)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    if os.path.exists(args.checkpoint) and args.continue_train:
        # load existing model
        print("==> loading existing model")
        model_info = torch.load(args.checkpoint)
        net.load_state_dict(model_info["state_dict"])
        optimizer = torch.optim.Adam(net.parameters())
        optimizer.load_state_dict(model_info["optimizer"])
        cur_epoch = model_info["epoch"] + 1
    else:
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        cur_epoch = 0
    lossfunction = nn.MSELoss().cuda()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=4,
                                                      verbose=True)

    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    # pnsr ssim
    avg_psnrs = {}
    avg_ssims = {}
    for j in range(args.frames_output):
        avg_psnrs[j] = []
        avg_ssims[j] = []
    if args.checkdata:
        # Checking dataloader
        print("Checking Dataloader!")
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            assert targetVar.shape == torch.Size([
                args.batchsize, args.frames_output, 1, args.data_h, args.data_w
            ])
            assert inputVar.shape == torch.Size([
                args.batchsize, args.frames_input, 1, args.data_h, args.data_w
            ])
        print("TrainLoader checking is complete!")
        t = tqdm(validLoader, leave=False, total=len(validLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            assert targetVar.shape == torch.Size([
                args.batchsize, args.frames_output, 1, args.data_h, args.data_w
            ])
            assert inputVar.shape == torch.Size([
                args.batchsize, args.frames_input, 1, args.data_h, args.data_w
            ])
        print("ValidLoader checking is complete!")
        # mini_val_loss = np.inf
    for epoch in range(cur_epoch, args.epochs + 1):
        # to track the training loss as the model trains
        train_losses = []
        # to track the validation loss as the model trains
        valid_losses = []
        psnr_dict = {}
        ssim_dict = {}
        for j in range(args.frames_output):
            psnr_dict[j] = 0
            ssim_dict[j] = 0
        image_log = []
        if exp is not None:
            exp.log_metric("epoch", epoch)
        ###################
        # train the model #
        ###################
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            inputs = inputVar.to(device)  # B,S,C,H,W
            label = targetVar.to(device)  # B,S,C,H,W
            optimizer.zero_grad()
            net.train()
            pred = net(inputs)  # B,S,C,H,W
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / args.batchsize
            train_losses.append(loss_aver)
            loss.backward()
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0)
            optimizer.step()
            t.set_postfix({
                "trainloss": "{:.6f}".format(loss_aver),
                "epoch": "{:02d}".format(epoch),
            })
        # tb.add_scalar('TrainLoss', loss_aver, epoch)
        ######################
        # validate the model #
        ######################
        with torch.no_grad():
            net.eval()
            t = tqdm(validLoader, leave=False, total=len(validLoader))
            for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
                inputs = inputVar.to(device)
                label = targetVar.to(device)
                pred = net(inputs)
                loss = lossfunction(pred, label)
                loss_aver = loss.item() / args.batchsize
                # record validation loss
                valid_losses.append(loss_aver)

                for j in range(args.frames_output):
                    psnr_dict[j] += psnr(pred[:, j], label[:, j])
                    ssim_dict[j] += ssim(pred[:, j], label[:, j])
                # print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
                t.set_postfix({
                    "validloss": "{:.6f}".format(loss_aver),
                    "epoch": "{:02d}".format(epoch),
                })
                if i % 500 == 499:
                    for k in range(args.frames_output):
                        image_log.append(label[0, k].unsqueeze(0).repeat(
                            1, 3, 1, 1))
                        image_log.append(pred[0, k].unsqueeze(0).repeat(
                            1, 3, 1, 1))
                    upload_images(
                        image_log,
                        epoch,
                        exp=exp,
                        im_per_row=2,
                        rows_per_log=int(len(image_log) / 2),
                    )
        # tb.add_scalar('ValidLoss', loss_aver, epoch)
        torch.cuda.empty_cache()
        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        for j in range(args.frames_output):
            avg_psnrs[j].append(psnr_dict[j] / i)
            avg_ssims[j].append(ssim_dict[j] / i)
        epoch_len = len(str(args.epochs))

        print_msg = (f"[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] " +
                     f"train_loss: {train_loss:.6f} " +
                     f"valid_loss: {valid_loss:.6f}" +
                     f"PSNR_1: {psnr_dict[0] / i:.6f}" +
                     f"SSIM_1: {ssim_dict[0] / i:.6f}")

        # print(print_msg)
        # clear lists to track next epoch
        if exp is not None:
            exp.log_metric("TrainLoss", train_loss)
            exp.log_metric("ValidLoss", valid_loss)
            exp.log_metric("PSNR_1", psnr_dict[0] / i)
            exp.log_metric("SSIM_1", ssim_dict[0] / i)
        pla_lr_scheduler.step(valid_loss)  # lr_scheduler
        model_dict = {
            "epoch": epoch,
            "state_dict": net.state_dict(),
            "optimizer": optimizer.state_dict(),
            "avg_psnrs": avg_psnrs,
            "avg_ssims": avg_ssims,
            "avg_valid_losses": avg_valid_losses,
            "avg_train_losses": avg_train_losses,
        }
        save_flag = False
        if epoch % args.save_every == 0:
            torch.save(
                model_dict,
                save_dir + "/" +
                "checkpoint_{}_{:.6f}.pth".format(epoch, valid_loss.item()),
            )
            print("Saved" +
                  "checkpoint_{}_{:.6f}.pth".format(epoch, valid_loss.item()))
            save_flag = True
        if avg_psnrs[0][-1] == max(avg_psnrs[0]) and not save_flag:
            torch.save(
                model_dict,
                save_dir + "/" + "bestpsnr_1.pth",
            )
            print("Best psnr found and saved")
            save_flag = True
        if avg_ssims[0][-1] == max(avg_ssims[0]) and not save_flag:
            torch.save(
                model_dict,
                save_dir + "/" + "bestssim_1.pth",
            )
            print("Best ssim found and saved")
            save_flag = True
        if avg_valid_losses[-1] == min(avg_valid_losses) and not save_flag:
            torch.save(
                model_dict,
                save_dir + "/" + "bestvalidloss.pth",
            )
            print("Best validloss found and saved")
            save_flag = True
        if not save_flag:
            torch.save(
                model_dict,
                save_dir + "/" + "checkpoint.pth",
            )
            print("The latest normal checkpoint saved")
        early_stopping(valid_loss.item(), model_dict, epoch, save_dir)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    with open("avg_train_losses.txt", "wt") as f:
        for i in avg_train_losses:
            print(i, file=f)

    with open("avg_valid_losses.txt", "wt") as f:
        for i in avg_valid_losses:
            print(i, file=f)
    def train(self,
              save_model=False,
              save_loss=False,
              early_stopping=True,
              shuffle=False):
        """
        This function performs the training of our ConvNets. It compiles the
        theano functions and performs parameters updates
        (by calling compile_functions), saves several useful
        information during training and stops using early stopping where also
        the model parameters are saved. All the basic components are described
        below as well as their respective modules/functions:
            1) functions compilation: Training.compile_functions(module:
                trainingtesting). Here you can also find optimization details
                such as regularization term in the loss for autoencoder
            2) load/save weights, early stopping: SaveWeights,
            LoadWeights(module: saveloadweights)
            3) networks definitions: module: networks.py. Here you can find
            details related with network design choices as well as
            regularization layers(e.g. dropout) or other techniques such as
            tied weights in the autoencoder.
        """
        dataset = os.path.join(self._datasets_dir, self._dataset)
        dataset += '.hdf5'
        dset = h5py.File(dataset, 'r')

        fn_train, fn_val, net, lr, lr_decay = self._compile_functions()

        if type(save_model) is not bool:
            raise TypeError('save_model should be boolean')
        if save_model:
            models_dir = './models'
            if not os.path.exists(models_dir):
                os.mkdir(models_dir)

            if self._network_type == self.SIMPLE:
                if self._input_channels == 1:
                    input_type = 'depth'
                elif self._input_channels == 4:
                    input_type = 'rgb'
                save_dir = '{0:s}/{1:s}/{2:s}/{3:f}'.format(
                    self._dataset, self._network_type, input_type,
                    self.convnet._model_hp_dict['p'])
                sw = SaveWeights(os.path.join(models_dir, save_dir), net,
                                 self._patience, 'loss')
            elif self._network_type == self.CONV_FUSING:
                save_dir = '{0:s}/{1:s}/{2:s}/{3:d}/{4:f}'.format(
                    self._dataset, self._network_type, self._fusion_type,
                    self._fusion_level, self.convnet._model_hp_dict['p'])
                sw = SaveWeights(os.path.join(models_dir, save_dir), net,
                                 self._patience, 'loss')
            elif self._network_type == self.DENSE_FUSING:
                save_dir = '{0:s}/{1:s}/{2:s}/{3:d}/{4:f}'.format(
                    self._dataset, self._network_type, self._fusion_type,
                    self._fusion_level, self.convnet._model_hp_dict['p'])
                sw = SaveWeights(os.path.join(models_dir, save_dir), net,
                                 self._patience, 'loss')
            elif self._network_type == self.SCORE_FUSING:
                save_dir = '{0:s}/{1:s}/{2:s}/{3:f}'.format(
                    self._dataset, self._network_type, self._fusion_type,
                    self.convnet._model_hp_dict['p'])
                sw = SaveWeights(os.path.join(models_dir, save_dir), net,
                                 self._patience, 'loss')
            elif self._network_type == self.INPUT_FUSING:
                save_dir = '{0:s}/{1:s}/{2:s}/{3:f}'.format(
                    self._dataset, self._network_type, self._fusion_type,
                    self.convnet._model_hp_dict['p'])
                sw = SaveWeights(os.path.join(models_dir, save_dir), net,
                                 self._patience, 'loss')
        elif early_stopping:
            es = EarlyStopping(net, self._patience, 'loss')
        if self._validate:
            idx_train, idx_val = load_dsets_trainval(
                './train_test_splits/nyu_split.npz')
            bg_train = BatchGenerator(dset,
                                      self._dataset,
                                      self._group,
                                      iterable=idx_train,
                                      shuffle=shuffle)
            bg_val = BatchGenerator(dset,
                                    self._dataset,
                                    self._group,
                                    iterable=idx_val,
                                    shuffle=shuffle)
        else:
            bg_train = BatchGenerator(dset,
                                      self._dataset,
                                      self._group,
                                      shuffle=shuffle)
        print 'Training started...\n'
        if save_model:
            training_information = self._training_loop(bg_train,
                                                       bg_val,
                                                       fn_train,
                                                       fn_val,
                                                       lr,
                                                       lr_decay,
                                                       sw=sw)
        elif early_stopping:
            training_information = self._training_loop(bg_train,
                                                       bg_val,
                                                       fn_train,
                                                       fn_val,
                                                       lr,
                                                       lr_decay,
                                                       es=es)
        else:
            training_information = self._training_loop(bg_train, bg_val,
                                                       fn_train, fn_val, lr,
                                                       lr_decay)
        if self._save_settings:
            settings_dir = './settings'
            if not os.path.exists(settings_dir):
                os.mkdir(settings_dir)
            val_loss_array = np.array(training_information['val_loss'])
            best_loss = np.amin(val_loss_array)
            if self._network_type == self.SIMPLE:
                if self._input_channels == 1:
                    input_type = 'depth'
                elif self._input_channels == 4:
                    input_type = 'rgb'
                save_dir = '{0:s}/{1:s}/{2:s}/{3:f}'.format(
                    self._dataset, self._network_type, input_type,
                    self.convnet._model_hp_dict['p'])
                save_hyperparams(os.path.join(settings_dir,
                                              save_dir), self._opt_hp_dict,
                                 self._model_hp_dict, best_loss)
            elif self._network_type == self.CONV_FUSING:
                save_dir = '{0:s}/{1:s}/{2:s}/{3:d}/{4:f}'.format(
                    self._dataset, self._network_type, self._fusion_type,
                    self._fusion_level, self.convnet._model_hp_dict['p'])
                save_hyperparams(os.path.join(settings_dir,
                                              save_dir), self._opt_hp_dict,
                                 self._model_hp_dict, best_loss)
            elif self._network_type == self.DENSE_FUSING:
                save_dir = '{0:s}/{1:s}/{2:s}/{3:d}/{4:f}'.format(
                    self._dataset, self._network_type, self._fusion_type,
                    self._fusion_level, self.convnet._model_hp_dict['p'])
                save_hyperparams(os.path.join(settings_dir,
                                              save_dir), self._opt_hp_dict,
                                 self._model_hp_dict, best_loss)
            elif self._network_type == self.SCORE_FUSING:
                save_dir = '{0:s}/{1:s}/{2:s}/{3:f}'.format(
                    self._dataset, self._network_type, self._fusion_type,
                    self.convnet._model_hp_dict['p'])
                save_hyperparams(os.path.join(settings_dir,
                                              save_dir), self._opt_hp_dict,
                                 self._model_hp_dict, best_loss)
            elif self._network_type == self.INPUT_FUSING:
                save_dir = '{0:s}/{1:s}/{2:s}/{3:f}'.format(
                    self._dataset, self._network_type, self._fusion_type,
                    self.convnet._model_hp_dict['p'])
                save_hyperparams(os.path.join(settings_dir,
                                              save_dir), self._opt_hp_dict,
                                 self._model_hp_dict, best_loss)
        if save_loss:
            train_val_loss_dir = './train_val_loss'
            if not os.path.exists(train_val_loss_dir):
                os.mkdir(train_val_loss_dir)
            if self._network_type == self.SIMPLE:
                if self._input_channels == 1:
                    input_type = 'depth'
                elif self._input_channels == 4:
                    input_type = 'rgb'
                save_dir = '{0:s}/{1:s}/{2:s}/{3:f}'.format(
                    self._dataset, self._network_type, input_type,
                    self.convnet._model_hp_dict['p'])
                save_dir = os.path.join(train_val_loss_dir, save_dir)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                with open(os.path.join(save_dir, 'train_val_loss.pkl'), 'wb')\
                        as f:
                    pickle.dump(training_information,
                                f,
                                protocol=pickle.HIGHEST_PROTOCOL)
            elif self._network_type == self.CONV_FUSING:
                save_dir = '{0:s}/{1:s}/{2:s}/{3:d}/{4:f}'.format(
                    self._dataset, self._network_type, self._fusion_type,
                    self._fusion_level, self.convnet._model_hp_dict['p'])
                save_dir = os.path.join(train_val_loss_dir, save_dir)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                with open(os.path.join(save_dir, 'train_val_loss.pkl'), 'wb')\
                        as f:
                    pickle.dump(training_information,
                                f,
                                protocol=pickle.HIGHEST_PROTOCOL)
            elif self._network_type == self.DENSE_FUSING:
                save_dir = '{0:s}/{1:s}/{2:s}/{3:d}/{4:f}'.format(
                    self._dataset, self._network_type, self._fusion_type,
                    self._fusion_level, self.convnet._model_hp_dict['p'])
                save_dir = os.path.join(train_val_loss_dir, save_dir)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                with open(os.path.join(save_dir, 'train_val_loss.pkl'), 'wb')\
                        as f:
                    pickle.dump(training_information,
                                f,
                                protocol=pickle.HIGHEST_PROTOCOL)
            elif self._network_type == self.SCORE_FUSING:
                save_dir = '{0:s}/{1:s}/{2:s}/{3:f}'.format(
                    self._dataset, self._network_type, self._fusion_type,
                    self.convnet._model_hp_dict['p'])
                save_dir = os.path.join(train_val_loss_dir, save_dir)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                with open(os.path.join(save_dir, 'train_val_loss.pkl'), 'wb')\
                        as f:
                    pickle.dump(training_information,
                                f,
                                protocol=pickle.HIGHEST_PROTOCOL)
            elif self._network_type == self.INPUT_FUSING:
                save_dir = '{0:s}/{1:s}/{2:s}/{3:f}'.format(
                    self._dataset, self._network_type, self._fusion_type,
                    self.convnet._model_hp_dict['p'])
                save_dir = os.path.join(train_val_loss_dir, save_dir)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                with open(os.path.join(save_dir, 'train_val_loss.pkl'), 'wb')\
                        as f:
                    pickle.dump(training_information,
                                f,
                                protocol=pickle.HIGHEST_PROTOCOL)
        return training_information
Exemple #11
0
    def __init__(self, text, windowSize=5, negWords=15, embedDim=200, vocabSize=None, 
                 nOccur=10, phMinCount=5, phThresh=10, phDepth=2,
                 wInit='scaled-uniform', epochs=50, batchSize= 1024, 
                 optimizer='SGD', lr=0.01, patience=5, epsilon=1e-5, raw=False, 
                 tShuff=False, saveFreq=-1, restoreBest=True, outPath='./'):

        """ Args:
                text (nested list): input text as list of sentences.
                windowSize (int): size of the context window.
                negWords (int): number of negative words used in training.
                embedDim (int): dimensionality of the embedded space (default 200).
                vocabSize (int): size of the the vocabulary (default None)
                nOccur (int): minimum number of occurrencies to keep a word in the dictionary,
                          can be overwritten by vocabSiz (default 10).
                phMinCount (int): minimum number of occurrences to keep a phrase (default 5).
                phThresh (float): minimum score to keep a phrase (default 10).
                phDepth (int): number of recursions during phrase search (1 = bi-grams, default 2).
                wInit (string): distribution from which to draw initial node weights (only 'scaled-uniform'
                        and 'xavier' are currently available, default 'scaled-uniform').
                epochs (int): number of epochs  (default 50).
                batchSize (int): size of batches (default 1024).
                optimizer (str): optimizer choice, 'SGD' amd 'Adagrad' only 
                        (default 'SGD').
                lr (float): learning rage (default .01).
                patience (int): early stop patience (default 5).
                epsilon (float): early stop epsilon (default 1e-5).
                raw (bool): if True clean the input text (default True).                
                tShuff (bool): shuffle training set at each epoch (default false).
                saveFreq (int): frequency of model checkpoints, if < 0 don't save checkpoints (default -1).
                restoreBest (bool): restore and save best model by early stopping.
                outPath (string): path to directory where to save the trained models.
            """

        """ Set up training dataset and batches. """

        self.trainDs = textDataset(text, windowSize, negWords, vocabSize=vocabSize, nOccur=nOccur,
                                    phMinCount=phMinCount, phThresh=phThresh, phDepth=phDepth,  raw=raw)
        self.trainBatch = DataLoader(self.trainDs, batch_size = batchSize, shuffle = tShuff)
        
        """ Set up model """

        self.model = skipGram(int(self.trainDs.wDict.shape[0]), embedDim, wInit)

        """ Send model to GPU if available. """

        if torch.cuda.is_available():
            self.model.cuda()

        self.epochs = epochs
        

        if optimizer == 'SGD':
             # no momentum allowed with sparse matrices :(
            self.optimizer = SGD(self.model.parameters(), lr=lr)

        elif optimizer == 'Adagrad':
            self.optimizer = Adagrad(self.model.parameters(), lr=lr)

        else:
            print ('ERROR: '+optimizer+' is not available, please select SGD or Adagrad.')
            sys.exit(1)


        self.losses = []

        """ Set up early stopping. """

        self.earlStop = EarlyStopping(patience=patience, epsilon=epsilon, keepBest=True)
        self.restoreBest = restoreBest

        self.saveFreq = saveFreq
        if self.saveFreq < 0:
            self.saveFreq = self.epochs + 1 


        self.outPath = outPath
        if not os.path.exists(self.outPath):
            os.makedirs(self.outPath)
Exemple #12
0
def train_model(
        idx_np, name: str, model_class: Type[nn.Module], graph: SparseGraph, model_args: dict,
        learning_rate: float, reg_lambda: float,
        stopping_args: dict = stopping_args,
        test: bool = True, device: str = 'cuda',
        torch_seed: int = None, print_interval: int = 10) -> Tuple[nn.Module, dict]:

    labels_all = graph.labels
    idx_all = {key: torch.LongTensor(val) for key, val in idx_np.items()}

    logging.log(21, f"{model_class.__name__}: {model_args}")
    if torch_seed is None:
        torch_seed = gen_seeds()
    torch.manual_seed(seed=torch_seed)
    logging.log(22, f"PyTorch seed: {torch_seed}")

    nfeatures = graph.attr_matrix.shape[1]
    nclasses = max(labels_all) + 1
    model = model_class(nfeatures, nclasses, **model_args).to(device)

    reg_lambda = torch.tensor(reg_lambda, device=device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    dataloaders = get_dataloaders(idx_all, labels_all)
    early_stopping = EarlyStopping(model, **stopping_args)
    attr_mat_norm_np = normalize_attributes(graph.attr_matrix)
    attr_mat_norm = matrix_to_torch(attr_mat_norm_np).to(device)

    epoch_stats = {'train': {}, 'stopping': {}}

    start_time = time.time()
    last_time = start_time
    for epoch in range(early_stopping.max_epochs):
        for phase in epoch_stats.keys():

            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0
            running_corrects = 0

            for idx, labels in dataloaders[phase]:
                idx = idx.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):


                    log_preds = model(attr_mat_norm, idx)
                    preds = torch.argmax(log_preds, dim=1)

                    # Calculate loss
                    cross_entropy_mean = F.nll_loss(log_preds, labels)
                    l2_reg = sum((torch.sum(param ** 2) for param in model.reg_params))
                    loss = cross_entropy_mean + reg_lambda / 2 * l2_reg

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    # Collect statistics
                    running_loss += loss.item() * idx.size(0)
                    running_corrects += torch.sum(preds == labels)


            # Collect statistics
            epoch_stats[phase]['loss'] = running_loss / len(dataloaders[phase].dataset)
            epoch_stats[phase]['acc'] = running_corrects.item() / len(dataloaders[phase].dataset)

        if epoch % print_interval == 0:
            duration = time.time() - last_time
            last_time = time.time()
            print(f"Epoch{epoch}:  "
                         f"Train loss = {epoch_stats['train']['loss']:.2f}, "
                         f"train acc = {epoch_stats['train']['acc'] * 100:.1f}, "
                         f"early stopping loss = {epoch_stats['stopping']['loss']:.2f}, "
                         f"early stopping acc = {epoch_stats['stopping']['acc'] * 100:.1f} "
                         f"({duration:.3f} sec)")

            logging.info(f"Epoch {epoch}: "
                         f"Train loss = {epoch_stats['train']['loss']:.2f}, "
                         f"train acc = {epoch_stats['train']['acc'] * 100:.1f}, "
                         f"early stopping loss = {epoch_stats['stopping']['loss']:.2f}, "
                         f"early stopping acc = {epoch_stats['stopping']['acc'] * 100:.1f} "
                         f"({duration:.3f} sec)")

        if len(early_stopping.stop_vars) > 0:
            stop_vars = [epoch_stats['stopping'][key]
                         for key in early_stopping.stop_vars]
            if early_stopping.check(stop_vars, epoch):
                break
    runtime = time.time() - start_time
    runtime_perepoch = runtime / (epoch + 1)
    logging.log(22, f"Last epoch: {epoch}, best epoch: {early_stopping.best_epoch} ({runtime:.3f} sec)")

    # Load best model weights

    model.load_state_dict(early_stopping.best_state, False)

    train_preds = get_predictions(model, attr_mat_norm, idx_all['train'])
    train_acc = (train_preds == labels_all[idx_all['train']]).mean()

    stopping_preds = get_predictions(model, attr_mat_norm, idx_all['stopping'])
    stopping_acc = (stopping_preds == labels_all[idx_all['stopping']]).mean()
    logging.log(21, f"Early stopping accuracy: {stopping_acc * 100:.1f}%")

    valtest_preds = get_predictions(model, attr_mat_norm, idx_all['valtest'])
    valtest_acc = (valtest_preds == labels_all[idx_all['valtest']]).mean()
    valtest_name = 'Test' if test else 'Validation'
    logging.log(22, f"{valtest_name} accuracy: {valtest_acc * 100:.1f}%")
    
    result = {}
    result['predictions'] = get_predictions(model, attr_mat_norm, torch.arange(len(labels_all)))
    result['train'] = {'accuracy': train_acc}
    result['early_stopping'] = {'accuracy': stopping_acc}
    result['valtest'] = {'accuracy': valtest_acc}
    result['runtime'] = runtime
    result['runtime_perepoch'] = runtime_perepoch

    return model, result
                              batch_size=VALBATCHSIZE,
                              pin_memory=True)
testloader = DataLoader(testset, batch_size=VALBATCHSIZE, pin_memory=True)

model = Net(network.edge_index.to(device), network.num_nodes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# to track the training loss as the model trains
train_losses = []
# to track the validation loss as the model trains
valid_losses = []
# to track the average training loss per epoch as the model trains
avg_train_losses = []
# to track the average validation loss per epoch as the model trains
avg_valid_losses = []
early_stopping = EarlyStopping(patience=PATIENCE, verbose=True)

for epoch in range(EPOCHS):
    for batch in loader:
        model.train()
        batch = batch.to(device)
        optimizer.zero_grad()
        trainmask, testmask = masks[randint(
            0, MASKCOUNT - 1
        )]  # chose a random mask from the set of masks, use for this batch
        trainmask = trainmask[0:len(batch)].to(device)
        testmask = testmask[0:len(batch)].to(device)
        batch *= trainmask  # mask (set to zero) the features of the training nodes in this batch
        out = model(batch)
        loss = F.mse_loss(
            out[testmask],
Exemple #14
0
def train():
    '''
    main function to run the training
    '''
    restore = False
    #TIMESTAMP = "2020-03-09T00-00-00"
    if args.timestamp == "NA":
        TIMESTAMP = datetime.now().strftime("%b%d-%H%M%S")
        print('TIMESTAMP', TIMESTAMP)
    else:
        # restore
        restore = True
        TIMESTAMP = args.timestamp
    save_dir = './save_model/' + TIMESTAMP

    if restore:
        # restore args
        with open(os.path.join(save_dir, 'cmd_args.txt'), 'r') as f:
            args.__dict__ = json.load(f)

    encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
    decoder = Decoder(decoder_params[0], decoder_params[1],
                      args.frames_output).cuda()
    net = ED(encoder, decoder)
    run_dir = './runs/' + TIMESTAMP
    if not os.path.isdir(run_dir):
        os.makedirs(run_dir)
    tb = SummaryWriter(run_dir)
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=30, verbose=True)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    if os.path.exists(os.path.join(save_dir, 'checkpoint.pth.tar')):
        # load existing model
        print('==> loading existing model')
        model_info = torch.load(os.path.join(save_dir, 'checkpoin.pth.tar'))
        net.load_state_dict(model_info['state_dict'])
        optimizer = torch.optim.Adam(net.parameters())
        optimizer.load_state_dict(model_info['optimizer'])
        cur_epoch = model_info['epoch'] + 1
    else:
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        cur_epoch = 0
    lossfunction = nn.MSELoss().cuda()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=4,
                                                      verbose=True)

    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    # mini_val_loss = np.inf
    for epoch in range(cur_epoch, args.epochs + 1):
        ###################
        # train the model #
        ###################
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):

            inputs = inputVar.to(device)  # B,S,C,H,W
            label = targetVar.to(device)  # B,S,C,H,W
            optimizer.zero_grad()
            net.train()
            pred = net(inputs)  # B,S,C,H,W
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / args.batch_size
            train_losses.append(loss_aver)
            loss.backward()
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0)
            optimizer.step()
            t.set_postfix({
                'trainloss': '{:.6f}'.format(loss_aver),
                'epoch': '{:02d}'.format(epoch)
            })
        tb.add_scalar('TrainLoss', loss_aver, epoch)
        ######################
        # validate the model #
        ######################
        with torch.no_grad():
            net.eval()
            t = tqdm(validLoader, leave=False, total=len(validLoader))
            for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
                if i == 3000:
                    break
                inputs = inputVar.to(device)
                label = targetVar.to(device)
                pred = net(inputs)
                loss = lossfunction(pred, label)
                loss_aver = loss.item() / args.batch_size
                # record validation loss
                valid_losses.append(loss_aver)
                #print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
                t.set_postfix({
                    'validloss': '{:.6f}'.format(loss_aver),
                    'epoch': '{:02d}'.format(epoch)
                })

        tb.add_scalar('ValidLoss', loss_aver, epoch)
        torch.cuda.empty_cache()
        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = len(str(args.epochs))

        print_msg = (f'[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.6f} ' +
                     f'valid_loss: {valid_loss:.6f}')

        print(print_msg)
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        pla_lr_scheduler.step(valid_loss)  # lr_scheduler
        model_dict = {
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        early_stopping(valid_loss.item(), model_dict, epoch, save_dir)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    with open("avg_train_losses.txt", 'wt') as f:
        for i in avg_train_losses:
            print(i, file=f)

    with open("avg_valid_losses.txt", 'wt') as f:
        for i in avg_valid_losses:
            print(i, file=f)

    # save args
    if not restore:
        with open(os.path.join(save_dir, 'cmd_args.txt'), 'w+') as f:
            json.dump(args.__dict__, f, indent=2)
Exemple #15
0
def main(args):
    # fix random seeds
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
    print(device)
    criterion = nn.CrossEntropyLoss()
    cluster_log = Logger(os.path.join(args.exp, 'clusters.pickle'))

    # CNN
    if args.verbose:
        print('Architecture: {}'.format(args.arch))

    '''
    ##########################################
    ##########################################
    # Model definition
    ##########################################
    ##########################################'''
    model = models.__dict__[args.arch](bn=True, num_cluster=args.nmb_cluster, num_category=args.nmb_category)
    fd = int(model.cluster_layer[0].weight.size()[1])  # due to transpose, fd is input dim of W (in dim, out dim)
    model.cluster_layer = None
    model.category_layer = None
    model.features = torch.nn.DataParallel(model.features)
    model = model.double()
    model.to(device)
    cudnn.benchmark = True

    if args.optimizer is 'Adam':
        print('Adam optimizer: conv')
        optimizer_body = torch.optim.Adam(
            filter(lambda x: x.requires_grad, model.parameters()),
            lr=args.lr_Adam,
            betas=(0.9, 0.999),
            weight_decay=10 ** args.wd,
        )
    else:
        print('SGD optimizer: conv')
        optimizer_body = torch.optim.SGD(
            filter(lambda x: x.requires_grad, model.parameters()),
            lr=args.lr_SGD,
            momentum=args.momentum,
            weight_decay=10 ** args.wd,
        )
    '''
    ###############
    ###############
    category_layer
    ###############
    ###############
    '''
    model.category_layer = nn.Sequential(
        nn.Linear(fd, args.nmb_category),
        nn.Softmax(dim=1),
    )
    model.category_layer[0].weight.data.normal_(0, 0.01)
    model.category_layer[0].bias.data.zero_()
    model.category_layer = model.category_layer.double()
    model.category_layer.to(device)

    '''
    ############################
    ############################
    # EarlyStopping (test_accuracy_bal, 100)
    ############################
    ############################
    '''
    early_stopping = EarlyStopping(model, **stopping_args)
    stop_vars = []

    if args.optimizer is 'Adam':
        print('Adam optimizer: conv')
        optimizer_category = torch.optim.Adam(
            filter(lambda x: x.requires_grad, model.category_layer.parameters()),
            lr=args.lr_Adam,
            betas=(0.9, 0.999),
            weight_decay=10 ** args.wd,
        )
    else:
        print('SGD optimizer: conv')
        optimizer_category = torch.optim.SGD(
            filter(lambda x: x.requires_grad, model.category_layer.parameters()),
            lr=args.lr_SGD,
            momentum=args.momentum,
            weight_decay=10 ** args.wd,
        )
    '''
    ########################################
    ########################################
    Create echogram sampling index
    ########################################
    ########################################'''

    print('Sample echograms.')
    dataset_cp, dataset_semi = sampling_echograms_full(args)
    dataloader_cp = torch.utils.data.DataLoader(dataset_cp,
                                                shuffle=False,
                                                batch_size=args.batch,
                                                num_workers=args.workers,
                                                drop_last=False,
                                                pin_memory=True)

    dataloader_semi = torch.utils.data.DataLoader(dataset_semi,
                                                shuffle=False,
                                                batch_size=args.batch,
                                                num_workers=args.workers,
                                                drop_last=False,
                                                pin_memory=True)

    dataset_test_bal, dataset_test_unbal = sampling_echograms_test(args)
    dataloader_test_bal = torch.utils.data.DataLoader(dataset_test_bal,
                                                shuffle=False,
                                                batch_size=args.batch,
                                                num_workers=args.workers,
                                                drop_last=False,
                                                pin_memory=True)

    dataloader_test_unbal = torch.utils.data.DataLoader(dataset_test_unbal,
                                                shuffle=False,
                                                batch_size=args.batch,
                                                num_workers=args.workers,
                                                drop_last=False,
                                                pin_memory=True)

    # clustering algorithm to use
    deepcluster = clustering.__dict__[args.clustering](args.nmb_cluster, args.pca)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            # remove top located layer parameters from checkpoint
            copy_checkpoint_state_dict = checkpoint['state_dict'].copy()
            for key in list(copy_checkpoint_state_dict):
                if 'cluster_layer' in key:
                    del copy_checkpoint_state_dict[key]
                # if 'category_layer' in key:
                #     del copy_checkpoint_state_dict[key]
            checkpoint['state_dict'] = copy_checkpoint_state_dict
            model.load_state_dict(checkpoint['state_dict'])
            optimizer_body.load_state_dict(checkpoint['optimizer_body'])
            optimizer_category.load_state_dict(checkpoint['optimizer_category'])
            category_save = os.path.join(args.exp,  'category_layer.pth.tar')
            if os.path.isfile(category_save):
                category_layer_param = torch.load(category_save)
                model.category_layer.load_state_dict(category_layer_param)
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # creating checkpoint repo
    exp_check = os.path.join(args.exp, 'checkpoints')
    if not os.path.isdir(exp_check):
        os.makedirs(exp_check)

    exp_bal = os.path.join(args.exp, 'bal')
    exp_unbal = os.path.join(args.exp, 'unbal')
    for dir_bal in [exp_bal, exp_unbal]:
        for dir_2 in ['features', 'pca_features', 'pred']:
            dir_to_make = os.path.join(dir_bal, dir_2)
            if not os.path.isdir(dir_to_make):
                os.makedirs(dir_to_make)

    if os.path.isfile(os.path.join(args.exp, 'loss_collect.pickle')):
        with open(os.path.join(args.exp, 'loss_collect.pickle'), "rb") as f:
            loss_collect = pickle.load(f)
    else:
        loss_collect = [[], [], [], [], [], [], [], [], []]

    if os.path.isfile(os.path.join(args.exp, 'nmi_collect.pickle')):
        with open(os.path.join(args.exp, 'nmi_collect.pickle'), "rb") as ff:
            nmi_save = pickle.load(ff)
    else:
        nmi_save = []
    '''
    #######################
    #######################
    MAIN TRAINING
    #######################
    #######################'''
    for epoch in range(args.start_epoch, early_stopping.max_epochs):
        end = time.time()
        print('#####################  Start training at Epoch %d ################'% epoch)
        model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1]
        model.cluster_layer = None
        model.category_layer = None

        '''
        #######################
        #######################
        PSEUDO-LABEL GENERATION
        #######################
        #######################
        '''
        print('Cluster the features')
        features_train, input_tensors_train, labels_train = compute_features(dataloader_cp, model, len(dataset_cp), device=device, args=args)
        clustering_loss, pca_features = deepcluster.cluster(features_train, verbose=args.verbose)

        nan_location = np.isnan(pca_features)
        inf_location = np.isinf(pca_features)
        if (not np.allclose(nan_location, 0)) or (not np.allclose(inf_location, 0)):
            print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location), ' Inf count: ', np.sum(inf_location))
            print('Skip epoch ', epoch)
            torch.save(pca_features, 'tr_pca_NaN_%d.pth.tar' % epoch)
            torch.save(features_train, 'tr_feature_NaN_%d.pth.tar' % epoch)
            continue

        print('Assign pseudo labels')
        size_cluster = np.zeros(len(deepcluster.images_lists))
        for i,  _list in enumerate(deepcluster.images_lists):
            size_cluster[i] = len(_list)
        print('size in clusters: ', size_cluster)
        img_label_pair_train = zip_img_label(input_tensors_train, labels_train)
        train_dataset = clustering.cluster_assign(deepcluster.images_lists,
                                                  img_label_pair_train)  # Reassigned pseudolabel

        # uniformly sample per target
        sampler_train = UnifLabelSampler(int(len(train_dataset)),
                                   deepcluster.images_lists)

        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch,
            shuffle=False,
            num_workers=args.workers,
            sampler=sampler_train,
            pin_memory=True,
        )
        '''
        ####################################################################
        ####################################################################
        TRSNSFORM MODEL FOR SELF-SUPERVISION // SEMI-SUPERVISION
        ####################################################################
        ####################################################################
        '''
        # Recover classifier with ReLU (that is not used in clustering)
        mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end
        mlp.append(nn.ReLU(inplace=True).to(device))
        model.classifier = nn.Sequential(*mlp)
        model.classifier.to(device)

        '''SELF-SUPERVISION (PSEUDO-LABELS)'''
        model.category_layer = None
        model.cluster_layer = nn.Sequential(
            nn.Linear(fd, args.nmb_cluster),  # nn.Linear(4096, num_cluster),
            nn.Softmax(dim=1),  # should be removed and replaced by ReLU for category_layer
        )
        model.cluster_layer[0].weight.data.normal_(0, 0.01)
        model.cluster_layer[0].bias.data.zero_()
        model.cluster_layer = model.cluster_layer.double()
        model.cluster_layer.to(device)

        ''' train network with clusters as pseudo-labels '''
        with torch.autograd.set_detect_anomaly(True):
            pseudo_loss, semi_loss, semi_accuracy = semi_train(train_dataloader, dataloader_semi, model, fd, criterion,
                                                               optimizer_body, optimizer_category, epoch, device=device, args=args)

        # save checkpoint
        if (epoch + 1) % args.checkpoints == 0:
            path = os.path.join(
                args.exp,
                'checkpoints',
                'checkpoint_' + str(epoch) + '.pth.tar',
            )
            if args.verbose:
                print('Save checkpoint at: {0}'.format(path))
            torch.save({'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer_body': optimizer_body.state_dict(),
                        'optimizer_category': optimizer_category.state_dict(),
                        }, path)


        '''
        ##############
        ##############
        # TEST phase
        ##############
        ##############
        '''
        test_loss_bal, test_accuracy_bal, test_pred_bal, test_label_bal = test(dataloader_test_bal, model, criterion, device, args)
        test_loss_unbal, test_accuracy_unbal, test_pred_unbal, test_label_unbal = test(dataloader_test_unbal, model, criterion, device, args)

        '''Save prediction of the test set'''
        if (epoch % args.save_epoch == 0):
            with open(os.path.join(args.exp, 'bal', 'pred', 'sup_epoch_%d_te_bal.pickle' % epoch), "wb") as f:
                pickle.dump([test_pred_bal, test_label_bal], f)
            with open(os.path.join(args.exp, 'unbal', 'pred', 'sup_epoch_%d_te_unbal.pickle' % epoch), "wb") as f:
                pickle.dump([test_pred_unbal, test_label_unbal], f)

        if args.verbose:
            print('###### Epoch [{0}] ###### \n'
                  'Time: {1:.3f} s\n'
                  'Pseudo tr_loss: {2:.3f} \n'
                  'SEMI tr_loss: {3:.3f} \n'
                  'TEST_bal loss: {4:.3f} \n'
                  'TEST_unbal loss: {5:.3f} \n'
                  'Clustering loss: {6:.3f} \n\n'
                  'SEMI accu: {7:.3f} \n'
                  'TEST_bal accu: {8:.3f} \n'
                  'TEST_unbal accu: {9:.3f} \n'
                  .format(epoch, time.time() - end, pseudo_loss, semi_loss,
                          test_loss_bal, test_loss_unbal, clustering_loss, semi_accuracy, test_accuracy_bal, test_accuracy_unbal))
            try:
                nmi = normalized_mutual_info_score(
                    clustering.arrange_clustering(deepcluster.images_lists),
                    clustering.arrange_clustering(cluster_log.data[-1])
                )
                nmi_save.append(nmi)
                print('NMI against previous assignment: {0:.3f}'.format(nmi))
                with open(os.path.join(args.exp, 'nmi_collect.pickle'), "wb") as ff:
                    pickle.dump(nmi_save, ff)
            except IndexError:
                pass
            print('####################### \n')

        # save cluster assignments
        cluster_log.log(deepcluster.images_lists)

        # save running checkpoint
        torch.save({'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer_body': optimizer_body.state_dict(),
                    'optimizer_category': optimizer_category.state_dict(),
                    },
                   os.path.join(args.exp, 'checkpoint.pth.tar'))
        torch.save(model.category_layer.state_dict(), os.path.join(args.exp, 'category_layer.pth.tar'))

        loss_collect[0].append(epoch)
        loss_collect[1].append(pseudo_loss)
        loss_collect[2].append(semi_loss)
        loss_collect[3].append(clustering_loss)
        loss_collect[4].append(test_loss_bal)
        loss_collect[5].append(test_loss_unbal)
        loss_collect[6].append(semi_accuracy)
        loss_collect[7].append(test_accuracy_bal)
        loss_collect[8].append(test_accuracy_unbal)
        with open(os.path.join(args.exp, 'loss_collect.pickle'), "wb") as f:
            pickle.dump(loss_collect, f)

        if (epoch % args.save_epoch == 0):
            out = produce_test_result_bal(epoch, model, dataloader_test_bal, dataset_test_bal, device, args, deepcluster)
            out = produce_test_result_unbal(epoch, model, dataloader_test_unbal, dataset_test_unbal, device, args, deepcluster)

        '''EarlyStopping'''
        if early_stopping.check(loss_collect[7], epoch):
            break

    out = produce_test_result_bal(epoch, model, dataloader_test_bal, dataset_test_bal, device, args, deepcluster)
    out = produce_test_result_unbal(epoch, model, dataloader_test_unbal, dataset_test_unbal, device, args,
                                        deepcluster)


        '''
Exemple #16
0
    #optimizers
    ae_opt = torch.optim.Adam(auto_encoder.parameters(), lr=lr)
    mlp_opt = torch.optim.Adam(mlp.parameters(), lr=lr)

    #LR schedulers
    ae_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(ae_opt, 'min')
    mlp_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(mlp_opt, 'min')

    #MSE loss
    loss_fnAE = nn.MSELoss()
    loss_fnMLP = nn.MSELoss()

    #Earlystopping
    model_weights = f'{CACHE_PATH}/model_{_fold}.pkl'
    es = EarlyStopping(patience=10, mode='min')

    #train-test split
    trainDataset = MarketDataset(train.loc[tr], feat_cols, target_cols)
    valDataset = MarketDataset(train.loc[te], feat_cols, target_cols)

    trainLoader = DataLoader(trainDataset, batch_size=128)
    valLoader = DataLoader(trainDataset, batch_size=128)

    for epoch in (t := trange(EPOCHS)):
        train_lossAE = train_ae(auto_encoder, ae_opt, trainLoader, loss_fnAE,
                                device)
        ae_scheduler.step(train_lossAE)

        train_lossMLP = train_mlp(mlp, auto_encoder, mlp_opt, trainLoader,
                                  loss_fnMLP, device)
Exemple #17
0
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=50, factor=0.618)
scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[200, 300, 400, 500, 600, 700], gamma=0.6)
# convert to cuda
if args.cuda:
    model.cuda()

# For the mix mode, lables and indexes are in cuda.
if args.cuda or args.mixmode:
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()

if args.warm_start is not None and args.warm_start != "":
    early_stopping = EarlyStopping(fname=args.warm_start, verbose=False)
    print("Restore checkpoint from %s" % (early_stopping.fname))
    model.load_state_dict(early_stopping.load_checkpoint())

# set early_stopping
if args.early_stopping > 0:
    early_stopping = EarlyStopping(patience=args.early_stopping, verbose=False)
    print("Model is saving to: %s" % (early_stopping.fname))

if args.no_tensorboard is False:
    tb_writer = SummaryWriter(
        comment=f"-dataset_{args.dataset}-type_{args.type}")


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
def trainNet(model, train_loader, val_loader, val_loader_ttimes, \
                device, adj, nn_ixs, edge_index, coords=None):

    # Print all of the hyper parameters of the training iteration:
    print("===== HYPERPARAMETERS =====")
    print("batch_size=", config['dataloader']['batch_size'])
    print("epochs=", config['num_epochs'])
    print("learning_rate=", config['optimizer']['lr'])
    print("mask_threshold=", config['mask_threshold'])
    print("nh1=", config['model']['KipfNet']['nh1'])
    print("K=", config['model']['KipfNet']['K'])
    print("K_mix=", config['model']['KipfNet']['K_mix'])
    print("inout_skipconn=", config['model']['KipfNet']['inout_skipconn'])
    print("=" * 30)
    # define the optimizer & learning rate
    #optim = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])#
    if config['optimizer_name'] == 'SGD':
        optim = torch.optim.SGD(model.parameters(), **config['optimizer'])
    elif config['optimizer_name'] == 'ADAM':
        optim = torch.optim.Adam(model.parameters(), **config['optimizer'])

    # scheduler = StepLR(optim, step_size=config['lr_step_size'], gamma=config['lr_gamma'])

    nh1 = config['model']['KipfNet']['nh1']
    K = config['model']['KipfNet']['K']
    K_mix = config['model']['KipfNet']['K_mix']
    inout_skipconn = config['model']['KipfNet']['inout_skipconn']

    log_dir = 'runs/graphs/' + 'KipfNet' + '_nh1=' + str(nh1) \
                + '_K=' + str(K) + '_Kmix=' + str(K_mix) \
                + '_skip_conn' + str(inout_skipconn) \
                + '_' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-") \
                + '-'.join(config['dataset']['cities'])
    writer = Visualizer(log_dir)

    # dump config file
    with open(os.path.join(log_dir, 'config.json'), 'w') as fp:
        json.dump(config, fp)

    # Time for printing
    training_start_time = time.time()
    globaliter = 0

    # initialize the early_stopping object
    early_stopping = EarlyStopping(log_dir,
                                   patience=config['patience'],
                                   verbose=True)
    #    adj = adj.to(device)

    # Loop for n_epochs
    for epoch_idx, epoch in enumerate(range(config['num_epochs'])):

        writer.write_lr(optim, globaliter)

        # train for one epoch
        globaliter = train(model,
                           train_loader,
                           optim,
                           device,
                           writer,
                           epoch,
                           globaliter,
                           adj,
                           nn_ixs,
                           edge_index,
                           coords=coords)

        # At the end of the epoch, do a pass on the validation set
        # val_loss = validate(model, val_loader, device, writer, globaliter, adj, nn_ixs, edge_index)
        val_loss_testtimes = validate(model,
                                      val_loader_ttimes,
                                      device,
                                      writer,
                                      globaliter,
                                      adj,
                                      nn_ixs,
                                      edge_index,
                                      if_testtimes=True,
                                      coords=coords)
        # early_stopping needs the validation loss to check if it has decresed,
        # and if it has, it will make a checkpoint of the current model
        early_stopping(val_loss_testtimes, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

        if config['debug'] and epoch_idx >= 0:
            break

        # scheduler.step()

    print("Training finished, took {:.2f}s".format(time.time() -
                                                   training_start_time))

    #    model.load_state_dict(torch.load('runs/checkpoint.pt'))

    # remember to close
    writer.close()
Exemple #19
0
# The model
classifier_model = LinearClassifier(input_dim=args.hidden, output_dim=nclass)

unsupervised_model = SAGE_Full(nfeat,
                      args.hidden,
                      args.nhiddenlayer,
                      F.relu,
                      args.dropout,
                      'gcn')


optimizer = optim.Adam(classifier_model.parameters(),
                       lr=args.lr, weight_decay=args.weight_decay)

early_stopping = EarlyStopping(patience=args.early_stopping, fname='best_classifier.model',
                               save_model_pth=model_save_pth)

# # define contrastive model
# contrast = NCEAverage(args.hidden, n_nodes, args.nce_k, args.nce_t, args.nce_m, args.softmax)
# criterion_v1 = NCESoftmaxLoss() if args.softmax else NCECriterion(n_nodes)
# criterion_v2 = NCESoftmaxLoss() if args.softmax else NCECriterion(n_nodes)

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[200, 300, 400, 500, 600, 700], gamma=0.5)
# convert to cuda
if args.cuda:
    classifier_model.cuda()
    unsupervised_model.cuda()
    # contrast.cuda()
    # criterion_v1.cuda()
    # criterion_v2.cuda()
Exemple #20
0
testWords = data_preprocess_test(testEval)

#augmenting the dataset size
trainWords, trainLabels, trainTags = data_augment(trainWords, trainTags,
                                                  trainLabels)

layers = 1030
dropout = 0.2

tokenizer = erine_larg_tok
pretrained = erine_larg_mod

model = Model(pretrained, tokenizer, layers, dropout)

model_path = 'bert_no_aug.pth'
early_stopping = EarlyStopping(model_path, 4, True)
optimizer = optim.Adamax(model.parameters(), lr=0.001)
loss_func = nn.MSELoss(reduction='mean')

batch = 100
folds = 5
epoch = 6

combined_data = trainWords
combined_tags = trainTags
combined_labels = trainLabels
data = np.asarray(combined_data)
tags = np.asarray(combined_tags)
labels = np.asarray(combined_labels)
kf = KFold(n_splits=folds, random_state=0, shuffle=True)
Exemple #21
0
def train():
    '''
    main function to run the training
    '''
    encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
    decoder = Decoder(decoder_params[0], decoder_params[1]).cuda()
    net = ED(encoder, decoder)
    run_dir = './runs/' + TIMESTAMP + args.mname
    if not os.path.isdir(run_dir):
        os.makedirs(run_dir)
    tb = SummaryWriter(run_dir)
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=200, verbose=True)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    if os.path.exists(os.path.join(save_dir, 'checkpoint.pth.tar')):
        # load existing model
        print('==> loading existing model')
        model_info = torch.load(os.path.join(save_dir, 'checkpoin.pth.tar'))
        net.load_state_dict(model_info['state_dict'])
        optimizer = torch.optim.Adam(net.parameters())
        optimizer.load_state_dict(model_info['optimizer'])
        cur_epoch = model_info['epoch'] + 1
    else:
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        cur_epoch = 0

    class_weights = torch.FloatTensor([1.0, 15.0]).cuda()    
    lossfunction = nn.CrossEntropyLoss(weight=class_weights).cuda()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=5,
                                                      verbose=True)

    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    min_train_loss = np.inf
    for epoch in range(cur_epoch, args.epochs + 1):
        print(time.strftime("now time: %Y%m%d_%H:%M", time.localtime(time.time())))
        ###################
        # train the model #
        ###################
        # tqdm 进度条
        t = tqdm(trainLoader, total=len(trainLoader))
        for i, (seq_len, scan_seq, _, mask_seq, _) in enumerate(t):
            # 序列长度不固定,至少前2帧用来输入,固定预测后3帧
            inputs = inputs = torch.cat((scan_seq, mask_seq.float()), dim=2).to(device)[:,:-3,...]   # B,S,C,H,W
            label = mask_seq.to(device)[:,(seq_len-3):,...]     # B,S,C,H,W    
            optimizer.zero_grad()
            net.train()         # 将module设置为 training mode,只影响dropout和batchNorm
            pred = net(inputs)  # B,S,C,H,W

            # 在tensorboard中绘制可视化结果
            if i % 100 == 0:
                grid_ri_lab, grid_pred = get_visualization_example(scan_seq.to(device), mask_seq.to(device), pred, device)
                tb.add_image('visualization/train/rangeImage_gtMask', grid_ri_lab, global_step=epoch)
                tb.add_image('visualization/train/prediction', grid_pred, global_step=epoch)
                
            seq_number, batch_size, input_channel, height, width = pred.size() 
            pred = pred.reshape(-1, input_channel, height, width)  # reshape to B*S,C,H,W
            seq_number, batch_size, input_channel, height, width = label.size() 
            label = label.reshape(-1, height, width) # reshape to B*S,H,W
            label = label.to(device=device, dtype=torch.long)
            # 计算loss
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / (label.shape[0] * batch_size) 
            train_losses.append(loss_aver)
            loss.backward()
            # 防止梯度爆炸,进行梯度裁剪,指定clip_value之后,裁剪的范围就是[-clip_value, clip_value]
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=30.0)
            optimizer.step()
            t.set_postfix({
                'trainloss': '{:.6f}'.format(loss_aver),
                'epoch': '{:02d}'.format(epoch)
            })
        tb.add_scalar('TrainLoss', np.average(train_losses), epoch)
        ######################
        # validate the model #
        ######################
        with torch.no_grad():
            # 将module设置为 eval模式, 只影响dropout和batchNorm
            net.eval()
            # tqdm 进度条
            t = tqdm(validLoader, total=len(validLoader))
            for i, (seq_len, scan_seq, _, mask_seq, _) in enumerate(t):
                if i == 300:    # 限制 validate 数量
                    break
                # 序列长度不固定,至少前2帧用来输入,固定预测后3帧
                inputs = torch.cat((scan_seq, mask_seq.float()), dim=2).to(device)   # B,S,C,H,W
                label = mask_seq.to(device)[:,(seq_len-3):,...]     # B,S,C,H,W    
                pred = net(inputs)

                # 在tensorboard中绘制可视化结果
                if i % 100 == 0:
                    grid_ri_lab, grid_pred = get_visualization_example(scan_seq.to(device), mask_seq.to(device), pred, device)
                    tb.add_image('visualization/valid/rangeImage_gtMask', grid_ri_lab, global_step=epoch)
                    tb.add_image('visualization/valid/prediction', grid_pred, global_step=epoch)
                    
                seq_number, batch_size, input_channel, height, width = pred.size() 
                pred = pred.reshape(-1, input_channel, height, width)  # reshape to B*S,C,H,W
                seq_number, batch_size, input_channel, height, width = label.size() 
                label = label.reshape(-1, height, width) # reshape to B*S,H,W
                label = label.to(device=device, dtype=torch.long)
                loss = lossfunction(pred, label)
                loss_aver = loss.item() / (label.shape[0] * batch_size)
                # record validation loss
                valid_losses.append(loss_aver)
                #print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
                t.set_postfix({
                    'validloss': '{:.6f}'.format(loss_aver),
                    'epoch': '{:02d}'.format(epoch)
                })
                # get_visualization_example(inputs, label, pred)

        tb.add_scalar('ValidLoss', np.average(valid_losses), epoch)
        torch.cuda.empty_cache()
        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = len(str(args.epochs))

        print_msg = (f'[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.6f} ' +
                     f'valid_loss: {valid_loss:.6f}')

        print(print_msg)
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        pla_lr_scheduler.step(valid_loss)  # lr_scheduler
        model_dict = {
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        # 保存train loss最低的模型
        if (train_loss < min_train_loss):
            torch.save(model_dict, save_dir + "/" + "best_train_checkpoint.pth.tar")
            min_train_loss = train_loss
        # 保存valid loss最低的模型
        early_stopping(valid_loss.item(), model_dict, epoch, save_dir)
        if early_stopping.early_stop:
            print("Early stopping")
            break
    # end for

    with open("avg_train_losses.txt", 'wt') as f:
        for i in avg_train_losses:
            print(i, file=f)

    with open("avg_valid_losses.txt", 'wt') as f:
        for i in avg_valid_losses:
            print(i, file=f)
def train_SGCN(treeDic, x_test, x_train, TDdroprate, BUdroprate, lr,
               weight_decay, patience, n_epochs, batchsize, dataname, iter):
    model = Net(5000, 64, 64).to(device)
    BU_params = list(map(id, model.BUrumorSGCN.conv1.parameters()))
    BU_params += list(map(id, model.BUrumorSGCN.conv2.parameters()))
    base_params = filter(lambda p: id(p) not in BU_params, model.parameters())
    optimizer = th.optim.Adam([{
        'params': base_params
    }, {
        'params': model.BUrumorSGCN.conv1.parameters(),
        'lr': lr / 5
    }, {
        'params': model.BUrumorSGCN.conv2.parameters(),
        'lr': lr / 5
    }],
                              lr=lr,
                              weight_decay=weight_decay)
    model.train()
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    for epoch in range(n_epochs):
        traindata_list, testdata_list = loadBiData(dataname, treeDic, x_train,
                                                   x_test, TDdroprate,
                                                   BUdroprate)
        train_loader = DataLoader(traindata_list,
                                  batch_size=batchsize,
                                  shuffle=True,
                                  num_workers=0)
        test_loader = DataLoader(testdata_list,
                                 batch_size=batchsize,
                                 shuffle=True,
                                 num_workers=0)
        avg_loss = []
        avg_acc = []
        batch_idx = 0
        tqdm_train_loader = tqdm(train_loader)
        for Batch_data in tqdm_train_loader:
            Batch_data.to(device)
            out_labels = model(Batch_data)
            finalloss = F.nll_loss(out_labels, Batch_data.y)
            loss = finalloss
            optimizer.zero_grad()
            loss.backward()
            avg_loss.append(loss.item())
            optimizer.step()
            _, pred = out_labels.max(dim=-1)
            correct = pred.eq(Batch_data.y).sum().item()
            train_acc = correct / len(Batch_data.y)
            avg_acc.append(train_acc)
            print(
                "Iter {:03d} | Epoch {:05d} | Batch{:02d} | Train_Loss {:.4f}| Train_Accuracy {:.4f}"
                .format(iter, epoch, batch_idx, loss.item(), train_acc))
            batch_idx = batch_idx + 1

        train_losses.append(np.mean(avg_loss))
        train_accs.append(np.mean(avg_acc))

        temp_val_losses = []
        temp_val_accs = []
        temp_val_Acc_all, temp_val_Acc1, temp_val_Prec1, temp_val_Recll1, temp_val_F1, \
        temp_val_Acc2, temp_val_Prec2, temp_val_Recll2, temp_val_F2, \
        temp_val_Acc3, temp_val_Prec3, temp_val_Recll3, temp_val_F3, \
        temp_val_Acc4, temp_val_Prec4, temp_val_Recll4, temp_val_F4 = [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []
        model.eval()
        tqdm_test_loader = tqdm(test_loader)
        for Batch_data in tqdm_test_loader:
            Batch_data.to(device)
            val_out = model(Batch_data)
            val_loss = F.nll_loss(val_out, Batch_data.y)
            temp_val_losses.append(val_loss.item())
            _, val_pred = val_out.max(dim=1)
            correct = val_pred.eq(Batch_data.y).sum().item()
            val_acc = correct / len(Batch_data.y)
            Acc_all, Acc1, Prec1, Recll1, F1, Acc2, Prec2, Recll2, F2, Acc3, Prec3, Recll3, F3, Acc4, Prec4, Recll4, F4 = evaluation4class(
                val_pred, Batch_data.y)
            temp_val_Acc_all.append(Acc_all), temp_val_Acc1.append(Acc1), temp_val_Prec1.append(
                Prec1), temp_val_Recll1.append(Recll1), temp_val_F1.append(F1), \
            temp_val_Acc2.append(Acc2), temp_val_Prec2.append(Prec2), temp_val_Recll2.append(
                Recll2), temp_val_F2.append(F2), \
            temp_val_Acc3.append(Acc3), temp_val_Prec3.append(Prec3), temp_val_Recll3.append(
                Recll3), temp_val_F3.append(F3), \
            temp_val_Acc4.append(Acc4), temp_val_Prec4.append(Prec4), temp_val_Recll4.append(
                Recll4), temp_val_F4.append(F4)
            temp_val_accs.append(val_acc)
        val_losses.append(np.mean(temp_val_losses))
        val_accs.append(np.mean(temp_val_accs))
        print("Epoch {:05d} | Val_Loss {:.4f}| Val_Accuracy {:.4f}".format(
            epoch, np.mean(temp_val_losses), np.mean(temp_val_accs)))

        res = [
            'acc:{:.4f}'.format(np.mean(temp_val_Acc_all)),
            'C1:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc1),
                                                    np.mean(temp_val_Prec1),
                                                    np.mean(temp_val_Recll1),
                                                    np.mean(temp_val_F1)),
            'C2:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc2),
                                                    np.mean(temp_val_Prec2),
                                                    np.mean(temp_val_Recll2),
                                                    np.mean(temp_val_F2)),
            'C3:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc3),
                                                    np.mean(temp_val_Prec3),
                                                    np.mean(temp_val_Recll3),
                                                    np.mean(temp_val_F3)),
            'C4:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc4),
                                                    np.mean(temp_val_Prec4),
                                                    np.mean(temp_val_Recll4),
                                                    np.mean(temp_val_F4))
        ]
        print('results:', res)
        early_stopping(np.mean(temp_val_losses), np.mean(temp_val_accs),
                       np.mean(temp_val_F1), np.mean(temp_val_F2),
                       np.mean(temp_val_F3), np.mean(temp_val_F4), model,
                       'BiSGCN', dataname)
        accs = np.mean(temp_val_accs)
        F1 = np.mean(temp_val_F1)
        F2 = np.mean(temp_val_F2)
        F3 = np.mean(temp_val_F3)
        F4 = np.mean(temp_val_F4)
        if early_stopping.early_stop:
            print("Early stopping")
            accs = early_stopping.accs
            F1 = early_stopping.F1
            F2 = early_stopping.F2
            F3 = early_stopping.F3
            F4 = early_stopping.F4
            break
    return train_losses, val_losses, train_accs, val_accs, accs, F1, F2, F3, F4
def train(X_train, y_train, X_dev, y_dev, X_test, y_test, model, weights,
          args):
    if args.cuda:
        model.cuda()

    if args.Adam is True:
        print("Adam Training......")
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.init_weight_decay)
    elif args.SGD is True:
        print("SGD Training.......")
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.init_weight_decay,
                                    momentum=args.momentum_value)
    elif args.Adadelta is True:
        print("Adadelta Training.......")
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         lr=args.lr,
                                         weight_decay=args.init_weight_decay)

    criterion = nn.NLLLoss()

    steps = 0
    epoch_step = 0
    model_count = 0
    loss_full = []
    best_accuracy = Best_Result()
    model.train()
    early_stopping = EarlyStopping(patience=args.patience, verbose=True)
    for epoch in range(1, args.epochs + 1):
        steps = 0
        print("\n## The {} Epoch, All {} Epochs ! ##".format(
            epoch, args.epochs))
        loss_epoch = []
        if (model in 'HCL'):
            g = gen_minibatch_HAN(X_train,
                                  y_train,
                                  mini_batch_size,
                                  shuffle=True)
            for tokens, labels in g:
                embedding = nn.Embedding.from_pretrained(weights)
                tokens = embedding(tokens.long())

                optimizer.zero_grad()

                logit = model(tokens)
                loss = criterion(logit, labels)
                loss.backward()
                if args.init_clip_max_norm is not None:
                    utils.clip_grad_norm_(model.parameters(),
                                          max_norm=args.init_clip_max_norm)
                optimizer.step()

                loss_full.append(loss.item())
                loss_epoch.append(loss.item())

            torch.cuda.empty_cache()
            model.eval()
            print('Average training loss at this epoch..minibatch ',
                  np.mean(loss_epoch))
            vlos = check_val_loss(X_test, y_test, mini_batch_size, sent_attn)
            print('Test Loss at ', i, ' is ', vlos)

            torch.cuda.empty_cache()
            early_stopping(vlos, sent_attn)
            if early_stopping.early_stop:
                print("Early stopping")
                model_count += 1
                break

            model.train()

    return model_count