Esempio n. 1
0
def main(opt):
    # load model
    model = model_utils.load(opt)
    state_dict = torch.load(opt['model.model_path'])
    model.load_state_dict(state_dict)
    model.eval()

    # load opts
    model_opt_file = os.path.join(os.path.dirname(opt['model.model_path']), 'opt.json')
    with open(model_opt_file, 'r') as f:
        model_opt = json.load(f)

    # Postprocess arguments
    model_opt['model.x_dim'] = map(int, model_opt['model.x_dim'].split(','))
    model_opt['log.fields'] = model_opt['log.fields'].split(',')

    # construct data
    data_opt = { 'data.' + k: v for k,v in filter_opt(model_opt, 'data').items() }

    episode_fields = {
        'data.test_way': 'data.way',
        'data.test_shot': 'data.shot',
        'data.test_query': 'data.query',
        'data.test_episodes': 'data.train_episodes'
    }

    for k,v in episode_fields.items():
        if opt[k] != 0:
            data_opt[k] = opt[k]
        elif model_opt[k] != 0:
            data_opt[k] = model_opt[k]
        else:
            data_opt[k] = model_opt[v]

    print("Evaluating {:d}-way, {:d}-shot with {:d} query examples/class over {:d} episodes".format(
        data_opt['data.test_way'], data_opt['data.test_shot'],
        data_opt['data.test_query'], data_opt['data.test_episodes']))

    torch.manual_seed(1234)
    if data_opt['data.cuda']:
        torch.cuda.manual_seed(1234)

    data = data_utils.load(data_opt, ['test'])

    if data_opt['data.cuda']:
        model.cuda()

    meters = { field: tnt.meter.AverageValueMeter() for field in model_opt['log.fields'] }

    model_utils.evaluate(model, data['test'], meters, desc="test")

    for field,meter in meters.items():
        mean, std = meter.value()
        print("test {:s}: {:0.6f} +/- {:0.6f}".format(field, mean, 1.96 * std / math.sqrt(data_opt['data.test_episodes'])))
Esempio n. 2
0
def main(opt):
    if not os.path.isdir(opt['log.exp_dir']):
        os.makedirs(opt['log.exp_dir'])

    # save opts
    with open(os.path.join(opt['log.exp_dir'], 'opt.json'), 'w') as f:
        json.dump(opt, f)
        f.write('\n')

    trace_file = os.path.join(opt['log.exp_dir'], 'trace.txt')

    # Postprocess arguments
    opt['model.x_dim'] = list(map(int, opt['model.x_dim'].split(',')))
    opt['log.fields'] = opt['log.fields'].split(',')

    np.random.seed(4321)
    torch.manual_seed(1234)
    if opt['data.cuda']:
        torch.cuda.manual_seed(1234)

    if opt['data.trainval']:
        data = data_utils.load(opt, ['trainval'])
        train_loader = data['trainval']
        val_loader = None
    else:
        data = data_utils.load(opt, ['train', 'val'])
        train_loader = data['train']
        val_loader = data['val']

    model = model_utils.load(opt)

    if opt['data.cuda']:
        model.cuda()

    engine = Engine()

    meters = { 'train': { field: tnt.meter.AverageValueMeter() for field in opt['log.fields'] } }

    if val_loader is not None:
        meters['val'] = { field: tnt.meter.AverageValueMeter() for field in opt['log.fields'] }

    def on_start(state):
        if os.path.isfile(trace_file):
            os.remove(trace_file)
        state['scheduler'] = lr_scheduler.StepLR(state['optimizer'], opt['train.decay_every'], gamma=1)
    engine.hooks['on_start'] = on_start

    def on_start_epoch(state):
        for split, split_meters in meters.items():
            for field, meter in split_meters.items():
                meter.reset()
        state['scheduler'].step()
    engine.hooks['on_start_epoch'] = on_start_epoch

    def on_update(state):
        for field, meter in meters['train'].items():
            meter.add(state['output'][field])
    engine.hooks['on_update'] = on_update

    def on_end_epoch(hook_state, state):
        if val_loader is not None:
            if 'best_loss' not in hook_state:
                hook_state['best_loss'] = np.inf
            if 'wait' not in hook_state:
                hook_state['wait'] = 0

        if val_loader is not None:
            model_utils.evaluate(state['model'],
                                 val_loader,
                                 meters['val'],
                                 desc="Epoch {:d} valid".format(state['epoch']))

        meter_vals = log_utils.extract_meter_values(meters)
        print("Epoch {:02d}: {:s}".format(state['epoch'], log_utils.render_meter_values(meter_vals)))
        meter_vals['epoch'] = state['epoch']
        with open(trace_file, 'a') as f:
            json.dump(meter_vals, f)
            f.write('\n')

        if val_loader is not None and False:    # disable this block
            if meter_vals['val']['loss'] < hook_state['best_loss']:
                hook_state['best_loss'] = meter_vals['val']['loss']
                print("==> best model (loss = {:0.6f}), saving model...".format(hook_state['best_loss']))

                state['model'].cpu()
                torch.save(state['model'], os.path.join(opt['log.exp_dir'], 'best_model.pt'))
                if opt['data.cuda']:
                    state['model'].cuda()

                hook_state['wait'] = 0
            else:
                hook_state['wait'] += 1

                if hook_state['wait'] > opt['train.patience']:
                    print("==> patience {:d} exceeded".format(opt['train.patience']))
                    state['stop'] = True
        else:
            state['model'].cpu()
            torch.save(state['model'], os.path.join(opt['log.exp_dir'], 'best_model.pt'))
            if opt['data.cuda']:
                state['model'].cuda()

    engine.hooks['on_end_epoch'] = partial(on_end_epoch, { })

    engine.train(
        model = model,
        loader = train_loader,
        optim_method = getattr(optim, opt['train.optim_method']),
        optim_config = { 'lr': opt['train.learning_rate'],
                         'weight_decay': opt['train.weight_decay'] },
        max_epoch = opt['train.epochs']
    )
