Esempio n. 1
0
def train(args, model, train_loader, criterion, optimizer, lr_schedule, trlog):
    model.train()
    tl = Averager()
    ta = Averager()
    global global_count, writer
    for i, (data, label) in enumerate(train_loader, 1):
        global_count = global_count + 1
        if torch.cuda.is_available():
            data, label = data.cuda(), label.long().cuda()
        elif not args.mixup:
            label = label.long()

        optimizer.zero_grad()
        logits = model(data)
        loss = criterion(logits, label)
        loss.backward()
        optimizer.step()
        lr_schedule.step()

        writer.add_scalar('data/loss', float(loss), global_count)
        tl.add(loss.item())
        if not args.mixup:
            acc = count_acc(logits, label)
            ta.add(acc)
            writer.add_scalar('data/acc', float(acc), global_count)

        if (i - 1) % 100 == 0 or i == len(train_loader):
            if not args.mixup:
                print(
                    'epoch {}, train {}/{}, lr={:.5f}, loss={:.4f} acc={:.4f}'.
                    format(epoch, i, len(train_loader),
                           optimizer.param_groups[0]['lr'], loss.item(), acc))
            else:
                print('epoch {}, train {}/{}, lr={:.5f}, loss={:.4f}'.format(
                    epoch, i, len(train_loader),
                    optimizer.param_groups[0]['lr'], loss.item()))

    if trlog is not None:
        tl = tl.item()
        trlog['train_loss'].append(tl)
        if not args.mixup:
            ta = ta.item()
            trlog['train_acc'].append(ta)
        else:
            trlog['train_acc'].append(0)
        return model, trlog
    else:
        return model
Esempio n. 2
0
class Trainer(object, metaclass=abc.ABCMeta):
    def __init__(self, args):
        self.args = args
        # ensure_path(
        #     self.args.save_path,
        #     scripts_to_save=['model/models', 'model/networks', __file__],
        # )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0

    @abc.abstractmethod
    def train(self):
        pass

    @abc.abstractmethod
    def evaluate(self, data_loader):
        pass

    @abc.abstractmethod
    def evaluate_test(self, data_loader):
        pass

    @abc.abstractmethod
    def final_record(self):
        pass

    def try_evaluate(self, epoch):
        args = self.args
        if self.train_epoch % args.eval_interval == 0:
            vl, va, vap = self.evaluate(self.val_loader)
            self.logger.add_scalar('val_loss', float(vl), self.train_epoch)
            self.logger.add_scalar('val_acc', float(va), self.train_epoch)
            print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(
                epoch, vl, va, vap))

            if va >= self.trlog['max_acc']:
                self.trlog['max_acc'] = va
                self.trlog['max_acc_interval'] = vap
                self.trlog['max_acc_epoch'] = self.train_epoch
                self.save_model('max_acc')

    def try_logging(self, tl1, tl2, ta, tg=None):
        args = self.args
        if self.train_step % args.log_interval == 0:
            print(
                'epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}'
                .format(self.train_epoch, self.train_step, self.max_steps,
                        tl1.item(), tl2.item(), ta.item(),
                        self.optimizer.param_groups[0]['lr']))
            self.logger.add_scalar('train_total_loss', tl1.item(),
                                   self.train_step)
            self.logger.add_scalar('train_loss', tl2.item(), self.train_step)
            self.logger.add_scalar('train_acc', ta.item(), self.train_step)
            if tg is not None:
                self.logger.add_scalar('grad_norm', tg.item(), self.train_step)
            print('data_timer: {:.2f} sec, '     \
                  'forward_timer: {:.2f} sec,'   \
                  'backward_timer: {:.2f} sec, ' \
                  'optim_timer: {:.2f} sec'.format(
                        self.dt.item(), self.ft.item(),
                        self.bt.item(), self.ot.item())
                  )
            self.logger.dump()

    def save_model(self, name):
        torch.save(dict(params=self.model.state_dict()),
                   osp.join(self.args.save_path, name + '.pth'))

    def __str__(self):
        return "{}({})".format(self.__class__.__name__,
                               self.model.__class__.__name__)
