예제 #1
0
def train(args, model, train_loader, criterion, optimizer, lr_schedule, trlog):
    model.train()
    tl = Averager()
    ta = Averager()
    global global_count, writer
    for i, (data, label) in enumerate(train_loader, 1):
        global_count = global_count + 1
        if torch.cuda.is_available():
            data, label = data.cuda(), label.long().cuda()
        elif not args.mixup:
            label = label.long()

        optimizer.zero_grad()
        logits = model(data)
        loss = criterion(logits, label)
        loss.backward()
        optimizer.step()
        lr_schedule.step()

        writer.add_scalar('data/loss', float(loss), global_count)
        tl.add(loss.item())
        if not args.mixup:
            acc = count_acc(logits, label)
            ta.add(acc)
            writer.add_scalar('data/acc', float(acc), global_count)

        if (i - 1) % 100 == 0 or i == len(train_loader):
            if not args.mixup:
                print(
                    'epoch {}, train {}/{}, lr={:.5f}, loss={:.4f} acc={:.4f}'.
                    format(epoch, i, len(train_loader),
                           optimizer.param_groups[0]['lr'], loss.item(), acc))
            else:
                print('epoch {}, train {}/{}, lr={:.5f}, loss={:.4f}'.format(
                    epoch, i, len(train_loader),
                    optimizer.param_groups[0]['lr'], loss.item()))

    if trlog is not None:
        tl = tl.item()
        trlog['train_loss'].append(tl)
        if not args.mixup:
            ta = ta.item()
            trlog['train_acc'].append(ta)
        else:
            trlog['train_acc'].append(0)
        return model, trlog
    else:
        return model
예제 #2
0
    def train_tst(self):
        args = self.args
        self.model.train()
        if self.args.fix_BN:
            self.model.encoder.eval()

        # Clear evaluation file
        with open(osp.join(self.args.save_path, 'eval.jl'), 'w') as fp:
            pass

        # start FSL training
        label, label_aux = self.prepare_label()
        for epoch in range(1, args.max_epoch + 1):
            self.train_epoch += 1
            self.model.train()
            if self.args.fix_BN:
                self.model.encoder.eval()

            tl1 = Averager()
            tl2 = Averager()
            ta = Averager()

            start_tm = time.time()
            for batch in self.train_loader:
                self.train_step += 1

                if torch.cuda.is_available():
                    data, gt_label = [_.cuda() for _ in batch]
                else:
                    data, gt_label = batch[0], batch[1]

                data_tm = time.time()
                self.dt.add(data_tm - start_tm)

                # get saved centers
                logits, reg_logits = self.para_model(data)
                if reg_logits is not None:
                    loss = F.cross_entropy(logits, label)
                    total_loss = loss + args.balance * F.cross_entropy(reg_logits, label_aux)
                else:
                    loss = F.cross_entropy(logits, label)
                    total_loss = F.cross_entropy(logits, label)

                tl2.add(loss)
                forward_tm = time.time()
                self.ft.add(forward_tm - data_tm)
                acc = count_acc(logits, label)

                tl1.add(total_loss.item())
                ta.add(acc)

                self.optimizer.zero_grad()
                total_loss.backward()
                backward_tm = time.time()
                self.bt.add(backward_tm - forward_tm)

                self.optimizer.step()
                optimizer_tm = time.time()
                self.ot.add(optimizer_tm - backward_tm)

                # refresh start_tm
                start_tm = time.time()

                if args.debug_fast:
                    print('Debug fast, breaking training after 1 mini-batch')
                    break

            self.lr_scheduler.step()
            self.try_evaluate_tst(epoch)

            print('ETA:{}/{}'.format(
                self.timer.measure(),
                self.timer.measure(self.train_epoch / args.max_epoch))
            )

        torch.save(self.trlog, osp.join(args.save_path, 'trlog'))