Esempio n. 3
0
def main(opt):
    # load model
    model = torch.load(opt['model.model_path'])
    # def get_image_input_hook(self, input, output):
    #     # print(len(input))
    #     # print(type(input))
    #     # print(len(output))
    #     # print(type(output))
    #     print("in:",input[0].data.cpu().numpy()[0].shape)
    #     print("out:",output.data.cpu().numpy()[0].shape)
    #     viz.image(input[0].data.cpu().numpy()[0])
    #     for i in output.data.cpu().numpy()[0]:
    #         viz.image(i)
    # model.encoder[0][0].register_forward_hook(get_image_input_hook)
    model.eval()

    # load opts
    model_opt_file = os.path.join(os.path.dirname(opt['model.model_path']),
                                  'opt.json')
    with open(model_opt_file, 'r') as f:
        model_opt = json.load(f)

    # Postprocess arguments
    model_opt['model.x_dim'] = map(int, model_opt['model.x_dim'].split(','))
    model_opt['log.fields'] = model_opt['log.fields'].split(',')

    # construct data
    data_opt = {
        'data.' + k: v
        for k, v in filter_opt(model_opt, 'data').items()
    }

    episode_fields = {
        'data.test_way': 'data.way',
        'data.test_shot': 'data.shot',
        'data.test_query': 'data.query',
        'data.test_episodes': 'data.train_episodes'
    }

    for k, v in episode_fields.items():
        if opt[k] != 0:
            data_opt[k] = opt[k]
        elif model_opt[k] != 0:
            data_opt[k] = model_opt[k]
        else:
            data_opt[k] = model_opt[v]

    print(
        "Evaluating {:d}-way, {:d}-shot with {:d} query examples/class over {:d} episodes"
        .format(data_opt['data.test_way'], data_opt['data.test_shot'],
                data_opt['data.test_query'], data_opt['data.test_episodes']))

    torch.manual_seed(1234)
    if data_opt['data.cuda']:
        torch.cuda.manual_seed(1234)
    print(data_opt)
    data = data_utils.load(data_opt, ['test'])

    if data_opt['data.cuda']:
        model.cuda()

    meters = {
        field: tnt.meter.AverageValueMeter()
        for field in model_opt['log.fields']
    }

    model_utils.evaluate(model, data['test'], meters, desc="test")

    for field, meter in meters.items():
        mean, std = meter.value()
        print("test {:s}: {:0.6f} +/- {:0.6f}".format(
            field, mean,
            1.96 * std / math.sqrt(data_opt['data.test_episodes'])))