Esempio n. 3
0
class Trainer(object, metaclass=abc.ABCMeta):
    def __init__(self, args):
        if args.dataset == 'CUB':
            self.VAL_SETTING = [(5, 1), (5, 5), (5, 20)]
        else:
            self.VAL_SETTING = [(5, 1), (5, 5), (5, 20), (5, 50)]
        if args.eval_dataset == 'CUB':
            self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20)]
        else:
            self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20), (5, 50)]
        self.args = args
        # ensure_path(
        #     self.args.save_path,
        #     scripts_to_save=['model/models', 'model/networks', __file__],
        # )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0

    @abc.abstractmethod
    def train(self):
        pass

    @abc.abstractmethod
    def evaluate(self, data_loader):
        pass

    @abc.abstractmethod
    def evaluate_test(self, data_loader):
        pass

    @abc.abstractmethod
    def final_record(self):
        pass

    def try_evaluate(self, epoch):
        args = self.args
        if self.train_epoch % args.eval_interval == 0:
            if args.eval_all:
                for i, (args.eval_way, args.eval_shot) in enumerate(self.VAL_SETTING):
                    if i == 0:
                        vl, va, vap = self.eval_process(args, epoch)
                    else:
                        self.eval_process(args, epoch)
            else:
                vl, va, vap = self.eval_process(args, epoch)
            if va >= self.trlog['max_acc']:
                self.trlog['max_acc'] = va
                self.trlog['max_acc_interval'] = vap
                self.trlog['max_acc_epoch'] = self.train_epoch
                self.save_model('max_acc')
            print('best epoch {}, best val acc={:.4f} + {:.4f}'.format(
                self.trlog['max_acc_epoch'],
                self.trlog['max_acc'],
                self.trlog['max_acc_interval']))

    def eval_process(self, args, epoch):
        valset = self.valset
        if args.model_class in ['QsimProtoNet', 'QsimMatchNet']:
            val_sampler = NegativeSampler(args, valset.label,
                                          args.num_eval_episodes,
                                          args.eval_way, args.eval_shot + args.eval_query)
        else:
            val_sampler = CategoriesSampler(valset.label,
                                            args.num_eval_episodes,
                                            args.eval_way, args.eval_shot + args.eval_query)
        val_loader = DataLoader(dataset=valset,
                                batch_sampler=val_sampler,
                                num_workers=args.num_workers,
                                pin_memory=True)
        vl, va, vap = self.evaluate(val_loader)
        self.logger.add_scalar('%dw%ds_val_loss' % (args.eval_way, args.eval_shot), float(vl),
                               self.train_epoch)
        self.logger.add_scalar('%dw%ds_val_acc' % (args.eval_way, args.eval_shot), float(va),
                               self.train_epoch)
        print('epoch {},{} way {} shot, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(epoch, args.eval_way,
                                                                                   args.eval_shot, vl, va,
                                                                                   vap))
        return vl, va, vap

    def try_logging(self, tl1, tl2, ta, tg=None):
        args = self.args
        if self.train_step % args.log_interval == 0:
            print('epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}'
                  .format(self.train_epoch,
                          self.train_step,
                          self.max_steps,
                          tl1.item(), tl2.item(), ta.item(),
                          self.optimizer.param_groups[0]['lr']))
            self.logger.add_scalar('train_total_loss', tl1.item(), self.train_step)
            self.logger.add_scalar('train_loss', tl2.item(), self.train_step)
            self.logger.add_scalar('train_acc', ta.item(), self.train_step)
            if tg is not None:
                self.logger.add_scalar('grad_norm', tg.item(), self.train_step)
            print('data_timer: {:.2f} sec, ' \
                  'forward_timer: {:.2f} sec,' \
                  'backward_timer: {:.2f} sec, ' \
                  'optim_timer: {:.2f} sec'.format(
                self.dt.item(), self.ft.item(),
                self.bt.item(), self.ot.item())
            )
            self.logger.dump()

    def save_model(self, name):
        torch.save(
            dict(params=self.model.state_dict()),
            osp.join(self.args.save_path, name + '.pth')
        )

    def __str__(self):
        return "{}({})".format(
            self.__class__.__name__,
            self.model.__class__.__name__
        )
