Exemplo n.º 1
0
def main():
    # parse arg and start experiment
    global args

    args = arg_parser.parse_args()
    args.config_of_data = config.datasets[args.data]
    args.num_classes = config.datasets[args.data]['num_classes']

    # resume from a checkpoint
    print("=> loading checkpoint '{}'".format(args.resume))
    checkpoint = torch.load(args.resume)
    old_args = checkpoint['args']
    print('Old args:')
    print(old_args)
    # set args based on checkpoint
    for name in arch_resume_names:
        if name in vars(args) and name in vars(old_args):
            setattr(args, name, getattr(old_args, name))

    model = getModel(**vars(args))
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}'".format(args.resume))

    cudnn.benchmark = True

    # check if the folder exists
    create_save_folder(args.save, args.force)

    # create dataloader
    if args.data == 'val_2cls':
        loader = validation_2cls.val_2cls()
    elif args.data == 'val_3cls':
        loader = val_3cls.val_3cls()
    else:
        raise NotImplemented

    img_list = loader.get_img_list()

    for i, (img_name, _) in enumerate(img_list):
        img, truth = loader.get_item(i)
        input = img[np.newaxis, :]
        predict = inference(input, model)

        log_path = os.path.join(args.save, 'val_result.csv')

        with open(log_path, 'a') as file:
            content = img_name + ',' + str(
                predict.data[0]) + ',' + str(truth) + '\n'
            file.write(content)
