Exemplo n.º 1
0
def train_network():
    """
    Main training loop for BSDR
    """
    network = BSDR_Net()

    model_save_path = os.path.join(model_save_dir, 'train2')
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)
        os.makedirs(os.path.join(model_save_path, 'snapshots'))
        os.makedirs(os.path.join(model_save_dir, 'dump'))
        os.makedirs(os.path.join(model_save_dir, 'dump_test'))
    global f
    snapshot_path = os.path.join(model_save_path, 'snapshots')
    f = open(os.path.join(model_save_path, 'train0.log'), 'w')

    # -- Logging Parameters
    log(f, 'args: ' + str(args))
    log(f, 'model: ' + str(network), False)

    network = load_net(network, 'models_BSDR/train2/snapshots',
                       str(args.model_name))

    log(f, 'Testing...')
    epoch_test_losses, mae = test_network(dataset, 'test', network, False)
    log(
        f, 'TEST epoch: ' + str(-1) + ' test loss1, mae:' +
        str(epoch_test_losses))

    return
Exemplo n.º 2
0
def train_network():
    network = Stage2CountingNet()
    model_save_dir = './models_stage_2'
    model_save_path = os.path.join(model_save_dir, 'train2')
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)
        os.makedirs(os.path.join(model_save_path, 'snapshots'))
        os.makedirs(os.path.join(model_save_dir,'dump'))
        os.makedirs(os.path.join(model_save_dir,'dump_test'))
    global f
    snapshot_path = os.path.join(model_save_path, 'snapshots')

    network = load_net(network, snapshot_path, get_filename(network.name, args.best_model_name))
    print(network)
    epoch_test_losses, mae = test_network(dataset, 'test', network, print_output=os.path.join(model_save_dir,'dump_test'))
    print('TEST mae, mse:' + str(epoch_test_losses))
    return