Esempio n. 4
0
def validate(args, model, val_loader, epoch, trlog=None):
    model.eval()
    global writer
    vl_dist, va_dist, vl_sim, va_sim = Averager(), Averager(), Averager(
    ), Averager()
    if trlog is not None:
        print('[Dist] best epoch {}, current best val acc={:.4f}'.format(
            trlog['max_acc_dist_epoch'], trlog['max_acc_dist']))
        print('[Sim] best epoch {}, current best val acc={:.4f}'.format(
            trlog['max_acc_sim_epoch'], trlog['max_acc_sim']))
    # test performance with Few-Shot
    label = torch.arange(args.num_val_class).repeat(args.query).long()
    if torch.cuda.is_available():
        label = label.cuda()
    with torch.no_grad():
        for i, batch in tqdm(enumerate(val_loader, 1), total=len(val_loader)):
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data, _ = batch
            data_shot, data_query = data[:args.num_val_class], data[
                args.num_val_class:]  # 16-way test
            logits_dist, logits_sim = model.forward_proto(
                data_shot, data_query, args.num_val_class)
            loss_dist = F.cross_entropy(logits_dist, label)
            acc_dist = count_acc(logits_dist, label)
            loss_sim = F.cross_entropy(logits_sim, label)
            acc_sim = count_acc(logits_sim, label)
            vl_dist.add(loss_dist.item())
            va_dist.add(acc_dist)
            vl_sim.add(loss_sim.item())
            va_sim.add(acc_sim)

    vl_dist = vl_dist.item()
    va_dist = va_dist.item()
    vl_sim = vl_sim.item()
    va_sim = va_sim.item()

    print(
        'epoch {}, val, loss_dist={:.4f} acc_dist={:.4f} loss_sim={:.4f} acc_sim={:.4f}'
        .format(epoch, vl_dist, va_dist, vl_sim, va_sim))

    if trlog is not None:
        writer.add_scalar('data/val_loss_dist', float(vl_dist), epoch)
        writer.add_scalar('data/val_acc_dist', float(va_dist), epoch)
        writer.add_scalar('data/val_loss_sim', float(vl_sim), epoch)
        writer.add_scalar('data/val_acc_sim', float(va_sim), epoch)

        if va_dist > trlog['max_acc_dist']:
            trlog['max_acc_dist'] = va_dist
            trlog['max_acc_dist_epoch'] = epoch
            save_model('max_acc_dist', model, args)
            save_checkpoint(True)

        if va_sim > trlog['max_acc_sim']:
            trlog['max_acc_sim'] = va_sim
            trlog['max_acc_sim_epoch'] = epoch
            save_model('max_acc_sim', model, args)
            save_checkpoint(True)

        trlog['val_loss_dist'].append(vl_dist)
        trlog['val_acc_dist'].append(va_dist)
        trlog['val_loss_sim'].append(vl_sim)
        trlog['val_acc_sim'].append(va_sim)
        return trlog
Esempio n. 5
0
            loss = criterion(logits, label)
            acc = count_acc(logits, label)
            writer.add_scalar('data/loss', float(loss), global_count)
            writer.add_scalar('data/acc', float(acc), global_count)
            if (i - 1) % 100 == 0:
                print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'.format(
                    epoch, i, len(train_loader), loss.item(), acc))

            tl.add(loss.item())
            ta.add(acc)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        tl = tl.item()
        ta = ta.item()

        # do not do validation in first 500 epoches
        if epoch > 100 or (epoch - 1) % 5 == 0:
            model.eval()
            vl_dist = Averager()
            va_dist = Averager()
            vl_sim = Averager()
            va_sim = Averager()
            print('[Dist] best epoch {}, current best val acc={:.4f}'.format(
                trlog['max_acc_dist_epoch'], trlog['max_acc_dist']))
            print('[Sim] best epoch {}, current best val acc={:.4f}'.format(
                trlog['max_acc_sim_epoch'], trlog['max_acc_sim']))
            # test performance with Few-Shot
            label = torch.arange(valset.num_class).repeat(args.query)
