Ejemplo n.º 1
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    torch.cuda.set_device(args.gpu)
    cudnn.enabled = True
    logging.info("args = %s", args)

    genotype = eval("genotypes.%s" % args.arch)
    if args.dataset in LARGE_DATASETS:
        model = NetworkLarge(args.init_channels, CLASSES, args.layers,
                             args.auxiliary, genotype)
    else:
        model = Network(args.init_channels, CLASSES, args.layers,
                        args.auxiliary, genotype)
    model = model.cuda()
    utils.load(model, args.model_path)

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    _, test_transform = utils.data_transforms(args.dataset, args.cutout,
                                              args.cutout_length)
    if args.dataset == "CIFAR100":
        test_data = dset.CIFAR100(root=args.data,
                                  train=False,
                                  download=True,
                                  transform=test_transform)
    elif args.dataset == "CIFAR10":
        test_data = dset.CIFAR10(root=args.data,
                                 train=False,
                                 download=True,
                                 transform=test_transform)
    elif args.dataset == "sport8":
        dset_cls = dset.ImageFolder
        val_path = '%s/Sport8/test' % args.data
        test_data = dset_cls(root=val_path, transform=test_transform)
    elif args.dataset == "mit67":
        dset_cls = dset.ImageFolder
        val_path = '%s/MIT67/test' % args.data
        test_data = dset_cls(root=val_path, transform=test_transform)
    elif args.dataset == "flowers102":
        dset_cls = dset.ImageFolder
        val_path = '%s/flowers102/test' % args.tmp_data_dir
        test_data = dset_cls(root=val_path, transform=test_transform)
    test_queue = torch.utils.data.DataLoader(test_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=False,
                                             num_workers=2)

    model.drop_path_prob = 0.0
    test_acc, test_obj = infer(test_queue, model, criterion)
    logging.info('Test_acc %f', test_acc)
Ejemplo n.º 2
0
def main():
    # args & device
    args = config.get_args()
    if torch.cuda.is_available():
        print('Train on GPU!')
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    # dataset
    assert args.dataset in ['cifar10', 'imagenet']
    train_transform, valid_transform = data_transforms(args)
    if args.dataset == 'cifar10':
        trainset = torchvision.datasets.CIFAR10(root=os.path.join(args.data_dir, 'cifar'), train=True,
                                                download=True, transform=train_transform)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                                   shuffle=True, pin_memory=True, num_workers=8)
        valset = torchvision.datasets.CIFAR10(root=os.path.join(args.data_dir, 'cifar'), train=False,
                                              download=True, transform=valid_transform)
        val_loader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size,
                                                 shuffle=False, pin_memory=True, num_workers=8)
    elif args.dataset == 'imagenet':
        train_data_set = datasets.ImageNet(os.path.join(args.data_dir, 'ILSVRC2012', 'train'), train_transform)
        val_data_set = datasets.ImageNet(os.path.join(args.data_dir, 'ILSVRC2012', 'valid'), valid_transform)
        train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, shuffle=True,
                                                   num_workers=8, pin_memory=True, sampler=None)
        val_loader = torch.utils.data.DataLoader(val_data_set, batch_size=args.batch_size, shuffle=False,
                                                 num_workers=8, pin_memory=True)

    # SinglePath_OneShot
    choice = [3, 1, 2, 1, 0, 1, 3, 3, 1, 3, 0, 1, 0, 3, 3, 3, 3, 3, 0, 3]
    #[2, 0, 2, 3, 2, 2, 3, 1, 2, 1, 0, 1, 0, 3, 1, 0, 0, 2, 3, 2]
    model = SinglePath_Network(args.dataset, args.resize, args.classes, args.layers, choice)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, args.momentum, args.weight_decay)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 1 - (epoch / args.epochs))

    # flops & params & structure
    flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32),) if args.dataset == 'cifar10'
                            else (torch.randn(1, 3, 224, 224),), verbose=False)
    # print(model)
    print('Random Path of the Supernet: Params: %.2fM, Flops:%.2fM' % ((params / 1e6), (flops / 1e6)))
    model = model.to(device)
    summary(model, (3, 32, 32) if args.dataset == 'cifar10' else (3, 224, 224))

    # train supernet
    start = time.time()
    for epoch in range(args.epochs):
        train(args, epoch, train_loader, device, model, criterion, optimizer, scheduler, supernet=False)
        scheduler.step()
        if (epoch + 1) % args.val_interval == 0:
            validate(args, epoch, val_loader, device, model, criterion, supernet=False)
            utils.save_checkpoint({'state_dict': model.state_dict(), }, epoch + 1, tag=args.exp_name)
    utils.time_record(start)
Ejemplo n.º 3
0
def main():
    args = build_args()
    image_datasets, dataloaders, dataset_sizes, class_names, device = ut.data_transforms(
        args.dir, args.device, args.batch_size)
    model_ft, criterion, optimizer_ft, exp_lr_scheduler = tr.finetune_convnet(
        device)
    model_trained_ = tr.train_model(model_ft,
                                    criterion,
                                    optimizer_ft,
                                    exp_lr_scheduler,
                                    dataloaders,
                                    device,
                                    dataset_sizes,
                                    num_epochs=args.num_epochs)

    model_save(args.savepath, model_trained_)
Ejemplo n.º 4
0
def main():
    train_data, validation_data, test_data = utils.data_transforms(pa.data_dir)
    trainloader, validationloader, testloader = utils.data_loaders(pa.data_dir)

    model, criterion, optimizer = utils.network_setup(
        pa.architecture, pa.dropout, pa.input_units, pa.hidden_units,
        pa.learning_rate, pa.device)

    utils.network_training(model, trainloader, validationloader, criterion,
                           optimizer, pa.epochs, pa.print_every, pa.device)

    utils.save_checkpoint(model, train_data, optimizer, pa.architecture,
                          pa.dropout, pa.input_units, pa.hidden_units,
                          pa.learning_rate, pa.epochs, pa.save_dir)

    print("Finished training!")
def main():
    train_data, validation_data, test_data = utils.data_transforms(pa.data_dir)
    trainloader, validationloader, testloader = utils.data_loaders(pa.data_dir)

    model = utils.load_checkpoint(pa.save_dir)

    with open(pa.category_names) as json_file:
        cat_to_name = json.load(json_file)

    probs, classes = utils.predict(pa.image_path, model, pa.topk, pa.device)

    probs = probs.type(torch.FloatTensor).to('cpu').numpy()
    classes = classes.type(torch.FloatTensor).to('cpu').numpy()
    classes = classes.astype(int)
    classes = classes.astype(str)

    class_names = [cat_to_name[i] for i in classes[0]]

    print(probs)
    print(classes)
    print(class_names)

    print("Finsihed predicting!")
