Ejemplo n.º 1
0
def main(args):
    # hard coded values
    in_channels = 3  # rgb channels of orignal image fed to rotnet
    if args.layer == 1:
        in_features = int(96 * 16 * 16)
    else:
        in_features = int(192 * 8 * 8)
    rot_classes = 4
    out_classes = 10
    lr_decay_rate = 0.2  # lr is multiplied by decay rate after a milestone epoch is reached
    mult = 1  # data become mult times
    ####################

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255),
                             (63.0 / 255, 62.1 / 255, 66.7 / 255))
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255),
                             (63.0 / 255, 62.1 / 255, 66.7 / 255))
    ])

    trainset = datasets.CIFAR10(root='results/',
                                train=True,
                                download=True,
                                transform=train_transform)
    testset = datasets.CIFAR10(root='results/',
                               train=False,
                               download=True,
                               transform=test_transform)

    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=0)
    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=0)

    rot_network = mdl.RotNet(in_channels=in_channels,
                             num_nin_blocks=args.nins,
                             out_classes=rot_classes).to(args.device)
    class_network = mdl.NonLinearClassifier(in_channels=in_features,
                                            out_classes=out_classes).to(
                                                args.device)

    if args.opt == 'adam':
        rot_optimizer = optim.Adam(rot_network.parameters(),
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
        class_optimizer = optim.Adam(class_network.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    else:
        rot_optimizer = optim.SGD(rot_network.parameters(),
                                  lr=args.lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
        class_optimizer = optim.SGD(class_network.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    rot_scheduler = optim.lr_scheduler.MultiStepLR(rot_optimizer,
                                                   milestones=args.milestones,
                                                   gamma=lr_decay_rate)
    class_scheduler = optim.lr_scheduler.MultiStepLR(
        class_optimizer, milestones=args.milestones, gamma=lr_decay_rate)

    ####################################### Saving information
    results_dict = {}
    # These will store the values for best test accuracy model
    results_dict['train_loss'] = -1
    results_dict['train_acc'] = -1
    results_dict['test_loss'] = -1
    results_dict['test_acc'] = -1
    results_dict['best_acc_epoch'] = -1
    # For storing training history
    results_dict['train_loss_hist'] = []
    results_dict['train_acc_hist'] = []
    results_dict['test_loss_hist'] = []
    results_dict['test_acc_hist'] = []

    # directories to save models
    checkpoint_path = os.path.join(args.results_dir, 'model.pth')
    checkpoint_path_best_acc = os.path.join(args.results_dir,
                                            'model_best_acc.pth')

    #########
    test_acc_max = -math.inf
    loop_start_time = time.time()
    checkpoint = {}
    for epoch in range(args.epochs):
        train(args, rot_network, class_network, train_loader, rot_optimizer,
              class_optimizer, mult, rot_scheduler, class_scheduler, epoch,
              in_features)

        train_loss, train_acc = test(args, rot_network, class_network,
                                     train_loader, mult, 'Train', in_features)
        results_dict['train_loss_hist'].append(train_loss)
        results_dict['train_acc_hist'].append(train_acc)

        test_loss, test_acc = test(args, rot_network, class_network,
                                   test_loader, mult, 'Test', in_features)
        results_dict['test_loss_hist'].append(test_loss)
        results_dict['test_acc_hist'].append(test_acc)
        print(
            'Epoch {} finished --------------------------------------------------------------------------'
            .format(epoch + 1))

        checkpoint = {
            'class_model_state_dict': class_network.state_dict(),
            'class_optimizer_state_dict': class_optimizer.state_dict(),
            'rot_model_state_dict': rot_network.state_dict(),
            'rot_optimizer_state_dict': rot_optimizer.state_dict(),
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'test_loss': test_loss,
            'test_acc': test_acc
        }

        if test_acc > test_acc_max:
            test_acc_max = test_acc
            if os.path.isfile(checkpoint_path_best_acc):
                os.remove(checkpoint_path_best_acc)

            torch.save(checkpoint, checkpoint_path_best_acc)

            results_dict['best_acc_epoch'] = epoch + 1
            results_dict['train_loss'] = train_loss
            results_dict['train_acc'] = train_acc
            results_dict['test_loss'] = test_loss
            results_dict['test_acc'] = test_acc

    torch.save(checkpoint, checkpoint_path)

    print('Total time for training loop = ', time.time() - loop_start_time)

    return results_dict
Ejemplo n.º 2
0
def main(args):
    # hard coded values
    in_channels = 3  # rgb channels of input image
    out_classes = out_size  # d length
    lr_decay_rate = 0.2  # lr is multiplied by decay rate after a milestone epoch is reached
    mult = 1  # data become mult times
    ####################

    #train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor()])
    #test_transform = transforms.ToTensor()

    trainset = MyTrainset()
    testset = MyTestset()
    #testset = datasets.CIFAR10(root='.', train=False, download=True, transform=test_transform)

    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=0)
    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=0)

    network = mdl.RotNet(in_channels=in_channels,
                         num_nin_blocks=args.nins,
                         out_classes=out_classes).to(args.device)

    if args.opt == 'adam':
        optimizer = optim.Adam(network.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    else:
        optimizer = optim.SGD(network.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.milestones,
                                               gamma=lr_decay_rate)

    ####################################### Saving information
    results_dict = {}
    # These will store the values for best test accuracy model
    results_dict['train_loss'] = -1
    results_dict['train_acc'] = -1
    results_dict['test_loss'] = -1
    results_dict['test_acc'] = -1
    results_dict['best_acc_epoch'] = -1
    # For storing training history
    results_dict['train_loss_hist'] = []
    results_dict['train_acc_hist'] = []
    results_dict['test_loss_hist'] = []
    results_dict['test_acc_hist'] = []

    # directories to save models
    checkpoint_path = os.path.join(args.results_dir, 'model.pth')
    checkpoint_path_best_acc = os.path.join(args.results_dir,
                                            'model_best_acc.pth')

    test_acc_max = -math.inf
    loop_start_time = time.time()
    checkpoint = {}
    for epoch in range(args.epochs):
        train(args, network, train_loader, optimizer, mult, scheduler, epoch)

        train_loss, train_acc = test(args, network, train_loader, mult,
                                     'Train')
        results_dict['train_loss_hist'].append(train_loss)
        results_dict['train_acc_hist'].append(train_acc)

        test_loss, test_acc = test(args, network, test_loader, mult, 'Test')
        results_dict['test_loss_hist'].append(test_loss)
        results_dict['test_acc_hist'].append(test_acc)
        print(
            'Epoch {} finished --------------------------------------------------------------------------',
            epoch + 1)

        checkpoint = {
            'model_state_dict': network.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'test_loss': test_loss,
            'test_acc': test_acc
        }

        if test_acc > test_acc_max:
            test_acc_max = test_acc
            if os.path.isfile(checkpoint_path_best_acc):
                os.remove(checkpoint_path_best_acc)

            torch.save(checkpoint, checkpoint_path_best_acc)

            results_dict['best_acc_epoch'] = epoch + 1
            results_dict['train_loss'] = train_loss
            results_dict['train_acc'] = train_acc
            results_dict['test_loss'] = test_loss
            results_dict['test_acc'] = test_acc

        if epoch + 1 in args.epochs_to_save:
            torch.save(
                checkpoint,
                os.path.join(args.results_dir,
                             'model_epoch_' + str(epoch + 1) + '.pth'))

    torch.save(checkpoint, checkpoint_path)

    print('Total time for training loop = ', time.time() - loop_start_time)

    return results_dict