def main(opt):
    ###########################################
    # Boilerplate
    ###########################################

    #assert not (opt['clustering'] != 'wasserstein' and opt['train_loss'] in ['sinkhorn', 'twostep']),\
    #    'Only Wasserstein clustering is compatible with Sinkhorn and Twostep meta-training losses'

    if not os.path.isdir(opt['log.exp_dir']):
        os.makedirs(opt['log.exp_dir'])

    # save opts
    with open(os.path.join(opt['log.exp_dir'], 'opt.json'), 'w') as f:
        json.dump(opt, f, indent=4)
        f.write('\n')

    trace_file = os.path.join(opt['log.exp_dir'], 'trace.txt')

    # Adapt model size to dataset
    if opt['data.dataset'] == 'omniglot':
        opt['model.x_dim'] = '1,28,28'
    elif opt['data.dataset'] == 'miniimagenet':
        opt['model.x_dim'] = '3,84,84'

    # Postprocess arguments
    opt['model.x_dim'] = list(map(int, opt['model.x_dim'].split(',')))
    opt['log.fields'] = opt['log.fields'].split(',')

    torch.manual_seed(1234)
    if opt['data.cuda']:
        torch.cuda.manual_seed(1234)

    ###########################################
    # Data
    ###########################################
    if opt['data.trainval']:
        data = data_utils.load(opt, ['trainval'])
        train_loader = data['trainval']
        val_loader = None

        # Prepare datasets
        train_iter = data_adapter(make_infinite(train_loader), opt, train=True)
        val_iter = None
    else:
        data = data_utils.load(opt, ['train', 'val', 'test'])
        train_loader = data['train']
        val_loader = data['val']
        test_loader = data['test']

        # Prepare datasets
        train_iter = data_adapter(make_infinite(train_loader), opt, train=True)
        val_iter = data_adapter(make_infinite(val_loader), opt, train=False)
        test_iter = data_adapter(make_infinite(test_loader), opt, train=False)

    ###########################################
    # Create model and optimizer
    ###########################################

    model = model_utils.load(opt)

    if opt['checkpoint']:
        print 'Loading from checkpoint', opt['checkpoint']
        model = torch.load(opt['checkpoint'])

    if opt['checkpoint_state']:
        print 'Loading state from checkpoint', opt['checkpoint_state']
        model.load_state_dict(
            torch.load(opt['checkpoint_state'],
                       map_location=lambda storage, loc: storage))

    if opt['data.cuda']:
        model.cuda()

    Optimizer = getattr(optim, opt['train.optim_method'])
    optimizer = Optimizer(model.parameters(),
                          lr=opt['train.learning_rate'],
                          weight_decay=opt['train.weight_decay'])

    scheduler = lr_scheduler.StepLR(optimizer,
                                    opt['train.decay_every'],
                                    gamma=0.5)

    ###########################################
    # Training loop
    ###########################################

    summary = Summary()

    #### Start of training loop
    softmax_regularization = 1. / opt['temperature']
    sinkhorn_regularizations = [
        float(x) for x in opt['regularizations'].split(',')
    ]
    print 'Sinkhorn regularizations will take parameters', sinkhorn_regularizations
    for iteration in xrange(opt['iterations']):

        # Sample from training
        with Timer() as train_load_timer:

            sample_train, new_epoch = train_iter.next()
            # For debug
            #plt.imshow(0.5 + 0.5 * np.rollaxis(sample_train['xs'].numpy(), 2, 5)[0].reshape((5 * 84, 84, 3)))

        # Compute loss; backprop
        with Timer() as train_backprop_timer:

            model.train()  # batchnorm train mode

            # z = h(x)
            embedding_train = model.embed(sample_train,
                                          raw_input=opt['rawinput'])

            if iteration == 0:
                print 'Debug: Tensor sizes'
                print 'xs', sample_train['xs'].size()
                print 'xq', sample_train['xq'].size()
                print 'zs', embedding_train['zs'].size()
                print 'zq', embedding_train['zq'].size()
                # Should be 64 for omniglot and 1600 for miniimagenet
            del sample_train  # save memory

            # Supervised and Clustering Losses
            train_supervised_info = model.supervised_loss(
                embedding_train, regularization=softmax_regularization)
            for skr in sinkhorn_regularizations:
                gamma = 1. / skr
                train_clustering_info = model.clustering_loss(
                    embedding_train,
                    regularization=gamma,
                    clustering_type=opt['clustering'],
                    sanity_check=opt['sanity_check'])
                # unsupervised losses
                summary.log(
                    iteration,
                    'train/SupportClusteringAcc_softmax_reg{}'.format(skr),
                    train_clustering_info['SupportClusteringAcc_softmax'])
                summary.log(
                    iteration,
                    'train/SupportClusteringAcc_sinkhorn_reg{}'.format(skr),
                    train_clustering_info['SupportClusteringAcc_sinkhorn'])
                summary.log(
                    iteration,
                    'train/QueryClusteringAcc_softmax_reg{}'.format(skr),
                    train_clustering_info['QueryClusteringAcc_softmax'])
                summary.log(
                    iteration,
                    'train/QueryClusteringAcc_sinkhorn_reg{}'.format(skr),
                    train_clustering_info['QueryClusteringAcc_sinkhorn'])

            if opt['train_loss'] == 'softmax':  # softmax
                total_loss = train_supervised_info['SupervisedLoss_softmax']
            elif opt['train_loss'] == 'sinkhorn':
                total_loss = train_supervised_info['SupervisedLoss_sinkhorn']
            elif opt['train_loss'] == 'twostep':
                total_loss = train_supervised_info['SupervisedLoss_twostep']
            elif opt['train_loss'] == 'end2end':
                #total_loss = train_clustering_info['SupportClusteringLoss_sinkhorn']
                total_loss = train_clustering_info[
                    'SupportClusteringLoss_softmax']
            elif opt['train_loss'] == 'evalonly':
                total_loss = torch.zeros([])
            else:
                raise Exception('Unknown meta-training loss {}'.format(
                    opt['train_loss']))

            if opt['centroid_loss'] > 0.:
                centroid_loss = opt['centroid_loss'] * train_supervised_info[
                    'ClassVariance']
                total_loss = total_loss + centroid_loss
                summary.log(iteration, 'train/CentroidLoss',
                            centroid_loss.item())  # Supervised accuracy
            summary.log(iteration, 'train/CentroidLossUnscaled',
                        train_supervised_info['ClassVariance'].item(
                        ))  # Supervised accuracy

            if not opt['rawinput'] and opt['train_loss'] != 'evalonly':
                # No need to backprop in rawinput mode
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

        # supervised losses
        summary.log(iteration, 'train/SupervisedAcc_softmax',
                    train_supervised_info['SupervisedAcc_softmax'].item())
        summary.log(iteration, 'train/SupervisedAcc_sinkhorn',
                    train_supervised_info['SupervisedAcc_sinkhorn'].item())
        summary.log(iteration, 'train/SupervisedAcc_twostep',
                    train_supervised_info['SupervisedAcc_twostep'].item())
        summary.log(iteration, 'train/SupervisedLoss_softmax',
                    train_supervised_info['SupervisedLoss_softmax'].item())
        summary.log(iteration, 'train/SupervisedLoss_sinkhorn',
                    train_supervised_info['SupervisedLoss_sinkhorn'].item())
        summary.log(iteration, 'train/SupervisedLoss_twostep',
                    train_supervised_info['SupervisedLoss_twostep'].item())

        # unsupervised losses
        summary.log(iteration, 'train/SupportClusteringAcc_softmax',
                    train_clustering_info['SupportClusteringAcc_softmax'])
        summary.log(iteration, 'train/SupportClusteringAcc_sinkhorn',
                    train_clustering_info['SupportClusteringAcc_sinkhorn'])
        summary.log(iteration, 'train/QueryClusteringAcc_softmax',
                    train_clustering_info['QueryClusteringAcc_softmax'])
        summary.log(iteration, 'train/QueryClusteringAcc_sinkhorn',
                    train_clustering_info['QueryClusteringAcc_sinkhorn'])

        # end to end losses
        summary.log(
            iteration, 'train/SupportClusteringLoss_softmax',
            train_clustering_info['SupportClusteringLoss_softmax'].item())
        summary.log(
            iteration, 'train/SupportClusteringLoss_sinkhorn',
            train_clustering_info['SupportClusteringLoss_sinkhorn'].item())
        summary.log(
            iteration, 'train/QueryClusteringLoss_softmax',
            train_clustering_info['QueryClusteringLoss_softmax'].item())
        summary.log(
            iteration, 'train/QueryClusteringLoss_sinkhorn',
            train_clustering_info['QueryClusteringLoss_sinkhorn'].item())

        summary.log(iteration, 'train/_TimeLoad', train_load_timer.interval)
        summary.log(iteration, 'train/_TimeBackprop',
                    train_backprop_timer.interval)
        summary.log(iteration, 'train/TotalLoss',
                    total_loss.item())  # Supervised accuracy

        # Sample from validation and test
        if iteration % opt['validate_interval'] == 0 and val_iter is not None:

            for subset, subset_iter in [('val', val_iter),
                                        ('test', test_iter)]:

                with Timer() as val_load_timer:

                    sample_val, __ = subset_iter.next()

                with Timer() as val_eval_timer:

                    # Weird? deactivate batchnorm train mode
                    model.eval()

                    # z = h(x)
                    embedding_val = model.embed(sample_val,
                                                raw_input=opt['rawinput'])
                    detach_(embedding_val)  # save memory
                    del sample_val  # save memory

                    val_supervised_info = model.supervised_loss(
                        embedding_val, regularization=softmax_regularization)

                    for skr in sinkhorn_regularizations:
                        gamma = 1. / skr
                        val_clustering_info = model.clustering_loss(
                            embedding_val,
                            regularization=gamma,
                            clustering_type=opt['clustering'],
                            sanity_check=opt['sanity_check'])

                        # log unsupervised losses
                        summary.log(
                            iteration,
                            '{}/SupportClusteringAcc_softmax_reg{}'.format(
                                subset, skr),
                            val_clustering_info['SupportClusteringAcc_softmax']
                        )
                        summary.log(
                            iteration,
                            '{}/SupportClusteringAcc_sinkhorn_reg{}'.format(
                                subset, skr), val_clustering_info[
                                    'SupportClusteringAcc_sinkhorn'])
                        summary.log(
                            iteration,
                            '{}/QueryClusteringAcc_softmax_reg{}'.format(
                                subset, skr),
                            val_clustering_info['QueryClusteringAcc_softmax'])
                        summary.log(
                            iteration,
                            '{}/QueryClusteringAcc_sinkhorn_reg{}'.format(
                                subset, skr),
                            val_clustering_info['QueryClusteringAcc_sinkhorn'])

                # supervised losses
                summary.log(
                    iteration, '{}/SupervisedAcc_softmax'.format(subset),
                    val_supervised_info['SupervisedAcc_softmax'].item())
                summary.log(
                    iteration, '{}/SupervisedAcc_sinkhorn'.format(subset),
                    val_supervised_info['SupervisedAcc_sinkhorn'].item())
                summary.log(
                    iteration, '{}/SupervisedAcc_twostep'.format(subset),
                    val_supervised_info['SupervisedAcc_twostep'].item())
                summary.log(
                    iteration, '{}/SupervisedLoss_softmax'.format(subset),
                    val_supervised_info['SupervisedLoss_softmax'].item())
                summary.log(
                    iteration, '{}/SupervisedLoss_sinkhorn'.format(subset),
                    val_supervised_info['SupervisedLoss_sinkhorn'].item())
                summary.log(
                    iteration, '{}/SupervisedLoss_twostep'.format(subset),
                    val_supervised_info['SupervisedLoss_twostep'].item())

                # unsupervised losses
                summary.log(
                    iteration,
                    '{}/SupportClusteringAcc_softmax'.format(subset),
                    val_clustering_info['SupportClusteringAcc_softmax'])
                summary.log(
                    iteration,
                    '{}/SupportClusteringAcc_sinkhorn'.format(subset),
                    val_clustering_info['SupportClusteringAcc_sinkhorn'])
                summary.log(iteration,
                            '{}/QueryClusteringAcc_softmax'.format(subset),
                            val_clustering_info['QueryClusteringAcc_softmax'])
                summary.log(iteration,
                            '{}/QueryClusteringAcc_sinkhorn'.format(subset),
                            val_clustering_info['QueryClusteringAcc_sinkhorn'])

                # end-to-end losses
                summary.log(
                    iteration,
                    '{}/SupportClusteringLoss_softmax'.format(subset),
                    val_clustering_info['SupportClusteringLoss_softmax'].item(
                    ))
                summary.log(
                    iteration,
                    '{}/SupportClusteringLoss_sinkhorn'.format(subset),
                    val_clustering_info['SupportClusteringLoss_sinkhorn'].item(
                    ))
                summary.log(
                    iteration, '{}/QueryClusteringLoss_softmax'.format(subset),
                    val_clustering_info['QueryClusteringLoss_softmax'].item())
                summary.log(
                    iteration,
                    '{}/QueryClusteringLoss_sinkhorn'.format(subset),
                    val_clustering_info['QueryClusteringLoss_sinkhorn'].item())

                summary.log(iteration, '{}/_TimeLoad'.format(subset),
                            val_load_timer.interval)
                summary.log(iteration, '{}/_TimeEval'.format(subset),
                            val_eval_timer.interval)

        # End of epoch? -> schedule new learning rate
        if new_epoch and iteration > 0:
            print 'End of epoch, scheduling new learning rate'
            scheduler.step()

            summary.log(iteration, 'other/_LR', scheduler.get_lr())

        # Save model
        if iteration % 200 == 0:

            if opt['rawinput'] or opt['train_loss'] == 'evalonly':
                print 'No model to save in raw_input mode'
            else:
                print 'Saving current model'
                model.cpu()

                torch.save(
                    model, os.path.join(opt['log.exp_dir'],
                                        'current_model.pt'))

                if iteration % 2000 == 0:
                    print 'Saving model at iteration', iteration
                    torch.save(
                        model,
                        os.path.join(opt['log.exp_dir'],
                                     'model_{}.pt'.format(iteration)))

                if opt['data.cuda']:
                    model.cuda()

        # Log

        if iteration % 10 == 0:
            print 'Iteration', iteration
            if opt['train_loss'] == 'evalonly':
                print '*' * 32
                print 'Full summary. Iteration {}'.format(iteration)
                print '*' * 32
                summary.print_full_summary()
            else:
                print '!' * 32
                print 'Running averages. Iteration {}'.format(iteration)
                print '!' * 32
                summary.print_full_summary()
                if opt['hide_test']:
                    summary.print_summary(exclude='test/')
                else:
                    summary.print_summary()

        #### Save log
        if iteration % 500 == 0 or iteration < 10 or (
                iteration % 100 == 0 and opt['train_loss'] == 'evalonly'):
            try:
                with open(os.path.join(opt['log.exp_dir'], 'log.json'),
                          'wb') as fp:
                    json.dump(summary.logs, fp)
                # Dumpy full summary as well, although this mostly makes sense in evalonly mode
                with open(os.path.join(opt['log.exp_dir'], 'summary.txt'),
                          'wb') as fp:
                    fp.write('Iteration {}/{}\n'.format(
                        iteration, opt['iterations']))
                    fp.write(summary.get_full_summary())
            except Exception as e:
                print 'Could not dump log file! Ignoring for now', e
