Exemplo n.º 1
0
def main():
    # Step 1: init dataloader
    print("init data loader")
    args_data = {}
    args_data['x_dim'] = x_dim
    args_data['ratio'] = ratio
    args_data['seed'] = seed
    if dataset == 'mini':
        loader_train = dataset_mini(n_examples, n_episodes, 'train', args_data)
        loader_val = dataset_mini(n_examples, n_episodes, 'val', args_data)
    elif dataset == 'tiered':
        loader_train = dataset_tiered(n_examples, n_episodes, 'train',
                                      args_data)
        loader_val = dataset_tiered(n_examples, n_episodes, 'val', args_data)

    if not pkl:
        loader_train.load_data()
        loader_val.load_data()
    else:
        loader_train.load_data_pkl()
        loader_val.load_data_pkl()

    # Step 2: init neural networks
    print("init neural networks")

    # construct the model
    model = models.LabelPropagation(args)
    model.cuda(0)

    # optimizer
    model_optim = torch.optim.Adam(model.parameters(), lr=lr)
    model_scheduler = StepLR(model_optim, step_size=step_size, gamma=gamma)

    # load the saved model
    if iters > 0:
        model.load_state_dict(
            torch.load('checkpoints/%s/models/%s_%d_model.t7' %
                       (args['exp_name'], alg, iters)))
        print('Loading Parameters from %s: %d' % (args['exp_name'], iters))

    # Step 3: Train and validation
    print("Training...")

    best_acc = 0.0
    best_loss = np.inf
    wait = 0
    for ep in range(iters, n_epochs):
        loss_tr = []
        ce_list = []

        acc_tr = []
        loss_val = []
        acc_val = []

        for epi in tqdm(range(n_episodes), desc='train_epoc:{}'.format(ep)):

            model_scheduler.step(ep * n_episodes + epi)

            # set train mode
            model.train()

            # sample data for next batch
            support, s_labels, query, q_labels, unlabel = loader_train.next_data(
                n_way, n_shot, n_query)
            support = np.reshape(support,
                                 (support.shape[0] * support.shape[1], ) +
                                 support.shape[2:])
            support = torch.from_numpy(np.transpose(support, (0, 3, 1, 2)))
            query = np.reshape(query, (query.shape[0] * query.shape[1], ) +
                               query.shape[2:])
            query = torch.from_numpy(np.transpose(query, (0, 3, 1, 2)))
            s_labels = torch.from_numpy(np.reshape(s_labels, (-1, )))
            q_labels = torch.from_numpy(np.reshape(q_labels, (-1, )))
            s_labels = s_labels.type(torch.LongTensor)
            q_labels = q_labels.type(torch.LongTensor)
            s_onehot = torch.zeros(n_way * n_shot,
                                   n_way).scatter_(1, s_labels.view(-1, 1), 1)
            q_onehot = torch.zeros(n_way * n_query,
                                   n_way).scatter_(1, q_labels.view(-1, 1), 1)

            inputs = [
                support.cuda(0),
                s_onehot.cuda(0),
                query.cuda(0),
                q_onehot.cuda(0)
            ]

            loss, acc = model(inputs)
            loss_tr.append(loss.item())
            acc_tr.append(acc.item())

            model.zero_grad()
            loss.backward()
            #torch.nn.utils.clip_grad_norm(model.parameters(), 4.0)
            model_optim.step()

        for epi in tqdm(range(n_episodes), desc='val epoc:{}'.format(ep)):
            # set eval mode
            model.eval()

            # sample data for next batch
            support, s_labels, query, q_labels, unlabel = loader_val.next_data(
                n_test_way, n_test_shot, n_test_query)
            support = np.reshape(support,
                                 (support.shape[0] * support.shape[1], ) +
                                 support.shape[2:])
            support = torch.from_numpy(np.transpose(support, (0, 3, 1, 2)))
            query = np.reshape(query, (query.shape[0] * query.shape[1], ) +
                               query.shape[2:])
            query = torch.from_numpy(np.transpose(query, (0, 3, 1, 2)))
            s_labels = torch.from_numpy(np.reshape(s_labels, (-1, )))
            q_labels = torch.from_numpy(np.reshape(q_labels, (-1, )))
            s_labels = s_labels.type(torch.LongTensor)
            q_labels = q_labels.type(torch.LongTensor)
            s_onehot = torch.zeros(n_test_way * n_test_shot,
                                   n_test_way).scatter_(
                                       1, s_labels.view(-1, 1), 1)
            q_onehot = torch.zeros(n_test_way * n_test_query,
                                   n_test_way).scatter_(
                                       1, q_labels.view(-1, 1), 1)

            with torch.no_grad():
                inputs = [
                    support.cuda(0),
                    s_onehot.cuda(0),
                    query.cuda(0),
                    q_onehot.cuda(0)
                ]
                loss, acc = model(inputs)

            loss_val.append(loss.item())
            acc_val.append(acc.item())

        print(
            'epoch:{}, loss_tr:{:.5f}, acc_tr:{:.5f}, loss_val:{:.5f}, acc_val:{:.5f}'
            .format(ep, np.mean(loss_tr), np.mean(acc_tr), np.mean(loss_val),
                    np.mean(acc_val)))

        # Model Save and Stop Criterion
        cond1 = (np.mean(acc_val) > best_acc)
        cond2 = (np.mean(loss_val) < best_loss)

        if cond1 or cond2:
            best_acc = np.mean(acc_val)
            best_loss = np.mean(loss_val)
            print('best val loss:{:.5f}, acc:{:.5f}'.format(
                best_loss, best_acc))

            # save model
            torch.save(
                model.state_dict(), 'checkpoints/%s/models/%s_%d_model.t7' %
                (args['exp_name'], alg, (ep + 1) * n_episodes))

            f = open('checkpoints/' + args['exp_name'] + '/log.txt', 'a')
            print('{} {:.5f} {:.5f}'.format((ep + 1) * n_episodes, best_loss,
                                            best_acc),
                  file=f)
            f.close()

            wait = 0

        else:
            wait += 1
            if ep % 100 == 0:
                torch.save(
                    model.state_dict(),
                    'checkpoints/%s/models/%s_%d_model.t7' %
                    (args['exp_name'], alg, (ep + 1) * n_episodes))

                f = open('checkpoints/' + args['exp_name'] + '/log.txt', 'a')
                print('{} {:.5f} {:.5f}'.format((ep + 1) * n_episodes,
                                                np.mean(loss_val),
                                                np.mean(acc_val)),
                      file=f)
                f.close()

        if wait > patience and ep > n_epochs:
            break
