예제 #1
0
        trlog['max_acc'] = 0.0
        trlog['max_acc_epoch'] = 0
        trlog['cur_train_epoch'] = 0
        print("start new training")

    start_epoch = trlog['cur_train_epoch'] + 1
    print(f"start training from epoch: {start_epoch}")

    timer = Timer()
    global_count = 0
    writer = SummaryWriter(logdir=args.save_path)

    for epoch in range(start_epoch, args.max_epoch + 1):
        lr_scheduler.step()
        model.train()
        tl = Averager()
        ta = Averager()

        label = torch.arange(args.way).repeat(args.query)
        if torch.cuda.is_available():
            label = label.type(torch.cuda.LongTensor)
        else:
            label = label.type(torch.LongTensor)

        for i, batch in enumerate(train_loader, 1):
            global_count = global_count + 1
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
            p = args.shot * args.way
예제 #2
0
        torch.backends.cudnn.benchmark = True
        model = model.cuda()

    test_set = Dataset('test', args)
    sampler = CategoriesSampler(test_set.label, 10000, args.way,
                                args.shot + args.query)
    loader = DataLoader(test_set,
                        batch_sampler=sampler,
                        num_workers=8,
                        pin_memory=True)
    test_acc_record = np.zeros((10000, ))

    model.load_state_dict(torch.load(args.model_path)['params'])
    model.eval()

    ave_acc = Averager()
    label = torch.arange(args.way).repeat(args.query)
    if torch.cuda.is_available():
        label = label.type(torch.cuda.LongTensor)
    else:
        label = label.type(torch.LongTensor)

    for i, batch in enumerate(loader, 1):
        if torch.cuda.is_available():
            data, _ = [_.cuda() for _ in batch]
        else:
            data = batch[0]
        k = args.way * args.shot
        data_shot, data_query = data[:k], data[k:]
        logits, _ = model(data_shot, data_query)
        acc = count_acc(logits, label)
예제 #3
0
    trlog['args'] = vars(args)
    trlog['train_loss'] = []
    trlog['val_loss'] = []
    trlog['train_acc'] = []
    trlog['val_acc'] = []
    trlog['max_acc'] = 0.0
    trlog['max_acc_epoch'] = 0

    timer = Timer()
    global_count = 0
    writer = SummaryWriter(logdir=args.save_path)

    for epoch in range(1, args.max_epoch + 1):
        lr_scheduler.step()
        model.train()
        tl = Averager()
        ta = Averager()

        label = torch.arange(args.way).repeat(args.query)
        if torch.cuda.is_available():
            label = label.type(torch.cuda.LongTensor)
        else:
            label = label.type(torch.LongTensor)

        for i, batch in enumerate(train_loader, 1):
            print("\n\nbatch: ", len(batch))
            global_count = global_count + 1
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
예제 #4
0
    label = torch.arange(args.way, dtype=torch.int8).repeat(args.query)
    #label.shape[0] = n_way * n_q uery
    att_label = torch.zeros(label.shape[0], args.way + 1, args.way + 1)
    for i in range(att_label.shape[0]):
        att_label[i,:] = att_label_basis[label[i].item()] #每一项是一个矩阵
    label = label.type(torch.LongTensor)


    if torch.cuda.is_available():
        label = label.cuda()
        att_label = att_label.cuda()
            
    for epoch in range(1, args.max_epoch + 1):
        lr_scheduler.step()
        model.train()
        tl = Averager()
        ta = Averager()
            
        for i, batch in enumerate(train_loader, 1): #idx start from 1
            global_count = global_count + 1
            #is batch[1] label ?
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
            p = args.shot * args.way
            data_shot, data_query = data[:p], data[p:]
            logits, att = model(data_shot, data_query)
            loss = F.cross_entropy(logits, label)
            acc = count_acc(logits, label)
            writer.add_scalar('data/loss', float(loss), global_count)
예제 #5
0
    else:
        raise ValueError('Non-supported Dataset.')

    model = ProtoNet(args)
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        model = model.cuda()    
    test_set = Dataset('test', args)
    sampler = CategoriesSampler(test_set.label, 10000, args.way, args.shot + args.query)
    loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True)
    test_acc_record = np.zeros((10000,))

    model.load_state_dict(torch.load(args.model_path)['params'])
    model.eval()

    ave_acc = Averager()
    label = torch.arange(args.way).repeat(args.query)
    if torch.cuda.is_available():
        label = label.type(torch.cuda.LongTensor)
    else:
        label = label.type(torch.LongTensor)
        
    with torch.no_grad():
        for i, batch in enumerate(loader, 1):
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
            k = args.way * args.shot
            data_shot, data_query = data[:k], data[k:]