Ejemplo n.º 1
0
            vl.add(loss.item())
            va.add(acc)

            proto = None
            logits = None
            loss = None

        vl = vl.item()
        va = va.item()
        print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))

        if va > trlog['max_acc']:
            trlog['max_acc'] = va
            save_model('max-acc')

        trlog['train_loss'].append(tl)
        trlog['train_acc'].append(ta)
        trlog['val_loss'].append(vl)
        trlog['val_acc'].append(va)

        torch.save(trlog, osp.join(save_path, 'trlog'))

        save_model('epoch-last')

        if epoch % save_epoch == 0:
            save_model('epoch-{}'.format(epoch))

        print('ETA:{}/{}'.format(timer.measure(),
                                 timer.measure(epoch / max_epoch)))
        if va > trlog["max_acc"]:
            trlog["max_acc"] = va
            trlog["max_acc_epoch"] = epoch
            save_model("max_acc")

        trlog["train_loss"].append(tl)
        trlog["train_acc"].append(ta)
        trlog["val_loss"].append(vl)
        trlog["val_acc"].append(va)

        torch.save(trlog, osp.join(args.save_path, "trlog"))

        save_model("epoch-last")

        print(
            "ETA:{}/{}".format(timer.measure(), timer.measure(epoch / args.max_epoch))
        )
    writer.close()

    # Test Phase
    trlog = torch.load(osp.join(args.save_path, "trlog"))
    test_set = Dataset("test", args)
    sampler = CategoriesSampler(
        test_set.label, 10000, args.validation_way, args.shot + args.query
    )
    loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True)
    test_acc_record = np.zeros((10000,))

    model.load_state_dict(
        torch.load(osp.join(args.save_path, "max_acc" + ".pth"))["params"]
    )
Ejemplo n.º 3
0
            val_accuracies.append(acc.item())
            val_losses.append(loss.item())

        val_acc_avg = np.mean(np.array(val_accuracies))
        val_acc_ci95 = 1.96 * np.std(np.array(val_accuracies)) / np.sqrt(
            opt.val_episode)

        val_loss_avg = np.mean(np.array(val_losses))

        if val_acc_avg > max_val_acc:
            max_val_acc = val_acc_avg
            torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()},\
                       os.path.join(opt.save_path, 'best_model.pth'))
            log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} % (Best)'\
                  .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95))
        else:
            log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} %'\
                  .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95))

        torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()}\
                   , os.path.join(opt.save_path, 'last_epoch.pth'))

        if epoch % opt.save_epoch == 0:
            torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()}\
                       , os.path.join(opt.save_path, 'epoch_{}.pth'.format(epoch)))

        log(
            log_file_path, 'Elapsed Time: {}/{}\n'.format(
                timer.measure(), timer.measure(epoch / float(opt.num_epoch))))
Ejemplo n.º 4
0
Archivo: train.py Proyecto: yf1291/nlp3
            loss = x_entropy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1))
            acc = count_accuracy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1))

            val_accuracies.append(acc.item())
            val_losses.append(loss.item())
            
        val_acc_avg = np.mean(np.array(val_accuracies))
        val_acc_ci95 = 1.96 * np.std(np.array(val_accuracies)) / np.sqrt(opt.val_episode)

        val_loss_avg = np.mean(np.array(val_losses))

        if val_acc_avg > max_val_acc:
            max_val_acc = val_acc_avg
            torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()},\
                       os.path.join(opt.save_path, 'best_model.pth'))
            log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} % (Best)'\
                  .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95))
        else:
            log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} %'\
                  .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95))

        torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()}\
                   , os.path.join(opt.save_path, 'last_epoch.pth'))

        if epoch % opt.save_epoch == 0:
            torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()}\
                       , os.path.join(opt.save_path, 'epoch_{}.pth'.format(epoch)))

        log(log_file_path, 'Elapsed Time: {}/{}\n'.format(timer.measure(), timer.measure(epoch / float(opt.num_epoch))))