Exemplo n.º 2
0
def main():
    # parse arg and start experiment
    global args
    best_acc = 0.
    best_epoch = 0

    args = arg_parser.parse_args()
    args.config_of_data = config.datasets[args.data]
    args.num_classes = config.datasets[args.data]['num_classes']

    # limit the gpu id to use
    # WARNING: This assignment should be down at the beginning, in case of different assignment for different parts
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    gpu_id = []
    for i in range(len(args.gpu_id.split(','))):
        gpu_id.append(i)

    if configure is None:
        args.tensorboard = False
        print(Fore.RED +
              'WARNING: you don\'t have tesnorboard_logger installed' +
              Fore.RESET)

    # optionally resume from a checkpoint
    if args.resume:
        if args.resume and os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            old_args = checkpoint['args']
            print('Old args:')
            print(old_args)
            # set args based on checkpoint
            if args.start_epoch <= 0:
                args.start_epoch = checkpoint['epoch'] + 1
            best_epoch = args.start_epoch - 1
            print('Epoch recovered:%d' % checkpoint['epoch'])
            best_acc = checkpoint['best_acc']
            for name in arch_resume_names:
                if name in vars(args) and name in vars(old_args):
                    setattr(args, name, getattr(old_args, name))

            model = getModel(**vars(args))

            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}'"
                  .format(args.resume))
        else:
            print(
                "=> no checkpoint found at '{}'".format(
                    Fore.RED +
                    args.resume +
                    Fore.RESET),
                file=sys.stderr)
            return
    elif args.pretrain:
        # create model
        print("=> creating model '{}'".format(args.arch))
        model = getModel(**vars(args))
        model = load_pretrained_diff_parameter(model, args.pretrain)

        print("=> pre-train weights loaded")
    else:
        # create model
        print("=> creating model '{}'".format(args.arch))
        model = getModel(**vars(args))

    model = torch.nn.DataParallel(model, device_ids=gpu_id).cuda()

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # define optimizer
    optimizer = get_optimizer(model, args)

    # set random seed
    torch.manual_seed(args.seed)

    Trainer = import_module(args.trainer).Trainer
    trainer = Trainer(model, criterion, optimizer, args)

    # create dataloader
    if args.eval == 'train':
        train_loader, _, _ = getDataloaders(args.data,
                                            splits=('train'), batch_size=args.batch_size)
        trainer.test(train_loader, best_epoch)
        return
    elif args.eval == 'val':
        _, val_loader, _ = getDataloaders(args.data,
                                          splits=('val'), batch_size=args.batch_size)
        trainer.test(val_loader, best_epoch)
        return
    elif args.eval == 'test':
        _, _, test_loader = getDataloaders(args.data, splits=('test'), batch_size=args.batch_size)
        trainer.test(test_loader, best_epoch)
        return
    else:
        train_loader, val_loader, _ = getDataloaders(args.data,
                                                     splits=('train', 'val'),
                                                     batch_size=args.batch_size)

    # check if the folder exists
    create_save_folder(args.save, args.force)

    # set up logging
    global log_print, f_log
    f_log = open(os.path.join(args.save, 'log.txt'), 'w')

    def log_print(*args):
        print(*args)
        print(*args, file=f_log)

    log_print('args:')
    log_print(args)
    # print('model:', file=f_log)
    # print(model, file=f_log)
    f_log.flush()
    log_print('# of params:',
              str(sum([p.numel() for p in model.parameters()])))

    torch.save(args, os.path.join(args.save, 'args.pth'))
    if args.tensorboard:
        configure(args.save, flush_secs=5)

    for epoch in range(args.start_epoch, args.epochs + 1):

        # train for one epoch
        train_loss, train_acc, lr = trainer.train(
            train_loader, epoch)

        if args.tensorboard:
            log_value('lr', lr, epoch)
            log_value('train_loss', train_loss, epoch)
            log_value('train_acc', train_acc, epoch)

        # evaluate on validation set
        val_loss, val_acc, recall, precision, f1, acc = trainer.test(val_loader, epoch, silence=True)

        if args.tensorboard:
            log_value('val_loss', val_loss, epoch)
            log_value('val_acc', val_acc, epoch)
            # log recall, precision and f1 value for every class
            # labels should be sequential natural numbers like 0,1,2....
            for i in range(args.num_classes):
                try:
                    log_value('cls_' + str(i) + '_recall', recall[i], epoch)
                except:
                    log_value('cls_' + str(i) + '_recall', 0, epoch)
                try:
                    log_value('cls_' + str(i) + '_precision', precision[i], epoch)
                except:
                    log_value('cls_' + str(i) + '_precision', 0, epoch)
                try:
                    log_value('cls_' + str(i) + '_f1', f1[i], epoch)
                except:
                    log_value('cls_' + str(i) + '_f1', 0, epoch)
                try:
                    log_value('cls_' + str(i) + 'acc', acc[i], epoch)
                except:
                    log_value('cls_' + str(i) + 'acc', 0, epoch)

        # save scores to a tsv file, rewrite the whole file to prevent
        # accidental deletion
        print(('epoch:{}\tlr:{}\ttrain_loss:{:.4f}\ttrain_acc:{:.4f}\tval_loss:{:.4f}\tval_acc:{:.4f}')
                      .format(epoch, lr, train_loss,train_acc, val_loss, val_acc), file=f_log)
        for i in range(args.num_classes):
            try:
                print(('cls_{}_recall: {:.4f}').format(i, recall[i]), file=f_log)
            except:
                print(('cls_{}_recall: {:.4f}').format(i, 0), file=f_log)
            try:
                print(('cls_{}_precision: {:.4f}').format(i, precision[i]), file=f_log)
            except:
                print(('cls_{}_precision: {:.4f}').format(i, 0), file=f_log)
            try:
                print(('cls_{}_f1: {:.4f}').format(i, f1[i]), file=f_log)
            except:
                print(('cls_{}_f1: {:.4f}').format(i, 0), file=f_log)
            try:
                print(('cls_{}_acc: {:.4f}').format(i, acc[i]), file=f_log)
            except:
                print(('cls_{}_acc: {:.4f}').format(i, 0), file=f_log)
        f_log.flush()

        # remember best err@1 and save checkpoint
        is_best = val_acc > best_acc
        if is_best:
            best_acc = val_acc
            best_epoch = epoch
            print(Fore.GREEN + 'Best var_acc {}'.format(best_acc) + Fore.RESET, file=f_log)
        f_log.flush()

        dict = {
            'args': args,
            'epoch': epoch,
            'best_epoch': best_epoch,
            'arch': args.arch,
            'state_dict': model.module.state_dict(),
            'best_acc': best_acc,
        }
        # state_dict: model.state_dict() will add "module" layer in front of every model. The reading of this kind of
        # checkpoint requires to initialize model with DataParallel before resuming.
        save_checkpoint(dict, is_best, args.save, filename='checkpoint_' + str(epoch) + '.pth.tar')
        if not is_best and epoch - best_epoch >= args.patience > 0:
            break
    print('Best best_acc: {:.4f} at epoch {}'.format(best_acc, best_epoch), file=f_log)