Exemplo n.º 2
0
def main():
    # init dataloader
    print("init data loader")
    args_data = {}
    args_data['x_dim'] = '84,84,3'
    args_data['ratio'] = 1.0
    args_data['seed'] = seed
    print('seed:',seed)
    if dataset=='mini':
        loader_test = dataset_mini(n_examples, n_episodes, 'test', args_data)
    elif dataset=='tiered':
        loader_test = dataset_tiered(n_examples, n_episodes, 'test', args_data)
    
    if not pkl:
        loader_test.load_data()
    else:
        loader_test.load_data_pkl()
    

    # Step 2: init neural networks
    print("init neural networks")

    # construct the model
    model = models.LabelPropagation(args)
    model.cuda(0)

    # load the saved model

    if iters>0:
        model.load_state_dict(torch.load('checkpoints/%s/models/%s_%d_model.t7' % (args['exp_name'], alg, iters)))
    else:
        a = torch.load('checkpoints/%s/models/%s_model_best_%s.t7' %(args['exp_name'], alg, part) )
        #a[Relation]
        model.load_state_dict(torch.load('checkpoints/%s/models/%s_model_best_%s.t7' %(args['exp_name'], alg, part) ))
    print('Loading Parameters from %s' %(args['exp_name']))


    # Step 3: build graph
    print("Testing...")

    all_acc = []
    all_std = []
    all_ci95 = []

    ce_list = []

    for rep in range(repeat):
        list_acc = []

        for epi in tqdm(range(n_test_episodes), desc='test:{}'.format(rep)):

            model.eval()

            # sample data for next batch
            support, s_labels, query, q_labels, unlabel = loader_test.next_data(n_test_way, n_test_shot, n_test_query, train=False)
            support = np.reshape(support, (support.shape[0]*support.shape[1],)+support.shape[2:])
            support = torch.from_numpy(np.transpose(support, (0,3,1,2)))
            query   = np.reshape(query, (query.shape[0]*query.shape[1],)+query.shape[2:])
            query   = torch.from_numpy(np.transpose(query, (0,3,1,2)))
            s_labels = torch.from_numpy(np.reshape(s_labels,(-1,)))
            q_labels = torch.from_numpy(np.reshape(q_labels,(-1,)))
            s_labels = s_labels.type(torch.LongTensor)
            q_labels = q_labels.type(torch.LongTensor)
            s_onehot = torch.zeros(n_test_way*n_test_shot, n_test_way).scatter_(1, s_labels.view(-1,1), 1)
            q_onehot = torch.zeros(n_test_way*n_test_query, n_test_way).scatter_(1, q_labels.view(-1,1), 1)

            with torch.no_grad():
                inputs = [support.cuda(0), s_onehot.cuda(0), query.cuda(0), q_onehot.cuda(0)]
                loss, acc = model(inputs)
            
            list_acc.append(acc.item())

        mean_acc = np.mean(list_acc)
        std_acc  = np.std(list_acc)
        ci95 = 1.96*std_acc/np.sqrt(n_test_episodes)
        m,ci = mean_confidence_interval(list_acc)
        
        print('label, acc:{:.4f},std:{:.4f},ci95:{:.4f},ci:{:.4f}'.format(mean_acc, std_acc, ci95, ci))
        all_acc.append(mean_acc)
        all_std.append(std_acc)
        all_ci95.append(ci95)

    ind = np.argmax(all_acc)
    print('Max acc:{:.5f}, std:{:.5f}, ci95: {:.5f}'.format(all_acc[ind], all_std[ind], all_ci95[ind]))
    print('Avg over {} runs: mean:{:.5f}, std:{:.5f}, ci95: {:.5f}'.format(repeat,np.mean(all_acc),np.mean(all_std),np.mean(all_ci95)))