Пример #1
0
    train_sampler_ = PrototypicalBatchSampler(trainset.label, 3, train_way,
                                              shot + query)
    train_loader = DataLoader(dataset=trainset,
                              batch_sampler=train_sampler_,
                              num_workers=8)

    valset = DiabeticRetinopathy('val')
    val_sampler = PrototypicalBatchSampler(valset.label, 4, test_way,
                                           shot + query)
    val_loader = DataLoader(dataset=valset,
                            batch_sampler=val_sampler,
                            num_workers=8,
                            pin_memory=True)

    model = Convnet()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=20,
                                                   gamma=0.5)

    def save_model(name):
        torch.save(model.state_dict(), osp.join(save_path, name + '.pth'))

    trlog = {}
    trlog['args'] = vars(args)
    trlog['train_loss'] = []
    trlog['val_loss'] = []
    trlog['train_acc'] = []
    trlog['val_acc'] = []
    trlog['max_acc'] = 0.0
Пример #2
0
    noise_dim = 128
    model_gen = Hallucinator(noise_dim).cuda()
    model_gen.load_state_dict(
        torch.load(
            './iterative_G3_trainval_lr150epoch_dataaugumentation_2epoch-175_gen.pth'
        ))

    global_proto = torch.load('./global_proto_all_new.pth')
    global_base = global_proto[:args.n_base_class, :]
    global_novel = global_proto[args.n_base_class:, :]

    global_base = [Variable(global_base.cuda(), requires_grad=True)]
    global_novel = [Variable(global_novel.cuda(), requires_grad=True)]

    learning_rate = 0.001
    optimizer_cnn = torch.optim.SGD(model_cnn.parameters(),
                                    lr=learning_rate,
                                    momentum=0.9)
    optimizer_atten = torch.optim.SGD(model_reg.parameters(),
                                      lr=learning_rate,
                                      momentum=0.9)
    optimizer_gen = torch.optim.SGD(model_gen.parameters(),
                                    lr=learning_rate,
                                    momentum=0.9)
    optimizer_global1 = torch.optim.SGD(global_base,
                                        lr=learning_rate,
                                        momentum=0.9)
    optimizer_global2 = torch.optim.SGD(global_novel,
                                        lr=learning_rate,
                                        momentum=0.9)
    valset = MiniImageNet('val')
    val_sampler = CategoriesSampler(valset.label, 400, args.test_way,
                                    args.shot + args.query)
    val_loader = DataLoader(dataset=valset,
                            batch_sampler=val_sampler,
                            num_workers=8,
                            pin_memory=True)

    if args.multi == False:
        gen = generator(1600, 1).cuda()
    else:
        gen = generator(1600, 1600).cuda()
    model = Convnet().cuda()
    #gradient_mean=Bufferswitch()

    optimizer = torch.optim.Adam(list(model.parameters()) +
                                 list(gen.parameters()),
                                 lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=20,
                                                   gamma=0.5)

    def save_model(name):
        torch.save(model.state_dict(), osp.join(save_path, name + '.pth'))

    def save_gen(name):
        torch.save(gen.state_dict(), osp.join(save_path, name + '.pth'))

    trlog = {}
    trlog['args'] = vars(args)
    trlog['train_loss'] = []