Exemplo n.º 3
0
def train_network():
    """
    Main training loop for BSDR
    """
    network = BSDR_Net()
    nrn_networks = []
    nrn_snapshot_path = 'models_NRN/train2/snapshots'
    nrn_models_list = pickle.load(
        open(os.path.join(nrn_snapshot_path, 'best_model.pkl'), 'rb'))

    before_nrn_sum = []
    # bp()
    for i in range(num_density_categories):
        nrn_net = NRN()
        nrn_model_file = nrn_models_list[i]
        nrn_net = load_net(nrn_net, nrn_snapshot_path, nrn_model_file)
        nrn_net = nrn_net.cuda()
        nrn_net.eval()
        nrn_networks.append(nrn_net)
        before_nrn_sum.append(check_conv_gradient_change(nrn_net))

    # load_model_VGG16(network)
    model_save_path = os.path.join(model_save_dir, 'train2')
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)
        os.makedirs(os.path.join(model_save_path, 'snapshots'))
        os.makedirs(os.path.join(model_save_dir, 'dump'))
        os.makedirs(os.path.join(model_save_dir, 'dump_test'))
    global f
    snapshot_path = os.path.join(model_save_path, 'snapshots')
    f = open(os.path.join(model_save_path, 'train0.log'), 'w')

    # -- Logging Parameters
    log(f, 'args: ' + str(args))
    log(f, 'model: ' + str(network), False)
    log(f, 'Training0...')
    log(f, 'LR: %.12f.' % (args.lr))
    log(f, 'NRN folder {}'.format(nrn_snapshot_path))
    log(f, 'NRN models : {}'.format(nrn_models_list))

    start_epoch = 0
    num_epochs = args.epochs
    valid_losses = {}
    test_losses = {}
    train_losses = {}
    for metric in ['loss1', 'new_mae', 'mse']:
        valid_losses[metric] = []
        test_losses[metric] = []

    for metric in ['loss1']:
        train_losses[metric] = []

    batch_size = args.batch_size
    num_train_images = len(dataset.data_files['train'])
    num_patches_per_image = args.patches
    assert (batch_size < (num_patches_per_image * num_train_images))
    num_batches_per_epoch = num_patches_per_image * num_train_images // batch_size
    assert (num_batches_per_epoch >= 1)

    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 network.parameters()),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    network = load_rot_model_blocks(
        network,
        snapshot_path='models_rot_net/train2/snapshots',
        excluded_layers=excluded_layers)

    # -- Main Training Loop
    all_epoch_test_accs = []

    global sampled_GT

    log(f, 'Testing...')
    epoch_test_losses, mae = test_network(dataset, 'test', network, False)
    log(
        f, 'TEST epoch: ' + str(-1) + ' test loss1, mae:' +
        str(epoch_test_losses))

    for e_i, epoch in enumerate(range(start_epoch, num_epochs)):
        avg_loss = []

        global blur_sigma
        for b_i in range(num_batches_per_epoch):
            # Generate next training sample
            Xs, Ys, Ys_full_counts = dataset.train_get_data(
                batch_size=args.batch_size)
            category_labels = np.digitize(Ys_full_counts,
                                          count_density_threshold,
                                          right=True)
            if args.use_noisygt:

                before_noisy_gt_maps = []

                for i in range(batch_size):
                    image = Xs[i].transpose(
                        (1, 2, 0)).astype('uint8')  #(224,224,3)
                    noisy_gt_map = create_noisy_gt(image, output_downscale,
                                                   blur_sigma)
                    before_noisy_gt_maps.append(noisy_gt_map[None, ...])
                before_noisy_gt_maps = np.array(before_noisy_gt_maps)
                # bp()
                assert (Ys.shape == before_noisy_gt_maps.shape)
                before_noisy_X = torch.autograd.Variable(
                    torch.from_numpy(before_noisy_gt_maps)).cuda().float()
                before_noisy_X.requires_grad = False
                Ys_counts = Ys.reshape(
                    (Ys.shape[0], -1)).sum(axis=1).astype('int')
                assert (Ys_counts.shape == Ys_full_counts.shape)

                factor_arr = []
                upsample = nn.Upsample(scale_factor=network_output_downscale,
                                       mode='nearest')
                for i in range(batch_size):
                    factor = nrn_networks[category_labels[i]](
                        before_noisy_X[i][None, ...])[0]
                    factor_arr.append(factor.detach().cpu().numpy())

                factor_arr = np.array(factor_arr)
                factor_arr = torch.autograd.Variable(
                    torch.from_numpy(factor_arr)).cuda().float()
                # bp()
                after_noisy_X = upsample(factor_arr) * before_noisy_X
                after_noisy_gt_maps = after_noisy_X.detach().cpu().numpy()
                assert (Ys.shape == after_noisy_gt_maps.shape ==
                        before_noisy_gt_maps.shape)
                Ys = after_noisy_gt_maps

            train_loss = train_function(Xs, Ys, network,
                                        optimizer)  #sampled_GT
            avg_loss.append(train_loss)
            for i in range(num_density_categories):
                after_sum = check_conv_gradient_change(nrn_networks[i])
                assert (np.all(before_nrn_sum[i] == after_sum))
            # Logging losses after 1k iterations.
            if b_i % 10 == 0:
                log(
                    f, 'Epoch %d [%d]: %s loss: %s.' %
                    (epoch, b_i, [network.name], train_loss))

        avg_loss = np.mean(np.array(avg_loss))
        train_losses['loss1'].append(avg_loss)
        log(
            f, 'TRAIN epoch: ' + str(epoch) + ' train mean loss1:' +
            str(avg_loss))

        torch.cuda.empty_cache()
        log(f, 'Testing...')
        epoch_test_losses, mae = test_network(dataset, 'test', network, False)
        log(
            f, 'TEST epoch: ' + str(epoch) + ' test loss1, mae:' +
            str(epoch_test_losses))

        epoch_val_losses, valid_mae = test_network(dataset, 'test_valid',
                                                   network, False)
        log(
            f, 'TEST valid epoch: ' + str(epoch) + ' test valid loss1, mae' +
            str(epoch_val_losses))
        # exit(0)

        for metric in ['loss1', 'new_mae', 'mse']:
            valid_losses[metric].append(epoch_val_losses[metric])
            test_losses[metric].append(epoch_test_losses[metric])

        min_valid_epoch = np.argmin(valid_losses['new_mae'])
        min_test_epoch = np.argmin(test_losses['new_mae'])

        log(
            f,
            'Best valid so far epoch: {}, valid mae: {},test (mae:{},mse:{}), min_test epoch {}, (mae:{},mse:{})'
            .format(min_valid_epoch, valid_losses['new_mae'][min_valid_epoch],
                    test_losses['new_mae'][min_valid_epoch],
                    test_losses['mse'][min_valid_epoch], min_test_epoch,
                    test_losses['new_mae'][min_test_epoch],
                    test_losses['mse'][min_test_epoch]))
        # Save networks
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': network.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, snapshot_path, get_filename(network.name, epoch + 1))

        print('saving graphs...')
        with open(os.path.join(snapshot_path, 'losses.pkl'), 'wb') as lossfile:
            pickle.dump((train_losses, valid_losses, test_losses),
                        lossfile,
                        protocol=2)

        for metric in train_losses.keys():
            if "maxima_split" not in metric:
                if isinstance(train_losses[metric][0], list):
                    for i in range(len(train_losses[metric][0])):
                        plt.plot([a[i] for a in train_losses[metric]])
                        plt.savefig(
                            os.path.join(snapshot_path,
                                         'train_%s_%d.png' % (metric, i)))
                        plt.clf()
                        plt.close()
                # print(metric, "METRIC", train_losses[metric])
                plt.plot(train_losses[metric])
                plt.savefig(
                    os.path.join(snapshot_path, 'train_%s.png' % metric))
                plt.clf()
                plt.close()

        for metric in valid_losses.keys():
            if isinstance(valid_losses[metric][0], list):
                for i in range(len(valid_losses[metric][0])):
                    plt.plot([a[i] for a in valid_losses[metric]])
                    plt.savefig(
                        os.path.join(snapshot_path,
                                     'valid_%s_%d.png' % (metric, i)))
                    plt.clf()
                    plt.close()
            plt.plot(valid_losses[metric])
            plt.savefig(os.path.join(snapshot_path, 'valid_%s.png' % metric))
            plt.clf()
            plt.close()

        for metric in test_losses.keys():
            if isinstance(test_losses[metric][0], list):
                for i in range(len(test_losses[metric][0])):
                    plt.plot([a[i] for a in test_losses[metric]])
                    plt.savefig(
                        os.path.join(snapshot_path,
                                     'test_%s_%d.png' % (metric, i)))
                    plt.clf()
                    plt.close()
            plt.plot(test_losses[metric])
            plt.savefig(os.path.join(snapshot_path, 'test_%s.png' % metric))
            plt.clf()
            plt.close()

    min_valid_epoch = np.argmin(valid_losses['new_mae'])
    network = load_net(network, snapshot_path,
                       get_filename(network.name, min_valid_epoch + 1))
    log(f, 'Testing on best model {}'.format(min_valid_epoch))
    epoch_test_losses, mae = test_network(dataset,
                                          'test',
                                          network,
                                          print_output=os.path.join(
                                              model_save_dir, 'dump_test'))
    log(
        f, 'TEST epoch: ' + str(epoch) + ' test loss1, mae:' +
        str(epoch_test_losses))
    log(f, 'Exiting train...')
    f.close()
    return