Esempio n. 5
0
def main(opt):
    # load model
    model = torch.load(opt['model.model_path'])
    model.eval()

    # load opts
    model_opt_file = os.path.join(os.path.dirname(opt['model.model_path']),
                                  'opt.json')
    with open(model_opt_file, 'r') as f:
        model_opt = json.load(f)

    # Postprocess arguments
    model_opt['model.x_dim'] = map(int, model_opt['model.x_dim'].split(','))
    model_opt['log.fields'] = model_opt['log.fields'].split(',')

    # construct data
    data_opt = {
        'data.' + k: v
        for k, v in filter_opt(model_opt, 'data').items()
    }

    episode_fields = {
        'data.test_way': 'data.way',
        'data.test_shot': 'data.shot',
        'data.test_query': 'data.query',
        'data.test_episodes': 'data.train_episodes'
    }

    for k, v in episode_fields.items():
        if opt[k] != 0:
            data_opt[k] = opt[k]
        elif model_opt[k] != 0:
            data_opt[k] = model_opt[k]
        else:
            data_opt[k] = model_opt[v]

    print(
        "Evaluating {:d}-way, {:d}-shot with {:d} query examples/class over {:d} episodes"
        .format(data_opt['data.test_way'], data_opt['data.test_shot'],
                data_opt['data.test_query'], data_opt['data.test_episodes']))

    torch.manual_seed(1234)
    if data_opt['data.cuda']:
        torch.cuda.manual_seed(1234)

    data = data_utils.load(data_opt, ['test'])

    if data_opt['data.cuda']:
        model.cuda()

    meters = {
        field: tnt.meter.AverageValueMeter()
        for field in model_opt['log.fields']
    }
    if opt['stage'] == 'protonet':
        _, class_acc, class_prec, prec_micro = model_utils.evaluate(
            model,
            data['test'],
            meters,
            stage='protonet',
            desc="test",
            evaluation=True)
    else:
        _, class_acc, class_prec, prec_micro = model_utils.evaluate(
            model,
            data['test'],
            meters,
            stage='feat',
            desc='test',
            evaluation=True)
    for field, meter in meters.items():
        mean, std = meter.value()
        print("test {:s}: {:0.6f} +/- {:0.6f}".format(
            field, mean,
            1.96 * std / math.sqrt(data_opt['data.test_episodes'])))

    mean_prec = 0
    n = 0
    for k in class_acc.keys():
        print('class {} acc: {:0.4f}'.format(k, class_acc[k]))
    for k in class_prec.keys():
        mean_prec += class_prec[k]
        n += 1
        print('class {} prec: {:0.4f}'.format(k, class_prec[k]))
    mean_prec = mean_prec / n
    print('Average prec(macro): {:0.4f}; Average prec(micro): {:0.4f}'.format(
        mean_prec, prec_micro))
