Exemple #1
0
def train(args, model, train_loader, criterion, optimizer, lr_schedule, trlog):
    model.train()
    tl = Averager()
    ta = Averager()
    global global_count, writer
    for i, (data, label) in enumerate(train_loader, 1):
        global_count = global_count + 1
        if torch.cuda.is_available():
            data, label = data.cuda(), label.long().cuda()
        elif not args.mixup:
            label = label.long()

        optimizer.zero_grad()
        logits = model(data)
        loss = criterion(logits, label)
        loss.backward()
        optimizer.step()
        lr_schedule.step()

        writer.add_scalar('data/loss', float(loss), global_count)
        tl.add(loss.item())
        if not args.mixup:
            acc = count_acc(logits, label)
            ta.add(acc)
            writer.add_scalar('data/acc', float(acc), global_count)

        if (i - 1) % 100 == 0 or i == len(train_loader):
            if not args.mixup:
                print(
                    'epoch {}, train {}/{}, lr={:.5f}, loss={:.4f} acc={:.4f}'.
                    format(epoch, i, len(train_loader),
                           optimizer.param_groups[0]['lr'], loss.item(), acc))
            else:
                print('epoch {}, train {}/{}, lr={:.5f}, loss={:.4f}'.format(
                    epoch, i, len(train_loader),
                    optimizer.param_groups[0]['lr'], loss.item()))

    if trlog is not None:
        tl = tl.item()
        trlog['train_loss'].append(tl)
        if not args.mixup:
            ta = ta.item()
            trlog['train_acc'].append(ta)
        else:
            trlog['train_acc'].append(0)
        return model, trlog
    else:
        return model
Exemple #2
0
    def __init__(self, args):
        if args.dataset == 'CUB':
            self.VAL_SETTING = [(5, 1), (5, 5), (5, 20)]
        else:
            self.VAL_SETTING = [(5, 1), (5, 5), (5, 20), (5, 50)]
        if args.eval_dataset == 'CUB':
            self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20)]
        else:
            self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20), (5, 50)]
        self.args = args
        # ensure_path(
        #     self.args.save_path,
        #     scripts_to_save=['model/models', 'model/networks', __file__],
        # )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0
Exemple #3
0
    def __init__(self, args):
        self.args = args
        # ensure_path(
        #     self.args.save_path,
        #     scripts_to_save=['model/models', 'model/networks', __file__],
        # )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0

        # For tst
        if args.tst_free:
            self.trlog['max_tst_criterion'] = 0.0
            self.trlog['max_tst_criterion_interval'] = 0.
            self.trlog['max_tst_criterion_epoch'] = 0
            self.trlog['tst_criterion'] = args.tst_criterion
Exemple #4
0
    def __init__(self, args):
        self.args = args
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_auc'] = 0.0
        self.trlog['max_auc_epoch'] = 0
        self.trlog['max_auc_interval'] = 0.0
Exemple #5
0
    def __init__(self, args):
        self.args = args
        ensure_path(
            self.args.save_path,
            scripts_to_save=['model/models', 'model/networks', __file__],
        )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0
Exemple #6
0
    def train_tst(self):
        args = self.args
        self.model.train()
        if self.args.fix_BN:
            self.model.encoder.eval()

        # Clear evaluation file
        with open(osp.join(self.args.save_path, 'eval.jl'), 'w') as fp:
            pass

        # start FSL training
        label, label_aux = self.prepare_label()
        for epoch in range(1, args.max_epoch + 1):
            self.train_epoch += 1
            self.model.train()
            if self.args.fix_BN:
                self.model.encoder.eval()

            tl1 = Averager()
            tl2 = Averager()
            ta = Averager()

            start_tm = time.time()
            for batch in self.train_loader:
                self.train_step += 1

                if torch.cuda.is_available():
                    data, gt_label = [_.cuda() for _ in batch]
                else:
                    data, gt_label = batch[0], batch[1]

                data_tm = time.time()
                self.dt.add(data_tm - start_tm)

                # get saved centers
                logits, reg_logits = self.para_model(data)
                if reg_logits is not None:
                    loss = F.cross_entropy(logits, label)
                    total_loss = loss + args.balance * F.cross_entropy(reg_logits, label_aux)
                else:
                    loss = F.cross_entropy(logits, label)
                    total_loss = F.cross_entropy(logits, label)

                tl2.add(loss)
                forward_tm = time.time()
                self.ft.add(forward_tm - data_tm)
                acc = count_acc(logits, label)

                tl1.add(total_loss.item())
                ta.add(acc)

                self.optimizer.zero_grad()
                total_loss.backward()
                backward_tm = time.time()
                self.bt.add(backward_tm - forward_tm)

                self.optimizer.step()
                optimizer_tm = time.time()
                self.ot.add(optimizer_tm - backward_tm)

                # refresh start_tm
                start_tm = time.time()

                if args.debug_fast:
                    print('Debug fast, breaking training after 1 mini-batch')
                    break

            self.lr_scheduler.step()
            self.try_evaluate_tst(epoch)

            print('ETA:{}/{}'.format(
                self.timer.measure(),
                self.timer.measure(self.train_epoch / args.max_epoch))
            )

        torch.save(self.trlog, osp.join(args.save_path, 'trlog'))