Ejemplo n.º 5
0
def main(args):
    device = torch.device(args.device)
    ensure_path(args.save_path)

    data = Data(args.dataset, args.n_batches, args.train_way, args.test_way, args.shot, args.query)
    train_loader = data.train_loader
    val_loader = data.valid_loader

    model = Convnet(x_dim=2).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    def save_model(name):
        torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth'))
    
    trlog = dict(
        args=vars(args),
        train_loss=[],
        val_loss=[],
        train_acc=[],
        val_acc=[],
        max_acc=0.0,
    )

    timer = Timer()

    for epoch in range(1, args.max_epoch + 1):
        lr_scheduler.step()

        model.train()

        tl = Averager()
        ta = Averager()

        for i, batch in enumerate(train_loader, 1):
            data, _ = [_.to(device) for _ in batch]
            data = data.reshape(-1, 2, 105, 105)
            p = args.shot * args.train_way
            embedded = model(data)
            embedded_shot, embedded_query = embedded[:p], embedded[p:]

            proto = embedded_shot.reshape(args.shot, args.train_way, -1).mean(dim=0)

            label = torch.arange(args.train_way).repeat(args.query).to(device)

            logits = euclidean_metric(embedded_query, proto)
            loss = F.cross_entropy(logits, label)
            acc = count_acc(logits, label)
            print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'
                  .format(epoch, i, len(train_loader), loss.item(), acc))

            tl.add(loss.item())
            ta.add(acc)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        tl = tl.item()
        ta = ta.item()

        model.eval()

        vl = Averager()
        va = Averager()

        for i, batch in enumerate(val_loader, 1):
            data, _ = [_.cuda() for _ in batch]
            data = data.reshape(-1, 2, 105, 105)
            p = args.shot * args.test_way
            data_shot, data_query = data[:p], data[p:]

            proto = model(data_shot)
            proto = proto.reshape(args.shot, args.test_way, -1).mean(dim=0)

            label = torch.arange(args.test_way).repeat(args.query).to(device)

            logits = euclidean_metric(model(data_query), proto)
            loss = F.cross_entropy(logits, label)
            acc = count_acc(logits, label)

            vl.add(loss.item())
            va.add(acc)

        vl = vl.item()
        va = va.item()
        print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))

        if va > trlog['max_acc']:
            trlog['max_acc'] = va
            save_model('max-acc')

        trlog['train_loss'].append(tl)
        trlog['train_acc'].append(ta)
        trlog['val_loss'].append(vl)
        trlog['val_acc'].append(va)

        torch.save(trlog, osp.join(args.save_path, 'trlog'))

        save_model('epoch-last')

        if epoch % args.save_epoch == 0:
            save_model('epoch-{}'.format(epoch))

        print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
Ejemplo n.º 6
0
def main():
    timer = Timer()
    args, writer = init()

    train_file = args.dataset_dir + 'train.json'
    val_file = args.dataset_dir + 'val.json'

    few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot, n_query=args.n_query)
    n_episode = 10 if args.debug else 100
    if args.method_type is Method_type.baseline:
        train_datamgr = SimpleDataManager(train_file, args.dataset_dir, args.image_size, batch_size=64)
        train_loader = train_datamgr.get_data_loader(aug = True)
    else:
        train_datamgr = SetDataManager(train_file, args.dataset_dir, args.image_size,
                                       n_episode=n_episode, mode='train', **few_shot_params)
        train_loader = train_datamgr.get_data_loader(aug=True)

    val_datamgr = SetDataManager(val_file, args.dataset_dir, args.image_size,
                                     n_episode=n_episode, mode='val', **few_shot_params)
    val_loader = val_datamgr.get_data_loader(aug=False)

    if args.model_type is Model_type.ConvNet:
        pass
    elif args.model_type is Model_type.ResNet12:
        from methods.backbone import ResNet12
        encoder = ResNet12()
    else:
        raise ValueError('')

    if args.method_type is Method_type.baseline:
        from methods.baselinetrain import BaselineTrain
        model = BaselineTrain(encoder, args)
    elif args.method_type is Method_type.protonet:
        from methods.protonet import ProtoNet
        model = ProtoNet(encoder, args)
    else:
        raise ValueError('')

    from torch.optim import SGD,lr_scheduler
    if args.method_type is Method_type.baseline:
        optimizer = SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1)
    else:
        optimizer = torch.optim.SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4,
                                    nesterov=True)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)

    args.ngpu = torch.cuda.device_count()
    torch.backends.cudnn.benchmark = True
    model = model.cuda()

    label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query))
    label = label.cuda()

    if args.test:
        test(model, label, args, few_shot_params)
        return

    if args.resume:
        resume_OK =  resume_model(model, optimizer, args, scheduler)
    else:
        resume_OK = False
    if (not resume_OK) and  (args.warmup is not None):
        load_pretrained_weights(model, args)

    if args.debug:
        args.max_epoch = args.start_epoch + 1

    for epoch in range(args.start_epoch, args.max_epoch):
        train_one_epoch(model, optimizer, args, train_loader, label, writer, epoch)
        scheduler.step()

        vl, va = val(model, args, val_loader, label)
        if writer is not None:
            writer.add_scalar('data/val_acc', float(va), epoch)
        print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))

        if va >= args.max_acc:
            args.max_acc = va
            print('saving the best model! acc={:.4f}'.format(va))
            save_model(model, optimizer, args, epoch, args.max_acc, 'max_acc', scheduler)
        save_model(model, optimizer, args, epoch, args.max_acc, 'epoch-last', scheduler)
        if epoch != 0:
            print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
    if writer is not None:
        writer.close()
    test(model, label, args, few_shot_params)