Exemplo n.º 4
0
def train_network():
    network = Stage2CountingNet()
    model_save_dir = './models_stage_2'
    model_save_path = os.path.join(model_save_dir, 'train2')
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)
        os.makedirs(os.path.join(model_save_path, 'snapshots'))
        os.makedirs(os.path.join(model_save_dir, 'dump'))
        os.makedirs(os.path.join(model_save_dir, 'dump_test'))
    global f
    snapshot_path = os.path.join(model_save_path, 'snapshots')
    f = open(os.path.join(model_save_path, 'train0.log'), 'w')

    # -- Logging Parameters
    log(f, 'args: ' + str(args))
    log(f, 'model: ' + str(network), False)
    log(f, 'Stage2...')
    log(f, 'LR: %.12f.' % (args.lr))

    start_epoch = 0
    num_epochs = args.epochs
    valid_losses = {}
    train_losses = {}
    for metric in ['loss1', 'new_mae']:
        valid_losses[metric] = []

    for metric in ['loss1']:
        train_losses[metric] = []

    batch_size = args.batch_size
    num_train_images = len(dataset.data_files['train'])
    num_patches_per_image = args.patches
    assert (batch_size < (num_patches_per_image * num_train_images))
    num_batches_per_epoch = num_patches_per_image * num_train_images // batch_size
    assert (num_batches_per_epoch >= 1)

    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 network.parameters()),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    network = load_rot_model_blocks(
        network,
        snapshot_path='models_stage_1/train2/snapshots/',
        excluded_layers=excluded_layers)

    shift_thresh = get_shift_thresh()
    Lambda = get_lambda()
    log(f, "Shift Thresh: {}, Lambda: {}".format(shift_thresh, Lambda))

    # -- Main Training Loop
    min_valid_loss = 100.
    min_valid_epoch = -1

    before_BN_weights_sum = check_BN_no_gradient_change(
        network, exclude_list=excluded_layers)
    before_conv_weights_sum = check_conv_no_gradient_change(
        network, exclude_list=excluded_layers)

    stop_training = False

    global sampled_GT

    for e_i, epoch in enumerate(range(start_epoch, num_epochs)):
        avg_loss = []

        # b_i - batch index
        for b_i in range(num_batches_per_epoch):
            # Generate next training sample
            Xs, _ = dataset.train_get_data(batch_size=args.batch_size)

            after_conv_weights_sum = check_conv_no_gradient_change(
                network, exclude_list=excluded_layers)
            assert (np.all(before_conv_weights_sum == after_conv_weights_sum))

            sampled_GT = None
            sampled_GT_shape = args.sbs * 7 * 7 * \
                (8 // args.kernel_size) * (8 // args.kernel_size)

            sampling_parameters = [args.alpha, Lambda]
            sampled_GT = powerlaw.Truncated_Power_Law(
                parameters=sampling_parameters).generate_random(
                    sampled_GT_shape)

            for s_i, s_val in enumerate(sampled_GT):
                if s_val < shift_thresh:
                    sampled_GT[s_i] = np.random.uniform(0, shift_thresh)
            assert (sampled_GT.shape[0] == (sampled_GT_shape)
                    and sampled_GT.ndim == 1)

            train_loss = train_function(Xs, sampled_GT, network, optimizer)
            avg_loss.append(train_loss)

            # Logging losses after each iteration.
            if b_i % 1 == 0:
                log(
                    f, 'Epoch %d [%d]: %s loss: %s.' %
                    (epoch, b_i, [network.name], train_loss))
            after_BN_weights_sum = check_BN_no_gradient_change(
                network, exclude_list=excluded_layers)
            after_conv_weights_sum = check_conv_no_gradient_change(
                network, exclude_list=excluded_layers)

            assert (np.all(before_BN_weights_sum == after_BN_weights_sum))
            assert (np.all(before_conv_weights_sum == after_conv_weights_sum))

        # -- Stats update
        avg_loss = np.mean(np.array(avg_loss))
        train_losses['loss1'].append(avg_loss)
        log(
            f, 'TRAIN epoch: ' + str(epoch) + ' train mean loss1:' +
            str(avg_loss))

        torch.cuda.empty_cache()

        log(f, 'Validating...')

        epoch_val_losses, valid_mae = test_network(dataset, 'test_valid',
                                                   network, True)
        log(
            f, 'TEST valid epoch: ' + str(epoch) + ' test valid loss1, mae' +
            str(epoch_val_losses))

        for metric in ['loss1', 'new_mae']:
            valid_losses[metric].append(epoch_val_losses[metric])

        if e_i > args.ma_window:
            valid_losses_smooth = np.mean(
                valid_losses['loss1'][-args.ma_window:])
            if valid_losses_smooth < min_valid_loss:
                min_valid_loss = valid_losses_smooth
                min_valid_epoch = e_i
                count = 0
            else:
                count = count + 1
                if count > args.patience:
                    stop_training = True

        log(
            f, 'Best valid so far epoch: {}, valid_loss: {}'.format(
                min_valid_epoch, valid_losses['loss1'][min_valid_epoch]))
        # Save networks
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': network.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, snapshot_path, get_filename(network.name, epoch + 1))

        print('saving graphs...')
        with open(os.path.join(snapshot_path, 'losses.pkl'), 'wb') as lossfile:
            pickle.dump((train_losses, valid_losses), lossfile, protocol=2)

        for metric in train_losses.keys():
            if "maxima_split" not in metric:
                if isinstance(train_losses[metric][0], list):
                    for i in range(len(train_losses[metric][0])):
                        plt.plot([a[i] for a in train_losses[metric]])
                        plt.savefig(
                            os.path.join(snapshot_path,
                                         'train_%s_%d.png' % (metric, i)))
                        plt.clf()
                        plt.close()
                plt.plot(train_losses[metric])
                plt.savefig(
                    os.path.join(snapshot_path, 'train_%s.png' % metric))
                plt.clf()
                plt.close()

        for metric in valid_losses.keys():
            if isinstance(valid_losses[metric][0], list):
                for i in range(len(valid_losses[metric][0])):
                    plt.plot([a[i] for a in valid_losses[metric]])
                    plt.savefig(
                        os.path.join(snapshot_path,
                                     'valid_%s_%d.png' % (metric, i)))
                    plt.clf()
                    plt.close()
            plt.plot(valid_losses[metric])
            plt.savefig(os.path.join(snapshot_path, 'valid_%s.png' % metric))
            plt.clf()
            plt.close()

        if stop_training:
            break

    network = load_net(network, snapshot_path,
                       get_filename(network.name, min_valid_epoch + 1))
    log(f, 'Testing on best model {}'.format(min_valid_epoch))
    epoch_test_losses, mae = test_network(dataset,
                                          'test',
                                          network,
                                          print_output=os.path.join(
                                              model_save_dir, 'dump_test'))
    log(
        f, 'TEST epoch: ' + str(epoch) + ' test loss1, mae:' +
        str(epoch_test_losses) + ", " + str(mae))
    log(f, 'Exiting train...')
    f.close()
    return
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Supervised training')
    parser.add_argument(
        '--autoaugment',
        action='store_true',
        default=False,
        help='Use autoaugment policy, only for CIFAR10 (Default: False)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='Input batch size for training (default: 64)')
    parser.add_argument('--dataset',
                        type=str,
                        default='cifar10',
                        help='Dataset name (default: CIFAR10)')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        metavar='N',
                        help='Number of epochs to train (default: 200)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        metavar='LR',
                        help='Learning rate (default: 0.1)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument(
        '--network',
        type=str,
        default='ResNet-18',
        help=
        'Network model (default: ResNet-18), choose between (ResNet-18, TempEns, RevNet-18)'
    )
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='Disables CUDA training')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='Number of data loading workers')
    parser.add_argument('--rotnet_dir',
                        type=str,
                        default='',
                        help='RotNet saved directory')
    parser.add_argument('--save_dir',
                        type=str,
                        default='./data/supervised/',
                        help='Directory to save models')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        help='Random seed (default: 1)')

    args = parser.parse_args()
    args.name = 'supervised_%s_%s_seed%u' % (args.dataset.lower(),
                                             args.network.lower(), args.seed)

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    dataset_train = dataset.GenericDataset(dataset_name=args.dataset,
                                           split='train',
                                           autoaugment=args.autoaugment)
    dataset_test = dataset.GenericDataset(dataset_name=args.dataset,
                                          split='test')

    dloader_train = dataset.DataLoader(dataset=dataset_train,
                                       batch_size=args.batch_size,
                                       num_workers=args.num_workers,
                                       shuffle=True)

    dloader_test = dataset.DataLoader(dataset=dataset_test,
                                      batch_size=args.batch_size,
                                      num_workers=args.num_workers,
                                      shuffle=False)

    # Load model
    model = models.load_net(args.network, dataset_train.n_classes)

    # Use rotnet pretraining
    if args.rotnet_dir:
        # Load rotNet model, manually delete layers > 2
        state_dict_rotnet = torch.load(
            os.path.join(
                args.rotnet_dir, 'rotNet_%s_%s_lr_best.pth' %
                (args.dataset, args.network.lower())))
        for key in state_dict_rotnet.copy().keys():
            if 'fc' in key or 'layer3' in key or 'layer4' in key:
                del state_dict_rotnet[key]
        model.load_state_dict(state_dict_rotnet, strict=False)

        # Only finetune lower layers (>2)
        for name, param in model.named_parameters():
            if 'fc' not in name and 'layer3' not in name and 'layer4' not in name:
                param.requires_grad = False

    model = model.to(device)

    # Init optimizer and loss
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=5e-4,
                          nesterov=True)
    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer,
                                                milestones=[60, 120, 160, 200],
                                                gamma=0.2)
    criterion = nn.CrossEntropyLoss()

    best_acc = 0
    for epoch in range(args.epochs + 1):
        loss_record = train(epoch, model, device, dloader_train, optimizer,
                            exp_lr_scheduler, criterion, args)
        acc_record = test(model, device, dloader_test, args)

        is_best = acc_record.avg > best_acc
        best_loss = max(acc_record.avg, best_acc)
        utils.save_checkpoint(
            model.state_dict(),
            is_best,
            args.save_dir,
            checkpoint=args.name + 'supervised_training_ckpt.pth',
            best_model=args.name + 'supervised_training_best.pth')