Exemplo n.º 3
0
def main():
    global args
    best_loss = 1.e12
    best_epoch = 0

    args = arg_parser.parse_args()
    args.config_of_data = config.datasets[args.data]
    args.num_classes = config.datasets[args.data]['num_classes']

    if configure is None:
        args.tensorboard = False
        print(Fore.RED +
              'WARNING: you don\'t have tesnorboard_logger installed' +
              Fore.RESET)

    # optionally resume from a checkpoint
    if args.resume:
        if args.resume and os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            old_args = checkpoint['args']
            print('Old args:')
            print(old_args)
            # set args based on checkpoint
            # TODO: necessary?
            if args.start_epoch <= 0:
                args.start_epoch = checkpoint['epoch'] + 1
            best_epoch = checkpoint['best_epoch']
            best_loss = checkpoint['best_loss']
            for name in arch_resume_names:
                if name in vars(args) and name in vars(old_args):
                    setattr(args, name, getattr(old_args, name))
            # TODO
            # model = getModel(**vars(args))
            print("=> creating model")
            model = smlp.Spatial_MLP()
            model = nn.DataParallel(model, device_ids=[0, 1, 2, 3]).cuda()
            print(model)
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        else:
            print("=> no checkpoint found at '{}'".format(Fore.RED +
                                                          args.resume +
                                                          Fore.RESET),
                  file=sys.stderr)
            return

    else:
        # create model
        print("=> creating model")
        model = smlp.Spatial_MLP()
        model = nn.DataParallel(model, device_ids=[0, 1, 2, 3]).cuda()
        print(model)

    # define loss function (criterion)
    criterion = nn.BCEWithLogitsLoss().cuda()

    # define optimizer
    optimizer = get_optimizer(model, args)

    # Trainer
    trainer = Trainer(model, criterion, optimizer, args)

    # Load data
    print("=> loading data")
    dtrain, dval = dataloader(args.batch_size)

    # Evaluate
    stt = time.time()
    if args.evaluate == 'train':
        print("=> evaluating model for training set from epoch {}".format(
            best_epoch))
        train_loss, train_err = trainer.test(dtrain, best_epoch)
        print('Done in {:.3f}s'.format(time.time() - stt))
        return
    elif args.evaluate == 'val':
        print("=> evaluating model for testing set from epoch {}".format(
            best_epoch))
        val_loss, val_err = trainer.test(dval, best_epoch)
        print('Done in {:.3f}s'.format(time.time() - stt))
        return

    # check if the folder exists
    create_save_folder(args.save, args.force)

    # set up logging
    global log_print, f_log
    f_log = open(os.path.join(args.save, 'log.txt'), 'w')

    def log_print(*args):
        print(*args)
        print(*args, file=f_log)

    log_print('Task: ', CLASS)
    log_print('args:')
    log_print(args)
    print('model:', file=f_log)
    print(model, file=f_log)
    log_print('# of params:',
              str(sum([p.numel() for p in model.parameters()])))

    f_log.flush()
    torch.save(args, os.path.join(args.save, 'args.pth'))
    scores = [
        'epoch\tlr\ttrain_loss\tval_loss\ttrain_err1'
        '\tval_err1\ttrain_err5\tval_err'
    ]

    if args.tensorboard:
        configure(args.save, flush_secs=5)

    print("=> training")
    for epoch in range(args.start_epoch, args.epochs + args.start_epoch):

        # train for one epoch
        train_loss, lr = trainer.train(dtrain, epoch)

        if args.tensorboard:
            log_value('lr', lr, epoch)
            log_value('train_loss', train_loss, epoch)

        # evaluate on validation set
        val_loss, val_err = trainer.test(dval, epoch)

        if args.tensorboard:
            log_value('val_loss', val_loss, epoch)

        # save scores to a tsv file, rewrite the whole file to prevent
        # accidental deletion
        scores.append(
            ('{}\t{}' + '\t{:.4f}' * 2).format(epoch, lr, train_loss,
                                               val_loss))
        with open(os.path.join(args.save, 'scores.tsv'), 'w') as f:
            print('\n'.join(scores), file=f)

        # remember best err@1 and save checkpoint
        is_best = val_loss < best_loss
        if is_best:
            best_loss = val_loss
            best_epoch = epoch
            print(Fore.GREEN + 'Best var_err1 {}'.format(best_loss) +
                  Fore.RESET)
            # test_loss, test_err1, test_err1 = validate(
            #     test_loader, model, criterion, epoch, True)
            # save test
        save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                'best_epoch': best_epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_loss': best_loss,
            }, is_best, args.save)
        if not is_best and epoch - best_epoch >= args.patience > 0:
            break

    print('Best val_loss: {:.4f} at epoch {}'.format(best_loss, best_epoch))