예제 #3
0
    def train(self):
        args = self.args
        # start GFSL training
        for epoch in range(1, args.max_epoch + 1):
            self.train_epoch += 1
            self.model.train()

            tl1 = Averager()
            tl2 = Averager()
            ta = Averager()

            start_tm = time.time()
            for _, batch in enumerate(
                    zip(self.train_fsl_loader, self.train_gfsl_loader)):
                self.train_step += 1

                if torch.cuda.is_available():
                    support_data, support_label = batch[0][0].cuda(
                    ), batch[0][1].cuda()
                    query_data, query_label = batch[1][0].cuda(
                    ), batch[1][1].cuda()
                else:
                    support_data, support_label = batch[0][0], batch[0][1]
                    query_data, query_label = batch[1][0], batch[1][1]

                data_tm = time.time()
                self.dt.add(data_tm - start_tm)

                logits = self.model(support_data, query_data, support_label)
                loss = F.cross_entropy(
                    logits,
                    query_label.view(-1, 1).repeat(1, args.num_tasks).view(-1))
                tl2.add(loss.item())

                forward_tm = time.time()
                self.ft.add(forward_tm - data_tm)
                acc = count_acc(
                    logits,
                    query_label.view(-1, 1).repeat(1, args.num_tasks).view(-1))

                tl1.add(loss.item())
                ta.add(acc)

                self.optimizer.zero_grad()
                loss.backward()
                backward_tm = time.time()
                self.bt.add(backward_tm - forward_tm)

                self.optimizer.step()
                self.lr_scheduler.step()
                optimizer_tm = time.time()
                self.ot.add(optimizer_tm - backward_tm)
                self.try_logging(tl1, tl2, ta)

                # refresh start_tm
                start_tm = time.time()
                del logits, loss
                torch.cuda.empty_cache()

                self.try_evaluate(epoch)

            print('ETA:{}/{}'.format(
                self.timer.measure(),
                self.timer.measure(self.train_epoch / args.max_epoch)))

        torch.save(self.trlog, osp.join(args.save_path, 'trlog'))
        self.save_model('epoch-last')
예제 #4
0
def validate(args, model, val_loader, epoch, trlog=None):
    model.eval()
    global writer
    vl_dist, va_dist, vl_sim, va_sim = Averager(), Averager(), Averager(
    ), Averager()
    if trlog is not None:
        print('[Dist] best epoch {}, current best val acc={:.4f}'.format(
            trlog['max_acc_dist_epoch'], trlog['max_acc_dist']))
        print('[Sim] best epoch {}, current best val acc={:.4f}'.format(
            trlog['max_acc_sim_epoch'], trlog['max_acc_sim']))
    # test performance with Few-Shot
    label = torch.arange(args.num_val_class).repeat(args.query).long()
    if torch.cuda.is_available():
        label = label.cuda()
    with torch.no_grad():
        for i, batch in tqdm(enumerate(val_loader, 1), total=len(val_loader)):
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data, _ = batch
            data_shot, data_query = data[:args.num_val_class], data[
                args.num_val_class:]  # 16-way test
            logits_dist, logits_sim = model.forward_proto(
                data_shot, data_query, args.num_val_class)
            loss_dist = F.cross_entropy(logits_dist, label)
            acc_dist = count_acc(logits_dist, label)
            loss_sim = F.cross_entropy(logits_sim, label)
            acc_sim = count_acc(logits_sim, label)
            vl_dist.add(loss_dist.item())
            va_dist.add(acc_dist)
            vl_sim.add(loss_sim.item())
            va_sim.add(acc_sim)

    vl_dist = vl_dist.item()
    va_dist = va_dist.item()
    vl_sim = vl_sim.item()
    va_sim = va_sim.item()

    print(
        'epoch {}, val, loss_dist={:.4f} acc_dist={:.4f} loss_sim={:.4f} acc_sim={:.4f}'
        .format(epoch, vl_dist, va_dist, vl_sim, va_sim))

    if trlog is not None:
        writer.add_scalar('data/val_loss_dist', float(vl_dist), epoch)
        writer.add_scalar('data/val_acc_dist', float(va_dist), epoch)
        writer.add_scalar('data/val_loss_sim', float(vl_sim), epoch)
        writer.add_scalar('data/val_acc_sim', float(va_sim), epoch)

        if va_dist > trlog['max_acc_dist']:
            trlog['max_acc_dist'] = va_dist
            trlog['max_acc_dist_epoch'] = epoch
            save_model('max_acc_dist', model, args)
            save_checkpoint(True)

        if va_sim > trlog['max_acc_sim']:
            trlog['max_acc_sim'] = va_sim
            trlog['max_acc_sim_epoch'] = epoch
            save_model('max_acc_sim', model, args)
            save_checkpoint(True)

        trlog['val_loss_dist'].append(vl_dist)
        trlog['val_acc_dist'].append(va_dist)
        trlog['val_loss_sim'].append(vl_sim)
        trlog['val_acc_sim'].append(va_sim)
        return trlog
