Ejemplo n.º 1
0
def infer(valid_queue, model, criterion, MSELoss, local_rank, epoch):
    model.eval()
    loss_avg = AvgrageMeter()
    Acc_avg = AvgrageMeter()
    infer_time = AvgrageMeter()
    class_acc = ClassAcc(GESTURE_CLASSES)
    for step, (inputs, target, heatmap) in enumerate(valid_queue):
        n = inputs.size(0)
        inputs, target = map(lambda x: x.cuda(local_rank, non_blocking=True),
                             [inputs, target])
        end = time.time()
        if args.Network == 'RAAR3D':
            logits, sk, feature = model(inputs)
            loss_mse = MSELoss(sk, heatmap.cuda(local_rank, non_blocking=True))
            loss_ce = criterion(logits, target)
            loss = loss_ce + args.mse_weight * loss_mse
        else:
            logits, feature = model(inputs)
            loss = criterion(logits, target)

        infer_time.update(time.time() - end)

        accuracy = calculate_accuracy(logits, target)
        if args.distp:
            torch.distributed.barrier()
            reduced_loss = reduce_mean(loss, args.nprocs)
            reduced_acc = reduce_mean(accuracy, args.nprocs)
        else:
            reduced_loss, reduced_acc = loss, accuracy
            class_acc.update(logits, target)

        loss_avg.update(reduced_loss.item(), n)
        Acc_avg.update(reduced_acc.item(), n)

        if step % args.report_freq == 0 and local_rank == 0:
            log_info = {
                'Epoch':
                epoch + 1,
                'Mini-Batch':
                '{:0>4d}/{:0>4d}'.format(
                    step + 1,
                    len(valid_queue.dataset) //
                    (args.batch_size * args.nprocs)),
                'Inference time':
                round(infer_time.avg, 4),
                'LossEC':
                round(loss_avg.avg, 3),
                'Acc':
                round(Acc_avg.avg, 4)
            }
            print_func(log_info)

    if args.show_class_acc and not args.distp:
        import matplotlib.pyplot as plt
        if not os.path.exists(args.demo_dir): os.makedirs(args.demo_dir)
        with open(args.demo_dir + '/class_acc.txt', 'a') as f:
            txt = str(class_acc.result()) + '\n'
            f.writelines(txt)

        with open(args.demo_dir + '/class_acc.txt', 'r') as f:
            data = [eval(l.strip()) for l in f.readlines()]
        fig, ax = plt.subplots()
        plot_color, plot_shape = ['b', 'g', 'r', 'c', 'm', 'y',
                                  'o'], ['-', '--', '-.', ':']
        for i, d in enumerate(data):
            ax.plot(list(map(str, range(GESTURE_CLASSES))),
                    d,
                    '{}{}'.format(random.choice(plot_color),
                                  random.choice(plot_shape)),
                    label='curve{}'.format(i))
        ax.set(xlabel='class',
               ylabel='Acc',
               title='The accuracy rate of each class.')
        ax.grid()
        ax.legend()
        fig.savefig(args.demo_dir + '/class_num.png')
        logging.info('Save done!')

    return Acc_avg.avg, loss_avg.avg