Exemplo n.º 4
0
def train_model():
    # parse arg and start experiment
    global args
    best_err1 = 100.
    best_epoch = 0

    args = arg_parser.parse_args()
    args.config_of_data = config.datasets[args.data]
    args.num_classes = config.datasets[args.data]['num_classes']
    if configure is None:
        args.tensorboard = False
        print(Fore.RED +
              'WARNING: you don\'t have tesnorboard_logger installed' +
              Fore.RESET)

    # optionally resume from a checkpoint
    if args.resume:
        if args.resume and os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            old_args = checkpoint['args']
            print('Old args:')
            print(old_args)
            # set args based on checkpoint
            if args.start_epoch <= 0:
                args.start_epoch = checkpoint['epoch'] + 1
            best_epoch = args.start_epoch - 1
            best_err1 = checkpoint['best_err1']
            for name in arch_resume_names:
                if name in vars(args) and name in vars(old_args):
                    setattr(args, name, getattr(old_args, name))
            model = getModel(**vars(args))
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print(
                "=> no checkpoint found at '{}'".format(
                    Fore.RED +
                    args.resume +
                    Fore.RESET),
                file=sys.stderr)
            return
    else:
        # create model
        print("=> creating model '{}'".format(args.arch))
        model = getModel(**vars(args))

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # define optimizer
    optimizer = get_optimizer(model, args)

    # set random seed
    torch.manual_seed(args.seed)

    Trainer = import_module(args.trainer).Trainer
    trainer = Trainer(model, criterion, optimizer, args)

    # create dataloader
    if args.evaluate == 'train':
        train_loader, _, _ = getDataloaders(
            splits=('train'), **vars(args))
        trainer.test(train_loader, best_epoch)
        return
    elif args.evaluate == 'val':
        _, val_loader, _ = getDataloaders(
            splits=('val'), **vars(args))
        trainer.test(val_loader, best_epoch)
        return
    elif args.evaluate == 'test':
        _, _, test_loader = getDataloaders(
            splits=('test'), **vars(args))
        trainer.test(test_loader, best_epoch)
        return
    else:
        train_loader, val_loader, _ = getDataloaders(
            splits=('train', 'val'), **vars(args))

    # check if the folder exists
    create_save_folder(args.save, args.force)

    # set up logging
    global log_print, f_log
    f_log = open(os.path.join(args.save, 'log.txt'), 'w')

    def log_print(*args):
        print(*args)
        print(*args, file=f_log)
    log_print('args:')
    log_print(args)
    print('model:', file=f_log)
    print(model, file=f_log)
    log_print('# of params:',
              str(sum([p.numel() for p in model.parameters()])))
    f_log.flush()
    torch.save(args, os.path.join(args.save, 'args.pth'))
    scores = ['epoch\tlr\ttrain_loss\tval_loss\ttrain_err1'
              '\tval_err1\ttrain_err5\tval_err']
    if args.tensorboard:
        configure(args.save, flush_secs=5)

    for epoch in range(args.start_epoch, args.epochs + 1):

        # train for one epoch
        train_loss, train_err1, train_err5, lr = trainer.train(
            train_loader, epoch)

        if args.tensorboard:
            log_value('lr', lr, epoch)
            log_value('train_loss', train_loss, epoch)
            log_value('train_err1', train_err1, epoch)
            log_value('train_err5', train_err5, epoch)

        # evaluate on validation set
        val_loss, val_err1, val_err5 = trainer.test(val_loader, epoch)

        if args.tensorboard:
            log_value('val_loss', val_loss, epoch)
            log_value('val_err1', val_err1, epoch)
            log_value('val_err5', val_err5, epoch)

        # save scores to a tsv file, rewrite the whole file to prevent
        # accidental deletion
        scores.append(('{}\t{}' + '\t{:.4f}' * 6)
                      .format(epoch, lr, train_loss, val_loss,
                              train_err1, val_err1, train_err5, val_err5))
        with open(os.path.join(args.save, 'scores.tsv'), 'w') as f:
            print('\n'.join(scores), file=f)

        # remember best err@1 and save checkpoint
        is_best = val_err1 < best_err1
        if is_best:
            best_err1 = val_err1
            best_epoch = epoch
            print(Fore.GREEN + 'Best var_err1 {}'.format(best_err1) +
                  Fore.RESET)
            # test_loss, test_err1, test_err1 = validate(
            #     test_loader, model, criterion, epoch, True)
            # save test
        save_checkpoint({
            'args': args,
            'epoch': epoch,
            'best_epoch': best_epoch,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_err1': best_err1,
        }, is_best, args.save)
        if not is_best and epoch - best_epoch >= args.patience > 0:
            break
    print('Best val_err1: {:.4f} at epoch {}'.format(best_err1, best_epoch))