def generate_caption_visualization(encoder,
                                   decoder,
                                   img_path,
                                   word_dict,
                                   beam_size=3):
    '''
    Function to visualize the step by step development of the caption along with the corresponding attention component visualization.
    
    Arguments:
        encoder: Instance of the trained Encoder for encoding of images
        decoder: Instance of the trained Decoder for caption prediction from encoded image
        img_path (str): Complete path of the image to be visualized
        word_dict (dict): Dictionary of words (vocabulary)
        beam_size (int): Number of top candidates to consider for beam search. Default = 3
    '''

    # Load the image and transform it
    img = pil_loader(img_path)
    img = data_transforms(img)
    img = torch.FloatTensor(img)
    img = img.unsqueeze(0)

    # Get the caption and the corresponding attention weights from the trained network
    img_features = encoder(img)
    img_features = img_features.expand(beam_size, img_features.size(1),
                                       img_features.size(2))
    sentence, alpha = decoder.caption(img_features, beam_size)

    # Using the dictionary, convert the encoded caption to normal words
    token_dict = {idx: word for word, idx in word_dict.items()}
    sentence_tokens = []
    for word_idx in sentence:
        sentence_tokens.append(token_dict[word_idx])
        if word_idx == word_dict['<eos>']:
            break

    # Resizing image for a standard display
    img = Image.open(img_path)
    w, h = img.size
    if w > h:
        w = w * 256 / h
        h = 256
    else:
        h = h * 256 / w
        w = 256
    left = (w - 224) / 2
    top = (h - 224) / 2
    resized_img = img.resize((int(w), int(h)), Image.BICUBIC).crop(
        (left, top, left + 224, top + 224))
    img = np.array(resized_img.convert('RGB').getdata()).reshape(224, 224, 3)
    img = img.astype('float32') / 255

    num_words = len(sentence_tokens)
    w = np.round(np.sqrt(num_words))
    h = np.ceil(np.float32(num_words) / w)
    alpha = torch.tensor(alpha)

    # Plot the different attention weighted versions of the original image along with the resultant caption word prediction
    f = plt.figure(figsize=(8, 9))
    plot_height = ceil((num_words + 3) / 4.0)
    ax1 = f.add_subplot(4, plot_height, 1)
    plt.imshow(img)
    plt.axis('off')
    for idx in range(num_words):
        ax2 = f.add_subplot(4, plot_height, idx + 2)
        label = sentence_tokens[idx]
        plt.text(0, 1, label, backgroundcolor='white', fontsize=13)
        plt.text(0, 1, label, color='black', fontsize=13)
        plt.imshow(img)

        if encoder.network == 'vgg19':
            shape_size = 14
        else:
            shape_size = 7

        alpha_img = skimage.transform.pyramid_expand(alpha[idx, :].reshape(
            shape_size, shape_size),
                                                     upscale=16,
                                                     sigma=20)

        plt.imshow(alpha_img, alpha=0.8)
        plt.set_cmap(cm.Greys_r)
        plt.axis('off')
    plt.show()
def generate_image_caption(encoder,
                           decoder,
                           img_path,
                           word_dict,
                           beam_size=3,
                           ax=plt):
    '''
    Function to display the image along with the resultant predicted caption.
    
    Arguments:
        encoder: Instance of the trained Encoder for encoding of images
        decoder: Instance of the trained Decoder for caption prediction from encoded image
        img_path (str): Complete path of the image to be visualized
        word_dict (dict): Dictionary of words (vocabulary)
        beam_size (int): Number of top candidates to consider for beam search. Default = 3
        ax: axes for plotting
    '''

    # Load the image and transform it
    img = pil_loader(img_path)
    img = data_transforms(img)
    img = torch.FloatTensor(img)
    img = img.unsqueeze(0)

    # Get the caption from the trained network
    img_features = encoder(img)
    img_features = img_features.expand(beam_size, img_features.size(1),
                                       img_features.size(2))
    sentence, alpha = decoder.caption(img_features, beam_size)

    # Using the dictionary, convert the encoded caption to normal words
    token_dict = {idx: word for word, idx in word_dict.items()}
    sentence_tokens = []
    for word_idx in sentence:
        if word_idx == word_dict['<start>']:
            continue
        if word_idx == word_dict['<eos>']:
            break
        sentence_tokens.append(token_dict[word_idx])

    # Resizing image for a standard display
    img = Image.open(img_path)
    w, h = img.size
    if w > h:
        w = w * 256 / h
        h = 256
    else:
        h = h * 256 / w
        w = 256
    left = (w - 224) / 2
    top = (h - 224) / 2
    resized_img = img.resize((int(w), int(h)), Image.BICUBIC).crop(
        (left, top, left + 224, top + 224))
    img = np.array(resized_img.convert('RGB').getdata()).reshape(224, 224, 3)
    img = img.astype('float32') / 255

    # Creation of a sentence from the list of words
    caption = ''
    for word in sentence_tokens:
        if word is sentence_tokens[len(sentence_tokens) - 1]:
            caption = caption + word + '.'
        else:
            caption = caption + word + ' '

    ax.imshow(img)
    ax.set_title(caption.capitalize())
    ax.axis('off')