Esempio n. 6
0
def main(opt):
    # 新建日志目录
    if not os.path.isdir(opt['log.exp_dir']):
        os.makedirs(opt['log.exp_dir'])

    # save opts
    with open(os.path.join(opt['log.exp_dir'], 'opt.json'), 'w') as f:
        json.dump(opt, f)
        f.write('\n')

    trace_file = os.path.join(opt['log.exp_dir'], 'trace.txt')

    # Postprocess arguments
    opt['model.x_dim'] = list(map(int, opt['model.x_dim'].split(',')))
    opt['log.fields'] = opt['log.fields'].split(',')

    torch.manual_seed(1234)
    if opt['data.cuda']:
        torch.cuda.manual_seed(1234)

    #??? trainval是什么???
    if opt['data.trainval']:
        data = data_utils.load(opt, ['trainval'])
        train_loader = data['trainval']
        val_loader = None
    else:
        data = data_utils.load(opt, ['train', 'val'])
        train_loader = data['train']
        val_loader = data['val']

    model = model_utils.load(opt)
    #model = torch.load("results/m5_5way5shot/pre.t7")

    if opt['data.cuda']:
        model.cuda()

    engine = Engine()

    meters = {
        'train':
        {field: tnt.meter.AverageValueMeter()
         for field in opt['log.fields']}
    }

    if val_loader is not None:
        meters['val'] = {
            field: tnt.meter.AverageValueMeter()
            for field in opt['log.fields']
        }

    # 看名字知道功能的start函数,配置优化器
    def on_start(state):
        if os.path.isfile(trace_file):
            os.remove(trace_file)
        state['scheduler'] = lr_scheduler.StepLR(state['optimizer'],
                                                 opt['train.decay_every'],
                                                 gamma=0.5)

    engine.hooks['on_start'] = on_start

    # 第一个epoch需要解决的事
    def on_start_epoch(state):
        for split, split_meters in meters.items():
            for field, meter in split_meters.items():
                meter.reset()
        state['scheduler'].step()

    engine.hooks['on_start_epoch'] = on_start_epoch

    # 更新那个算平均的类
    def on_update(state):
        for field, meter in meters['train'].items():
            meter.add(state['output'][field])

    engine.hooks['on_update'] = on_update

    #一个epoch结束时判断训练效果,以及是否结束训练(patience?为什么不用loss的改变?看了实际训练貌似loss变化挺大的)
    title = '%s, %s: %i_%iw_%is' % (opt['model.exp_name'], opt['data.dataset'],
                                    opt['data.way'], opt['data.test_way'],
                                    opt['data.test_shot'])
    lossPic = visual_utils.train_val_loss(title)
    accPic = visual_utils.train_val_acc(title)

    def on_end_epoch(hook_state, state):
        if val_loader is not None:
            if 'best_loss' not in hook_state:
                hook_state['best_loss'] = np.inf
            if 'wait' not in hook_state:
                hook_state['wait'] = 0

        if val_loader is not None:
            model_utils.evaluate(state['model'],
                                 val_loader,
                                 meters['val'],
                                 desc="Epoch {:d} valid".format(
                                     state['epoch']))

        meter_vals = log_utils.extract_meter_values(meters)
        lossPic(state['epoch'], meter_vals['train']['loss'],
                meter_vals['val']['loss'])
        accPic(state['epoch'], meter_vals['train']['acc'],
               meter_vals['val']['acc'])
        print("Epoch {:02d}: {:s}".format(
            state['epoch'], log_utils.render_meter_values(meter_vals)))
        meter_vals['epoch'] = state['epoch']
        with open(trace_file, 'a') as f:
            json.dump(meter_vals, f)
            f.write('\n')

        if val_loader is not None:
            if meter_vals['val']['loss'] < hook_state['best_loss']:
                hook_state['best_loss'] = meter_vals['val']['loss']
                print(
                    "==> best model (loss = {:0.6f}), saving model...".format(
                        hook_state['best_loss']))

                state['model'].cpu()
                torch.save(state['model'],
                           os.path.join(opt['log.exp_dir'], 'best_model.t7'))
                if opt['data.cuda']:
                    state['model'].cuda()

                hook_state['wait'] = 0
            else:
                hook_state['wait'] += 1

                if hook_state['wait'] > opt['train.patience']:
                    print("==> patience {:d} exceeded".format(
                        opt['train.patience']))
                    state['stop'] = True
        else:
            state['model'].cpu()
            torch.save(state['model'],
                       os.path.join(opt['log.exp_dir'], 'best_model.t7'))
            if opt['data.cuda']:
                state['model'].cuda()

    engine.hooks['on_end_epoch'] = partial(on_end_epoch, {})

    engine.train(model=model,
                 loader=train_loader,
                 optim_method=getattr(optim, opt['train.optim_method']),
                 optim_config={
                     'lr': opt['train.learning_rate'],
                     'weight_decay': opt['train.weight_decay']
                 },
                 max_epoch=opt['train.epochs'])