Exemplo n.º 5
0
def eval_superpixel():
     # parse arg and start experiment
    global args
    args = arg_parser.parse_args()
    args.config_of_data = config.datasets[args.data]
    args.num_classes = config.datasets[args.data]['num_classes']
    if configure is None:
        args.tensorboard = False
        print(Fore.RED +
              'WARNING: you don\'t have tesnorboard_logger installed' +
              Fore.RESET)

    model = getModel(**vars(args))
    saved_checkpoint = torch.load("./saved_checkpoints/cifar10+-resnet-56/model_best.pth.tar")
    model.load_state_dict(saved_checkpoint['state_dict'])
    
    model.eval()

    # get test images
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=1,
                                         shuffle=False, num_workers=2)

    count = 0
    for images, labels in test_loader: 

        count +=1

        if count > 5:
            break
        # show images
       # imshow(torchvision.utils.make_grid(images), ' '.join('%5s' % classes[labels[j]] for j in range(1)))
        
        if use_cuda == True:
            images, labels = images.cuda(), labels.cuda()

        images, labels = Variable(images, volatile=True), Variable(labels)

        org_img = images[0]

        org_img = org_img.type(torch.FloatTensor).data
        org_img = org_img.numpy()
        img = org_img.transpose( 1, 2, 0 )
        img -= img.min()
        img /= img.max()
        img *= 255
        img = img.astype(np.uint8)
       
        # cv2.imshow('org_img_index{}_label_{}.png'.format(count, labels[0].cpu().data.numpy()[0]), img)
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()
   
        if count == 5:

            cv2.imwrite('original_img_index{}_label_{}.png'.format(count, labels[0].cpu().data.numpy()[0]), img)

            segments = felzenszwalb(img_as_float(img), scale=100, sigma=0.5, min_size=10)
            
            print("Felzenszwalb number of segments: {}".format(len(np.unique(segments))))
            

            # cv2.imshow('superpixels', mark_boundaries(img_as_float(img), segments))
            # cv2.waitKey(0)
            # cv2.destroyAllWindows()
            output = model(images)
            pred = output.data.max(1, keepdim=True)[1]
            
       
            correct_pred_count = 0
            wrong_pred_count = 0
            for i in range(1000):               
                random_sampled_list= random.sample(range(np.unique(segments)[0], np.unique(segments)[-1]), 5)
               
                mask = np.zeros(img.shape[:2], dtype= "uint8")
                mask.fill(255)
                for (j, segVal) in enumerate(random_sampled_list):
                    mask[segments == segVal] = 0
                    

                masked_img = org_img * mask
                
                masked_img -= masked_img.min()
                masked_img /= masked_img.max()
                masked_img *= 255
                masked_img = normalize_image(masked_img)

                masked_img_batch = masked_img[None, :, :, :]

            
                masked_img_tensor = Variable(torch.from_numpy(masked_img_batch)).cuda()
                mask_output = model(masked_img_tensor)
                
                pred_mask = mask_output.data.max(1, keepdim=True)[1]
               
                print("pred_mask[0]", pred_mask[0].cpu().numpy()[0])

                if pred_mask[0].cpu().numpy()[0] == labels[0].cpu().data.numpy()[0]:
                    correct_pred_count+=1
                    print("correct_pred_count: ", correct_pred_count)
                    cv2.imwrite('./masks/mask_{}_{}.png'.format(i, 1), mask)
                    cv2.imwrite('./mask_on_img/masked_imgs_{}.png'.format(i), masked_img.transpose(1, 2, 0))
                else:
                    wrong_pred_count+=1
                    print("wrong_pred_count: ", wrong_pred_count)
                    cv2.imwrite('./masks/mask_{}_{}.png'.format(i, 0), mask)
                    cv2.imwrite('./mask_on_img/masked_imgs_{}.png'.format(i), masked_img.transpose(1, 2, 0))