Esempio n. 6
0
class Trainer(object, metaclass=abc.ABCMeta):
    def __init__(self, args):
        self.args = args
        # ensure_path(
        #     self.args.save_path,
        #     scripts_to_save=['model/models', 'model/networks', __file__],
        # )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0

        # For tst
        if args.tst_free:
            self.trlog['max_tst_criterion'] = 0.0
            self.trlog['max_tst_criterion_interval'] = 0.
            self.trlog['max_tst_criterion_epoch'] = 0
            self.trlog['tst_criterion'] = args.tst_criterion

    @abc.abstractmethod
    def train(self):
        pass

    @abc.abstractmethod
    def evaluate(self, data_loader):
        pass

    @abc.abstractmethod
    def evaluate_test(self, data_loader):
        pass

    @abc.abstractmethod
    def final_record(self):
        pass

    def print_metric_summaries(self, metric_summaries, prefix='\t'):
        for key, (mean, std) in metric_summaries.items():
            print('{}{}: {:.4f} +/- {:.4f}'.format(prefix, key, mean, std))

    def log_metric_summaries(self, metric_summaries, epoch, prefix=''):
        for key, (mean, std) in metric_summaries.items():
            self.logger.add_scalar('{}{}'.format(prefix, key), mean, epoch)

    def try_evaluate(self, epoch):
        args = self.args
        if self.train_epoch % args.eval_interval == 0:

            if not args.tst_free:
                vl, va, vap = self.evaluate(self.val_loader)
                self.logger.add_scalar('val_loss', float(vl), self.train_epoch)
                self.logger.add_scalar('val_acc', float(va), self.train_epoch)
                print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(
                    epoch, vl, va, vap))
            else:
                vl, va, vap, metrics = self.evaluate(self.val_loader)
                self.logger.add_scalar('val_loss', float(vl), self.train_epoch)
                self.logger.add_scalar('val_acc', float(va), self.train_epoch)
                print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(
                    epoch, vl, va, vap))
                self.print_metric_summaries(metrics, prefix='\tval_')
                self.log_metric_summaries(metrics, epoch=epoch, prefix='val_')

            if va >= self.trlog['max_acc']:
                self.trlog['max_acc'] = va
                self.trlog['max_acc_interval'] = vap
                self.trlog['max_acc_epoch'] = self.train_epoch
                self.save_model('max_acc')

            # Probably a different criterion for TST -> optimize here.
            if args.tst_free and args.tst_criterion:
                assert args.tst_criterion in metrics, 'Criterion {} not found in {}'.format(
                    args.tst_criterion, metrics.keys())
                criterion, criterion_interval = metrics[args.tst_criterion]
                if criterion >= self.trlog['max_tst_criterion']:
                    self.trlog['max_tst_criterion'] = criterion
                    self.trlog[
                        'max_tst_criterion_interval'] = criterion_interval
                    self.trlog['max_tst_criterion_epoch'] = self.train_epoch
                    self.save_model('max_tst_criterion')
                    print(
                        'Found new best model at Epoch {} : Validation {} = {:.4f} +/- {:4f}'
                        .format(self.train_epoch, args.tst_criterion,
                                criterion, criterion_interval))

    def try_logging(self, tl1, tl2, ta, tg=None):
        args = self.args
        if self.train_step % args.log_interval == 0:
            print(
                'epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}'
                .format(self.train_epoch, self.train_step, self.max_steps,
                        tl1.item(), tl2.item(), ta.item(),
                        self.optimizer.param_groups[0]['lr']))
            self.logger.add_scalar('train_total_loss', tl1.item(),
                                   self.train_step)
            self.logger.add_scalar('train_loss', tl2.item(), self.train_step)
            self.logger.add_scalar('train_acc', ta.item(), self.train_step)
            if tg is not None:
                self.logger.add_scalar('grad_norm', tg.item(), self.train_step)
            print('data_timer: {:.2f} sec, '     \
                  'forward_timer: {:.2f} sec,'   \
                  'backward_timer: {:.2f} sec, ' \
                  'optim_timer: {:.2f} sec'.format(
                        self.dt.item(), self.ft.item(),
                        self.bt.item(), self.ot.item())
                  )
            self.logger.dump()

    def save_model(self, name):
        torch.save(dict(params=self.model.state_dict()),
                   osp.join(self.args.save_path, name + '.pth'))

    def __str__(self):
        return "{}({})".format(self.__class__.__name__,
                               self.model.__class__.__name__)