Exemplo n.º 1
0
                     args.num_layers,
                     dropout=args.dropout,
                     alpha=args.alpha,
                     lamda=args.lamda)
    labels = labels.squeeze()
    # set_seed(args.seed)

    optimizer = th.optim.Adam([{
        'params': model.params1,
        'weight_decay': args.weight_decay1
    }, {
        'params': model.params2,
        'weight_decay': args.weight_decay2
    }],
                              lr=args.learn_rate)
    early_stopping = EarlyStopping(args.patience, file_name='tmp')

    device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
    graph = graph.to(device)
    features = features.to(device)
    labels = labels.to(device)
    train_mask = train_mask.to(device)
    val_mask = val_mask.to(device)
    test_mask = test_mask.to(device)
    print(model)
    model = model.to(device)

    dur = []
    for epoch in range(args.num_epochs):
        if epoch >= 3:
            t0 = time.time()
Exemplo n.º 2
0
    model = VSGCNetMulti(num_feats,
                         num_classes,
                         k=args.k,
                         dropout=args.dropout,
                         propagation=args.propagation)
    labels = labels.squeeze()

    # set_seed(args.seed)

    optimizer = th.optim.Adam(model.parameters(),
                              lr=args.learn_rate,
                              weight_decay=args.weight_decay)

    early_stopping = EarlyStopping(args.patience,
                                   file_name='{}_{}'.format(
                                       args.filename, args.dataset))

    device = th.device(
        "cuda:{}".format(args.cuda) if th.cuda.is_available() else "cpu")

    graph = graph.to(device)
    features = features.to(device)
    labels = labels.to(device)
    train_mask = train_mask.to(device)
    val_mask = val_mask.to(device)
    test_mask = test_mask.to(device)
    model = model.to(device)

    dur = []
    for epoch in range(args.num_epochs):
Exemplo n.º 3
0
def train(model,
          train_data,
          valid_data,
          device="cpu",
          save_path=None,
          early_stopping=False,
          opti=torch.optim.Adam,
          loss=torch.nn.CrossEntropyLoss(),
          max_epoch=20,
          static_opti_conf=None,
          scheduled_opti_conf=None,
          after_update=None,
          after_epoch=None,
          before_update=None,
          accuracy_method=None,
          verbose=False):

    if static_opti_conf is None and scheduled_opti_conf is None:
        raise RuntimeError("No opti conf given !")

    history_acc_train = list()
    history_acc_valid = list()
    # Move model to good device
    model = model.to(device)

    best_acc = 0.0

    # Instance optimizer and loss
    init_conf = static_opti_conf if not static_opti_conf is None else scheduled_opti_conf.get(
        0, dict())
    init_conf = deepcopy(init_conf)
    for element in _EXTERN_PARAM:
        try:
            init_conf.pop(element)  # remove patience term
        except KeyError:
            pass
    optimizer = opti(model.parameters(), **init_conf)
    loss = loss.to(device)

    # Init early stopping if needed
    if early_stopping:
        if not static_opti_conf is None:
            early_control = EarlyStopping(static_opti_conf.get("patience", 7),
                                          verbose=True)
        else:
            early_control = EarlyStopping(static_opti_conf.get(
                0, {
                    "patience": 7
                }).get("patience", 7),
                                          verbose=True)

    mean_train_loss = 0
    mean_train_accuracy = 0

    mean_valid_loss = 0
    mean_valid_accuracy = 0

    for curr_epoch in range(max_epoch):
        print("epoch : ", curr_epoch)
        if not scheduled_opti_conf is None:
            _update_optimizer(optimizer,
                              scheduled_opti_conf.get(curr_epoch, None))
        #Train part
        optimizer.zero_grad()
        model.train()
        curr_batch = 0
        for x, y in train_data:
            curr_batch += 1
            # Switch params
            x = x.to(device)
            y = y.to(device)

            # Model forward
            outputs = model(x)

            optimizer.zero_grad()

            # Loss backward
            loss_val = loss(outputs, y)

            mean_train_loss = (mean_train_loss *
                               ((curr_batch - 1) * train_data.batch_size) +
                               train_data.batch_size * loss_val.item()) / (
                                   curr_batch * train_data.batch_size)
            # accuracy if needed
            if not accuracy_method is None:
                mean_train_accuracy = (
                    (curr_batch - 1) * train_data.batch_size *
                    mean_train_accuracy +
                    train_data.batch_size * accuracy_method(outputs, y)) / (
                        curr_batch * train_data.batch_size)

            loss_val.backward()

            # Do before methods
            if "before_update" in list(model.__class__.__dict__.keys()):
                model.before_update()
            if not before_update is None:
                before_update()

            optimizer.step()

            # Do after methods
            if "after_update" in list(model.__class__.__dict__.keys()):
                model.after_update()
            if not after_update is None:
                after_update()

        # Eval part
        if "after_epoch" in list(model.__class__.__dict__.keys()):
            model.after_epoch(epoch=curr_epoch)

        if not after_epoch is None:
            after_epoch(epoch=curr_epoch, model=model)
        #model.eval()
        with torch.no_grad():
            curr_batch = 0
            for x, y in valid_data:
                curr_batch += 1
                # Switch params
                x = x.to(device)
                y = y.to(device)

                # Model forward
                outputs = model(x)

                # Loss backward
                loss_val = loss(outputs, y)
                mean_valid_loss = (mean_valid_loss *
                                   ((curr_batch - 1) * valid_data.batch_size) +
                                   valid_data.batch_size * loss_val.item()) / (
                                       curr_batch * valid_data.batch_size)

                # accuracy if needed
                if not accuracy_method is None:
                    batch_accuracy = accuracy_method(outputs, y)
                    mean_valid_accuracy = (
                        (curr_batch - 1) * mean_valid_accuracy *
                        valid_data.batch_size + valid_data.batch_size *
                        batch_accuracy) / (curr_batch * valid_data.batch_size)

        best_acc = max(best_acc, mean_valid_accuracy)
        print("train acc : ", mean_train_accuracy)
        print("train loss: ", mean_train_loss)
        print("valid acc : ", mean_valid_accuracy)
        history_acc_train.append(mean_train_accuracy)
        history_acc_valid.append(mean_valid_accuracy)
        #print("train loss: ", mean_valid_loss)
        if early_stopping:
            if not accuracy_method is None:
                early_control(model, accuracy=mean_valid_accuracy)
            else:
                early_control(model, loss=mean_valid_loss)

            if early_control.early_stop:
                break

    if early_stopping:
        if early_control.early_stop:
            mean_valid_accuracy = early_control.best_score  #TODO Reload best model
    print("final acc : ", best_acc)
    return history_acc_train, history_acc_valid