Exemplo n.º 6
0
DEFAULT_IMAGE_SIZE = 224
NUM_CHANNELS = 3
NUM_CLASSES = 1001

NUM_IMAGES = {
    'train': 1281167,
    'validation': 50000,
}

_NUM_TRAIN_FILES = 1024
_SHUFFLE_BUFFER = 10000

DATASET_NAME = 'ImageNet'

args = arg_parser.parse_args()

args.grFactor = list(map(int, args.grFactor.split('-')))
args.bnFactor = list(map(int, args.bnFactor.split('-')))
args.nScales = len(args.grFactor)

if args.data == 'cifar10':
    args.num_classes = 10
elif args.data == 'cifar100':
    args.num_classes = 100
else:
    args.num_classes = 1001

###############################################################################
# Data processing
###############################################################################
Exemplo n.º 7
0
def main():
    # parse arg and start experiment
    global args
    best_err1 = 100.
    best_epoch = 0

    args = arg_parser.parse_args()
    args.config_of_data = config.datasets[args.data]
    args.num_classes = config.datasets[args.data]['num_classes']
    if configure is None:
        args.tensorboard = False
        print(Fore.RED +
              'WARNING: you don\'t have tensorboard_logger installed' +
              Fore.RESET)

    # optionally resume from a checkpoint
    if args.resume:
        if args.resume and os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            global checkpoint
            checkpoint = torch.load(args.resume)
            old_args = checkpoint['args']
            # set args based on checkpoint
            if args.start_epoch <= 0:
                args.start_epoch = checkpoint['epoch'] + 1
            best_epoch = args.start_epoch - 1
            best_err1 = checkpoint['best_err1']
            for name in arch_resume_names: 
            #['arch', 'depth', 'death_mode', 'death_rate', 'growth_rate', 'bn_size', 'compression']
                if name in vars(args) and name in vars(old_args):
                    setattr(args, name, getattr(old_args, name))
            model = getModel(**vars(args))
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print(
                "=> no checkpoint found at '{}'".format(
                    Fore.RED +
                    args.resume +
                    Fore.RESET),
                file=sys.stderr)
            return
    else:
        # create model
        print("=> creating model '{}'".format(args.arch))
        model = getModel(**vars(args))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # define optimizer
    optimizer = get_optimizer(model, args)

    Trainer = import_module(args.trainer).Trainer
    trainer = Trainer(model, criterion, optimizer, args)

    # create dataloader
    if args.evaluate == 'train':
        train_loader, _, _ = getDataloaders(
            splits=('train'), **vars(args))
        trainer.test(train_loader, best_epoch)
        return
    elif args.evaluate == 'val':
        _, val_loader, _ = getDataloaders(
            splits=('val'), **vars(args))
        trainer.test(val_loader, best_epoch)
        return
    elif args.evaluate == 'test':
        _, _, test_loader = getDataloaders(
            splits=('test'), **vars(args))
        if args.test_death_mode == 'none':
            trainer.test(test_loader, best_epoch)
        else:
            print ("Stochastic depth testing...")
            nblocks = (args.depth - 2) // 2
            n = (args.depth - 2) // 6 
            section_reps=[n]*3

            if args.test_death_mode == 'stoch':
                all_top1 = []
                for n in range(nblocks): #drop 0, 1, 2, ..., nblocks-1 blocks
                    print ("Dropping " + str(n)+ " blocks")
                    death_rates_list = [0]*(nblocks-n) + [1]*n
                    test_death_rate = []
                    count = 0
                    for i in range(len(section_reps)):
                        test_death_rate.append(death_rates_list[count:(count+section_reps[i])])
                        count += section_reps[i]
                    model = getModel(test_death_rate=test_death_rate, **vars(args))
                    model.load_state_dict(checkpoint['state_dict'])
                    optimizer = get_optimizer(model, args)
                    trainer = Trainer(model, criterion, optimizer, args)
                    _, top1, _ = trainer.test(test_loader, best_epoch)
                    all_top1.append(top1)

                with open(args.resume.split('/')[1]+'.csv','w') as f:
                    writer = csv.writer(f)
                    rows = zip(range(0, nblocks), all_top1)
                    for row in rows:
                        writer.writerow(row)
            else:
                for n in range(1, 25):
                    all_top1 = []
                    print ("Dropping " + str(n)+ " random blocks")
                    for t in range(10): #randomly remove n blocks for 5 times
                        random_ind = random.sample(range(nblocks), n)
                        print (random_ind)
                        death_rates_list = [0]*nblocks
                        for ind in random_ind:
                            death_rates_list[ind] = 1
                        test_death_rate = []
                        count = 0
                        for i in range(len(section_reps)):
                            test_death_rate.append(death_rates_list[count:(count+section_reps[i])])
                            count += section_reps[i]
                        model = getModel(test_death_rate=test_death_rate, **vars(args))
                        model.load_state_dict(checkpoint['state_dict'])
                        optimizer = get_optimizer(model, args)
                        trainer = Trainer(model, criterion, optimizer, args)
                        _, top1, _ = trainer.test(test_loader, best_epoch)
                        all_top1.append(top1)
                    print (min(all_top1))

        return

    else:
        train_loader, val_loader, _ = getDataloaders(
            splits=('train', 'val'), **vars(args))

    # check if the folder exists
    create_save_folder(args.save, args.force)

    # set up logging
    global log_print, f_log
    f_log = open(os.path.join(args.save, 'log.txt'), 'w')

    def log_print(*args):
        print(*args)
        print(*args, file=f_log)
    log_print('args:')
    log_print(args)
    print('model:', file=f_log)
    print(model, file=f_log)
    log_print('# of params:',
              str(sum([p.numel() for p in model.parameters()])))
    f_log.flush()
    torch.save(args, os.path.join(args.save, 'args.pth'))
    scores = ['epoch\tlr\ttrain_loss\tval_loss\ttrain_err1'
              '\tval_err1\ttrain_err5\tval_err']
    if args.tensorboard:
        configure(args.save, flush_secs=5)

    for epoch in range(args.start_epoch, args.epochs + 1):

        # train for one epoch
        train_loss, train_err1, train_err5, lr = trainer.train(
            train_loader, epoch)

        if args.tensorboard:
            log_value('lr', lr, epoch)
            log_value('train_loss', train_loss, epoch)
            log_value('train_err1', train_err1, epoch)
            log_value('train_err5', train_err5, epoch)

        # evaluate on validation set
        val_loss, val_err1, val_err5 = trainer.test(val_loader, epoch)

        if args.tensorboard:
            log_value('val_loss', val_loss, epoch)
            log_value('val_err1', val_err1, epoch)
            log_value('val_err5', val_err5, epoch)

        # save scores to a tsv file, rewrite the whole file to prevent
        # accidental deletion
        scores.append(('{}\t{}' + '\t{:.4f}' * 6)
                      .format(epoch, lr, train_loss, val_loss,
                              train_err1, val_err1, train_err5, val_err5))
        with open(os.path.join(args.save, 'scores.tsv'), 'w') as f:
            print('\n'.join(scores), file=f)

        # remember best err@1 and save checkpoint
        is_best = val_err1 < best_err1
        if is_best:
            best_err1 = val_err1
            best_epoch = epoch
            print(Fore.GREEN + 'Best var_err1 {}'.format(best_err1) +
                  Fore.RESET)
            # test_loss, test_err1, test_err1 = validate(
            #     test_loader, model, criterion, epoch, True)
            # save test
        save_checkpoint({
            'args': args,
            'epoch': epoch,
            'best_epoch': best_epoch,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_err1': best_err1,
        }, is_best, args.save)
        if not is_best and epoch - best_epoch >= args.patience > 0:
            break
    print('Best val_err1: {:.4f} at epoch {}'.format(best_err1, best_epoch))