Пример #4
0
def main(args):
    device = torch.device(args.device)
    ensure_path(args.save_path)

    data = Data(args.dataset, args.n_batches, args.train_way, args.test_way, args.shot, args.query)
    train_loader = data.train_loader
    val_loader = data.valid_loader

    model = Convnet(x_dim=2).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    def save_model(name):
        torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth'))
    
    trlog = dict(
        args=vars(args),
        train_loss=[],
        val_loss=[],
        train_acc=[],
        val_acc=[],
        max_acc=0.0,
    )

    timer = Timer()

    for epoch in range(1, args.max_epoch + 1):
        lr_scheduler.step()

        model.train()

        tl = Averager()
        ta = Averager()

        for i, batch in enumerate(train_loader, 1):
            data, _ = [_.to(device) for _ in batch]
            data = data.reshape(-1, 2, 105, 105)
            p = args.shot * args.train_way
            embedded = model(data)
            embedded_shot, embedded_query = embedded[:p], embedded[p:]

            proto = embedded_shot.reshape(args.shot, args.train_way, -1).mean(dim=0)

            label = torch.arange(args.train_way).repeat(args.query).to(device)

            logits = euclidean_metric(embedded_query, proto)
            loss = F.cross_entropy(logits, label)
            acc = count_acc(logits, label)
            print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'
                  .format(epoch, i, len(train_loader), loss.item(), acc))

            tl.add(loss.item())
            ta.add(acc)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        tl = tl.item()
        ta = ta.item()

        model.eval()

        vl = Averager()
        va = Averager()

        for i, batch in enumerate(val_loader, 1):
            data, _ = [_.cuda() for _ in batch]
            data = data.reshape(-1, 2, 105, 105)
            p = args.shot * args.test_way
            data_shot, data_query = data[:p], data[p:]

            proto = model(data_shot)
            proto = proto.reshape(args.shot, args.test_way, -1).mean(dim=0)

            label = torch.arange(args.test_way).repeat(args.query).to(device)

            logits = euclidean_metric(model(data_query), proto)
            loss = F.cross_entropy(logits, label)
            acc = count_acc(logits, label)

            vl.add(loss.item())
            va.add(acc)

        vl = vl.item()
        va = va.item()
        print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))

        if va > trlog['max_acc']:
            trlog['max_acc'] = va
            save_model('max-acc')

        trlog['train_loss'].append(tl)
        trlog['train_acc'].append(ta)
        trlog['val_loss'].append(vl)
        trlog['val_acc'].append(va)

        torch.save(trlog, osp.join(args.save_path, 'trlog'))

        save_model('epoch-last')

        if epoch % args.save_epoch == 0:
            save_model('epoch-{}'.format(epoch))

        print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