Ejemplo n.º 2
0
def train(train_queue, model, criterion, MSELoss, optimizer, lr, epoch,
          local_rank):
    model.train()
    loss_avg = AvgrageMeter()
    Acc_avg = AvgrageMeter()
    data_time = AvgrageMeter()
    end = time.time()
    for step, (inputs, target, heatmap) in enumerate(train_queue):
        data_time.update(time.time() - end)
        inputs, target = map(lambda x: x.cuda(local_rank, non_blocking=True),
                             [inputs, target])
        if args.Network == 'RAAR3D':
            logits, sk, feature = model(inputs)
            loss_mse = MSELoss(sk, heatmap.cuda(local_rank, non_blocking=True))
            loss_ce = criterion(logits, target)
            loss = loss_ce + args.mse_weight * loss_mse
        else:
            logits, feature = model(inputs)
            loss = criterion(logits, target)

        n = inputs.size(0)
        accuracy = calculate_accuracy(logits, target)
        if args.distp:
            torch.distributed.barrier()
            reduced_loss = reduce_mean(loss, args.nprocs)
            reduced_acc = reduce_mean(accuracy, args.nprocs)
        else:
            reduced_loss, reduced_acc = loss, accuracy

        loss_avg.update(reduced_loss.item(), n)
        Acc_avg.update(reduced_acc.item(), n)

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.module.parameters(), args.grad_clip)
        optimizer.step()

        if step % args.report_freq == 0 and local_rank == 0:
            log_info = {
                'Epoch':
                '{}/{}'.format(epoch + 1, args.epochs),
                'Mini-Batch':
                '{:0>5d}/{:0>5d}'.format(
                    step + 1,
                    len(train_queue.dataset) //
                    (args.batch_size * args.nprocs)),
                'Data time':
                round(data_time.avg, 4),
                'Lr':
                lr,
                'Total Loss':
                round(loss_avg.avg, 3),
                'Acc':
                round(Acc_avg.avg, 4)
            }
            print_func(log_info)
            if args.Network == 'RAAR3D':
                visual = FeatureMap2Heatmap(inputs, feature, heatmap, sk)
                vis.featuremap('Input', visual[0])
                vis.featuremap('HEATMAP', visual[2])
                vis.featuremap('SK', visual[3])
            else:
                visual = FeatureMap2Heatmap(inputs, feature)
                vis.featuremap('Input', visual[0])
            for i, feat in enumerate(visual[1]):
                vis.featuremap('feature{}'.format(i + 1), feat)
        end = time.time()
    return Acc_avg.avg, loss_avg.avg
Ejemplo n.º 3
0
def main(result_path, epoch_num):
    config = json.load(open(os.path.join(result_path, 'config.json')))

    fout_path = os.path.join(result_path, 'test_info.txt')
    fout_file = open(fout_path, 'a+')
    print_func(config, fout_file)

    trsfms = transforms.Compose([
        transforms.Resize((config['general']['image_size'],
                           config['general']['image_size'])),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    model = ALTNet(**config['arch'])
    print_func(model, fout_file)

    state_dict = torch.load(
        os.path.join(result_path,
                     '{}_best_model.pth'.format(config['data_name'])))
    model.load_state_dict(state_dict)

    if config['train']['loss']['name'] == 'CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss(**config['train']['loss']['args'])
    else:
        raise RuntimeError

    device, _ = prepare_device(config['n_gpu'])
    model = model.to(device)
    criterion = criterion.to(device)

    total_accuracy = 0.0
    total_h = np.zeros(epoch_num)
    total_accuracy_vector = []
    for epoch_idx in range(epoch_num):
        test_dataset = ImageFolder(
            data_root=config['general']['data_root'],
            mode='test',
            episode_num=600,
            way_num=config['general']['way_num'],
            shot_num=config['general']['shot_num'],
            query_num=config['general']['query_num'],
            transform=trsfms,
        )

        print_func('The num of the test_dataset: {}'.format(len(test_dataset)),
                   fout_file)

        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=config['test']['batch_size'],
            shuffle=True,
            num_workers=config['general']['workers_num'],
            drop_last=True,
            pin_memory=True)

        print_func('============ Testing on the test set ============',
                   fout_file)
        _, accuracies = validate(test_loader, model, criterion, epoch_idx,
                                 device, fout_file,
                                 config['general']['image2level'],
                                 config['general']['print_freq'])
        test_accuracy, h = mean_confidence_interval(accuracies)
        print_func("Test Accuracy: {}\t h: {}".format(test_accuracy, h[0]),
                   fout_file)

        total_accuracy += test_accuracy
        total_accuracy_vector.extend(accuracies)
        total_h[epoch_idx] = h

    aver_accuracy, _ = mean_confidence_interval(total_accuracy_vector)
    print_func(
        'Aver Accuracy: {:.3f}\t Aver h: {:.3f}'.format(
            aver_accuracy, total_h.mean()), fout_file)
    print_func('............Testing is end............', fout_file)
Ejemplo n.º 4
0
def validate(val_loader, model, criterion, epoch_index, device, fout_file,
             image2level, print_freq):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    accuracies = []

    end = time.time()
    with torch.no_grad():
        for episode_index, (query_images, query_targets, support_images,
                            support_targets) in enumerate(val_loader):

            way_num = len(support_images)
            shot_num = len(support_images[0])
            query_input = torch.cat(query_images, 0)
            query_targets = torch.cat(query_targets, 0)

            if image2level == 'image2task':
                image_list = []
                for images in support_images:
                    image_list.extend(images)
                support_input = [torch.cat(image_list, 0)]
            else:
                raise RuntimeError

            query_input = query_input.to(device)
            query_targets = query_targets.to(device)
            support_input = [item.to(device) for item in support_input]
            # support_targets = support_targets.to(device)

            # calculate the output
            _, output, _ = model(query_input, support_input)
            output = torch.mean(output.view(-1, way_num, shot_num), dim=2)
            loss = criterion(output, query_targets)

            # measure accuracy and record loss
            prec1, _ = accuracy(output, query_targets, topk=(1, 3))
            losses.update(loss.item(), query_input.size(0))
            top1.update(prec1[0], query_input.size(0))
            accuracies.append(prec1)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # print the intermediate results
            if episode_index % print_freq == 0 and episode_index != 0:
                info_str = (
                    'Test-({0}): [{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
                    'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                        epoch_index,
                        episode_index,
                        len(val_loader),
                        batch_time=batch_time,
                        loss=losses,
                        top1=top1))
                print_func(info_str, fout_file)
    return top1.avg, accuracies