Exemple #7
0
class Trainer(object, metaclass=abc.ABCMeta):
    def __init__(self, args):
        self.args = args
        # ensure_path(
        #     self.args.save_path,
        #     scripts_to_save=['model/models', 'model/networks', __file__],
        # )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0

    @abc.abstractmethod
    def train(self):
        pass

    @abc.abstractmethod
    def evaluate(self, data_loader):
        pass

    @abc.abstractmethod
    def evaluate_test(self, data_loader):
        pass

    @abc.abstractmethod
    def final_record(self):
        pass

    def try_evaluate(self, epoch):
        args = self.args
        if self.train_epoch % args.eval_interval == 0:
            vl, va, vap = self.evaluate(self.val_loader)
            self.logger.add_scalar('val_loss', float(vl), self.train_epoch)
            self.logger.add_scalar('val_acc', float(va), self.train_epoch)
            print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(
                epoch, vl, va, vap))

            if va >= self.trlog['max_acc']:
                self.trlog['max_acc'] = va
                self.trlog['max_acc_interval'] = vap
                self.trlog['max_acc_epoch'] = self.train_epoch
                self.save_model('max_acc')

    def try_logging(self, tl1, tl2, ta, tg=None):
        args = self.args
        if self.train_step % args.log_interval == 0:
            print(
                'epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}'
                .format(self.train_epoch, self.train_step, self.max_steps,
                        tl1.item(), tl2.item(), ta.item(),
                        self.optimizer.param_groups[0]['lr']))
            self.logger.add_scalar('train_total_loss', tl1.item(),
                                   self.train_step)
            self.logger.add_scalar('train_loss', tl2.item(), self.train_step)
            self.logger.add_scalar('train_acc', ta.item(), self.train_step)
            if tg is not None:
                self.logger.add_scalar('grad_norm', tg.item(), self.train_step)
            print('data_timer: {:.2f} sec, '     \
                  'forward_timer: {:.2f} sec,'   \
                  'backward_timer: {:.2f} sec, ' \
                  'optim_timer: {:.2f} sec'.format(
                        self.dt.item(), self.ft.item(),
                        self.bt.item(), self.ot.item())
                  )
            self.logger.dump()

    def save_model(self, name):
        torch.save(dict(params=self.model.state_dict()),
                   osp.join(self.args.save_path, name + '.pth'))

    def __str__(self):
        return "{}({})".format(self.__class__.__name__,
                               self.model.__class__.__name__)
Exemple #8
0
    def train(self):
        args = self.args
        # start GFSL training
        for epoch in range(1, args.max_epoch + 1):
            self.train_epoch += 1
            self.model.train()

            tl1 = Averager()
            tl2 = Averager()
            ta = Averager()

            start_tm = time.time()
            for _, batch in enumerate(
                    zip(self.train_fsl_loader, self.train_gfsl_loader)):
                self.train_step += 1

                if torch.cuda.is_available():
                    support_data, support_label = batch[0][0].cuda(
                    ), batch[0][1].cuda()
                    query_data, query_label = batch[1][0].cuda(
                    ), batch[1][1].cuda()
                else:
                    support_data, support_label = batch[0][0], batch[0][1]
                    query_data, query_label = batch[1][0], batch[1][1]

                data_tm = time.time()
                self.dt.add(data_tm - start_tm)

                logits = self.model(support_data, query_data, support_label)
                loss = F.cross_entropy(
                    logits,
                    query_label.view(-1, 1).repeat(1, args.num_tasks).view(-1))
                tl2.add(loss.item())

                forward_tm = time.time()
                self.ft.add(forward_tm - data_tm)
                acc = count_acc(
                    logits,
                    query_label.view(-1, 1).repeat(1, args.num_tasks).view(-1))

                tl1.add(loss.item())
                ta.add(acc)

                self.optimizer.zero_grad()
                loss.backward()
                backward_tm = time.time()
                self.bt.add(backward_tm - forward_tm)

                self.optimizer.step()
                self.lr_scheduler.step()
                optimizer_tm = time.time()
                self.ot.add(optimizer_tm - backward_tm)
                self.try_logging(tl1, tl2, ta)

                # refresh start_tm
                start_tm = time.time()
                del logits, loss
                torch.cuda.empty_cache()

                self.try_evaluate(epoch)

            print('ETA:{}/{}'.format(
                self.timer.measure(),
                self.timer.measure(self.train_epoch / args.max_epoch)))

        torch.save(self.trlog, osp.join(args.save_path, 'trlog'))
        self.save_model('epoch-last')