def main(opts):
    pprint(vars(opts))
    ensure_path(opts.name)
    with open(osp.join(opts.name, 'settings.txt'), 'w') as f:
        dic = vars(opts)
        f.writelines(["{}: {}\n".format(k, dic[k]) for k in dic])

    if opts.seed is not None:
        torch.manual_seed(opts.seed)
        torch.cuda.manual_seed_all(opts.seed)
        np.random.seed(opts.seed)
        print("\nrandom seed: {}".format(opts.seed))
    writer = SummaryWriter('./{}/run'.format(opts.name))

    device = torch.device(opts.device)
    print('using device: {}'.format(device))

    train_set = MiniImageNet(mode='train')
    train_sampler = MiniImageNetSampler(train_set.label, 100, opts.c_sampled,
                                        opts.shot_k + opts.query_k)
    train_loader = DataLoader(dataset=train_set,
                              batch_sampler=train_sampler,
                              num_workers=8,
                              pin_memory=True)
    val_set = MiniImageNet(mode='val')
    val_sampler = MiniImageNetSampler(val_set.label, 400, opts.n_way,
                                      opts.shot_k + opts.query_k)
    val_loader = DataLoader(dataset=val_set,
                            batch_sampler=val_sampler,
                            num_workers=8,
                            pin_memory=True)

    encoder = Convnet().to(device)
    # print(encoder.state_dict()['block4.0.bias'])
    pipeline = protoPipe(encoder, opts.shot_k, opts.query_k)
    pipeline.train()

    optimizer = optim.Adam(
        encoder.parameters(),
        opts.lr,
    )
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=2000,
                                                   gamma=0.5)
    max_val_acc = 0
    for epoch in range(opts.epoch // 100):

        for episode, batch in enumerate(train_loader, 0):
            lr_scheduler.step()

            task = batch[0].view(opts.c_sampled * (opts.shot_k + opts.query_k),
                                 3, 84, 84)
            loss, acc = pipeline(task.to(device), opts.c_sampled)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if episode % 10 == 0:
                writer.add_scalar("train_loss", loss.item(),
                                  epoch * 100 + episode)
                print(
                    'epoch: {}, episode: {} ,loss: {:.4f},acc: {:.4f}'.format(
                        epoch, episode, loss.item(), acc))

        val_acc = []
        pipeline.eval()
        for episode, batch in enumerate(val_loader, 0):
            task = batch[0].view(opts.n_way * (opts.shot_k + opts.query_k), 3,
                                 84, 84)
            loss, acc = pipeline(task.to(device), opts.n_way)
            val_acc.append(acc)

        m, h = mean_confidence_interval(val_acc)
        writer.add_scalar("val_acc", m, epoch * 100)
        print('VAL set  acc: {:.4f}, h: {:.4f}'.format(m, h))
        if epoch % (opts.epoch // 20) == 0:
            torch.save(pipeline.encoder.state_dict(),
                       osp.join(opts.name, 'epoch_{}.pth'.format(epoch)))
        if m > max_val_acc:
            max_val_acc = m
            torch.save(pipeline.encoder.state_dict(),
                       osp.join(opts.name, 'max_acc.pth'))
        pipeline.train()
Пример #6
0
class Agent():
    def __init__(self, memory_size=1000, crop_size=192, z_size=32, lr=LR):
        self.dqn, self.target = Convnet().cuda(), Convnet().cuda()
        self.crop_size = crop_size
        self.z_size = z_size
        self.memory_size = memory_size
        self.mb_pool = np.zeros([memory_size, crop_size, crop_size, z_size])
        self.maft_pool = np.zeros([memory_size, crop_size, crop_size, z_size])
        self.reward_pool = np.zeros([memory_size, 1])
        self.action_pool = np.zeros([memory_size, 1])
        self.count = 0
        self.learn_step_counter = 0
        self.optimizer = torch.optim.Adam(self.dqn.parameters(), lr=lr)
        self.loss_func = nn.MSELoss()

    def setFixed(self, fixed):
        self.fixed = fixed

    def chooseAction(self, m):
        m = torch.unsqueeze(torch.Tensor(m).cuda().float(), 0)
        # input only one sample
        if np.random.uniform() < EPSILON:  # greedy
            Q_pre = self.dqn(self.fixed, m)
            action = torch.argmax(Q_pre)[1].cpu().numpy()
        else:  # random
            action = np.random.randint(0, N_ACTIONS)
        return action

    def saveState(self, m_bf, a, reward, m_aft):
        if self.count == self.memory_size:
            self.flushMem()
        self.mb_pool[self.count, ...] = m_bf
        self.reward_pool[self.count, ...] = reward
        self.maft_pool[self.count, ...] = m_aft
        self.action_pool[self.count, ...] = a
        self.count += 1

    def flushMem(self):
        self.count = 0

    def saveCKPT(self, path):
        torch.save(self.dqn.state_dict(), path)

    def learnDQN(self):
        if self.learn_step_counter % TARGET_REPLACE_ITER == 0:
            self.target.load_state_dict(self.dqn.state_dict())
        self.learn_step_counter += 1
        # sample batch transitions
        sample_index = np.random.choice(self.memory_size, BATCH_SIZE)
        b_memory = self.mb_pool[sample_index]
        aft_memory = self.maft_pool[sample_index]
        r_memory = self.reward_pool[sample_index]
        a_memory = self.action_pool[sample_index]

        b_s = torch.Tensor(b_memory).float().cuda()
        b_a = torch.Tensor(a_memory.astype(int)).long().cuda()
        b_r = torch.Tensor(r_memory).float().cuda()
        b_aft = torch.Tensor(aft_memory).float().cuda()

        # q_eval w.r.t the action in experience
        q_eval = self.dqn(b_s).gather(1, b_a)  # shape (batch, 1)
        q_next = self.target(
            b_aft).detach()  # detach from graph, don't backpropagate
        q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE,
                                                       1)  # shape (batch, 1)
        loss = self.loss_func(q_eval, q_target)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()