Ejemplo n.º 5
0
def train(train_loader, model, criterion, optimizer, epoch_index, device,
          fout_file, image2level, print_freq):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.train()

    end = time.time()
    for episode_index, (query_images, query_targets, support_images,
                        support_targets) in enumerate(train_loader):
        data_time.update(time.time() - end)

        way_num = len(support_images)
        shot_num = len(support_images[0])
        query_input = torch.cat(query_images, 0)
        query_targets = torch.cat(query_targets, 0)
        # support_targets = torch.cat(support_targets, 0)

        if image2level == 'image2task':
            image_list = []
            for images in support_images:
                image_list.extend(images)
            support_input = [torch.cat(image_list, 0)]
        else:
            raise RuntimeError

        query_input = query_input.to(device)
        query_targets = query_targets.to(device)
        support_input = [item.to(device) for item in support_input]
        # support_targets = support_targets.to(device)

        # calculate the output
        _, output, _ = model(query_input, support_input)
        output = torch.sum(output.view(-1, way_num, shot_num), dim=2)
        loss = criterion(output, query_targets)

        # compute gradients
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        prec1, _ = accuracy(output, query_targets, topk=(1, 3))
        losses.update(loss.item(), query_input.size(0))
        top1.update(prec1[0], query_input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print the intermediate results
        if episode_index % print_freq == 0 and episode_index != 0:
            info_str = ('Eposide-({0}): [{1}/{2}]\t'
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                        'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                            epoch_index,
                            episode_index,
                            len(train_loader),
                            batch_time=batch_time,
                            data_time=data_time,
                            loss=losses,
                            top1=top1))
            print_func(info_str, fout_file)