def main():
    # Initialize the model for this run
    model_ft, input_size = initialize_model(model_name,
                                            num_classes,
                                            feature_extract,
                                            use_pretrained=True)
    model_ft.to(device)

    # Temporary header
    # directory - normal, bacteria, TB, COVID-19, virus
    dir_test = '/home/ubuntu/segmentation/output/COVID-19/'
    label = 3  # set 3 for COVID-19 for virus class

    # Data loader
    test_masked_images = sorted(glob.glob(dir_test + '*.npz'))
    #test_masks = sorted(glob.glob(dir_test + '*.mask.npy'))

    for masked_img in test_masked_images:

        test_masked_img = np.load(masked_img)
        #test_mask = np.load(mask)

        test_masked_img = Image.fromarray(test_masked_img).resize((1024, 1024))
        #test_mask = Image.fromarray(test_mask).resize((1024,1024))

        #test_img = np.asarray(test_img)
        #test_mask = np.round(np.asarray(test_mask))

        #test_masked = np.multiply(test_img, test_mask)

        test_normalized = test_masked_img

        h_whole = test_normalized.shape[0]  # original w
        w_whole = test_normalized.shape[1]  # original h

        background = np.zeros((h_whole, w_whole))
        background_indicer = np.zeros((h_whole, w_whole))

        sum_prob_wt = 0.0

        for i in range(header.repeat):

            non_zero_list = np.nonzero(test_normalized)

            random_index = random.randint(0, len(non_zero_list[0]) - 1)

            non_zero_row = non_zero_list[0][
                random_index]  # random non-zero row index
            non_zero_col = non_zero_list[1][
                random_index]  # random non-zero col index

            X_patch = test_normalized[
                int(max(0, non_zero_row - (header.img_size / 2))
                    ):int(min(h_whole, non_zero_row + (header.img_size / 2))),
                int(max(0, non_zero_col - (header.img_size / 2))
                    ):int(min(w_whole, non_zero_col + (header.img_size / 2)))]

            X_patch_img = data_transforms(
                augmentation(Image.fromarray(X_patch), rand_p=0.0,
                             mode='test'))
            X_patch_img_ = np.squeeze(np.asarray(X_patch_img))

            X_patch_1 = np.expand_dims(X_patch_img_, axis=0)
            X_patch_2 = np.expand_dims(X_patch_img_, axis=0)
            X_patch_3 = np.expand_dims(X_patch_img_, axis=0)

            X_ = np.concatenate((X_patch_1, X_patch_2, X_patch_3), axis=0)
            X_ = np.expand_dims(X_, axis=0)

            X = torch.from_numpy(X_)
            X = X.to(device)

            checkpoint = torch.load(
                os.path.join(header.save_dir,
                             str(header.inference_epoch) + '.pth'))
            model_ft.load_state_dict(checkpoint['model_state_dict'])
            model_ft.eval()
            outputs = model_ft(X)
            outputs_prob = F.softmax(outputs)

            prob = outputs_prob[0][label]
            prob_wt = prob.detach().cpu().numpy()

            gradcam = GradCAM.from_config(model_type='resnet',
                                          arch=model_ft,
                                          layer_name='layer4')

            mask, logit = gradcam(X, class_idx=label)
            mask_np = np.squeeze(mask.detach().cpu().numpy())
            indicer = np.ones((224, 224))

            mask_np = np.asarray(
                cv2.resize(
                    mask_np,
                    dsize=(
                        int(min(w_whole, non_zero_col +
                                (header.img_size / 2))) -
                        int(max(0, non_zero_col - (header.img_size / 2))),
                        int(min(h_whole, non_zero_row +
                                (header.img_size / 2))) -
                        int(max(0, non_zero_row - (header.img_size / 2))))))

            indicer = np.asarray(
                cv2.resize(
                    indicer,
                    dsize=(
                        int(min(w_whole, non_zero_col +
                                (header.img_size / 2))) -
                        int(max(0, non_zero_col - (header.img_size / 2))),
                        int(min(h_whole, non_zero_row +
                                (header.img_size / 2))) -
                        int(max(0, non_zero_row - (header.img_size / 2))))))

            mask_add = np.zeros((1024, 1024))
            mask_add[
                int(max(0, non_zero_row - (header.img_size / 2))
                    ):int(min(h_whole, non_zero_row + (header.img_size / 2))),
                int(max(0, non_zero_col - (header.img_size / 2))
                    ):int(min(w_whole, non_zero_col +
                              (header.img_size / 2)))] = mask_np
            mask_add = mask_add * prob_wt

            indicer_add = np.zeros((1024, 1024))
            indicer_add[
                int(max(0, non_zero_row - (header.img_size / 2))
                    ):int(min(h_whole, non_zero_row + (header.img_size / 2))),
                int(max(0, non_zero_col - (header.img_size / 2))
                    ):int(min(w_whole, non_zero_col +
                              (header.img_size / 2)))] = indicer
            indicer_add = indicer_add

            background = background + mask_add
            background_indicer = background_indicer + indicer_add  # number in this indicer means how many time the area included.

            sum_prob_wt = sum_prob_wt + prob_wt

        final_mask = np.divide(background, background_indicer + 1e-7)

        final_mask = np.expand_dims(np.expand_dims(final_mask, axis=0), axis=0)
        torch_final_mask = torch.from_numpy(final_mask)

        test_img = np.asarray(Image.fromarray(test_img).resize((1024, 1024)))
        test_img = (test_img - test_img.min()) / test_img.max()
        test_img = np.expand_dims(test_img, axis=0)
        test_img = np.concatenate((test_img, test_img, test_img), axis=0)
        torch_final_img = torch.from_numpy(np.expand_dims(test_img, axis=0))

        final_cam, cam_result = visualize_cam(torch_final_mask,
                                              torch_final_img)

        final_cam = (final_cam - final_cam.min()) / final_cam.max()
        final_cam_np = np.swapaxes(np.swapaxes(np.asarray(final_cam), 0, 2), 0,
                                   1)
        test_img_np = np.swapaxes(np.swapaxes(test_img, 0, 2), 0, 1)

        final_combined = test_img_np + final_cam_np
        final_combined = (final_combined -
                          final_combined.min()) / final_combined.max()

        plt.imshow(final_combined)
        plt.savefig(
            test_masked_img.split('.image.npy')[0] + '.patch.heatmap_' +
            '.png')
Ejemplo n.º 9
0
def main():
    # Check Checkpoints Direction
    if not os.path.exists(args.ckpt_dir):
        os.mkdir(args.ckpt_dir)

    # Define Data
    assert args.dataset in ['cifar10', 'imagenet']
    train_transform, valid_transform = utils.data_transforms(args)
    if args.dataset == 'cifar10':
        trainset = torchvision.datasets.CIFAR10(root=os.path.join(
            args.data_root, args.dataset),
                                                train=True,
                                                download=True,
                                                transform=train_transform)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=8)
        valset = torchvision.datasets.CIFAR10(root=os.path.join(
            args.data_root, args.dataset),
                                              train=False,
                                              download=True,
                                              transform=valid_transform)
        val_loader = torch.utils.data.DataLoader(valset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=8)
    elif args.dataset == 'imagenet':
        train_data_set = datasets.ImageNet(
            os.path.join(args.data_root, args.dataset, 'train'),
            train_transform)
        val_data_set = datasets.ImageNet(
            os.path.join(args.data_root, args.dataset, 'valid'),
            valid_transform)
        train_loader = torch.utils.data.DataLoader(train_data_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=8,
                                                   pin_memory=True,
                                                   sampler=None)
        val_loader = torch.utils.data.DataLoader(val_data_set,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=8,
                                                 pin_memory=True)
    else:
        raise ValueError('Undefined dataset !!!')

    # Define Supernet
    model = SinglePath_OneShot(args.dataset, args.resize, args.classes,
                               args.layers)
    logging.info(model)
    model = model.to(args.device)
    criterion = nn.CrossEntropyLoss().to(args.device)
    optimizer = torch.optim.SGD(model.parameters(), args.learning_rate,
                                args.momentum, args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.epochs)
    print('\n')

    # Running
    start = time.time()
    best_val_acc = 0.0
    for epoch in range(args.epochs):
        # Supernet Training
        train_loss, train_acc = train(args, epoch, train_loader, model,
                                      criterion, optimizer)
        scheduler.step()
        logging.info(
            '[Supernet Training] epoch: %03d, train_loss: %.3f, train_acc: %.3f'
            % (epoch + 1, train_loss, train_acc))
        # Supernet Validation
        val_loss, val_acc = validate(args, val_loader, model, criterion)
        # Save Best Supernet Weights
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_ckpt = os.path.join(args.ckpt_dir,
                                     '%s_%s' % (args.exp_name, 'best.pth'))
            torch.save(model.state_dict(), best_ckpt)
            logging.info('Save best checkpoints to %s' % best_ckpt)
        logging.info(
            '[Supernet Validation] epoch: %03d, val_loss: %.3f, val_acc: %.3f, best_acc: %.3f'
            % (epoch + 1, val_loss, val_acc, best_val_acc))
        print('\n')

    # Record Time
    utils.time_record(start)
    else:
        device = torch.device("cpu")

    # one-shot
    model = SinglePath_OneShot(args.dataset, args.resize, args.classes,
                               args.layers).to(device)
    ckpt_path = os.path.join(
        'snapshots',
        args.exp_name + '_ckpt_' + "{:0>4d}".format(args.epochs) + '.pth.tar')
    print('Load checkpoint from:', ckpt_path)
    checkpoint = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'], strict=True)
    criterion = nn.CrossEntropyLoss().to(device)

    # dataset
    _, valid_transform = utils.data_transforms(args)
    valset = torchvision.datasets.CIFAR10(root=os.path.join(
        args.data_dir, 'cifar'),
                                          train=False,
                                          download=False,
                                          transform=valid_transform)
    val_loader = torch.utils.data.DataLoader(valset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=8)

    # random search
    start = time.time()
    best_acc = 0.0
    acc_list = list()