Ejemplo n.º 7
0
    def train(self):
        """The function for the pre-train phase."""

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = self.Learner.epoch * len(self.train_loader)
        # Set tensorboardX
        writer = SummaryWriter(log_dir=self.Learner.pretrain_save_dir)

        # Start pretrain
        for epoch in range(self.Learner.epoch, self.max_epoch_pre + 1):
            # update learning rate
            self.Learner.scheduler_pre.step()

            # set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()

            # using tqdm to read samples from train loader
            tqdm_gen = tqdm.tqdm(self.train_loader)
            for i, batch in enumerate(tqdm_gen, 1):
                # update global count number
                global_count = global_count + 1

                if self.use_gpu:
                    batch = [_.cuda() for _ in batch]

                data, label = batch
                # set to right phase
                self.Learner.set_phase('pretrain')
                # output logits for model
                logits = self.Learner(data)
                # calculate train loss
                loss = F.cross_entropy(logits, label)
                # calculate train accuracy
                acc = count_acc(logits, label)
                # write the tensorboardX records
                writer.add_scalar('train/loss', float(loss), global_count)
                writer.add_scalar('train/acc', float(acc), global_count)
                # Print loss and accuracy for this step
                tqdm_gen.set_description('Epoch {}, Loss={:.4f} Acc={:.4f}'.format(epoch, loss.item(), acc))

                # Add loss and accuracy for the averagers
                train_loss_averager.add(loss.item())
                train_acc_averager.add(acc)

                # Loss backwards and optimizer updates
                self.Learner.optimizer_pre.zero_grad()
                loss.backward()
                self.Learner.optimizer_pre.step()

            # update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()

            # Start validation for this epoch, set model to eval mode
            self.Learner.set_phase('preval')
            # Set averager classes to record validation losses and accuracies
            val_loss_averager = Averager()
            val_acc_averager = Averager()

            # run meta-validation
            for i, batch in enumerate(self.val_loader, 1):
                data, _ = batch
                if self.use_gpu:
                    data = data.cuda()
                p = self.shot_pre * self.way_pre
                data_shot, data_query = data[:p], data[p:]

                logits = self.Learner(data_shot, self.label_shot, data_query)
                loss = F.cross_entropy(logits, self.label_query)
                acc = count_acc(logits, self.label_query)
                val_loss_averager.add(loss.item())
                val_acc_averager.add(acc)

            # update validation averagers
            val_loss_averager = val_loss_averager.item()
            val_acc_averager = val_acc_averager.item()
            # write the tensorboardX records
            writer.add_scalar('val/loss', float(val_loss_averager), epoch)
            writer.add_scalar('val/acc', float(val_acc_averager), epoch)
            # print loss and accuracy for this epoch
            print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(epoch, val_loss_averager, val_acc_averager))

            # update best saved model
            if val_acc_averager > self.train_log['max_acc']:
                self.train_log['max_acc'] = val_acc_averager
                self.train_log['max_acc_epoch'] = epoch
                self.Learner.save_pretrained_model(epoch=epoch, max_metric=val_acc_averager, is_best=True)

            # save model every 10 epochs
            if epoch % 10 == 0:
                self.Learner.save_pretrained_model(epoch=epoch, max_metric=val_acc_averager, is_best=False)

            # update the logs
            self.train_log['train_loss'].append(train_loss_averager)
            self.train_log['train_acc'].append(train_acc_averager)
            self.train_log['val_loss'].append(val_loss_averager)
            self.train_log['val_acc'].append(val_acc_averager)

            # Print previous information
            if epoch % 10 == 0:
                print('Best Epoch {}, Best Val acc={:.4f}'.format(self.train_log['max_acc_epoch'],
                                                                  self.train_log['max_acc']))

            if epoch % 10 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(timer.measure(),
                                                                    timer.measure(epoch / self.max_epoch_pre)))
        writer.close()
Ejemplo n.º 8
0
                'classifier': classifier.state_dict()
            }, os.path.join(save_path, '1_stage_best_model.pth'))
        log(log_file_path, 'Best model saving!!!')
        log(
            log_file_path,
            'Val_Epoch: [{}/{}]\tAccuracyBoth: {:.2f} +- {:.2f} %\tAccuracyBase: {:.2f} +- {:.2f} %\tAccuracyNovel: {:.2f} +- {:.2f} %'
            .format(epoch, params.num_epoch, val_acc_both, val_acc_both_ci95,
                    val_acc_base, val_acc_base_ci95, val_acc_novel,
                    val_acc_novel_ci95))
    else:
        log(
            log_file_path,
            'Val_Epoch: [{}/{}]\tAccuracyBoth: {:.2f} +- {:.2f} %\tAccuracyBase: {:.2f} +- {:.2f} %\tAccuracyNovel: {:.2f} +- {:.2f} %'
            .format(epoch, params.num_epoch, val_acc_both, val_acc_both_ci95,
                    val_acc_base, val_acc_base_ci95, val_acc_novel,
                    val_acc_novel_ci95))

    torch.save(
        {
            'embedding': embedding_model.state_dict(),
            'classifier': classifier.state_dict()
        }, os.path.join(save_path, '1_stage_last_model.pth'))

    log(
        log_file_path, 'Elapsed Time: {}/{}\n'.format(
            timer.measure(), timer.measure(epoch / float(params.num_epoch))))

log(
    log_file_path,
    'Best model saving!!!\tBest_Epoch: [{}/{}]\tAccuracyNovel: {:.2f} %'.
    format(max_val_epoch, params.num_epoch, max_val_acc))