Exemple #9
0
class Trainer(object, metaclass=abc.ABCMeta):
    def __init__(self, args):
        if args.dataset == 'CUB':
            self.VAL_SETTING = [(5, 1), (5, 5), (5, 20)]
        else:
            self.VAL_SETTING = [(5, 1), (5, 5), (5, 20), (5, 50)]
        if args.eval_dataset == 'CUB':
            self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20)]
        else:
            self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20), (5, 50)]
        self.args = args
        # ensure_path(
        #     self.args.save_path,
        #     scripts_to_save=['model/models', 'model/networks', __file__],
        # )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0

    @abc.abstractmethod
    def train(self):
        pass

    @abc.abstractmethod
    def evaluate(self, data_loader):
        pass

    @abc.abstractmethod
    def evaluate_test(self, data_loader):
        pass

    @abc.abstractmethod
    def final_record(self):
        pass

    def try_evaluate(self, epoch):
        args = self.args
        if self.train_epoch % args.eval_interval == 0:
            if args.eval_all:
                for i, (args.eval_way, args.eval_shot) in enumerate(self.VAL_SETTING):
                    if i == 0:
                        vl, va, vap = self.eval_process(args, epoch)
                    else:
                        self.eval_process(args, epoch)
            else:
                vl, va, vap = self.eval_process(args, epoch)
            if va >= self.trlog['max_acc']:
                self.trlog['max_acc'] = va
                self.trlog['max_acc_interval'] = vap
                self.trlog['max_acc_epoch'] = self.train_epoch
                self.save_model('max_acc')
            print('best epoch {}, best val acc={:.4f} + {:.4f}'.format(
                self.trlog['max_acc_epoch'],
                self.trlog['max_acc'],
                self.trlog['max_acc_interval']))

    def eval_process(self, args, epoch):
        valset = self.valset
        if args.model_class in ['QsimProtoNet', 'QsimMatchNet']:
            val_sampler = NegativeSampler(args, valset.label,
                                          args.num_eval_episodes,
                                          args.eval_way, args.eval_shot + args.eval_query)
        else:
            val_sampler = CategoriesSampler(valset.label,
                                            args.num_eval_episodes,
                                            args.eval_way, args.eval_shot + args.eval_query)
        val_loader = DataLoader(dataset=valset,
                                batch_sampler=val_sampler,
                                num_workers=args.num_workers,
                                pin_memory=True)
        vl, va, vap = self.evaluate(val_loader)
        self.logger.add_scalar('%dw%ds_val_loss' % (args.eval_way, args.eval_shot), float(vl),
                               self.train_epoch)
        self.logger.add_scalar('%dw%ds_val_acc' % (args.eval_way, args.eval_shot), float(va),
                               self.train_epoch)
        print('epoch {},{} way {} shot, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(epoch, args.eval_way,
                                                                                   args.eval_shot, vl, va,
                                                                                   vap))
        return vl, va, vap

    def try_logging(self, tl1, tl2, ta, tg=None):
        args = self.args
        if self.train_step % args.log_interval == 0:
            print('epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}'
                  .format(self.train_epoch,
                          self.train_step,
                          self.max_steps,
                          tl1.item(), tl2.item(), ta.item(),
                          self.optimizer.param_groups[0]['lr']))
            self.logger.add_scalar('train_total_loss', tl1.item(), self.train_step)
            self.logger.add_scalar('train_loss', tl2.item(), self.train_step)
            self.logger.add_scalar('train_acc', ta.item(), self.train_step)
            if tg is not None:
                self.logger.add_scalar('grad_norm', tg.item(), self.train_step)
            print('data_timer: {:.2f} sec, ' \
                  'forward_timer: {:.2f} sec,' \
                  'backward_timer: {:.2f} sec, ' \
                  'optim_timer: {:.2f} sec'.format(
                self.dt.item(), self.ft.item(),
                self.bt.item(), self.ot.item())
            )
            self.logger.dump()

    def save_model(self, name):
        torch.save(
            dict(params=self.model.state_dict()),
            osp.join(self.args.save_path, name + '.pth')
        )

    def __str__(self):
        return "{}({})".format(
            self.__class__.__name__,
            self.model.__class__.__name__
        )