Ejemplo n.º 11
0
def main():
  if not torch.cuda.is_available():
    logging.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  torch.cuda.set_device(args.gpu)
  cudnn.benchmark = True
  torch.manual_seed(args.seed)
  cudnn.enabled=True
  torch.cuda.manual_seed(args.seed)
  logging.info('gpu device = %d' % args.gpu)
  logging.info("args = %s", args)

  genotype = eval("genotypes.%s" % args.arch)
  if args.dataset in utils.LARGE_DATASETS:
    model = NetworkLarge(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
  else:
    model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
  model = model.cuda()

  logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

  criterion = nn.CrossEntropyLoss()
  criterion = criterion.cuda()
  optimizer = torch.optim.SGD(
      model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay
      )

  train_transform, valid_transform = utils.data_transforms(args.dataset, args.cutout, args.cutout_length)
  if args.dataset == "CIFAR100":
    train_data = dset.CIFAR100(root=args.datapath, train=True, download=True, transform=train_transform)
    valid_data = dset.CIFAR100(root=args.datapath, train=False, download=True, transform=valid_transform)
  elif args.dataset == "CIFAR10":
    train_data = dset.CIFAR10(root=args.datapath, train=True, download=True, transform=train_transform)
    valid_data = dset.CIFAR10(root=args.datapath, train=False, download=True, transform=valid_transform)
  elif args.dataset == 'MIT67':
    dset_cls = dset.ImageFolder
    data_path = '%s/MIT67/train' % args.datapath  # 'data/MIT67/train'
    val_path = '%s/MIT67/test' % args.datapath  # 'data/MIT67/val'
    train_data = dset_cls(root=data_path, transform=train_transform)
    valid_data = dset_cls(root=val_path, transform=valid_transform)
  elif args.dataset == 'Sport8':
    dset_cls = dset.ImageFolder
    data_path = '%s/Sport8/train' % args.datapath  # 'data/Sport8/train'
    val_path = '%s/Sport8/test' % args.datapath  # 'data/Sport8/val'
    train_data = dset_cls(root=data_path, transform=train_transform)
    valid_data = dset_cls(root=val_path, transform=valid_transform)
  elif args.dataset == "flowers102":
    dset_cls = dset.ImageFolder
    data_path = '%s/flowers102/train' % args.datapath
    val_path = '%s/flowers102/test' % args.datapath
    train_data = dset_cls(root=data_path, transform=train_transform)
    valid_data = dset_cls(root=val_path, transform=valid_transform)

  train_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2)

  valid_queue = torch.utils.data.DataLoader(
      valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)

  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))
  best_acc = 0.0
  for epoch in range(args.epochs):
    scheduler.step()
    logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
    model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

    train_acc, train_obj = train(train_queue, model, criterion, optimizer)
    logging.info('train_acc %f', train_acc)

    valid_acc, valid_obj = infer(valid_queue, model, criterion)
    if valid_acc > best_acc:
        best_acc = valid_acc
    logging.info('valid_acc %f, best_acc %f', valid_acc, best_acc)

    utils.save(model, os.path.join(args.save, 'weights.pt'))
Ejemplo n.º 12
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled=True
    torch.cuda.manual_seed(args.seed)
    logging.info("args = %s", args)
    logging.info("unparsed args = %s", unparsed)
    num_gpus = torch.cuda.device_count()
    
    genotype = eval("genotypes.%s" % args.arch)
    print('---------Genotype---------')
    logging.info(genotype)
    print('--------------------------')
    if args.dataset in utils.LARGE_DATASETS:
        model = NetworkLarge(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
    else:
        model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
    if num_gpus > 1:
        model = torch.nn.DataParallel(model)
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay
        )
    train_transform, valid_transform = utils.data_transforms(args.dataset,args.cutout,args.cutout_length)
    if args.dataset == "CIFAR100":
        train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
        valid_data = dset.CIFAR100(root=args.tmp_data_dir, train=False, download=True, transform=valid_transform)
    elif args.dataset == "CIFAR10":
        train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
        valid_data = dset.CIFAR10(root=args.tmp_data_dir, train=False, download=True, transform=valid_transform)
    elif args.dataset == 'mit67':
        dset_cls = dset.ImageFolder
        data_path = '%s/MIT67/train' % args.tmp_data_dir  
        val_path = '%s/MIT67/test' % args.tmp_data_dir 
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)
    elif args.dataset == 'sport8':
        dset_cls = dset.ImageFolder
        data_path = '%s/Sport8/train' % args.tmp_data_dir 
        val_path = '%s/Sport8/test' % args.tmp_data_dir  
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)
    elif args.dataset == "flowers102":
        dset_cls = dset.ImageFolder
        data_path = '%s/flowers102/train' % args.tmp_data_dir
        val_path = '%s/flowers102/test' % args.tmp_data_dir
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)
    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))
    best_acc = 0.0
    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('Epoch: %d lr %e', epoch, scheduler.get_lr()[0])
        if num_gpus > 1:
            model.module.drop_path_prob = args.drop_path_prob * epoch / args.epochs
        else:
            model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
        start_time = time.time()
        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('Train_acc: %f', train_acc)

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        if valid_acc > best_acc:
            best_acc = valid_acc
        logging.info('Valid_acc: %f', valid_acc)
        end_time = time.time()
        duration = end_time - start_time
        print('Epoch time: %ds.' % duration )
        utils.save(model, os.path.join(args.save, 'weights.pt'))