예제 #5
0
            if torch.cuda.is_available():
                data, label = [_.cuda() for _ in batch]
                label = label.type(torch.cuda.LongTensor)
            else:
                data, label = batch
                label = label.type(torch.LongTensor)
            logits = model(data)
            loss = criterion(logits, label)
            acc = count_acc(logits, label)
            writer.add_scalar('data/loss', float(loss), global_count)
            writer.add_scalar('data/acc', float(acc), global_count)
            if (i - 1) % 100 == 0:
                print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'.format(
                    epoch, i, len(train_loader), loss.item(), acc))

            tl.add(loss.item())
            ta.add(acc)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        tl = tl.item()
        ta = ta.item()

        # do not do validation in first 500 epoches
        if epoch > 100 or (epoch - 1) % 5 == 0:
            model.eval()
            vl_dist = Averager()
            va_dist = Averager()
            vl_sim = Averager()
예제 #6
0
    def train(self):
        args = self.args
        self.model.train()
        if self.args.fix_BN:
            self.model.encoder.eval()
        
        # start FSL training
        label, label_aux = self.prepare_label()
        for epoch in range(1, args.max_epoch + 1):
            self.train_epoch += 1
            self.model.train()
            if self.args.fix_BN:
                self.model.encoder.eval()
            
            tl1 = Averager()
            tl2 = Averager()
            ta = Averager()

            start_tm = time.time()

            for batch in tqdm(self.train_loader):

                data, gt_label = [_.cuda() for _ in batch]
                data_tm = time.time()
                self.dt.add(data_tm - start_tm)

                # get saved centers
                logits, reg_logits = self.para_model(data)
                logits = logits.view(-1, args.way)
                oh_query = torch.nn.functional.one_hot(label, args.way)

                sims = logits
                temp = (sims * oh_query).sum(-1)
                e_sim_p = temp - self.model.margin
                e_sim_p_pos = F.relu(e_sim_p)
                e_sim_p_neg = F.relu(-e_sim_p)

                l_open_margin = e_sim_p_pos.mean(-1)
                l_open = e_sim_p_neg.mean(-1)
                if reg_logits is not None:
                    loss = F.cross_entropy(logits, label)
                    total_loss = loss + args.balance * F.cross_entropy(reg_logits, label_aux)
                else:
                    loss = F.cross_entropy(logits, label)
                total_loss = total_loss + args.open_balance * l_open
                tl2.add(loss)
                forward_tm = time.time()
                self.ft.add(forward_tm - data_tm)
                acc = count_acc(logits, label)

                tl1.add(total_loss.item())
                ta.add(acc)

                self.optimizer.zero_grad()
                total_loss.backward(retain_graph=True)

                self.optimizer_margin.zero_grad()

                l_open_margin.backward()
                self.optimizer.step()
                self.optimizer_margin.step()

                backward_tm = time.time()
                self.bt.add(backward_tm - forward_tm)

                optimizer_tm = time.time()
                self.ot.add(optimizer_tm - backward_tm)
                # refresh start_tm
                start_tm = time.time()

            print('lr: {:.4f} Total_loss: {:.4f} ce_loss {:.4f} l_open: {:4f} R: {:4f}\n'.format(self.optimizer_margin.param_groups[0]['lr'],\
                total_loss.item(), loss.item(), l_open.item(), self.model.margin.item()))
            self.lr_scheduler.step()
            self.lr_scheduler_margin.step()

            self.try_evaluate(epoch)

            print('ETA:{}/{}'.format(
                    self.timer.measure(),
                    self.timer.measure(self.train_epoch / args.max_epoch))
            )

        torch.save(self.trlog, osp.join(args.save_path, 'trlog'))
        self.save_model('epoch-last')