Ejemplo n.º 6
0
def main(config):
    result_name = '{}_{}_{}way_{}shot'.format(
        config['data_name'],
        config['arch']['base_model'],
        config['general']['way_num'],
        config['general']['shot_num'],
    )
    save_path = os.path.join(config['general']['save_root'], result_name)
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    fout_path = os.path.join(save_path, 'train_info.txt')
    fout_file = open(fout_path, 'a+')
    with open(os.path.join(save_path, 'config.json'), 'w') as handle:
        json.dump(config, handle, indent=4, sort_keys=True)
    print_func(config, fout_file)

    train_trsfms = transforms.Compose([
        transforms.Resize((config['general']['image_size'],
                           config['general']['image_size'])),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    val_trsfms = transforms.Compose([
        transforms.Resize((config['general']['image_size'],
                           config['general']['image_size'])),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    model = ALTNet(**config['arch'])
    print_func(model, fout_file)

    optimizer = optim.Adam(model.parameters(), lr=config['train']['optim_lr'])

    if config['train']['lr_scheduler']['name'] == 'StepLR':
        lr_scheduler = optim.lr_scheduler.StepLR(
            optimizer=optimizer, **config['train']['lr_scheduler']['args'])
    elif config['train']['lr_scheduler']['name'] == 'MultiStepLR':
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer=optimizer, **config['train']['lr_scheduler']['args'])
    else:
        raise RuntimeError

    if config['train']['loss']['name'] == 'CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss(**config['train']['loss']['args'])
    else:
        raise RuntimeError

    device, _ = prepare_device(config['n_gpu'])
    model = model.to(device)
    criterion = criterion.to(device)

    best_val_prec1 = 0
    best_test_prec1 = 0
    for epoch_index in range(config['train']['epochs']):
        print_func('{} Epoch {} {}'.format('=' * 35, epoch_index, '=' * 35),
                   fout_file)
        train_dataset = ImageFolder(
            data_root=config['general']['data_root'],
            mode='train',
            episode_num=config['train']['episode_num'],
            way_num=config['general']['way_num'],
            shot_num=config['general']['shot_num'],
            query_num=config['general']['query_num'],
            transform=train_trsfms,
        )
        val_dataset = ImageFolder(
            data_root=config['general']['data_root'],
            mode='val',
            episode_num=config['test']['episode_num'],
            way_num=config['general']['way_num'],
            shot_num=config['general']['shot_num'],
            query_num=config['general']['query_num'],
            transform=val_trsfms,
        )
        test_dataset = ImageFolder(
            data_root=config['general']['data_root'],
            mode='test',
            episode_num=config['test']['episode_num'],
            way_num=config['general']['way_num'],
            shot_num=config['general']['shot_num'],
            query_num=config['general']['query_num'],
            transform=val_trsfms,
        )

        print_func(
            'The num of the train_dataset: {}'.format(len(train_dataset)),
            fout_file)
        print_func('The num of the val_dataset: {}'.format(len(val_dataset)),
                   fout_file)
        print_func('The num of the test_dataset: {}'.format(len(test_dataset)),
                   fout_file)

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=config['train']['batch_size'],
            shuffle=True,
            num_workers=config['general']['workers_num'],
            drop_last=True,
            pin_memory=True)
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=config['test']['batch_size'],
            shuffle=True,
            num_workers=config['general']['workers_num'],
            drop_last=True,
            pin_memory=True)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=config['test']['batch_size'],
            shuffle=True,
            num_workers=config['general']['workers_num'],
            drop_last=True,
            pin_memory=True)

        # train for 5000 episodes in each epoch
        print_func('============ Train on the train set ============',
                   fout_file)
        train(train_loader, model, criterion, optimizer, epoch_index, device,
              fout_file, config['general']['image2level'],
              config['general']['print_freq'])

        print_func('============ Validation on the val set ============',
                   fout_file)
        val_prec1 = validate(val_loader, model, criterion, epoch_index, device,
                             fout_file, config['general']['image2level'],
                             config['general']['print_freq'])
        print_func(
            ' * Prec@1 {:.3f} Best Prec1 {:.3f}'.format(
                val_prec1, best_val_prec1), fout_file)

        print_func('============ Testing on the test set ============',
                   fout_file)
        test_prec1 = validate(test_loader, model, criterion, epoch_index,
                              device, fout_file,
                              config['general']['image2level'],
                              config['general']['print_freq'])
        print_func(
            ' * Prec@1 {:.3f} Best Prec1 {:.3f}'.format(
                test_prec1, best_test_prec1), fout_file)

        if val_prec1 > best_val_prec1:
            best_val_prec1 = val_prec1
            best_test_prec1 = test_prec1
            save_model(model,
                       save_path,
                       config['data_name'],
                       epoch_index,
                       is_best=True)

        if epoch_index % config['general'][
                'save_freq'] == 0 and epoch_index != 0:
            save_model(model,
                       save_path,
                       config['data_name'],
                       epoch_index,
                       is_best=False)

        lr_scheduler.step()

    print_func('............Training is end............', fout_file)