Ejemplo n.º 13
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    model = Network(
        args.init_channels,
        CLASSES,
        args.layers,
        criterion,
        largemode=True if args.dataset in utils.LARGE_DATASETS else False)
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, valid_transform = utils.data_transforms(
        args.dataset, args.cutout, args.cutout_length)
    if args.dataset == "CIFAR100":
        train_data = dset.CIFAR100(root=args.datapath,
                                   train=True,
                                   download=True,
                                   transform=train_transform)
    elif args.dataset == "CIFAR10":
        train_data = dset.CIFAR10(root=args.datapath,
                                  train=True,
                                  download=True,
                                  transform=train_transform)
    elif args.dataset == 'MIT67':
        dset_cls = dset.ImageFolder
        data_path = '%s/MIT67/train' % args.datapath  # 'data/MIT67/train'
        val_path = '%s/MIT67/test' % args.datapath  # 'data/MIT67/val'
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)
    elif args.dataset == 'Sport8':
        dset_cls = dset.ImageFolder
        data_path = '%s/Sport8/train' % args.datapath  # 'data/Sport8/train'
        val_path = '%s/Sport8/test' % args.datapath  # 'data/Sport8/val'
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)
    elif args.dataset == "flowers102":
        dset_cls = dset.ImageFolder
        data_path = '%s/flowers102/train' % args.datapath
        val_path = '%s/flowers102/test' % args.datapath
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))
    import random
    random.shuffle(indices)

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    architect = Architect(model, args)

    for epoch in range(args.epochs):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)

        genotype = model.genotype()
        logging.info('genotype = %s', genotype)

        #print(F.softmax(model.alphas_normal, dim=-1))
        #print(F.softmax(model.alphas_reduce, dim=-1))

        # training
        train_acc, train_obj = train(train_queue, valid_queue, model,
                                     architect, criterion, optimizer, lr,
                                     epoch)
        logging.info('train_acc %f', train_acc)

        # validation
        if args.epochs - epoch <= 1:
            with open(args.save + "/best_genotype.txt", "w") as f:
                f.write(str(genotype))
            valid_acc, valid_obj = infer(valid_queue, model, criterion)
            logging.info('valid_acc %f', valid_acc)

        utils.save(model, os.path.join(args.save, 'weights.pt'))
Ejemplo n.º 14
0
from utils import data_transforms
from torchvision import datasets
from thop import profile
from torchsummary import summary

if __name__ == '__main__':
    args = get_args()
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    # dataset
    assert args.dataset in ['cifar10', 'imagenet']
    train_transform, valid_transform = data_transforms(args)
    if args.dataset == 'cifar10':
        trainset = torchvision.datasets.CIFAR10(root=os.path.join(
            args.data_dir, 'cifar'),
                                                train=True,
                                                download=True,
                                                transform=train_transform)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=8)
        valset = torchvision.datasets.CIFAR10(root=os.path.join(
            args.data_dir, 'cifar'),
                                              train=False,
                                              download=True,
Ejemplo n.º 15
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('GPU device = %d' % args.gpu)
    logging.info("args = %s", args)
    gpu_logger = GpuLogThread([args.gpu], writer, seconds=15 if not args.test else 1)
    gpu_logger.start()
    logging.debug(locals())
    model = None

    # prepare dataset
    train_transform, valid_transform = utils.data_transforms(args.dataset, args.cutout, args.cutout_length)
    if args.dataset == "CIFAR100":
        train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
    elif args.dataset == "CIFAR10":
        train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
    elif args.dataset == 'MIT67':
        dset_cls = dset.ImageFolder
        data_path = '%s/MIT67/train' % args.tmp_data_dir  # 'data/MIT67/train'
        val_path = '%s/MIT67/test' % args.tmp_data_dir  # 'data/MIT67/val'
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)
    elif args.dataset == 'Sport8':
        dset_cls = dset.ImageFolder
        data_path = '%s/Sport8/train' % args.tmp_data_dir  # 'data/Sport8/train'
        val_path = '%s/Sport8/test' % args.tmp_data_dir  # 'data/Sport8/val'
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)
    elif args.dataset == "flowers102":
        dset_cls = dset.ImageFolder
        data_path = '%s/flowers102/train' % args.tmp_data_dir
        val_path = '%s/flowers102/test' % args.tmp_data_dir
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))
    import random;random.shuffle(indices)

    train_iterator = utils.DynamicBatchSizeLoader(torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_multiples,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True, num_workers=args.workers), args.batch_size_min)

    valid_iterator = utils.DynamicBatchSizeLoader(torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_multiples,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory=True, num_workers=args.workers), args.batch_size_min)

    # build Network
    logging.debug('building network')
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    num_graph_edges = sum(list(range(2, 2 + args.blocks)))
    switches_normal = SwitchManager(num_graph_edges, copy.deepcopy(PRIMITIVES), 'normal')
    switches_reduce = SwitchManager(num_graph_edges, copy.deepcopy(PRIMITIVES), 'reduce')

    total_epochs = 0
    for cycle in parse_cycles():
        logging.debug('new cycle %s' % repr(cycle))
        print('\n' * 3, '-' * 100)
        print(cycle)
        print('', '-' * 100, '\n')

        writer.add_scalar('cycle/net_layers', cycle.net_layers, cycle.num)
        writer.add_scalar('cycle/net_init_c', cycle.net_init_c, cycle.num)
        writer.add_scalar('cycle/net_dropout', cycle.net_dropout, cycle.num)
        writer.add_scalar('cycle/ops_keep', cycle.ops_keep, cycle.num)
        writer.add_scalar('cycle/epochs', cycle.epochs, cycle.num)
        writer.add_scalar('cycle/grace_epochs', cycle.grace_epochs, cycle.num)
        writer.add_scalar('cycle/morphs', cycle.morphs, cycle.num)
        switches_normal.plot_ops(logging.info, writer, cycle.num)
        switches_reduce.plot_ops(logging.info, writer, cycle.num)

        # rebuild the model in each cycle, clean up the cache...
        logging.debug('building model')
        del model
        torch.cuda.empty_cache()
        if args.dataset == "CIFAR100":
            CLASSES = 100
        elif args.dataset == "CIFAR10":
            CLASSES = 10
        elif args.dataset == 'MIT67':
            dset_cls = dset.ImageFolder
            CLASSES = 67
        elif args.dataset == 'Sport8':
            dset_cls = dset.ImageFolder
            CLASSES = 8
        elif args.dataset == "flowers102":
            dset_cls = dset.ImageFolder
            CLASSES = 102
        model = Network(cycle.net_init_c,
                        CLASSES,
                        cycle.net_layers,
                        criterion,
                        switches_normal=switches_normal,
                        switches_reduce=switches_reduce,
                        steps=args.blocks,
                        p=cycle.net_dropout,
                        largemode=True if args.dataset in utils.LARGE_DATASETS else False)
        gpu_logger.reset_recent()
        if cycle.load:
            utils.load(model, model_path)
            if args.reset_alphas:
                model.reset_alphas()
        if args.test:
            model.randomize_alphas()
        if cycle.init_morphed:
            model.init_morphed(switches_normal, switches_reduce)
        model = model.cuda()
        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

        logging.debug('building optimizers')
        optimizer = torch.optim.SGD(model.net_parameters,
                                    args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        optimizer_a = torch.optim.Adam(model.arch_parameters,
                                       lr=args.arch_learning_rate, betas=(0.5, 0.999),
                                       weight_decay=args.arch_weight_decay)
        logging.debug('building scheduler')
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(cycle.epochs),
                                                               eta_min=args.learning_rate_min)

        if args.batch_size_max > args.batch_size_min:
            train_iterator.set_batch_size(args.batch_size_min)
            valid_iterator.set_batch_size(args.batch_size_min)

        sm_dim = -1
        scale_factor = 0.2
        for epoch in range(cycle.epochs):
            lr = scheduler.get_lr()[0]
            logging.info('Epoch: %d lr: %e', epoch, lr)
            epoch_start = time.time()
            # training
            if epoch < cycle.grace_epochs:
                model.update_p(cycle.net_dropout * (cycle.epochs - epoch - 1) / cycle.epochs)
            else:
                model.update_p(cycle.net_dropout * np.exp(-(epoch - cycle.grace_epochs) * scale_factor))
            train_acc, train_obj = train(train_iterator, valid_iterator, model, criterion, optimizer, optimizer_a,
                                         gpu_logger, train_arch=epoch >= cycle.grace_epochs)
            epoch_duration = time.time() - epoch_start

            # log info
            logging.info('Train_acc %f', train_acc)
            logging.info('Epoch time: %ds', epoch_duration)
            writer.add_scalar('train/accuracy', train_acc, total_epochs)
            writer.add_scalar('train/loss', train_obj, total_epochs)
            writer.add_scalar('epoch/lr', lr, total_epochs)
            writer.add_scalar('epoch/seconds', epoch_duration, total_epochs)
            writer.add_scalar('epoch/model.p', model.p, total_epochs)
            writer.add_scalar('epoch/batch_size', train_iterator.batch_size, total_epochs)

            # validation, only for the last 5 epochs in a cycle
            if cycle.epochs - epoch < 5:
                valid_acc, valid_obj = infer(valid_iterator, model, criterion)
                logging.info('Valid_acc %f', valid_acc)
                writer.add_scalar('valid/accuracy', valid_acc, total_epochs)
                writer.add_scalar('valid/loss', valid_obj, total_epochs)

            total_epochs += 1
            gpu_logger.reset_recent()
            scheduler.step()

        utils.save(model, model_path)

        print('\n' * 2, '------Dropping/morphing paths------')
        # Save switches info for s-c refinement.
        if cycle.is_last:
            switches_normal_copy = switches_normal.copy()
            switches_reduce_copy = switches_reduce.copy()

        # drop operations with low architecture weights, add morphed ones
        arch_param = model.arch_parameters
        normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()
        switches_normal.drop_and_morph(normal_prob, cycle.ops_keep, writer, cycle.num, num_morphs=cycle.morphs,
                                       no_zero=cycle.is_last and args.restrict_zero, keep_morphable=not cycle.is_last)
        reduce_prob = F.softmax(arch_param[1], dim=sm_dim).data.cpu().numpy()
        switches_reduce.drop_and_morph(reduce_prob, cycle.ops_keep, writer, cycle.num, num_morphs=cycle.morphs,
                                       no_zero=cycle.is_last and args.restrict_zero, keep_morphable=not cycle.is_last)
        logging.info('switches_normal = \n%s', switches_normal)
        logging.info('switches_reduce = \n%s', switches_reduce)

        # end last cycle with shortcut/zero pruning and save the genotype
        if cycle.is_last:
            #import ipdb;ipdb.set_trace()
            arch_param = model.arch_parameters
            normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()
            reduce_prob = F.softmax(arch_param[1], dim=sm_dim).data.cpu().numpy()
            normal_final = [0 for _ in range(num_graph_edges)]
            reduce_final = [0 for _ in range(num_graph_edges)]

            # Generate Architecture
            keep_normal = [0, 1]
            keep_reduce = [0, 1]
            n = 3
            start = 2
            for i in range(3):
                end = start + n
                tbsn = normal_final[start:end]
                tbsr = reduce_final[start:end]
                edge_n = sorted(range(n), key=lambda x: tbsn[x])
                keep_normal.append(edge_n[-1] + start)
                keep_normal.append(edge_n[-2] + start)
                edge_r = sorted(range(n), key=lambda x: tbsr[x])
                keep_reduce.append(edge_r[-1] + start)
                keep_reduce.append(edge_r[-2] + start)
                start = end
                n = n + 1
            for i in range(num_graph_edges):
                if i not in keep_normal:
                    for j in range(len(switches_normal.current_ops)):
                        switches_normal[i][j] = False
                if i not in keep_reduce:
                    for j in range(len(switches_reduce.current_ops)):
                        switches_reduce[i][j] = False

            switches_normal.keep_2_branches(normal_prob)
            switches_reduce.keep_2_branches(reduce_prob)
            switches_normal.plot_ops(logging.info, writer, cycle.num + 1)
            switches_reduce.plot_ops(logging.info, writer, cycle.num + 1)
            genotype = parse_network(switches_normal, switches_reduce)
            logging.info(genotype)
            save_genotype(args.save + 'genotype.json', genotype)
            with open(args.save + "/best_genotype.txt", "w") as f:
                f.write(str(genotype))
    gpu_logger.stop()