Exemple #10
0
def validate(args, model, val_loader, epoch, trlog=None):
    model.eval()
    global writer
    vl_dist, va_dist, vl_sim, va_sim = Averager(), Averager(), Averager(
    ), Averager()
    if trlog is not None:
        print('[Dist] best epoch {}, current best val acc={:.4f}'.format(
            trlog['max_acc_dist_epoch'], trlog['max_acc_dist']))
        print('[Sim] best epoch {}, current best val acc={:.4f}'.format(
            trlog['max_acc_sim_epoch'], trlog['max_acc_sim']))
    # test performance with Few-Shot
    label = torch.arange(args.num_val_class).repeat(args.query).long()
    if torch.cuda.is_available():
        label = label.cuda()
    with torch.no_grad():
        for i, batch in tqdm(enumerate(val_loader, 1), total=len(val_loader)):
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data, _ = batch
            data_shot, data_query = data[:args.num_val_class], data[
                args.num_val_class:]  # 16-way test
            logits_dist, logits_sim = model.forward_proto(
                data_shot, data_query, args.num_val_class)
            loss_dist = F.cross_entropy(logits_dist, label)
            acc_dist = count_acc(logits_dist, label)
            loss_sim = F.cross_entropy(logits_sim, label)
            acc_sim = count_acc(logits_sim, label)
            vl_dist.add(loss_dist.item())
            va_dist.add(acc_dist)
            vl_sim.add(loss_sim.item())
            va_sim.add(acc_sim)

    vl_dist = vl_dist.item()
    va_dist = va_dist.item()
    vl_sim = vl_sim.item()
    va_sim = va_sim.item()

    print(
        'epoch {}, val, loss_dist={:.4f} acc_dist={:.4f} loss_sim={:.4f} acc_sim={:.4f}'
        .format(epoch, vl_dist, va_dist, vl_sim, va_sim))

    if trlog is not None:
        writer.add_scalar('data/val_loss_dist', float(vl_dist), epoch)
        writer.add_scalar('data/val_acc_dist', float(va_dist), epoch)
        writer.add_scalar('data/val_loss_sim', float(vl_sim), epoch)
        writer.add_scalar('data/val_acc_sim', float(va_sim), epoch)

        if va_dist > trlog['max_acc_dist']:
            trlog['max_acc_dist'] = va_dist
            trlog['max_acc_dist_epoch'] = epoch
            save_model('max_acc_dist', model, args)
            save_checkpoint(True)

        if va_sim > trlog['max_acc_sim']:
            trlog['max_acc_sim'] = va_sim
            trlog['max_acc_sim_epoch'] = epoch
            save_model('max_acc_sim', model, args)
            save_checkpoint(True)

        trlog['val_loss_dist'].append(vl_dist)
        trlog['val_acc_dist'].append(va_dist)
        trlog['val_loss_sim'].append(vl_sim)
        trlog['val_acc_sim'].append(va_sim)
        return trlog
Exemple #11
0
        trlog['max_acc_sim'] = 0.0
        trlog['max_acc_sim_epoch'] = 0
        initial_lr = args.lr
        global_count = 0

    timer = Timer()
    writer = SummaryWriter(logdir=args.save_path)
    for epoch in range(init_epoch, args.max_epoch + 1):
        # refine the step-size
        if epoch in args.schedule:
            initial_lr *= args.gamma
            for param_group in optimizer.param_groups:
                param_group['lr'] = initial_lr

        model.train()
        tl = Averager()
        ta = Averager()

        for i, batch in enumerate(train_loader, 1):
            global_count = global_count + 1
            if torch.cuda.is_available():
                data, label = [_.cuda() for _ in batch]
                label = label.type(torch.cuda.LongTensor)
            else:
                data, label = batch
                label = label.type(torch.LongTensor)
            logits = model(data)
            loss = criterion(logits, label)
            acc = count_acc(logits, label)
            writer.add_scalar('data/loss', float(loss), global_count)
            writer.add_scalar('data/acc', float(acc), global_count)