def main():
    # parse arg and start experiment
    global args

    args = arg_parser.parse_args()
    args.config_of_data = config.datasets[args.data]
    args.num_classes = config.datasets[args.data]['num_classes']

    # resume from a checkpoint
    print("=> loading checkpoint '{}'".format(args.resume))
    checkpoint = torch.load(args.resume)
    old_args = checkpoint['args']
    print('Old args:')
    print(old_args)
    # set args based on checkpoint
    if args.start_epoch <= 0:
        args.start_epoch = checkpoint['epoch'] + 1
    for name in arch_resume_names:
        if name in vars(args) and name in vars(old_args):
            setattr(args, name, getattr(old_args, name))

    model = getModel(**vars(args))
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}'".format(args.resume))

    cudnn.benchmark = True

    # check if the folder exists
    create_save_folder(args.save, args.force)

    # create dataloader
    loader = val_3cls.val_3cls()
    img_list = loader.get_img_list()

    class SaveFeatures():
        features = None

        def __init__(self, m):
            self.hook = m.register_forward_hook(self.hook_fn)

        def hook_fn(self, module, input, output):
            self.features = ((output.cpu()).data).numpy()

        def remove(self):
            self.hook.remove()

    final_layer = model._modules.get('module').features.conv5_bn_ac

    activated_features = SaveFeatures(final_layer)

    model.eval()

    for i, (img_name, truth) in enumerate(img_list):
        img, truth = loader.get_item(i)
        input = img[np.newaxis, :]

        input_var = torch.autograd.Variable(input, volatile=True)
        output = model(input_var)
        _, predict = torch.max(output, dim=1)

        pred_probabilities = F.softmax(output, dim=1).data.squeeze()
        activated_features.remove()

        topk(pred_probabilities, 1)

        def returnCAM(feature_conv, weight_softmax, class_idx):
            # generate the class activation maps upsample to 256x256
            size_upsample = (256, 256)
            bz, nc, h, w = feature_conv.shape
            output_cam = []
            for idx in class_idx:
                cam = weight_softmax[class_idx].dot(
                    feature_conv.reshape((nc, h * w)))
                cam = cam.reshape(h, w)
                cam = cam - np.min(cam)
                cam_img = cam / np.max(cam)
                cam_img = np.uint8(255 * cam_img)
                output_cam.append(cv2.resize(cam_img, size_upsample))
            return output_cam

        weight_softmax_params = list(
            model._modules.get('module').classifier.parameters())
        weight_softmax = np.squeeze(
            weight_softmax_params[0].cpu().data.numpy())
        class_idx = topk(pred_probabilities, 1)[1].int()
        CAMs = returnCAM(activated_features.features, weight_softmax,
                         class_idx)

        img_path = os.path.join(loader.img_dir, img_name)
        img = cv2.imread(img_path)
        height, width, _ = img.shape
        heatmap = cv2.applyColorMap(cv2.resize(CAMs[0], (width, height)),
                                    cv2.COLORMAP_JET)
        result = heatmap * 0.5 + img * 0.5

        new_name = ('{}_predict{}_truth{}').format(
            img_name.split('.')[0], predict.data[0], truth)
        output_path = os.path.join(args.save, new_name + '_CAM.jpg')
        origin_output_path = os.path.join(args.save, new_name + '.jpg')

        cv2.imwrite(origin_output_path, img)
        cv2.imwrite(output_path, result)