Ejemplo n.º 16
0
                    type=float,
                    default=0.001,
                    help='Learning Rate')
parser.add_argument('--hidden_units',
                    type=int,
                    default=4096,
                    help='Neurons in the Hidden Layer')
parser.add_argument('--epochs', type=int, default=5, help='Epochs')
parser.add_argument('--gpu', type=str, default='cuda', help='GPU or CPU')
parser.add_argument('--save_dir',
                    type=str,
                    default='checkpoint.pth',
                    help='Path to checkpoint')
arg, unknown = parser.parse_known_args()

train_transforms, valid_transforms, test_transforms = data_transforms()
train_data, valid_data, test_data = data_loader(train_transforms,
                                                valid_transforms,
                                                test_transforms)
trainloader, validloader, testloader = model_data(train_data, valid_data,
                                                  test_data)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if arg.arch == 'vgg':
    input_size = 25088
    model = models.vgg16(pretrained=True)
elif arg.aech == 'densenet':
    input_size = 25088
    model = models.densenet121(pretrained=True)
Ejemplo n.º 17
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info("args = %s", args)
    logging.info("unparsed args = %s", unparsed)
    num_gpus = torch.cuda.device_count()
    gpu_logger = GpuLogThread(list(range(num_gpus)),
                              writer,
                              seconds=10 if args.test else 300)
    gpu_logger.start()

    genotype = genotypes.load_genotype(args.arch, skip_cons=args.arch_pref_sc)
    print('---------Genotype---------')
    logging.info(genotype)
    print('--------------------------')
    if args.dataset == "CIFAR100":
        CLASSES = 100
    elif args.dataset == "CIFAR10":
        CLASSES = 10
    elif args.dataset == 'MIT67':
        dset_cls = dset.ImageFolder
        CLASSES = 67
    elif args.dataset == 'Sport8':
        dset_cls = dset.ImageFolder
        CLASSES = 8
    elif args.dataset == "flowers102":
        dset_cls = dset.ImageFolder
        CLASSES = 102
    if args.dataset in utils.LARGE_DATASETS:
        model = NetworkLarge(args.init_channels, CLASSES, args.layers,
                             args.auxiliary, genotype)
    else:
        model = Network(args.init_channels, CLASSES, args.layers,
                        args.auxiliary, genotype)
    if num_gpus > 1:
        model = nn.DataParallel(model)
    model = model.cuda()
    logging.info("param count = %d", utils.count_parameters(model))
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, valid_transform = utils.data_transforms(
        args.dataset, args.cutout, args.cutout_length)
    if args.dataset == "CIFAR100":
        train_data = dset.CIFAR100(root=args.tmp_data_dir,
                                   train=True,
                                   download=True,
                                   transform=train_transform)
        valid_data = dset.CIFAR100(root=args.tmp_data_dir,
                                   train=False,
                                   download=True,
                                   transform=valid_transform)
    elif args.dataset == "CIFAR10":
        train_data = dset.CIFAR10(root=args.tmp_data_dir,
                                  train=True,
                                  download=True,
                                  transform=train_transform)
        valid_data = dset.CIFAR10(root=args.tmp_data_dir,
                                  train=False,
                                  download=True,
                                  transform=valid_transform)
    elif args.dataset == 'MIT67':
        dset_cls = dset.ImageFolder
        data_path = '%s/MIT67/train' % args.tmp_data_dir  # 'data/MIT67/train'
        val_path = '%s/MIT67/test' % args.tmp_data_dir  # 'data/MIT67/val'
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)
    elif args.dataset == 'Sport8':
        dset_cls = dset.ImageFolder
        data_path = '%s/Sport8/train' % args.tmp_data_dir  # 'data/Sport8/train'
        val_path = '%s/Sport8/test' % args.tmp_data_dir  # 'data/Sport8/val'
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)
    elif args.dataset == "flowers102":
        dset_cls = dset.ImageFolder
        data_path = '%s/flowers102/train' % args.tmp_data_dir
        val_path = '%s/flowers102/test' % args.tmp_data_dir
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)

    train_iterator = utils.DynamicBatchSizeLoader(
        torch.utils.data.DataLoader(train_data,
                                    batch_size=args.batch_multiples,
                                    shuffle=True,
                                    pin_memory=True,
                                    num_workers=args.workers),
        args.batch_size_min)
    test_iterator = utils.DynamicBatchSizeLoader(
        torch.utils.data.DataLoader(valid_data,
                                    batch_size=args.batch_multiples,
                                    shuffle=False,
                                    pin_memory=True,
                                    num_workers=args.workers),
        args.batch_size_min)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))
    best_acc = 0.0
    for epoch in range(args.epochs):
        lr = scheduler.get_lr()[0]
        drop_path_prob = args.drop_path_prob * epoch / args.epochs
        logging.info('Epoch: %d lr %e', epoch, lr)
        if num_gpus > 1:
            model.module.drop_path_prob = drop_path_prob
        else:
            model.drop_path_prob = drop_path_prob
        epoch_start_time = time.time()
        train_acc, train_obj = train(train_iterator, test_iterator, model,
                                     criterion, optimizer, gpu_logger)
        logging.info('Train_acc: %f', train_acc)

        test_acc, test_obj = infer(test_iterator, model, criterion)
        if test_acc > best_acc:
            best_acc = test_acc
        logging.info('Valid_acc: %f', test_acc)
        epoch_duration = time.time() - epoch_start_time
        utils.save(model, os.path.join(args.save, 'weights.pt'))

        # log info
        print('Epoch time: %ds.' % epoch_duration)
        writer.add_scalar('epoch/lr', lr, epoch)
        writer.add_scalar('epoch/drop_path_prob', drop_path_prob, epoch)
        writer.add_scalar('epoch/seconds', epoch_duration, epoch)
        writer.add_scalar('epoch/batch_size', train_iterator.batch_size, epoch)
        writer.add_scalar('train/accuracy', train_acc, epoch)
        writer.add_scalar('train/loss', train_obj, epoch)
        writer.add_scalar('test/accuracy', test_acc, epoch)
        writer.add_scalar('test/loss', test_obj, epoch)

        scheduler.step()
    gpu_logger.stop()