Exemple #12
0
                                    pin_memory=True)
            # test_set = Dataset('test', args.unsupervised, args)
            test_set = get_dataset(n, 'test', args.unsupervised, args)
            sampler = CategoriesSampler(test_set.label, args.num_test_episodes,
                                        args.way, args.shot + args.query)
            loader = DataLoader(dataset=test_set,
                                batch_sampler=sampler,
                                num_workers=8,
                                pin_memory=True)
            shot_label = torch.arange(min(args.way, valset.num_class)).repeat(
                args.shot).numpy()
            query_label = torch.arange(min(args.way, valset.num_class)).repeat(
                args.query).numpy()
            val_acc_record = np.zeros((500, len(c_list)))

            ave_acc = Averager()

            with torch.no_grad():
                for i, batch in tqdm(enumerate(val_loader, 1),
                                     total=len(val_loader),
                                     desc='val eval'):
                    if torch.cuda.is_available():
                        data, _ = [_.cuda() for _ in batch]
                    else:
                        data = batch[0]

                    data_emb = model(data, is_emb=True)
                    if args.centralize:
                        data_emb = data_emb - class_mean
                    if args.normalize:
                        data_emb = F.normalize(data_emb, dim=1, p=2)
Exemple #13
0
class Trainer(object, metaclass=abc.ABCMeta):
    def __init__(self, args):
        self.args = args
        # ensure_path(
        #     self.args.save_path,
        #     scripts_to_save=['model/models', 'model/networks', __file__],
        # )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0

        # For tst
        if args.tst_free:
            self.trlog['max_tst_criterion'] = 0.0
            self.trlog['max_tst_criterion_interval'] = 0.
            self.trlog['max_tst_criterion_epoch'] = 0
            self.trlog['tst_criterion'] = args.tst_criterion

    @abc.abstractmethod
    def train(self):
        pass

    @abc.abstractmethod
    def evaluate(self, data_loader):
        pass

    @abc.abstractmethod
    def evaluate_test(self, data_loader):
        pass

    @abc.abstractmethod
    def final_record(self):
        pass

    def print_metric_summaries(self, metric_summaries, prefix='\t'):
        for key, (mean, std) in metric_summaries.items():
            print('{}{}: {:.4f} +/- {:.4f}'.format(prefix, key, mean, std))

    def log_metric_summaries(self, metric_summaries, epoch, prefix=''):
        for key, (mean, std) in metric_summaries.items():
            self.logger.add_scalar('{}{}'.format(prefix, key), mean, epoch)

    def try_evaluate(self, epoch):
        args = self.args
        if self.train_epoch % args.eval_interval == 0:

            if not args.tst_free:
                vl, va, vap = self.evaluate(self.val_loader)
                self.logger.add_scalar('val_loss', float(vl), self.train_epoch)
                self.logger.add_scalar('val_acc', float(va), self.train_epoch)
                print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(
                    epoch, vl, va, vap))
            else:
                vl, va, vap, metrics = self.evaluate(self.val_loader)
                self.logger.add_scalar('val_loss', float(vl), self.train_epoch)
                self.logger.add_scalar('val_acc', float(va), self.train_epoch)
                print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(
                    epoch, vl, va, vap))
                self.print_metric_summaries(metrics, prefix='\tval_')
                self.log_metric_summaries(metrics, epoch=epoch, prefix='val_')

            if va >= self.trlog['max_acc']:
                self.trlog['max_acc'] = va
                self.trlog['max_acc_interval'] = vap
                self.trlog['max_acc_epoch'] = self.train_epoch
                self.save_model('max_acc')

            # Probably a different criterion for TST -> optimize here.
            if args.tst_free and args.tst_criterion:
                assert args.tst_criterion in metrics, 'Criterion {} not found in {}'.format(
                    args.tst_criterion, metrics.keys())
                criterion, criterion_interval = metrics[args.tst_criterion]
                if criterion >= self.trlog['max_tst_criterion']:
                    self.trlog['max_tst_criterion'] = criterion
                    self.trlog[
                        'max_tst_criterion_interval'] = criterion_interval
                    self.trlog['max_tst_criterion_epoch'] = self.train_epoch
                    self.save_model('max_tst_criterion')
                    print(
                        'Found new best model at Epoch {} : Validation {} = {:.4f} +/- {:4f}'
                        .format(self.train_epoch, args.tst_criterion,
                                criterion, criterion_interval))

    def try_logging(self, tl1, tl2, ta, tg=None):
        args = self.args
        if self.train_step % args.log_interval == 0:
            print(
                'epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}'
                .format(self.train_epoch, self.train_step, self.max_steps,
                        tl1.item(), tl2.item(), ta.item(),
                        self.optimizer.param_groups[0]['lr']))
            self.logger.add_scalar('train_total_loss', tl1.item(),
                                   self.train_step)
            self.logger.add_scalar('train_loss', tl2.item(), self.train_step)
            self.logger.add_scalar('train_acc', ta.item(), self.train_step)
            if tg is not None:
                self.logger.add_scalar('grad_norm', tg.item(), self.train_step)
            print('data_timer: {:.2f} sec, '     \
                  'forward_timer: {:.2f} sec,'   \
                  'backward_timer: {:.2f} sec, ' \
                  'optim_timer: {:.2f} sec'.format(
                        self.dt.item(), self.ft.item(),
                        self.bt.item(), self.ot.item())
                  )
            self.logger.dump()

    def save_model(self, name):
        torch.save(dict(params=self.model.state_dict()),
                   osp.join(self.args.save_path, name + '.pth'))

    def __str__(self):
        return "{}({})".format(self.__class__.__name__,
                               self.model.__class__.__name__)