Exemplo n.º 6
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='RotNet')
    parser.add_argument(
        '--autoaugment',
        action='store_true',
        default=False,
        help='Use autoaugment policy, only for CIFAR10 (Default: False)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='Input batch size for training (default: 64)')
    parser.add_argument('--dataset',
                        type=str,
                        default='cifar10',
                        help='Dataset name (default: CIFAR10)')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        metavar='N',
                        help='Number of epochs to train (default: 200)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        metavar='LR',
                        help='Learning rate (default: 0.1)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument(
        '--network',
        type=str,
        default='ResNet-18',
        help=
        'Network model (default: ResNet-18), choose between (ResNet-18, TempEns, RevNet-18)'
    )
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='Disables CUDA training')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='Number of data loading workers')
    parser.add_argument('--save_dir',
                        type=str,
                        default='./data/rotNet',
                        help='Directory to save models')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        help='Random seed (default: 1)')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    torch.manual_seed(args.seed)

    dataset_train = dataset.GenericDataset(dataset_name=args.dataset,
                                           split='train',
                                           autoaugment=args.autoaugment)
    dataset_test = dataset.GenericDataset(dataset_name=args.dataset,
                                          split='test')

    dloader_train = dataset.DataLoader(rotnet=True,
                                       dataset=dataset_train,
                                       batch_size=args.batch_size,
                                       num_workers=args.num_workers,
                                       shuffle=True)

    dloader_test = dataset.DataLoader(rotnet=True,
                                      dataset=dataset_test,
                                      batch_size=args.batch_size,
                                      num_workers=args.num_workers,
                                      shuffle=False)

    model = models.load_net(args.network)
    model = model.to(device)

    # follow the same setting as RotNet paper
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=5e-4,
                          nesterov=True)
    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer,
                                                milestones=[60, 120, 160, 200],
                                                gamma=0.2)
    criterion = nn.CrossEntropyLoss()

    best_acc = 0
    for epoch in range(args.epochs + 1):
        loss_record = train(epoch, model, device, dloader_train, optimizer,
                            exp_lr_scheduler, criterion, args)
        acc_record = test(model, device, dloader_test, args)

        is_best = acc_record.avg > best_acc
        best_loss = max(acc_record.avg, best_acc)
        utils.save_checkpoint(model.state_dict(),
                              is_best,
                              args.save_dir,
                              checkpoint='rotNet_%s_%s_lr_checkpoint.pth' %
                              (args.dataset, args.network.lower()),
                              best_model='rotNet_%s_%s_lr_best.pth' %
                              (args.dataset, args.network.lower()))

        # Saving milestones only
        if epoch in [59, 119, 159, 199]:
            print('Saving model at milestone: %u' % (epoch))
            utils.save_checkpoint(model.state_dict(),
                                  False,
                                  args.save_dir,
                                  checkpoint='rotNet_%s_%s_%u_checkpoint.pth' %
                                  (args.dataset, args.network.lower(), epoch))
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='Alternative Training for Semi-supervised learning')
    parser.add_argument(
        '--autoaugment',
        action='store_true',
        default=False,
        help='Use AutoAugment data augmentation (default: False)')
    parser.add_argument('--dataset',
                        type=str,
                        default='cifar10',
                        help='Dataset (default: cifar10)')
    parser.add_argument('--epochs_refine',
                        type=int,
                        default=100,
                        help='Refinement epochs on labelled set')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        help='Learning rate (default 0.01)')
    parser.add_argument('--milestones_outer',
                        nargs='+',
                        type=int,
                        default=[60, 100],
                        help='Outer loop milestones')
    parser.add_argument(
        '--milestones_inner',
        nargs='+',
        type=int,
        default=[7, 10],
        help='Inner loop milestones (change of lr and number of epochs)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        help='SGD momentum (default: 0.9)')
    parser.add_argument(
        '--nb_labels_per_class',
        type=int,
        default=10,
        help='Number of labelled samples per class (default: 10)')
    parser.add_argument('--network',
                        type=str,
                        default='ResNet-18',
                        help='Network (default: ResNet-18)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='Disables CUDA training (default: False)')
    parser.add_argument('--proportion_CE',
                        type=float,
                        default=0.5,
                        help='Weight of cross entropy loss')
    parser.add_argument('--rotnet_dir',
                        type=str,
                        default='',
                        help='RotNet saved directory')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        help='Random seed (default: 1)')
    parser.add_argument('--save_dir',
                        type=str,
                        default='./data/alternative_supervised/',
                        help='Directory to save models')
    args = parser.parse_args()

    global logger_module
    logger_module = args
    logger_module.time_start = datetime.datetime.now().strftime(
        '%Y-%m-%d %H:%M:%S')

    # Path to file
    os.makedirs(args.save_dir, exist_ok=True)
    args.name = 'alternative_%s_%s_seed%u' % (logger_module.dataset.lower(),
                                              logger_module.network.lower(),
                                              args.seed)
    logger_module.net_path = os.path.join(args.save_dir, args.name + '.pth')
    logger_module.pkl_path = os.path.join(args.save_dir, args.name + '.pkl')

    logger_module.train_loss = []
    logger_module.train_acc = []
    logger_module.test_loss = []
    logger_module.test_acc = []
    logger_module.test_acc5 = []
    logger_module.percentage_correct_training = []
    logger_module.number_training = []

    train_data = 'train_data' if args.dataset != 'svhn' else 'data'
    train_labels = 'train_labels' if args.dataset != 'svhn' else 'labels'

    with open(logger_module.pkl_path, "wb") as output_file:
        pickle.dump(vars(logger_module), output_file)

    # Set up seed and GPU usage
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Initialize the dataset
    train_set = dataset.GenericDataset(args.dataset, 'train')
    test_set = dataset.GenericDataset(args.dataset, 'test')

    # Build meta set containing only the restricted labeled samples
    meta_set = dataset.GenericDataset(args.dataset, 'train')
    index_meta = []
    for target in range(train_set.n_classes):
        index_meta.extend(
            np.random.choice(
                np.argwhere(
                    np.array(getattr(train_set.data, train_labels)) == target)
                [:, 0], args.nb_labels_per_class, False))

    setattr(
        meta_set.data, train_labels,
        list(itemgetter(*index_meta)(getattr(train_set.data, train_labels))))
    setattr(meta_set.data, train_data,
            list(itemgetter(*index_meta)(getattr(train_set.data, train_data))))

    # Copy train set for future reassignment
    trainset_targets_save = np.copy(getattr(train_set.data, train_labels))
    trainset_data_save = np.copy(getattr(train_set.data, train_data))

    # Dataloader iterators # TODO Autoaugment
    trainloader = dataset.DataLoader(train_set,
                                     batch_size=128,
                                     shuffle=True,
                                     num_workers=2)
    metaloader = dataset.DataLoader(meta_set,
                                    batch_size=128,
                                    shuffle=True,
                                    num_workers=2)
    testloader = dataset.DataLoader(test_set,
                                    batch_size=1000,
                                    shuffle=False,
                                    num_workers=1)

    # First network intialization
    model = models.load_net(logger_module.network, train_set.n_classes)

    # Load model
    if args.rotnet_dir:
        state_dict_rotnet = torch.load(
            os.path.join(
                args.rotnet_dir,
                'rotNet_%s_%s_lr_best.pth' % (logger_module.dataset.lower(),
                                              logger_module.network.lower())))
        del state_dict_rotnet['fc.weight']
        del state_dict_rotnet['fc.bias']
        model.load_state_dict(state_dict_rotnet, strict=False)
    model = model.to(device)

    global thought_targets
    global meta_labels_total
    for outer_loop in range(0, args.milestones_outer[1]):
        print('Entering outer loop %u' % (outer_loop))

        # Step 1: Fine-tune network and assign Labels
        fine_tune_and_assign_labels(args, model, metaloader, trainloader, testloader, device, train_set, trainset_data_save, trainset_targets_save,\
             index_meta, outer_loop)

        # Self distillation starts from a uniform distribution
        meta_labels_total = torch.ones(len(trainloader.dataset),
                                       trainloader.dataset.n_classes) / float(
                                           trainloader.dataset.n_classes)

        # Step 1.5: Reinitialize net
        model = models.load_net(logger_module.network, train_set.n_classes)
        # Load model
        if args.rotnet_dir:
            state_dict_rotnet = torch.load(
                os.path.join(
                    args.rotnet_dir, 'rotNet_%s_%s_lr_best.pth' %
                    (logger_module.dataset.lower(),
                     logger_module.network.lower())))
            del state_dict_rotnet['fc.weight']
            del state_dict_rotnet['fc.bias']
            model.load_state_dict(state_dict_rotnet, strict=False)
        model = model.to(device)

        # Freeze net first two blocks
        for name, param in model.named_parameters():
            if 'fc' not in name and 'layer3' not in name and 'layer4' not in name:
                param.requires_grad = False

        # Optimizer and LR scheduler
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=5e-4,
                              nesterov=True)
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[args.milestones_inner[0]], gamma=0.1)

        # Step 2: Training using predicted labels
        print('Labels assignment done. Entering inner loop')
        for epoch in range(args.milestones_inner[1]):
            scheduler.step()
            train(args, model, device, trainloader, optimizer, epoch, 'train',
                  outer_loop)
            test(args, model, device, testloader)
            logger_module.epoch = epoch

        with open(logger_module.pkl_path, "wb") as output_file:
            pickle.dump(vars(logger_module), output_file)

        torch.save(model.state_dict(), logger_module.net_path)
    test(args, model, device, testloader, True)
Exemplo n.º 8
0
    default="",
    metavar="FILE",
    help="path to config file",
    type=str,
)
parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

args = parser.parse_args()

if args.config_file:
    cfg.merge_from_file(args.config_file)

cfg.merge_from_list(args.opts)
cfg.freeze()

if len(args.sequences) == 0:
    args.sequences = None

net = models.load_net(cfg.MODEL.NET, cfg)
tracker = trackers.load_tracker(net, args.checkpoint, cfg)

experiment = experiments.ExperimentOTB(cfg, version=args.version, sequences=args.sequences)

experiment.run(tracker, visualize=args.visualize)
experiment.report([tracker.name], args=args)