Esempio n. 7
0
def main(opt, augment_stn):
    # load model
    model = torch.load(opt['model.model_path'])
    model.eval()

    # load opts
    model_opt_file = os.path.join(os.path.dirname(opt['model.model_path']),
                                  'opt.json')
    with open(model_opt_file, 'r') as f:
        model_opt = json.load(f)

    # Postprocess arguments
    model_opt['model.x_dim'] = map(int, model_opt['model.x_dim'].split(','))
    model_opt['log.fields'] = model_opt['log.fields'].split(',')
    if model_opt['run_stn']:
        print("Loading STN here")
        try:
            stn_model = torch.load(opt['model.model_path'].replace(
                '.pt', '_stn.pt'))
            stn_model.eval()
        except:
            stn_model = None
    else:
        stn_model = None

    # Augment overrides it anyway
    if not augment_stn:
        stn_model = None

    # construct data
    data_opt = {
        'data.' + k: v
        for k, v in filter_opt(model_opt, 'data').items()
    }

    episode_fields = {
        'data.test_way': 'data.way',
        'data.test_shot': 'data.shot',
        'data.test_query': 'data.query',
        'data.test_episodes': 'data.train_episodes'
    }

    for k, v in episode_fields.items():
        if opt[k] != 0:
            data_opt[k] = opt[k]
        elif model_opt[k] != 0:
            data_opt[k] = model_opt[k]
        else:
            data_opt[k] = model_opt[v]

    print(
        "Evaluating {:d}-way, {:d}-shot with {:d} query examples/class over {:d} episodes"
        .format(data_opt['data.test_way'], data_opt['data.test_shot'],
                data_opt['data.test_query'], data_opt['data.test_episodes']))

    torch.manual_seed(opt['seed'])
    if data_opt['data.cuda']:
        torch.cuda.manual_seed(opt['seed'])

    data = data_utils.load(data_opt, ['test'])

    if data_opt['data.cuda']:
        model.cuda()
        if stn_model is not None:
            stn_model.cuda()

    meters = {
        field: tnt.meter.AverageValueMeter()
        for field in model_opt['log.fields']
    }

    model_utils.evaluate(model, stn_model, data['test'], meters, desc="test")

    for field, meter in meters.items():
        mean, std = meter.value()
        print("test {:s}: {:0.6f} +/- {:0.6f}".format(
            field, mean,
            1.96 * std / math.sqrt(data_opt['data.test_episodes'])))