Ejemplo n.º 18
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled=True
    torch.cuda.manual_seed(args.seed)
    logging.info('GPU device = %d' % args.gpu)
    logging.info("args = %s", args)
    #  prepare dataset
    train_transform, valid_transform = utils.data_transforms(args.dataset,args.cutout,args.cutout_length)
    if args.dataset == "CIFAR100":
        train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
    elif args.dataset == "CIFAR10":
        train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
    elif args.dataset == 'mit67':
        dset_cls = dset.ImageFolder
        data_path = '%s/MIT67/train' % args.tmp_data_dir  # 'data/MIT67/train'
        val_path = '%s/MIT67/test' % args.tmp_data_dir  # 'data/MIT67/val'
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)
    elif args.dataset == 'sport8':
        dset_cls = dset.ImageFolder
        data_path = '%s/Sport8/train' % args.tmp_data_dir  # 'data/Sport8/train'
        val_path = '%s/Sport8/test' % args.tmp_data_dir  # 'data/Sport8/val'
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))
    random.shuffle(indices)
    
    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True, num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory=True, num_workers=args.workers)
    
    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    switches = []
    for i in range(14):
        switches.append([True for j in range(len(PRIMITIVES))])
    switches_normal = copy.deepcopy(switches)
    switches_reduce = copy.deepcopy(switches)
    # To be moved to args
    num_to_keep = [5, 3, 1]
    num_to_drop = [3, 2, 2]
    if len(args.add_width) == 3:
        add_width = args.add_width
    else:
        add_width = [0, 0, 0]
    if len(args.add_layers) == 3:
        add_layers = args.add_layers
    else:
        add_layers = [0, 3, 6]
    if len(args.dropout_rate) ==3:
        drop_rate = args.dropout_rate
    else:
        drop_rate = [0.0, 0.0, 0.0]
    eps_no_archs = [10, 10, 10]
    for sp in range(len(num_to_keep)):
        model = Network(args.init_channels + int(add_width[sp]), CLASSES, args.layers + int(add_layers[sp]), criterion, switches_normal=switches_normal, switches_reduce=switches_reduce, p=float(drop_rate[sp]), largemode=args.dataset in utils.LARGE_DATASETS)
        
        model = model.cuda()
        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
        network_params = []
        for k, v in model.named_parameters():
            if not (k.endswith('alphas_normal') or k.endswith('alphas_reduce')):
                network_params.append(v)       
        optimizer = torch.optim.SGD(
                network_params,
                args.learning_rate,
                momentum=args.momentum,
                weight_decay=args.weight_decay)
        optimizer_a = torch.optim.Adam(model.arch_parameters(),
                    lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, float(args.epochs), eta_min=args.learning_rate_min)
        sm_dim = -1
        epochs = args.epochs
        eps_no_arch = eps_no_archs[sp]
        scale_factor = 0.2
        for epoch in range(epochs):
            scheduler.step()
            lr = scheduler.get_lr()[0]
            logging.info('Epoch: %d lr: %e', epoch, lr)
            epoch_start = time.time()
            # training
            if epoch < eps_no_arch:
                model.p = float(drop_rate[sp]) * (epochs - epoch - 1) / epochs
                model.update_p()
                train_acc, train_obj = train(train_queue, valid_queue, model, network_params, criterion, optimizer, optimizer_a, lr, train_arch=False)
            else:
                model.p = float(drop_rate[sp]) * np.exp(-(epoch - eps_no_arch) * scale_factor) 
                model.update_p()                
                train_acc, train_obj = train(train_queue, valid_queue, model, network_params, criterion, optimizer, optimizer_a, lr, train_arch=True)
            logging.info('Train_acc %f', train_acc)
            epoch_duration = time.time() - epoch_start
            logging.info('Epoch time: %ds', epoch_duration)
            # validation
            if epochs - epoch < 5:
                valid_acc, valid_obj = infer(valid_queue, model, criterion)
                logging.info('Valid_acc %f', valid_acc)
        utils.save(model, os.path.join(args.save, 'weights.pt'))
        print('------Dropping %d paths------' % num_to_drop[sp])
        # Save switches info for s-c refinement. 
        if sp == len(num_to_keep) - 1:
            switches_normal_2 = copy.deepcopy(switches_normal)
            switches_reduce_2 = copy.deepcopy(switches_reduce)
        # drop operations with low architecture weights
        arch_param = model.arch_parameters()
        normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()        
        for i in range(14):
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_normal[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                drop = get_min_k_no_zero(normal_prob[i, :], idxs, num_to_drop[sp])
            else:
                drop = get_min_k(normal_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_normal[i][idxs[idx]] = False
        reduce_prob = F.softmax(arch_param[1], dim=-1).data.cpu().numpy()
        for i in range(14):
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_reduce[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                drop = get_min_k_no_zero(reduce_prob[i, :], idxs, num_to_drop[sp])
            else:
                drop = get_min_k(reduce_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_reduce[i][idxs[idx]] = False
        logging.info('switches_normal = %s', switches_normal)
        logging_switches(switches_normal)
        logging.info('switches_reduce = %s', switches_reduce)
        logging_switches(switches_reduce)
        
        if sp == len(num_to_keep) - 1:
            arch_param = model.arch_parameters()
            normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()
            reduce_prob = F.softmax(arch_param[1], dim=sm_dim).data.cpu().numpy()
            normal_final = [0 for idx in range(14)]
            reduce_final = [0 for idx in range(14)]
            # remove all Zero operations
            for i in range(14):
                if switches_normal_2[i][0] == True:
                    normal_prob[i][0] = 0
                normal_final[i] = max(normal_prob[i])
                if switches_reduce_2[i][0] == True:
                    reduce_prob[i][0] = 0
                reduce_final[i] = max(reduce_prob[i])                
            # Generate Architecture
            keep_normal = [0, 1]
            keep_reduce = [0, 1]
            n = 3
            start = 2
            for i in range(3):
                end = start + n
                tbsn = normal_final[start:end]
                tbsr = reduce_final[start:end]
                edge_n = sorted(range(n), key=lambda x: tbsn[x])
                keep_normal.append(edge_n[-1] + start)
                keep_normal.append(edge_n[-2] + start)
                edge_r = sorted(range(n), key=lambda x: tbsr[x])
                keep_reduce.append(edge_r[-1] + start)
                keep_reduce.append(edge_r[-2] + start)
                start = end
                n = n + 1
            for i in range(14):
                if not i in keep_normal:
                    for j in range(len(PRIMITIVES)):
                        switches_normal[i][j] = False
                if not i in keep_reduce:
                    for j in range(len(PRIMITIVES)):
                        switches_reduce[i][j] = False
            # translate switches into genotype
            genotype = parse_network(switches_normal, switches_reduce)
            logging.info(genotype)
            ## restrict skipconnect (normal cell only)
            logging.info('Restricting skipconnect...')
            for sks in range(0, len(PRIMITIVES)+1):
                max_sk = len(PRIMITIVES) - sks
                num_sk = check_sk_number(switches_normal)
                if num_sk < max_sk:
                    continue
                while num_sk > max_sk:
                    normal_prob = delete_min_sk_prob(switches_normal, switches_normal_2, normal_prob)
                    switches_normal = keep_1_on(switches_normal_2, normal_prob)
                    switches_normal = keep_2_branches(switches_normal, normal_prob)
                    num_sk = check_sk_number(switches_normal)
                logging.info('Number of skip-connect: %d', max_sk)
                genotype = parse_network(switches_normal, switches_reduce)
                logging.info(genotype)
    with open(args.save + "/best_genotype.txt", "w") as f:
        f.write(str(genotype))
def main():
    # Define Dataset
    assert args.dataset in ['cifar10', 'imagenet']
    train_transform, valid_transform = data_transforms(args)
    if args.dataset == 'cifar10':
        trainset = torchvision.datasets.CIFAR10(root=os.path.join(
            args.data_root, args.dataset),
                                                train=True,
                                                download=True,
                                                transform=train_transform)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=8)
        valset = torchvision.datasets.CIFAR10(root=os.path.join(
            args.data_root, args.dataset),
                                              train=False,
                                              download=True,
                                              transform=valid_transform)
        val_loader = torch.utils.data.DataLoader(valset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=8)
    elif args.dataset == 'imagenet':
        train_data_set = datasets.ImageNet(
            os.path.join(args.data_root, args.dataset, 'train'),
            train_transform)
        val_data_set = datasets.ImageNet(
            os.path.join(args.data_root, args.dataset, 'valid'),
            valid_transform)
        train_loader = torch.utils.data.DataLoader(train_data_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=8,
                                                   pin_memory=True,
                                                   sampler=None)
        val_loader = torch.utils.data.DataLoader(val_data_set,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=8,
                                                 pin_memory=True)
    else:
        raise ValueError('Undefined dataset !!!')

    # Define Choice Model
    choice = [1, 0, 3, 1, 3, 0, 3, 0, 0, 3, 3, 0, 1, 0, 1, 2, 2, 1, 1, 3]
    model = SinglePath_Network(args.dataset, args.resize, args.classes,
                               args.layers, choice)
    criterion = nn.CrossEntropyLoss().to(args.device)
    optimizer = torch.optim.SGD(model.parameters(), args.learning_rate,
                                args.momentum, args.weight_decay)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lambda epoch: 1 - (epoch / args.epochs))

    # Print Model Information
    flops, params = profile(
        model,
        inputs=(torch.randn(1, 3, 32, 32), ) if args.dataset == 'cifar10' else
        (torch.randn(1, 3, 224, 224), ),
        verbose=False)
    model = model.to(args.device)
    logging.info(model)
    logging.info('Choice Model Information: params: %.2fM, flops:%.2fM' %
                 ((params / 1e6), (flops / 1e6)))
    print('\n')

    # Running
    start = time.time()
    best_val_acc = 0.0
    for epoch in range(args.epochs):
        # Choice Model Training
        train_loss, train_acc = train(args, epoch, train_loader, model,
                                      criterion, optimizer)
        scheduler.step()
        logging.info(
            '[Model Training] epoch: %03d, train_loss: %.3f, train_acc: %.3f' %
            (epoch + 1, train_loss, train_acc))
        # Choice Model Validation
        val_loss, val_acc = validate(args, val_loader, model, criterion)
        # Save Best Supernet Weights
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_ckpt = os.path.join(args.ckpt_dir,
                                     '%s_%s' % (args.exp_name, 'best.pth'))
            torch.save(model.state_dict(), best_ckpt)
            logging.info('Save best checkpoints to %s' % best_ckpt)
        logging.info(
            '[Model Validation] epoch: %03d, val_loss: %.3f, val_acc: %.3f, best_acc: %.3f'
            % (epoch + 1, val_loss, val_acc, best_val_acc))
        print('\n')

    # Record Time
    utils.time_record(start)