Esempio n. 1
0
def train(batch_size=64, window_size=3, epochs=100):

    train_dataset = SupervisedDataset(mode='train',
                                      window_size=window_size,
                                      log_reg=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True)

    val_dataset = SupervisedDataset(mode='val',
                                    window_size=window_size,
                                    log_reg=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    base_lr_rate = 1e-3
    weight_decay = 1e-6  #0.000016

    #model = LSTM(input_size=78, hidden_size=128, num_classes=170, n_layers=2).to(device=torch.device('cuda:0'))
    model = LogisticRegression(
        num_keypoints=78, num_features=2,
        num_classes=170).to(device=torch.device('cuda:0'))
    '''
    for name, param in model.named_parameters():
        if 'bias' in name:
            nn.init.constant_(param, 0.0)
        elif 'weight' in name:
            nn.init.xavier_uniform_(param)
    '''

    criterion = nn.BCEWithLogitsLoss(
    )  # Use this for Logistic Regression training
    #criterion = nn.BCELoss()               # Use this for LSTM training (with Softmax)

    optimizer = optim.Adam(
        model.parameters(),
        lr=base_lr_rate)  #, weight_decay=weight_decay, amsgrad=True)

    best_epoch_train_accuracy = 0.0
    best_epoch_val_accuracy = 0.0

    for current_epoch in range(epochs):

        current_train_iter = 0
        current_val_iter = 0

        running_train_loss = 0.0
        current_average_train_loss = 0.0
        running_val_loss = 0.0
        current_average_val_loss = 0.0

        num_train_data = 0
        num_val_data = 0

        running_train_correct_preds = 0
        running_train_correct_classwise_preds = [0] * 170

        running_val_correct_preds = 0
        running_val_correct_classwise_preds = [0] * 170

        for phase in ['train', 'val']:

            # Train loop
            if phase == 'train':
                train_epoch_since = time.time()

                model.train()

                for train_batch_window, train_batch_label in train_loader:

                    current_train_iter += 1

                    outs = model(train_batch_window)

                    #scheduler = poly_lr_scheduler(optimizer = optimizer, init_lr = base_lr_rate, iter = current_iter, lr_decay_iter = 1,
                    #                          max_iter = max_iter, power = power)                                                          # max_iter = len(train_loader)

                    optimizer.zero_grad()

                    loss = criterion(outs, train_batch_label)
                    gt_confidence, gt_index = torch.max(train_batch_label,
                                                        dim=1)

                    #loss = criterion(outs, gt_index)

                    running_train_loss += loss.item()
                    current_average_train_loss = running_train_loss / current_train_iter

                    loss.backward(retain_graph=False)

                    optimizer.step()

                    pred_confidence, pred_index = torch.max(outs, dim=1)
                    #gt_confidence, gt_index = torch.max(train_batch_label, dim=1)
                    batch_correct_preds = torch.eq(
                        pred_index, gt_index).long().sum().item()
                    batch_accuracy = (batch_correct_preds /
                                      train_batch_window.shape[0]) * 100

                    num_train_data += train_batch_window.shape[0]
                    running_train_correct_preds += batch_correct_preds

                    if current_train_iter % 10 == 0:
                        #print(outs)
                        print(
                            f"\nITER#{current_train_iter} ({current_epoch+1}) BATCH TRAIN ACCURACY: {batch_accuracy:.4f}, RUNNING TRAIN LOSS: {loss.item():.8f}"
                        )
                        print(
                            f"Predicted / GT index:\n{pred_index}\n{gt_index}\n"
                        )

                last_epoch_average_train_loss = current_average_train_loss
                epoch_accuracy = (running_train_correct_preds /
                                  num_train_data) * 100

                print(
                    f"\n\nEPOCH#{current_epoch+1} EPOCH TRAIN ACCURACY (BEST): {epoch_accuracy:.4f} ({best_epoch_train_accuracy:.4f}), AVERAGE TRAIN LOSS: {last_epoch_average_train_loss:.8f}\n\n"
                )

                if epoch_accuracy > best_epoch_train_accuracy:
                    best_epoch_train_accuracy = epoch_accuracy
                    torch.save(
                        model.state_dict(),
                        f"./model_params/logistic_regression/train/train_logreg_epoch{current_epoch+1}_acc{best_epoch_train_accuracy:.4f}.pth"
                    )
                    print("\n\nSAVED BASED ON TRAIN ACCURACY!\n\n")

                train_time_elapsed = time.time() - train_epoch_since

            # Validation loop
            elif phase == 'val':
                val_epoch_since = time.time()

                model.eval()

                with torch.no_grad():
                    for val_batch_window, val_batch_label in val_loader:

                        current_val_iter += 1

                        outs = model(val_batch_window)

                        gt_confidence, gt_index = torch.max(val_batch_label,
                                                            dim=1)
                        val_loss = criterion(outs, val_batch_label)
                        #val_loss = criterion(outs, gt_index)

                        running_val_loss += val_loss.item()
                        current_average_val_loss = running_val_loss / current_val_iter

                        pred_confidence, pred_index = torch.max(outs, dim=1)
                        #gt_confidence, gt_index = torch.max(val_batch_label, dim=1)
                        batch_correct_preds = torch.eq(
                            pred_index, gt_index).long().sum().item()
                        batch_accuracy = (batch_correct_preds /
                                          val_batch_window.shape[0]) * 100

                        num_val_data += val_batch_window.shape[0]
                        running_val_correct_preds += batch_correct_preds

                        if current_val_iter % 10 == 0:
                            print(
                                f"\nITER#{current_val_iter} ({current_epoch+1}) BATCH VALIDATION ACCURACY: {batch_accuracy:.4f}, RUNNING VALIDATION LOSS: {val_loss.item():.8f}"
                            )
                            print(
                                f"Predicted / GT index:\n{pred_index}\n{gt_index}\n"
                            )

                    last_epoch_average_val_loss = current_average_val_loss
                    epoch_accuracy = (running_val_correct_preds /
                                      num_val_data) * 100
                    print(
                        f"\n\nEPOCH#{current_epoch+1} EPOCH VALIDATION ACCURACY (BEST): {epoch_accuracy:.4f} ({best_epoch_val_accuracy:.4f}), AVERAGE VALIDATION LOSS: {last_epoch_average_val_loss:.8f}\n\n"
                    )

                    if epoch_accuracy > best_epoch_val_accuracy:
                        best_epoch_val_accuracy = epoch_accuracy
                        torch.save(
                            model.state_dict(),
                            f"./model_params/logistic_regression/val/val_logreg_epoch{current_epoch+1}_acc{best_epoch_val_accuracy:.4f}.pth"
                        )
                        print("\n\nSAVED BASED ON VAL ACCURACY!\n\n")

                    val_time_elapsed = time.time() - val_epoch_since
Esempio n. 2
0
    writer.add_scalar('Stat. Parity/%s' % split, tot_stat_par.mean(), epoch)
    writer.add_scalar('Equalized Odds/%s' % split, tot_eq_odds.mean(), epoch)
    writer.add_histogram('L-inf Differences/%s' % split, l_inf_diffs, epoch)

    return tot_mix_loss


print('saving model to', models_dir)
writer = SummaryWriter(log_dir)

for epoch in range(args.num_epochs):
    run(autoencoder, classifier, optimizer, train_loader, 'train', epoch)

    autoencoder.eval()
    classifier.eval()
    valid_mix_loss = run(
        autoencoder, classifier, optimizer, val_loader, 'valid', epoch
    )
    scheduler.step(valid_mix_loss.mean())

    torch.save(
        autoencoder.state_dict(),
        path.join(models_dir, f'autoencoder_{epoch}.pt')
    )
    torch.save(
        classifier.state_dict(),
        path.join(models_dir, f'classifier_{epoch}.pt')
    )

writer.close()
Esempio n. 3
0
def main():
    args = parser.parse_args()

    # seed everything to ensure reproducible results from different runs
    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True

    ###########################################################################
    # Model
    ###########################################################################
    global best_acc1
    # create model
    if args.arch == 'LogisticRegression':
        model = LogisticRegression(input_size=13, n_classes=args.classes)
    elif args.arch == 'NeuralNet':
        model = NeuralNet(input_size=13, hidden_size=[32, 16], n_classes=args.classes) #hidden_size=[64, 32]

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
        torch.cuda.set_device(args.gpu)
        torch.backends.cudnn.benchmark = True
        model = model.cuda(args.gpu)

    # print(model)
    if args.train_file:
        print(30 * '=')
        print(summary(model, input_size=(1, 13),
                      batch_size=args.batch_size, device='cpu'))
        print(30 * '=')

    ###########################################################################
    # save directory
    ###########################################################################
    save_dir = os.path.join(os.getcwd(), args.save_dir)
    save_dir += ('/arch[{}]_optim[{}]_lr[{}]_lrsch[{}]_batch[{}]_'
                 'WeightedSampling[{}]').format(args.arch,
                                                args.optim,
                                                args.lr,
                                                args.lr_scheduler,
                                                args.batch_size,
                                                args.weighted_sampling)
    if args.suffix:
        save_dir += '_{}'.format(args.suffix)
    save_dir = save_dir[:]

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    ###########################################################################
    # Criterion and optimizer
    ###########################################################################
    # Initialise criterion and optimizer
    if args.gpu is not None:
        criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    else:
        criterion = nn.CrossEntropyLoss()

    # define optimizer
    print("=> using '{}' optimizer".format(args.optim))
    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)
    else:  # default is adam
        optimizer = torch.optim.Adam(model.parameters(), args.lr,
                                     betas=(0.9, 0.999), eps=1e-08,
                                     weight_decay=args.weight_decay,
                                     amsgrad=False)

    ###########################################################################
    # Resume training and load a checkpoint
    ###########################################################################
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    ###########################################################################
    # Data Augmentation
    ###########################################################################
    # TODO

    ###########################################################################
    # Learning rate scheduler
    ###########################################################################
    print("=> using '{}' initial learning rate (lr)".format(args.lr))
    # define learning rate scheduler
    scheduler = args.lr_scheduler
    if args.lr_scheduler == 'reduce':
        print("=> using '{}' lr_scheduler".format(args.lr_scheduler))
        # Reduce learning rate when a metric has stopped improving.
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               mode='min',
                                                               factor=0.5,
                                                               patience=10)
    elif args.lr_scheduler == 'cyclic':
        print("=> using '{}' lr_scheduler".format(args.lr_scheduler))
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,
                                                      base_lr=0.00005,
                                                      max_lr=0.005)
    elif args.lr_scheduler == 'cosine':
        print("=> using '{}' lr_scheduler".format(args.lr_scheduler))
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=100,
                                                               eta_min=0,
                                                               last_epoch=-1)

    ###########################################################################
    # load train data
    ###########################################################################
    if args.train_file:
        train_dataset = HeartDiseaseDataset(csv=args.train_file, label_names=LABELS)
        if args.weighted_sampling:
            train_sampler = torch.utils.data.WeightedRandomSampler(train_dataset.sample_weights,
                                                                   len(train_dataset),
                                                                   replacement=True)
        else:
            train_sampler = None

        ###########################################################################
        # update criterion
        print('class_sample_count ', train_dataset.class_sample_count)
        print('class_probability ', train_dataset.class_probability)
        print('class_weights ', train_dataset.class_weights)
        print('sample_weights ', train_dataset.sample_weights)

        if args.weighted_loss:
            if args.gpu is not None:
                criterion = nn.CrossEntropyLoss(weight=train_dataset.class_weights).cuda(args.gpu)
            else:
                criterion = nn.CrossEntropyLoss(weight=train_dataset.class_weights)

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size, shuffle=(train_sampler is None),
                                                   num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    ###########################################################################
    # load validation data
    ###########################################################################
    if args.valid_file:
        valid_dataset = HeartDiseaseDataset(csv=args.valid_file, label_names=LABELS)
        val_loader = torch.utils.data.DataLoader(valid_dataset,
                                                 batch_size=args.batch_size, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)

        if args.evaluate:
            # retrieve correct save path from saved model
            save_dir = os.path.split(args.resume)[0]
            validate(val_loader, model, criterion, save_dir, args)
            return

    ###########################################################################
    # Train the model
    ###########################################################################
    for epoch in range(args.start_epoch, args.epochs):
        # adjust_learning_rate(optimizer, epoch, args)
        print_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer,
              scheduler, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, save_dir, args)

        # update learning rate based on lr_scheduler
        if args.lr_scheduler == 'reduce':
            scheduler.step(acc1)
        elif args.lr_scheduler == 'cosine':
            scheduler.step()

        # remember best acc@1 and save checkpoint
        is_best = acc1 >= best_acc1
        best_acc1 = max(acc1, best_acc1)

        print("Saving model [{}]...".format(save_dir))
        save_checkpoint({'epoch': epoch + 1,
                         'arch': args.arch,
                         'state_dict': model.state_dict(),
                         'best_acc1': best_acc1,
                         'optimizer': optimizer.state_dict(),
                         'criterion': criterion, },
                        is_best,
                        save_dir=save_dir)
        print(30 * '=')