Esempio n. 8
0
def main(opt):
		
    if not os.path.isdir(opt['log.exp_dir']):
        os.makedirs(opt['log.exp_dir'])

    # save opts
    # 将opts加入文件中
    with open(os.path.join(opt['log.exp_dir'], 'opt.json'), 'w') as f:
        json.dump(opt, f)
        f.write('\n')

    trace_file = os.path.join(opt['log.exp_dir'], 'trace.txt')

    # Postprocess arguments
    opt['model.x_dim'] = list(map(int, opt['model.x_dim'].split(',')))
    opt['log.fields'] = opt['log.fields'].split(',')

    torch.manual_seed(1234)
    if opt['data.cuda']:
        torch.cuda.manual_seed(1234)
    
	  # 加载数据
    if opt['data.trainval']:
        # load Omniglot dataset
        data = data_utils.load(opt, ['trainval'])
        train_loader = data['trainval']
        val_loader = None
    else:
        data = data_utils.load(opt, ['train', 'val'])
        train_loader = data['train']
        val_loader = data['val']
    
    # 使用模型相关参数加载模型
    model = model_utils.load(opt)

    if opt['data.cuda']:
        model.cuda()

    engine = Engine()
    
    # torchnet.meter评估方法性能,这里用的平均值
    # 先建立一个每个指标及其对应评价值的字典
    meters = { 'train': { field: tnt.meter.AverageValueMeter() for field in opt['log.fields'] } }
    
    if val_loader is not None:
        meters['val'] = { field: tnt.meter.AverageValueMeter() for field in opt['log.fields'] }
    
    def on_start(state):
        if os.path.isfile(trace_file):
            os.remove(trace_file)
        # 定义学习率衰减机制
        state['scheduler'] = lr_scheduler.StepLR(state['optimizer'], opt['train.decay_every'], gamma=0.5)
    # 将上面的函数赋给匿名函数engine.hooks
    engine.hooks['on_start'] = on_start
    
    # 每个epoch开始时
    def on_start_epoch(state):
        for split, split_meters in meters.items():
            for field, meter in split_meters.items():
                # 重置评价指标
						meter.reset()

        # 调用optimizer的step()函数进行回传
			state['scheduler'].step()
    
		# 用匿名函数包装
    engine.hooks['on_start_epoch'] = on_start_epoch

    # 更新评价指标
    def on_update(state):
        # 对于所有训练指标,更新每个训练指标的值
        for field, meter in meters['train'].items():
            meter.add(state['output'][field])
    engine.hooks['on_update'] = on_update

    # 在每个epoch结束时,给出这个epoch的评价值
    def on_end_epoch(hook_state, state):
        if val_loader is not None:
            if 'best_loss' not in hook_state:
                hook_state['best_loss'] = np.inf
            if 'wait' not in hook_state:
                hook_state['wait'] = 0

        if val_loader is not None:
				# 使用val评价模型
            model_utils.evaluate(state['model'],
                                 val_loader,
                                 meters['val'],
                                 desc="Epoch {:d} valid".format(state['epoch']))

        meter_vals = log_utils.extract_meter_values(meters)
        print("Epoch {:02d}: {:s}".format(state['epoch'], log_utils.render_meter_values(meter_vals)))
        meter_vals['epoch'] = state['epoch']
        with open(trace_file, 'a') as f:
            json.dump(meter_vals, f)
            f.write('\n')
			
			# 更新目前最好的loss
        if val_loader is not None:
            if meter_vals['val']['loss'] < hook_state['best_loss']:
                hook_state['best_loss'] = meter_vals['val']['loss']
                print("==> best model (loss = {:0.6f}), saving model...".format(hook_state['best_loss']))

                state['model'].cpu()
                torch.save(state['model'], os.path.join(opt['log.exp_dir'], 'best_model.pt'))
                if opt['data.cuda']:
                    state['model'].cuda()

                hook_state['wait'] = 0
           
				# early stop 
				else:
                hook_state['wait'] += 1

                if hook_state['wait'] > opt['train.patience']:
                    print("==> patience {:d} exceeded".format(opt['train.patience']))
                    state['stop'] = True