Exemple #14
0
    def train(self):
        args = self.args
        self.model.train()
        if self.args.fix_BN:
            self.model.encoder.eval()
        
        # start FSL training
        label, label_aux = self.prepare_label()
        for epoch in range(1, args.max_epoch + 1):
            self.train_epoch += 1
            self.model.train()
            if self.args.fix_BN:
                self.model.encoder.eval()
            
            tl1 = Averager()
            tl2 = Averager()
            ta = Averager()

            start_tm = time.time()

            for batch in tqdm(self.train_loader):

                data, gt_label = [_.cuda() for _ in batch]
                data_tm = time.time()
                self.dt.add(data_tm - start_tm)

                # get saved centers
                logits, reg_logits = self.para_model(data)
                logits = logits.view(-1, args.way)
                oh_query = torch.nn.functional.one_hot(label, args.way)

                sims = logits
                temp = (sims * oh_query).sum(-1)
                e_sim_p = temp - self.model.margin
                e_sim_p_pos = F.relu(e_sim_p)
                e_sim_p_neg = F.relu(-e_sim_p)

                l_open_margin = e_sim_p_pos.mean(-1)
                l_open = e_sim_p_neg.mean(-1)
                if reg_logits is not None:
                    loss = F.cross_entropy(logits, label)
                    total_loss = loss + args.balance * F.cross_entropy(reg_logits, label_aux)
                else:
                    loss = F.cross_entropy(logits, label)
                total_loss = total_loss + args.open_balance * l_open
                tl2.add(loss)
                forward_tm = time.time()
                self.ft.add(forward_tm - data_tm)
                acc = count_acc(logits, label)

                tl1.add(total_loss.item())
                ta.add(acc)

                self.optimizer.zero_grad()
                total_loss.backward(retain_graph=True)

                self.optimizer_margin.zero_grad()

                l_open_margin.backward()
                self.optimizer.step()
                self.optimizer_margin.step()

                backward_tm = time.time()
                self.bt.add(backward_tm - forward_tm)

                optimizer_tm = time.time()
                self.ot.add(optimizer_tm - backward_tm)
                # refresh start_tm
                start_tm = time.time()

            print('lr: {:.4f} Total_loss: {:.4f} ce_loss {:.4f} l_open: {:4f} R: {:4f}\n'.format(self.optimizer_margin.param_groups[0]['lr'],\
                total_loss.item(), loss.item(), l_open.item(), self.model.margin.item()))
            self.lr_scheduler.step()
            self.lr_scheduler_margin.step()

            self.try_evaluate(epoch)

            print('ETA:{}/{}'.format(
                    self.timer.measure(),
                    self.timer.measure(self.train_epoch / args.max_epoch))
            )

        torch.save(self.trlog, osp.join(args.save_path, 'trlog'))
        self.save_model('epoch-last')