コード例 #1
0
    def checkAverager(self):
        acc = utils.Averager()
        acc.add(Variable(torch.Tensor([1, 2])))
        acc.add(Variable(torch.Tensor([[5, 6]])))
        assert acc.val() == 3.5

        acc = utils.Averager()
        acc.add(torch.Tensor([1, 2]))
        acc.add(torch.Tensor([[5, 6]]))
        assert acc.val() == 3.5
コード例 #2
0
ファイル: train_liif.py プロジェクト: zt706/liif
def train(train_loader, model, optimizer):
    model.train()
    loss_fn = nn.L1Loss()
    train_loss = utils.Averager()

    data_norm = config['data_norm']
    t = data_norm['inp']
    inp_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda()
    inp_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda()
    t = data_norm['gt']
    gt_sub = torch.FloatTensor(t['sub']).view(1, 1, -1).cuda()
    gt_div = torch.FloatTensor(t['div']).view(1, 1, -1).cuda()

    for batch in tqdm(train_loader, leave=False, desc='train'):
        for k, v in batch.items():
            batch[k] = v.cuda()

        inp = (batch['inp'] - inp_sub) / inp_div
        pred = model(inp, batch['coord'], batch['cell'])

        gt = (batch['gt'] - gt_sub) / gt_div
        loss = loss_fn(pred, gt)

        train_loss.add(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pred = None
        loss = None

    return train_loss.item()
コード例 #3
0
def train(train_loader, model, optimizer):
    model.train()
    loss_fn = nn.L1Loss()
    train_loss = utils.Averager()

    inp_data_norm = train_args.inp_data_norm
    inp_sub, inp_div = list(map(float, inp_data_norm.split(',')))
    inp_sub = torch.FloatTensor([inp_sub]).view(1, -1, 1, 1).cuda()
    inp_div = torch.FloatTensor([inp_div]).view(1, -1, 1, 1).cuda()
    gt_data_norm = train_args.gt_data_norm
    gt_sub, gt_div = list(map(float, gt_data_norm.split(',')))
    gt_sub = torch.FloatTensor([gt_sub]).view(1, 1, -1).cuda()
    gt_div = torch.FloatTensor([gt_div]).view(1, 1, -1).cuda()

    for batch in tqdm(train_loader, leave=False, desc='train'):
        for k, v in batch.items():
            batch[k] = v.cuda()

        inp = (batch['inp'] - inp_sub) / inp_div
        pred = model(inp, batch['coord'], batch['cell'])

        gt = (batch['gt'] - gt_sub) / gt_div
        loss = loss_fn(pred, gt)

        train_loss.add(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pred = None; loss = None

    return train_loss.item()
コード例 #4
0
ファイル: train.py プロジェクト: wondervictor/dpl.pytorch
def test(net, criterion, output_dir):
    # output_dir = 'devkit/results/VOC2012/Main/comp2_cls_val_xxxx.txt'
    net.eval()
    test_iter = iter(test_loader)
    test_averager = utils.Averager()
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    i = 0
    while i < len(train_loader):
        img, lbl, box = train_iter.next()
        load_data(images, img)
        load_data(labels, lbl)
        boxes = Variable(torch.FloatTensor(box)).cuda()
        output = net(images, boxes).squeeze(0)
        loss = criterion(output, labels)
        test_averager.add(loss)

        for m in xrange(opt.num_class):
            cls_file = os.path.join(
                output_dir, 'cls_val_' + val_dataset.classes[m] + '.txt')
            with open(cls_file, 'a') as f:
                f.write(val_dataset.image_index[i] + ' ' + str(output[m]) +
                        '\n')

            print 'im_cls: {:d}/{:d}: {}'.format(i + 1, len(train_loader),
                                                 val_dataset.image_index[i])

    val_dataset.do_python_eval(output_dir)
コード例 #5
0
def test(net, criterion, output_dir):
    # output_dir = 'devkit/results/VOC2012/Main/comp2_cls_val_xxxx.txt'
    test_iter = iter(test_loader)
    test_averager = utils.Averager()
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    i = 0
    while i < len(test_loader):
        img, lbl, box, shapes = test_iter.next()
        load_data(images, img)
        load_data(labels, lbl)
        boxes = Variable(torch.FloatTensor(box)).cuda()
        shapes = Variable(torch.FloatTensor(shapes)).cuda()
        cls_score1, cls_score2, _ = net(images, shapes, boxes)
        loss1 = criterion(cls_score1, labels)
        loss2 = criterion(cls_score2, labels)
        loss = loss1 + loss2
        test_averager.add(loss)
        cls_score = cls_score1 + cls_score2
        cls_score = cls_score.cpu().squeeze(0).data.numpy()
        for m in xrange(opt.num_class):
            cls_file = os.path.join(
                output_dir, 'comp2_cls_{}_'.format('val') +
                val_dataset.classes[m] + '.txt')
            with open(cls_file, 'a') as f:
                f.write(val_dataset.image_index[i] + ' ' + str(cls_score[m]) +
                        '\n')

        print 'im_cls: {:d}/{:d}: {}'.format(i + 1, len(test_loader),
                                             val_dataset.image_index[i])
        i = i + 1
    print("Avg Loss: {}".format(test_averager.val()))
コード例 #6
0
ファイル: training.py プロジェクト: jrieke/learning-algos
def evaluate(net, evaluation_data, params):
    """Evaluate the network on the evaluation_data."""
    averager = utils.Averager()
    for i_sample, (x, y_true) in enumerate(evaluation_data):
        x = x.flatten()
        y_true = y_true.flatten()
        y_pred = net.forward(x, params)

        true_label = y_true[
            0]  # target values in validation_data are labels (in training_data they are one-hot vectors)
        averager.add('loss', cross_entropy(y_pred, true_label))
        averager.add('acc', true_label == y_pred.argmax())

    print('Eval set average:\t', averager)
コード例 #7
0
def val(net, _dataset, criterion, max_iter=100):

    for p in crnn.parameters():
        p.requires_grad = False

    net.eval()
    data_loader = torch.utils.data.DataLoader(
        _dataset,
        shuffle=True,
        batch_size=opt.batch_size,
        num_workers=int(opt.workers),
        collate_fn=dataset.AlignCollate(img_height=opt.imgH, img_weight=opt.imgW, keep_ratio=opt.keep_ratio)
    )
    val_iter = iter(data_loader)

    i = 0
    n_correct = 0
    loss_avg = utils.Averager()

    max_iter = min(max_iter, len(data_loader))
    for i in range(max_iter):
        data = val_iter.next()
        i += 1
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.load_data(image, cpu_images)
        t, l = converter.encode(cpu_texts)
        utils.load_data(text, t)
        utils.load_data(length, l)

        preds, _ = crnn(image)
        preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, text, preds_size, length) / batch_size
        loss_avg.add(cost)

        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
        for pred, target in zip(sim_preds, cpu_texts):
            if pred == target.lower():
                n_correct += 1

    raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:opt.n_test_disp]
    # for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts):
    #     print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))

    accuracy = n_correct / float(max_iter * opt.batch_size)
    logger.log('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
コード例 #8
0
    def validate(self, logger):
        if self.val_dataloader is None:
            return 0
        logger.info('Start validate.')
        losses = utils.Averager()
        self.net.eval()
        n_correct = 0
        with torch.no_grad():
            for i, (imgs, labels) in enumerate(self.val_dataloader):
                batch_size = imgs.size()[0]
                imgs = imgs.cuda()
                preds = self.net(imgs).cpu()
                text, length = self.converter.encode(
                    labels
                )  # length  一个batch各个样本的字符长度, text 一个batch中所有中文字符所对应的下标
                preds_size = torch.IntTensor(
                    [preds.size(0)] * batch_size)  # timestep * batchsize
                loss_avg = self.criterion(preds, text, preds_size,
                                          length) / batch_size

                losses.update(loss_avg.item(), batch_size)

                _, preds = preds.max(2)
                preds = preds.transpose(1, 0).contiguous().view(-1)
                sim_preds = self.converter.decode(preds.data,
                                                  preds_size.data,
                                                  raw=False)
                for pred, target in zip(sim_preds, labels):
                    if pred == target:
                        n_correct += 1

        accuracy = n_correct / float(losses.count)

        logger.info(
            'Evaling loss: {:.3f}, accuracy: {:.3f}, [#correct:{} / #total:{}]'
            .format(losses.val(), accuracy, n_correct, losses.count))

        return accuracy
コード例 #9
0
ファイル: training.py プロジェクト: jrieke/learning-algos
def train_epoch(net, training_data, params):
    """Train the network for one epoch on the training_data."""
    random.shuffle(training_data)
    averager = utils.Averager()
    start_time = time.time()
    for i_sample, (x, y_true) in enumerate(training_data):
        x = x.flatten()
        y_true = y_true.flatten()

        y_pred = net.update(x, y_true, params, averager)

        true_label = y_true.argmax(
        )  # target values in training_data they are one-hot vectors (in validation_data they are labels)
        averager.add('loss', cross_entropy(y_pred, true_label))
        averager.add('acc', true_label == y_pred.argmax())

        if i_sample % 1000 == 0:
            print(
                f'{i_sample} / {len(training_data)} samples ({time.time()-start_time:.0f} s) - {averager}'
            )

    print('Took {:.0f} seconds'.format(time.time() - start_time))
    print('Train set average:\t', averager)
コード例 #10
0
ファイル: train_moco.py プロジェクト: tce/Bongard-LOGO
def main(config):
    svname = args.name
    if svname is None:
        svname = 'moco_{}'.format(config['train_dataset'])
        svname += '_' + config['model_args']['encoder']
        out_dim = config['model_args']['encoder_args']['out_dim']
        svname += '-out_dim' + str(out_dim)
    svname += '-seed' + str(args.seed)
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join(args.save_dir, svname)
    utils.ensure_path(save_path, remove=False)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    random_state = np.random.RandomState(args.seed)
    print('seed:', args.seed)

    logger = utils.Logger(file_name=os.path.join(save_path, "log_sdout.txt"),
                          file_mode="a+",
                          should_flush=True)

    #### Dataset ####

    # train
    train_dataset = datasets.make(config['train_dataset'],
                                  **config['train_dataset_args'])
    train_loader = DataLoader(train_dataset,
                              config['batch_size'],
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True,
                              drop_last=True)
    utils.log('train dataset: {} (x{})'.format(train_dataset[0][0][0].shape,
                                               len(train_dataset)))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(train_dataset, 'train_dataset', writer)

    # val
    if config.get('val_dataset'):
        eval_val = True
        val_dataset = datasets.make(config['val_dataset'],
                                    **config['val_dataset_args'])
        val_loader = DataLoader(val_dataset,
                                config['batch_size'],
                                num_workers=8,
                                pin_memory=True,
                                drop_last=True)
        utils.log('val dataset: {} (x{})'.format(val_dataset[0][0][0].shape,
                                                 len(val_dataset)))
        if config.get('visualize_datasets'):
            utils.visualize_dataset(val_dataset, 'val_dataset', writer)
    else:
        eval_val = False

    # few-shot eval
    if config.get('eval_fs'):
        ef_epoch = config.get('eval_fs_epoch')
        if ef_epoch is None:
            ef_epoch = 5
        eval_fs = True
        n_way = 2
        n_query = 1
        n_shot = 6

        if config.get('ep_per_batch') is not None:
            ep_per_batch = config['ep_per_batch']
        else:
            ep_per_batch = 1

        # tvals
        fs_loaders = {}
        tval_name_ntasks_dict = {
            'tval': 2000,
            'tval_ff': 600,
            'tval_bd': 480,
            'tval_hd_comb': 400,
            'tval_hd_novel': 320
        }  # numbers depend on dataset
        for tval_type in tval_name_ntasks_dict.keys():
            if config.get('{}_dataset'.format(tval_type)):
                tval_dataset = datasets.make(
                    config['{}_dataset'.format(tval_type)],
                    **config['{}_dataset_args'.format(tval_type)])
                utils.log('{} dataset: {} (x{})'.format(
                    tval_type, tval_dataset[0][0][0].shape, len(tval_dataset)))
                if config.get('visualize_datasets'):
                    utils.visualize_dataset(tval_dataset, 'tval_ff_dataset',
                                            writer)
                tval_sampler = BongardSampler(
                    tval_dataset.n_tasks,
                    n_batch=tval_name_ntasks_dict[tval_type] // ep_per_batch,
                    ep_per_batch=ep_per_batch,
                    seed=random_state.randint(2**31))
                tval_loader = DataLoader(tval_dataset,
                                         batch_sampler=tval_sampler,
                                         num_workers=8,
                                         pin_memory=True)
                fs_loaders.update({tval_type: tval_loader})
            else:
                fs_loaders.update({tval_type: None})

    else:
        eval_fs = False

    ########

    #### Model and Optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

    if eval_fs:
        fs_model = models.make('meta-baseline', encoder=None)
        fs_model.encoder = model.encoder

    if config.get('_parallel'):
        model = nn.DataParallel(model)
        if eval_fs:
            fs_model = nn.DataParallel(fs_model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(model.parameters(),
                                                   config['optimizer'],
                                                   **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    for epoch in range(1, max_epoch + 1 + 1):

        timer_epoch.s()
        aves_keys = ['tl', 'ta', 'vl', 'va', 'tvl', 'tva']
        if eval_fs:
            for k, v in fs_loaders.items():
                if v is not None:
                    aves_keys += ['fsa' + k.split('tval')[-1]]
        aves = {ave_k: utils.Averager() for ave_k in aves_keys}

        # train
        model.train()
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        for data, _ in tqdm(train_loader, desc='train', leave=False):
            logits, label = model(im_q=data[0].cuda(), im_k=data[1].cuda())

            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None
            loss = None

        # val
        if eval_val:
            model.eval()
            for data, _ in tqdm(val_loader, desc='val', leave=False):
                with torch.no_grad():
                    logits, label = model(im_q=data[0].cuda(),
                                          im_k=data[1].cuda())
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)

                aves['vl'].add(loss.item())
                aves['va'].add(acc)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            fs_model.eval()
            for k, v in fs_loaders.items():
                if v is not None:
                    ave_key = 'fsa' + k.split('tval')[-1]
                    np.random.seed(0)
                    for data, _ in tqdm(v, desc=ave_key, leave=False):
                        x_shot, x_query = fs.split_shot_query(
                            data[0].cuda(),
                            n_way,
                            n_shot,
                            n_query,
                            ep_per_batch=ep_per_batch)
                        label_query = fs.make_nk_label(
                            n_way, n_query, ep_per_batch=ep_per_batch).cuda()
                        with torch.no_grad():
                            logits = fs_model(x_shot, x_query).view(-1, n_way)
                            acc = utils.compute_acc(logits, label_query)
                        aves[ave_key].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)

        if epoch <= max_epoch:
            epoch_str = str(epoch)
        else:
            epoch_str = 'ex'
        log_str = 'epoch {}, train {:.4f}|{:.4f}'.format(
            epoch_str, aves['tl'], aves['ta'])
        writer.add_scalars('loss', {'train': aves['tl']}, epoch)
        writer.add_scalars('acc', {'train': aves['ta']}, epoch)

        if eval_val:
            log_str += ', val {:.4f}|{:.4f}, tval {:.4f}|{:.4f}'.format(
                aves['vl'], aves['va'], aves['tvl'], aves['tva'])
            writer.add_scalars('loss', {'val': aves['vl']}, epoch)
            writer.add_scalars('loss', {'tval': aves['tvl']}, epoch)
            writer.add_scalars('acc', {'val': aves['va']}, epoch)
            writer.add_scalars('acc', {'tval': aves['tva']}, epoch)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            log_str += ', fs'
            for ave_key in aves_keys:
                if 'fsa' in ave_key:
                    log_str += ' {}: {:.4f}'.format(ave_key, aves[ave_key])
                    writer.add_scalars('acc', {ave_key: aves[ave_key]}, epoch)

        if epoch <= max_epoch:
            log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        else:
            log_str += ', {}'.format(t_epoch)
        utils.log(log_str)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,
            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),
            'training': training,
        }
        if epoch <= max_epoch:
            torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))

            if (save_epoch is not None) and epoch % save_epoch == 0:
                torch.save(
                    save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

            if aves['va'] > max_va:
                max_va = aves['va']
                torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))
        else:
            torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth'))

        writer.flush()

    print('finished training!')
    logger.close()
コード例 #11
0
log_dir = expr_dir + 'log/'
if not os.path.exists(log_dir):
    os.mkdir(log_dir)

logger = utils.Logger(stdio=True, log_file=log_dir + "testing.log")
images = Variable(torch.FloatTensor(batch_size, 3, opt.img_size, opt.img_size))
labels = Variable(torch.FloatTensor(batch_size, opt.num_class))

if opt.cuda:
    criterion = criterion.cuda()
    dpl = dpl.cuda()
    images = images.cuda()
    labels = labels.cuda()

averager = utils.Averager()


def load_data(v, data):
    v.data.resize_(data.size()).copy_(data)


def test(net, criterion, output_dir):
    # output_dir = 'devkit/results/VOC2012/Main/comp2_cls_val_xxxx.txt'
    test_iter = iter(test_loader)
    test_averager = utils.Averager()
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    i = 0
    while i < len(test_loader):
コード例 #12
0
image = torch.FloatTensor(opt.batch_size, 3, opt.imgH, opt.imgH)
text = torch.IntTensor(opt.batch_size * 5)
length = torch.IntTensor(opt.batch_size)

if opt.cuda:
    crnn.cuda()
    crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
    image = image.cuda()
    criterion = criterion.cuda()

image = Variable(image)
text = Variable(text)
length = Variable(length)

# loss averager
loss_avg = utils.Averager()

# setup optimizer
if opt.adam:
    optimizer = optim.Adam(crnn.parameters(), lr=opt.lr,
                           betas=(opt.beta1, 0.999))
elif opt.adadelta:
    optimizer = optim.Adadelta(crnn.parameters(), lr=opt.lr)
else:
    optimizer = optim.RMSprop(crnn.parameters(), lr=opt.lr)


def val(net, _dataset, criterion, max_iter=100):

    for p in crnn.parameters():
        p.requires_grad = False
コード例 #13
0
def main(config):
    svname = config.get('sv_name')
    if args.tag is not None:
        svname += '_' + args.tag
    config['sv_name'] = svname
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    utils.log(svname)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))
    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    n_way, n_shot = config['n_way'], config['n_shot']
    n_query = config['n_query']
    n_pseudo = config['n_pseudo']
    ep_per_batch = config['ep_per_batch']

    if config.get('test_batches') is not None:
        test_batches = config['test_batches']
    else:
        test_batches = config['train_batches']

    for s in ['train', 'val', 'tval']:
        if config.get(f"{s}_dataset_args") is not None:
            config[f"{s}_dataset_args"]['data_dir'] = os.path.join(os.getcwd(), os.pardir, 'data_root')

    # train
    train_dataset = CustomDataset(config['train_dataset'], save_dir=config.get('load_encoder'),
                                  **config['train_dataset_args'])

    if config['train_dataset_args']['split'] == 'helper':
        with open(os.path.join(save_path, 'train_helper_cls.pkl'), 'wb') as f:
            pkl.dump(train_dataset.dataset_classes, f)

    train_sampler = EpisodicSampler(train_dataset, config['train_batches'], n_way, n_shot, n_query,
                                    n_pseudo, episodes_per_batch=ep_per_batch)
    train_loader = DataLoader(train_dataset, batch_sampler=train_sampler,
                                  num_workers=4, pin_memory=True)

    # tval
    if config.get('tval_dataset'):
        tval_dataset = CustomDataset(config['tval_dataset'],
                                     **config['tval_dataset_args'])

        tval_sampler = EpisodicSampler(tval_dataset, test_batches, n_way, n_shot, n_query,
                                       n_pseudo, episodes_per_batch=ep_per_batch)
        tval_loader = DataLoader(tval_dataset, batch_sampler=tval_sampler,
                                 num_workers=4, pin_memory=True)
    else:
        tval_loader = None

    # val
    val_dataset = CustomDataset(config['val_dataset'],
                                **config['val_dataset_args'])
    val_sampler = EpisodicSampler(val_dataset, test_batches, n_way, n_shot, n_query,
                                  n_pseudo, episodes_per_batch=ep_per_batch)
    val_loader = DataLoader(val_dataset, batch_sampler=val_sampler,
                            num_workers=4, pin_memory=True)


    #### Model and optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])
        if config.get('load_encoder'):
            encoder = models.load(torch.load(config['load_encoder'])).encoder
            model.encoder.load_state_dict(encoder.state_dict())
            if config.get('freeze_encoder'):
                for param in model.encoder.parameters():
                    param.requires_grad = False

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(
        model.parameters(),
        config['optimizer'], **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    aves_keys = ['tl', 'ta', 'tvl', 'tva', 'vl', 'va']
    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        if config.get('freeze_bn'):
            utils.freeze_bn(model)
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
        np.random.seed(epoch)

        for data in tqdm(train_loader, desc='train', leave=False):
            x_shot, x_query, x_pseudo = fs.split_shot_query(
                data.cuda(), n_way, n_shot, n_query, n_pseudo,
                ep_per_batch=ep_per_batch)
            label = fs.make_nk_label(n_way, n_query,
                                     ep_per_batch=ep_per_batch).cuda()

            logits = model(x_shot, x_query, x_pseudo)
            logits = logits.view(-1, n_way)
            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None; loss = None

            # eval
        model.eval()
        for name, loader, name_l, name_a in [
            ('tval', tval_loader, 'tvl', 'tva'),
            ('val', val_loader, 'vl', 'va')]:

            if (config.get('tval_dataset') is None) and name == 'tval':
                continue

            np.random.seed(0)
            for data in tqdm(loader, desc=name, leave=False):
                x_shot, x_query, x_pseudo = fs.split_shot_query(
                    data.cuda(), n_way, n_shot, n_query, n_pseudo,
                    ep_per_batch=ep_per_batch)
                label = fs.make_nk_label(n_way, n_query,
                                         ep_per_batch=ep_per_batch).cuda()

                with torch.no_grad():
                    logits = model(x_shot, x_query, x_pseudo)
                    logits = logits.view(-1, n_way)
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)

                aves[name_l].add(loss.item())
                aves[name_a].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()
            trlog[k].append(aves[k])

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)
        utils.log('epoch {}, train {:.4f}|{:.4f}, tval {:.4f}|{:.4f}, '
                  'val {:.4f}|{:.4f}, {} {}/{}'.format(
            epoch, aves['tl'], aves['ta'], aves['tvl'], aves['tva'],
            aves['vl'], aves['va'], t_epoch, t_used, t_estimate))

        writer.add_scalars('loss', {
            'train': aves['tl'],
            'tval': aves['tvl'],
            'val': aves['vl'],
        }, epoch)
        writer.add_scalars('acc', {
            'train': aves['ta'],
            'tval': aves['tva'],
            'val': aves['va'],
        }, epoch)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,

            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),

            'training': training,
        }
        torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(save_path, 'trlog.pth'))

        if (save_epoch is not None) and epoch % save_epoch == 0:
            torch.save(save_obj,
                       os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))

        writer.flush()
コード例 #14
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'meta'
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    if args.dataset == 'all':
        train_lst = ['ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd',
                'quickdraw', 'fungi', 'vgg_flower']
        eval_lst = ['ilsvrc_2012']
    else:
        train_lst = [args.dataset]
        eval_lst = [args.dataset]

    if config.get('no_train') == True:
        train_iter = None
    else:
        trainset = make_md(train_lst, 'episodic', split='train', image_size=126)
        train_iter = trainset.make_one_shot_iterator().get_next()

    if config.get('no_val') == True:
        val_iter = None
    else:
        valset = make_md(eval_lst, 'episodic', split='val', image_size=126)
        val_iter = valset.make_one_shot_iterator().get_next()

    testset = make_md(eval_lst, 'episodic', split='test', image_size=126)
    test_iter = testset.make_one_shot_iterator().get_next()

    sess = tf.Session()

    ########

    #### Model and optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

        if config.get('load_encoder'):
            encoder = models.load(torch.load(config['load_encoder'])).encoder
            model.encoder.load_state_dict(encoder.state_dict())

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(
            model.parameters(),
            config['optimizer'], **config['optimizer_args'])

    ########
    
    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    aves_keys = ['tl', 'ta', 'tvl', 'tva', 'vl', 'va']
    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    def process_data(e):
        e = list(e[0])
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(146),
            transforms.CenterCrop(128),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
        ])
        for ii in [0, 3]:
            e[ii] = ((e[ii] + 1.0) * 0.5 * 255).astype('uint8')
            tmp = torch.zeros(len(e[ii]), 3, 128, 128).float()
            for i in range(len(e[ii])):
                tmp[i] = transform(e[ii][i])
            e[ii] = tmp.cuda()

        e[1] = torch.from_numpy(e[1]).long().cuda()
        e[4] = torch.from_numpy(e[4]).long().cuda()

        return e

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        if config.get('freeze_bn'):
            utils.freeze_bn(model) 
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        if config.get('no_train') == True:
            pass
        else:
            for i_ep in tqdm(range(config['n_train'])):

                e = process_data(sess.run(train_iter))
                loss, acc = model(e[0], e[1], e[3], e[4])

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                aves['tl'].add(loss.item())
                aves['ta'].add(acc)

                loss = None 

        # eval
        model.eval()

        for name, ds_iter, name_l, name_a in [
                ('tval', val_iter, 'tvl', 'tva'),
                ('val', test_iter, 'vl', 'va')]:
            if config.get('no_val') == True and name == 'tval':
                continue

            for i_ep in tqdm(range(config['n_eval'])):

                e = process_data(sess.run(ds_iter))

                with torch.no_grad():
                    loss, acc = model(e[0], e[1], e[3], e[4])
                
                aves[name_l].add(loss.item())
                aves[name_a].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()
            trlog[k].append(aves[k])

        _sig = 0

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)
        utils.log('epoch {}, train {:.4f}|{:.4f}, tval {:.4f}|{:.4f}, '
                'val {:.4f}|{:.4f}, {} {}/{} (@{})'.format(
                epoch, aves['tl'], aves['ta'], aves['tvl'], aves['tva'],
                aves['vl'], aves['va'], t_epoch, t_used, t_estimate, _sig))

        writer.add_scalars('loss', {
            'train': aves['tl'],
            'tval': aves['tvl'],
            'val': aves['vl'],
        }, epoch)
        writer.add_scalars('acc', {
            'train': aves['ta'],
            'tval': aves['tva'],
            'val': aves['va'],
        }, epoch)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,

            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),

            'training': training,
        }
        torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(save_path, 'trlog.pth'))

        if (save_epoch is not None) and epoch % save_epoch == 0:
            torch.save(save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))

        writer.flush()
コード例 #15
0
ファイル: train_classifier.py プロジェクト: apu6/pseudo-shots
def main(config):
    svname = args.name
    if svname is None:
        svname = f"classifier-{config['train_dataset']}-{config['model_args']['encoder']}"
        clsfr = config['model_args']['classifier']
        if clsfr != 'linear-classifier':
            svname += '-' + clsfr

    svname += '-aux' + str(args.aux_level)

    if args.topk is not None:
        svname += f"-top{args.topk}"

    if args.tag is not None:
        svname += '_' + args.tag

    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    for s in ['train', 'val', 'tval', 'fs', 'fs_val']:
        if config.get(f"{s}_dataset_args") is not None:
            config[f"{s}_dataset_args"]['data_dir'] = os.path.join(
                os.getcwd(), os.pardir, 'data_root')

    # train
    train_dataset = TrainDataset(name=config['train_dataset'],
                                 **config['train_dataset_args'])
    train_loader = DataLoader(train_dataset,
                              config['batch_size'],
                              shuffle=True,
                              num_workers=16,
                              pin_memory=True,
                              drop_last=True)

    with open(os.path.join(save_path, 'training_classes.pkl'), 'wb') as f:
        pkl.dump(train_dataset.separated_training_classes, f)

    # val
    if config.get('val_dataset'):
        eval_val = True
        val_dataset = TrainDataset(config['val_dataset'],
                                   **config['val_dataset_args'])
        val_loader = DataLoader(val_dataset,
                                config['batch_size'],
                                num_workers=16,
                                pin_memory=True,
                                drop_last=True)
    else:
        eval_val = False

    # few-shot eval
    fs_loaders = {'fs_dataset': list(), 'fs_val_dataset': list()}
    for key in fs_loaders.keys():
        if config.get(key):
            ef_epoch = config.get('eval_fs_epoch')
            if ef_epoch is None:
                ef_epoch = 5
            eval_fs = True

            fs_dataset = CustomDataset(config[key], **config[key + '_args'])

            n_way = config['n_way'] if config.get('n_way') else 5
            n_query = config['n_query'] if config.get('n_query') else 15
            if config.get('n_pseudo') is not None:
                n_pseudo = config['n_pseudo']
            else:
                n_pseudo = 15
            n_batches = config['n_batches'] if config.get('n_batches') else 200
            ep_per_batch = config['ep_per_batch'] if config.get(
                'ep_per_batch') else 4
            n_shots = [1, 5]
            for n_shot in n_shots:
                fs_sampler = EpisodicSampler(fs_dataset,
                                             n_batches,
                                             n_way,
                                             n_shot,
                                             n_query,
                                             n_pseudo,
                                             episodes_per_batch=ep_per_batch)
                fs_loader = DataLoader(fs_dataset,
                                       batch_sampler=fs_sampler,
                                       num_workers=16,
                                       pin_memory=True)
                fs_loaders[key].append(fs_loader)
        else:
            eval_fs = False

    eval_fs = False
    for key in fs_loaders.keys():
        if config.get(key):
            eval_fs = True

    #### Model and Optimizer ####

    config['model_args']['classifier_args'][
        'n_classes'] = train_dataset.n_classes
    model = models.make(config['model'], **config['model_args'])

    if eval_fs:
        fs_model = models.make('meta-baseline', encoder=None)
        fs_model.encoder = model.encoder

    if config.get('_parallel'):
        model = nn.DataParallel(model)
        if eval_fs:
            fs_model = nn.DataParallel(fs_model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(model.parameters(),
                                                   config['optimizer'],
                                                   **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    for epoch in range(1, max_epoch + 1 + 1):
        timer_epoch.s()
        aves_keys = ['tl', 'ta', 'vl', 'va']
        if eval_fs:
            for n_shot in n_shots:
                aves_keys += ['fsa-' + str(n_shot)]
                if config.get('fs_val_dataset'):
                    aves_keys += ['fsav-' + str(n_shot)]
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        for data, label in tqdm(train_loader, desc='train', leave=False):
            data, label = data.cuda(), label.cuda()
            logits = model(data)
            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None
            loss = None

        # eval
        if eval_val:
            model.eval()
            for data, label in tqdm(val_loader, desc='val', leave=False):
                data, label = data.cuda(), label.cuda()
                with torch.no_grad():
                    logits = model(data)
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)

                aves['vl'].add(loss.item())
                aves['va'].add(acc)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            fs_model.eval()
            for key in fs_loaders.keys():
                if len(fs_loaders[key]) == 0:
                    continue
                tag = 'v' if key == 'fs_val_dataset' else ''
                for i, n_shot in enumerate(n_shots):
                    np.random.seed(0)
                    for data in tqdm(fs_loaders[key][i],
                                     desc='fs' + tag + '-' + str(n_shot),
                                     leave=False):
                        x_shot, x_query, x_pseudo = fs.split_shot_query(
                            data.cuda(),
                            n_way,
                            n_shot,
                            n_query,
                            pseudo=n_pseudo,
                            ep_per_batch=ep_per_batch)
                        label = fs.make_nk_label(
                            n_way, n_query, ep_per_batch=ep_per_batch).cuda()
                        with torch.no_grad():
                            logits = fs_model(x_shot, x_query, x_pseudo)
                            logits = logits.view(-1, n_way)
                            acc = utils.compute_acc(logits, label)
                        aves['fsa' + tag + '-' + str(n_shot)].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)

        if epoch <= max_epoch:
            epoch_str = str(epoch)
        else:
            epoch_str = 'ex'
        log_str = 'epoch {}, train {:.4f}|{:.4f}'.format(
            epoch_str, aves['tl'], aves['ta'])
        writer.add_scalars('loss', {'train': aves['tl']}, epoch)
        writer.add_scalars('acc', {'train': aves['ta']}, epoch)

        if eval_val:
            log_str += ', val {:.4f}|{:.4f}'.format(aves['vl'], aves['va'])
            writer.add_scalars('loss', {'val': aves['vl']}, epoch)
            writer.add_scalars('acc', {'val': aves['va']}, epoch)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            for key in fs_loaders.keys():
                if len(fs_loaders[key]) == 0:
                    continue
                tag = 'v' if key == 'fs_val_dataset' else ''
                log_str += ', fs' + tag
                for n_shot in n_shots:
                    key = 'fsa' + tag + '-' + str(n_shot)
                    log_str += ' {}: {:.4f}'.format(n_shot, aves[key])
                    writer.add_scalars('acc', {key: aves[key]}, epoch)

        if epoch <= max_epoch:
            log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        else:
            log_str += ', {}'.format(t_epoch)
        utils.log(log_str)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,
            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),
            'training': training,
        }
        if epoch <= max_epoch:
            torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))

            if (save_epoch is not None) and epoch % save_epoch == 0:
                torch.save(
                    save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

            if aves['va'] > max_va:
                max_va = aves['va']
                torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))
        else:
            torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth'))

        writer.flush()
コード例 #16
0
ファイル: train.py プロジェクト: cybworkspace/video-repr
def main(config):
    # Environment setup
    save_dir = config['save_dir']
    utils.ensure_path(save_dir)
    with open(osp.join(save_dir, 'config.yaml'), 'w') as f:
        yaml.dump(config, f, sort_keys=False)
    global log, writer
    logger = set_logger(osp.join(save_dir, 'log.txt'))
    log = logger.info
    writer = SummaryWriter(osp.join(save_dir, 'tensorboard'))

    os.environ['WANDB_NAME'] = config['exp_name']
    os.environ['WANDB_DIR'] = config['save_dir']
    if not config.get('wandb_upload', False):
        os.environ['WANDB_MODE'] = 'dryrun'
    t = config['wandb']
    os.environ['WANDB_API_KEY'] = t['api_key']
    wandb.init(project=t['project'], entity=t['entity'], config=config)

    log('logging init done.')
    log(f'wandb id: {wandb.run.id}')

    # Dataset, model and optimizer
    train_dataset = datasets.make((config['train_dataset']))
    test_dataset = datasets.make((config['test_dataset']))

    model = models.make(config['model'], args=None).cuda()
    log(f'model #params: {utils.compute_num_params(model)}')

    n_gpus = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
    if n_gpus > 1:
        model = nn.DataParallel(model)

    optimizer = utils.make_optimizer(model.parameters(), config['optimizer'])

    train_loader = DataLoader(train_dataset, config['batch_size'], shuffle=True,
                              num_workers=8, pin_memory=True)
    test_loader = DataLoader(test_dataset, config['batch_size'],
                             num_workers=8, pin_memory=True)

    # Ready for training
    max_epoch = config['max_epoch']
    n_milestones = config.get('n_milestones', 1)
    milestone_epoch = max_epoch // n_milestones
    min_test_loss = 1e18

    sample_batch_train = sample_data_batch(train_dataset).cuda()
    sample_batch_test = sample_data_batch(test_dataset).cuda()

    epoch_timer = utils.EpochTimer(max_epoch)
    for epoch in range(1, max_epoch + 1):
        log_text = f'epoch {epoch}'

        # Train
        model.train()

        adjust_lr(optimizer, epoch, max_epoch, config)
        log_temp_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        ave_scalars = {k: utils.Averager() for k in ['loss']}

        pbar = tqdm(train_loader, desc='train', leave=False)
        for data in pbar:
            data = data.cuda()
            t = train_step(model, data, data, optimizer)
            for k, v in t.items():
                ave_scalars[k].add(v, len(data))
            pbar.set_description(desc=f"train loss:{t['loss']:.4f}")

        log_text += ', train:'
        for k, v in ave_scalars.items():
            v = v.item()
            log_text += f' {k}={v:.4f}'
            log_temp_scalar('train/' + k, v, epoch)

        # Test
        model.eval()

        ave_scalars = {k: utils.Averager() for k in ['loss']}

        pbar = tqdm(test_loader, desc='test', leave=False)
        for data in pbar:
            data = data.cuda()
            t = eval_step(model, data, data)
            for k, v in t.items():
                ave_scalars[k].add(v, len(data))
            pbar.set_description(desc=f"test loss:{t['loss']:.4f}")

        log_text += ', test:'
        for k, v in ave_scalars.items():
            v = v.item()
            log_text += f' {k}={v:.4f}'
            log_temp_scalar('test/' + k, v, epoch)

        test_loss = ave_scalars['loss'].item()

        if epoch % milestone_epoch == 0:
            with torch.no_grad():
                pred = model(sample_batch_train).clamp(0, 1)
                video_batch = torch.cat([sample_batch_train, pred], dim=0)
                log_temp_videos('train/videos', video_batch, epoch)
                img_batch = video_batch[:, :, 3, :, :]
                log_temp_images('train/images', img_batch, epoch)

                pred = model(sample_batch_test).clamp(0, 1)
                video_batch = torch.cat([sample_batch_test, pred], dim=0)
                log_temp_videos('test/videos', video_batch, epoch)
                img_batch = video_batch[:, :, 3, :, :]
                log_temp_images('test/images', img_batch, epoch)

        # Summary and save
        log_text += ', {} {}/{}'.format(*epoch_timer.step())
        log(log_text)

        model_ = model.module if n_gpus > 1 else model
        model_spec = config['model']
        model_spec['sd'] = model_.state_dict()
        optimizer_spec = config['optimizer']
        optimizer_spec['sd'] = optimizer.state_dict()
        pth_file = {
            'model': model_spec,
            'optimizer': optimizer_spec,
            'epoch': epoch,
        }

        if test_loss < min_test_loss:
            min_test_loss = test_loss
            wandb.run.summary['min_test_loss'] = min_test_loss
            torch.save(pth_file, osp.join(save_dir, 'min-test-loss.pth'))

        torch.save(pth_file, osp.join(save_dir, 'epoch-last.pth'))

        writer.flush()
コード例 #17
0
def train(opt):
    model = www_model_jamo_vertical.STR(opt, device)
    print(
        'model parameters. height {}, width {}, num of fiducial {}, input channel {}, output channel {}, hidden size {},     batch max length {}'
        .format(opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel,
                opt.output_channel, opt.hidden_size, opt.batch_max_length))

    # weight initialization
    for name, param, in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initializaed')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)

        except Exception as e:
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # load pretrained model
    if opt.saved_model != '':
        base_path = './models'
        print(
            f'looking for pretrained model from {os.path.join(base_path, opt.saved_model)}'
        )

        try:
            model.load_state_dict(
                torch.load(os.path.join(base_path, opt.saved_model)))
            print('loading complete ')
        except Exception as e:
            print(e)
            print('coud not find model')

    #data parallel for multi GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()

    # loss
    criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
        device)  #ignore [GO] token = ignore index 0
    log_avg = utils.Averager()

    # filter that only require gradient descent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Tranable params : ', sum(params_num))

    # optimizer

    #     base_opt = optim.Adadelta(filtered_parameters, lr= opt.lr, rho = opt.rho, eps = opt.eps)
    base_opt = torch.optim.Adam(filtered_parameters, lr=0.001)
    optimizer = SWA(base_opt)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='max',
                                                           verbose=True,
                                                           patience=2,
                                                           factor=0.5)
    #     optimizer = adabound.AdaBound(filtered_parameters, lr=1e-3, final_lr=0.1)

    # opt log
    with open(f'./models/{opt.experiment_name}/opt.txt', 'a') as opt_file:
        opt_log = '---------------------Options-----------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)} : {str(v)}\n'
        opt_log += '---------------------------------------------\n'
        opt_file.write(opt_log)

    #start training
    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    swa_count = 0

    for n_epoch, epoch in enumerate(range(opt.num_epoch)):
        for n_iter, data_point in enumerate(data_loader):

            image_tensors, top, mid, bot = data_point

            image = image_tensors.to(device)
            text_top, length_top = top_converter.encode(
                top, batch_max_length=opt.batch_max_length)
            text_mid, length_mid = middle_converter.encode(
                mid, batch_max_length=opt.batch_max_length)
            text_bot, length_bot = bottom_converter.encode(
                bot, batch_max_length=opt.batch_max_length)
            batch_size = image.size(0)

            pred_top, pred_mid, pred_bot = model(image, text_top[:, :-1],
                                                 text_mid[:, :-1],
                                                 text_bot[:, :-1])

            #             cost_top = criterion(pred_top.view(-1, pred_top.shape[-1]), text_top[:, 1:].contiguous().view(-1))
            #             cost_mid = criterion(pred_mid.view(-1, pred_mid.shape[-1]), text_mid[:, 1:].contiguous().view(-1))
            #             cost_bot = criterion(pred_bot.view(-1, pred_bot.shape[-1]), text_bot[:, 1:].contiguous().view(-1))
            if n_iter % 2 == 0:

                cost_top = utils.reduced_focal_loss(
                    pred_top.view(-1, pred_top.shape[-1]),
                    text_top[:, 1:].contiguous().view(-1),
                    ignore_index=0,
                    gamma=2,
                    alpha=0.25,
                    threshold=0.5)
                cost_mid = utils.reduced_focal_loss(
                    pred_mid.view(-1, pred_mid.shape[-1]),
                    text_mid[:, 1:].contiguous().view(-1),
                    ignore_index=0,
                    gamma=2,
                    alpha=0.25,
                    threshold=0.5)
                cost_bot = utils.reduced_focal_loss(
                    pred_bot.view(-1, pred_bot.shape[-1]),
                    text_bot[:, 1:].contiguous().view(-1),
                    ignore_index=0,
                    gamma=2,
                    alpha=0.25,
                    threshold=0.5)
            else:
                cost_top = utils.CB_loss(text_top[:, 1:].contiguous().view(-1),
                                         pred_top.view(-1, pred_top.shape[-1]),
                                         top_per_cls, opt.top_n_cls, 'focal',
                                         0.999, 0.5)
                cost_mid = utils.CB_loss(text_mid[:, 1:].contiguous().view(-1),
                                         pred_mid.view(-1, pred_mid.shape[-1]),
                                         mid_per_cls, opt.middle_n_cls,
                                         'focal', 0.999, 0.5)
                cost_bot = utils.CB_loss(text_bot[:, 1:].contiguous().view(-1),
                                         pred_bot.view(-1, pred_bot.shape[-1]),
                                         bot_per_cls, opt.bottom_n_cls,
                                         'focal', 0.999, 0.5)
            cost = cost_top * 0.33 + cost_mid * 0.33 + cost_bot * 0.33

            loss_avg = utils.Averager()
            loss_avg.add(cost)

            model.zero_grad()
            cost.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), opt.grad_clip)  #gradient clipping with 5
            optimizer.step()
            print(loss_avg.val())

            #validation
            if (n_iter % opt.valInterval == 0) & (n_iter != 0):
                elapsed_time = time.time() - start_time
                with open(f'./models/{opt.experiment_name}/log_train.txt',
                          'a') as log:
                    model.eval()
                    with torch.no_grad():
                        valid_loss, current_accuracy, current_norm_ED, pred_top_str, pred_mid_str, pred_bot_str, label_top, label_mid, label_bot, infer_time, length_of_data = evaluate.validation_jamo(
                            model, criterion, valid_loader, top_converter,
                            middle_converter, bottom_converter, opt)
                    scheduler.step(current_accuracy)
                    model.train()

                    present_time = time.localtime()
                    loss_log = f'[epoch : {n_epoch}/{opt.num_epoch}] [iter : {n_iter*opt.batch_size} / {int(len(data) * 0.95)}]\n' + f'Train loss : {loss_avg.val():0.5f}, Valid loss : {valid_loss:0.5f}, Elapsed time : {elapsed_time:0.5f}, Present time : {present_time[1]}/{present_time[2]}, {present_time[3]+9} : {present_time[4]}'
                    loss_avg.reset()

                    current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"current_norm_ED":17s}: {current_norm_ED:0.2f}'

                    #keep the best
                    if current_accuracy > best_accuracy:
                        best_accuracy = current_accuracy
                        torch.save(
                            model.module.state_dict(),
                            f'./models/{opt.experiment_name}/best_accuracy_{round(current_accuracy,2)}.pth'
                        )

                    if current_norm_ED > best_norm_ED:
                        best_norm_ED = current_norm_ED
                        torch.save(
                            model.module.state_dict(),
                            f'./models/{opt.experiment_name}/best_norm_ED.pth')

                    best_model_log = f'{"Best accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'
                    loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                    print(loss_model_log)
                    log.write(loss_model_log + '\n')

                    dashed_line = '-' * 80
                    head = f'{"Ground Truth":25s} | {"Prediction" :25s}| T/F'
                    predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'

                    random_idx = np.random.choice(range(len(label_top)),
                                                  size=5,
                                                  replace=False)
                    label_concat = np.concatenate([
                        np.asarray(label_top).reshape(1, -1),
                        np.asarray(label_mid).reshape(1, -1),
                        np.asarray(label_bot).reshape(1, -1)
                    ],
                                                  axis=0).reshape(3, -1)
                    pred_concat = np.concatenate([
                        np.asarray(pred_top_str).reshape(1, -1),
                        np.asarray(pred_mid_str).reshape(1, -1),
                        np.asarray(pred_bot_str).reshape(1, -1)
                    ],
                                                 axis=0).reshape(3, -1)

                    for i in random_idx:
                        label_sample = label_concat[:, i]
                        pred_sample = pred_concat[:, i]

                        gt_str = utils.str_combine(label_sample[0],
                                                   label_sample[1],
                                                   label_sample[2])
                        pred_str = utils.str_combine(pred_sample[0],
                                                     pred_sample[1],
                                                     pred_sample[2])
                        predicted_result_log += f'{gt_str:25s} | {pred_str:25s} | \t{str(pred_str == gt_str)}\n'
                    predicted_result_log += f'{dashed_line}'
                    print(predicted_result_log)
                    log.write(predicted_result_log + '\n')

                # Stochastic weight averaging
                optimizer.update_swa()
                swa_count += 1
                if swa_count % 5 == 0:
                    optimizer.swap_swa_sgd()
                    torch.save(
                        model.module.state_dict(),
                        f'./models/{opt.experiment_name}/swa_{swa_count}.pth')

        if (n_epoch) % 5 == 0:
            torch.save(model.module.state_dict(),
                       f'./models/{opt.experiment_name}/{n_epoch}.pth')
コード例 #18
0
def main(config):
    # dataset
    dataset = datasets.make(config['dataset'], **config['dataset_args'])
    utils.log('dataset: {} (x{}), {}'.format(dataset[0][0].shape, len(dataset),
                                             dataset.n_classes))
    if not args.sauc:
        n_way = 5
    else:
        n_way = 2
    n_shot, n_unlabel, n_query = args.shot, 30, 15
    n_batch = 200
    ep_per_batch = 4
    batch_sampler = CategoriesSampler_Semi(dataset.label,
                                           n_batch,
                                           n_way,
                                           n_shot,
                                           n_unlabel,
                                           n_query,
                                           ep_per_batch=ep_per_batch)
    loader = DataLoader(dataset,
                        batch_sampler=batch_sampler,
                        num_workers=8,
                        pin_memory=True)

    # model
    if config.get('load') is None:
        model = models.make('meta-baseline', encoder=None)
    else:
        model = models.load(torch.load(config['load']))

    if config.get('load_encoder') is not None:
        encoder = models.load(torch.load(config['load_encoder'])).encoder
        model.encoder = encoder

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    model.eval()
    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    # testing
    aves_keys = ['vl', 'va']
    aves = {k: utils.Averager() for k in aves_keys}

    test_epochs = args.test_epochs
    np.random.seed(0)
    va_lst = []
    for epoch in range(1, test_epochs + 1):
        for data, _ in tqdm(loader, leave=False):
            x_shot, x_unlabel, x_query = fs.split_shot_query_semi(
                data.cuda(),
                n_way,
                n_shot,
                n_unlabel,
                n_query,
                ep_per_batch=ep_per_batch)

            with torch.no_grad():
                if not args.sauc:
                    logits = model(x_shot, x_unlabel, x_query).view(-1, n_way)
                    label = fs.make_nk_label(n_way,
                                             n_query,
                                             ep_per_batch=ep_per_batch).cuda()
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)

                    aves['vl'].add(loss.item(), len(data))
                    aves['va'].add(acc, len(data))
                    va_lst.append(acc)
                else:
                    x_shot = x_shot[:, 0, :, :, :, :].contiguous()
                    shot_shape = x_shot.shape[:-3]
                    img_shape = x_shot.shape[-3:]
                    bs = shot_shape[0]
                    p = model.encoder(x_shot.view(-1, *img_shape)).reshape(
                        *shot_shape, -1).mean(dim=1, keepdim=True)
                    q = model.encoder(x_query.view(-1, *img_shape)).view(
                        bs, -1, p.shape[-1])
                    p = F.normalize(p, dim=-1)
                    q = F.normalize(q, dim=-1)
                    s = torch.bmm(q, p.transpose(2, 1)).view(bs, -1).cpu()
                    for i in range(bs):
                        k = s.shape[1] // 2
                        y_true = [1] * k + [0] * k
                        acc = roc_auc_score(y_true, s[i])
                        aves['va'].add(acc, len(data))
                        va_lst.append(acc)

        print('test epoch {}: acc={:.2f} +- {:.2f} (%), loss={:.4f} (@{})'.
              format(epoch, aves['va'].item() * 100,
                     mean_confidence_interval(va_lst) * 100, aves['vl'].item(),
                     _[-1]))
コード例 #19
0
def main_worker(gpu, *args):
    ngpus_per_node, opt = args

    num_worker = 20
    batch_size = int(opt.batch_size / ngpus_per_node)
    num_worker = int(num_worker / ngpus_per_node)

    #     os.environ['MASTER_ADDR'] = '127.0.0.1'
    #     os.environ['MASTER_PORT'] = '2222'

    torch.distributed.init_process_group(
        backend='nccl',
        #         init_method = 'tcp://127.0.0.1:2222',
        init_method='env://',
        world_size=ngpus_per_node,
        rank=gpu)

    model = BaseModel.model(opt, device)

    # load pretrained model
    if opt.saved_model != '':
        base_path = './models'
        print(
            f'looking for pretrained model from {os.path.join(base_path, opt.saved_model)}'
        )
        try:
            model.load_state_dict(
                torch.load(
                    os.path.join(base_path, opt.saved_model),
                    map_location='cpu' if device.type == 'cpu' else None))
            print('loading complete ')
        except Exception as e:
            print(e)
            print('coud not load model')

    model = model.cuda(gpu)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])

    train_loader, test_loader = get_train_loader(opt, num_worker)

    # loss
    criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
        device)  #ignore [GO] token = ignore index 0
    log_avg = utils.Averager()

    # filter that only require gradient descent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Tranable params : ', sum(params_num))

    # optimizer
    optimizer = optim.Adadelta(filtered_parameters,
                               lr=opt.lr,
                               rho=opt.rho,
                               eps=opt.eps)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='max',
                                                           verbose=True,
                                                           patience=2,
                                                           factor=0.5)

    #start training
    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    swa_count = 0

    for n_epoch, epoch in enumerate(range(opt.num_epoch)):
        for n_iter, data_point in enumerate(train_loader):

            image_tensors, top, mid, bot = data_point

            image = image_tensors.to(gpu)
            text_top, length_top = top_converter.encode(
                top, batch_max_length=opt.batch_max_length)
            text_mid, length_mid = middle_converter.encode(
                mid, batch_max_length=opt.batch_max_length)
            text_bot, length_bot = bottom_converter.encode(
                bot, batch_max_length=opt.batch_max_length)

            text_top, length_top = text_top.to(gpu), length_top.to(gpu)
            text_mid, length_mid = text_mid.to(gpu), length_mid.to(gpu)
            text_bot, length_bot = text_bot.to(gpu), length_bot.to(gpu)

            batch_size = image.size(0)

            pred_top, pred_mid, pred_bot = model(image, text_top[:, :-1],
                                                 text_mid[:, :-1],
                                                 text_bot[:, :-1])

            #             cost_top = criterion(pred_top.view(-1, pred_top.shape[-1]), text_top[:, 1:].contiguous().view(-1))
            #             cost_mid = criterion(pred_mid.view(-1, pred_mid.shape[-1]), text_mid[:, 1:].contiguous().view(-1))
            #             cost_bot = criterion(pred_bot.view(-1, pred_bot.shape[-1]), text_bot[:, 1:].contiguous().view(-1))

            cost_top = utils.reduced_focal_loss(
                pred_top.view(-1, pred_top.shape[-1]),
                text_top[:, 1:].contiguous().view(-1),
                gamma=2,
                threshold=0.5)
            cost_mid = utils.reduced_focal_loss(
                pred_mid.view(-1, pred_mid.shape[-1]),
                text_mid[:, 1:].contiguous().view(-1),
                gamma=2,
                threshold=0.5)
            cost_bot = utils.reduced_focal_loss(
                pred_bot.view(-1, pred_bot.shape[-1]),
                text_bot[:, 1:].contiguous().view(-1),
                gamma=2,
                threshold=0.5)

            #             cost_top = utils.CB_loss(text_top[:, 1:].contiguous().view(-1), pred_top.view(-1, pred_top.shape[-1]), top_per_cls, opt.top_n_cls, 'focal', 0.99, 2)
            #             cost_mid = utils.CB_loss(text_mid[:, 1:].contiguous().view(-1), pred_mid.view(-1, pred_mid.shape[-1]), mid_per_cls, opt.mid_n_cls, 'focal', 0.99, 2)
            #             cost_bot = utils.CB_loss(text_bot[:, 1:].contiguous().view(-1), pred_bot.view(-1, pred_bot.shape[-1]), bot_per_cls, opt.bot_n_cls, 'focal', 0.99, 2)
            cost = cost_top + cost_mid + cost_bot

            #             print('Cost top : ', cost_top)
            #             print('Cost mid : ', cost_mid)
            #             print('Cost bot : ', cost_bot)
            loss_avg = utils.Averager()
            loss_avg.add(cost)

            model.zero_grad()
            cost.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), opt.grad_clip)  #gradient clipping with 5
            optimizer.step()

            print(
                f'epoch : {epoch} | step : {n_iter} / {len(train_loader)} | mp : {gpu}'
            )
コード例 #20
0
def main(config):

    config['dataset_args']['data_dir'] = os.path.join(os.getcwd(), os.pardir,
                                                      'data_root')
    dataset = CustomDataset(name=config['dataset'], **config['dataset_args'])
    n_way = 5
    n_shot = config['n_shot']
    n_query = config.get('n_query') if config.get(
        'n_query') is not None else 15
    n_pseudo = config['n_pseudo'] if config.get('n_pseudo') is not None else 15
    n_batch = config['train_batches'] if config.get(
        'train_batches') is not None else 200
    ep_per_batch = config['ep_per_batch'] if config.get(
        'ep_per_batch') is not None else 4

    batch_sampler = EpisodicSampler(dataset,
                                    n_batch,
                                    n_way,
                                    n_shot,
                                    n_query,
                                    n_pseudo,
                                    episodes_per_batch=ep_per_batch)
    loader = DataLoader(dataset,
                        batch_sampler=batch_sampler,
                        num_workers=4,
                        pin_memory=True)

    model_sv = torch.load(config['load'])
    model = models.load(model_sv)
    if config.get('fs_dataset'):
        fs_model = models.make('meta-baseline', encoder=None)
        fs_model.encoder = model.encoder
        model = fs_model

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    model.eval()

    # testing
    aves_keys = ['vl', 'va']
    aves = {k: utils.Averager() for k in aves_keys}

    test_epochs = args.test_epochs
    np.random.seed(0)
    va_lst = []
    for epoch in range(1, test_epochs + 1):
        for data in tqdm(loader, desc=f"eval: {epoch}", leave=False):
            x_shot, x_query, x_pseudo = fs.split_shot_query(
                data.cuda(),
                n_way,
                n_shot,
                n_query,
                n_pseudo,
                ep_per_batch=ep_per_batch)

            with torch.no_grad():
                logits = model(x_shot, x_query, x_pseudo)
                logits = logits.view(-1, n_way)
                label = fs.make_nk_label(n_way,
                                         n_query,
                                         ep_per_batch=ep_per_batch).cuda()

                loss = F.cross_entropy(logits, label)
                acc = utils.compute_acc(logits, label)

                aves['vl'].add(loss.item(), len(data))
                aves['va'].add(acc, len(data))
                va_lst.append(acc)

        utils.log(
            'test epoch {}: acc={:.2f} +- {:.2f} (%), loss={:.4f}'.format(
                epoch, aves['va'].item() * 100,
                mean_confidence_interval(va_lst) * 100, aves['vl'].item()),
            filename='test_log.txt')
コード例 #21
0
    def train(self):
        logging.basicConfig()
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        log_file_path = args.save_prefix + '_train.log'
        log_dir = os.path.dirname(log_file_path)
        if log_dir and not os.path.exists(log_dir):
            os.mkdir(log_dir)
        fh = logging.FileHandler(log_file_path)
        logger.addHandler(fh)
        logger.info(args)
        logger.info('Start training from [Epoch {}]'.format(args.start_epoch +
                                                            1))

        losses = utils.Averager()
        train_accuracy = utils.Averager()

        for epoch in range(args.start_epoch, args.nepoch):
            self.net.train()
            btic = time.time()
            for i, (imgs, labels) in enumerate(self.train_dataloader):
                batch_size = imgs.size()[0]
                imgs = imgs.cuda()
                preds = self.net(imgs).cpu()
                text, length = self.converter.encode(
                    labels
                )  # length  一个batch各个样本的字符长度, text 一个batch中所有中文字符所对应的下标
                preds_size = torch.IntTensor([preds.size(0)] * batch_size)
                loss_avg = self.criterion(preds, text, preds_size,
                                          length) / batch_size

                self.optimizer.zero_grad()
                loss_avg.backward()
                self.optimizer.step()

                losses.update(loss_avg.item(), batch_size)

                _, preds_m = preds.max(2)
                preds_m = preds_m.transpose(1, 0).contiguous().view(-1)
                sim_preds = self.converter.decode(preds_m.data,
                                                  preds_size.data,
                                                  raw=False)
                n_correct = 0
                for pred, target in zip(sim_preds, labels):
                    if pred == target:
                        n_correct += 1
                train_accuracy.update(n_correct, batch_size, MUL_n=False)

                if args.log_interval and not (i + 1) % args.log_interval:
                    logger.info(
                        '[Epoch {}/{}][Batch {}/{}], Speed: {:.3f} samples/sec, Loss:{:.3f}'
                        .format(epoch + 1, args.nepoch, i + 1,
                                len(self.train_dataloader),
                                batch_size / (time.time() - btic),
                                losses.val()))
                    losses.reset()

            logger.info(
                'Training accuracy: {:.3f}, [#correct:{} / #total:{}]'.format(
                    train_accuracy.val(), train_accuracy.sum,
                    train_accuracy.count))
            train_accuracy.reset()

            if args.val_interval and not (epoch + 1) % args.val_interval:
                acc = self.validate(logger)
                if acc > self.best_acc:
                    self.best_acc = acc
                    save_path = '{:s}_best.pth'.format(args.save_prefix)
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.net.state_dict(),
                            # 'optimizer_state_dict': self.optimizer.state_dict(),
                            'best_acc': self.best_acc,
                        },
                        save_path)
                logging.info("best acc is:{:.3f}".format(self.best_acc))
                if args.save_interval and not (epoch + 1) % args.save_interval:
                    save_path = '{:s}_{:04d}_{:.3f}.pth'.format(
                        args.save_prefix, epoch + 1, acc)
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.net.state_dict(),
                            # 'optimizer_state_dict': self.optimizer.state_dict(),
                            'best_acc': self.best_acc,
                        },
                        save_path)
コード例 #22
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'meta_{}-{}shot'.format(
                config['train_dataset'], config['n_shot'])
        svname += '_' + config['model'] + '-' + config['model_args']['encoder']
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    n_way, n_shot = config['n_way'], config['n_shot']
    n_query = config['n_query']

    if config.get('n_train_way') is not None:
        n_train_way = config['n_train_way']
    else:
        n_train_way = n_way
    if config.get('n_train_shot') is not None:
        n_train_shot = config['n_train_shot']
    else:
        n_train_shot = n_shot
    if config.get('ep_per_batch') is not None:
        ep_per_batch = config['ep_per_batch']
    else:
        ep_per_batch = 1

    # train
    train_dataset = datasets.make(config['train_dataset'],
                                  **config['train_dataset_args'])
    utils.log('train dataset: {} (x{}), {}'.format(
            train_dataset[0][0].shape, len(train_dataset),
            train_dataset.n_classes))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(train_dataset, 'train_dataset', writer)
    train_sampler = CategoriesSampler(
            train_dataset.label, config['train_batches'],
            n_train_way, n_train_shot + n_query,
            ep_per_batch=ep_per_batch)
    train_loader = DataLoader(train_dataset, batch_sampler=train_sampler,
                              num_workers=8, pin_memory=True)

    # tval
    if config.get('tval_dataset'):
        tval_dataset = datasets.make(config['tval_dataset'],
                                     **config['tval_dataset_args'])
        utils.log('tval dataset: {} (x{}), {}'.format(
                tval_dataset[0][0].shape, len(tval_dataset),
                tval_dataset.n_classes))
        if config.get('visualize_datasets'):
            utils.visualize_dataset(tval_dataset, 'tval_dataset', writer)
        tval_sampler = CategoriesSampler(
                tval_dataset.label, 200,
                n_way, n_shot + n_query,
                ep_per_batch=4)
        tval_loader = DataLoader(tval_dataset, batch_sampler=tval_sampler,
                                 num_workers=8, pin_memory=True)
    else:
        tval_loader = None

    # val
    val_dataset = datasets.make(config['val_dataset'],
                                **config['val_dataset_args'])
    utils.log('val dataset: {} (x{}), {}'.format(
            val_dataset[0][0].shape, len(val_dataset),
            val_dataset.n_classes))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(val_dataset, 'val_dataset', writer)
    val_sampler = CategoriesSampler(
            val_dataset.label, 200,
            n_way, n_shot + n_query,
            ep_per_batch=4)
    val_loader = DataLoader(val_dataset, batch_sampler=val_sampler,
                            num_workers=8, pin_memory=True)

    ########

    #### Model and optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

        if config.get('load_encoder'):
            encoder = models.load(torch.load(config['load_encoder'])).encoder
            model.encoder.load_state_dict(encoder.state_dict())

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(
            model.parameters(),
            config['optimizer'], **config['optimizer_args'])

    ########
    
    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    aves_keys = ['tl', 'ta', 'tvl', 'tva', 'vl', 'va']
    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        if config.get('freeze_bn'):
            utils.freeze_bn(model) 
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        np.random.seed(epoch)
        for data, _ in tqdm(train_loader, desc='train', leave=False):
            x_shot, x_query = fs.split_shot_query(
                    data.cuda(), n_train_way, n_train_shot, n_query,
                    ep_per_batch=ep_per_batch)
            label = fs.make_nk_label(n_train_way, n_query,
                    ep_per_batch=ep_per_batch).cuda()

            logits = model(x_shot, x_query).view(-1, n_train_way)
            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None; loss = None 

        # eval
        model.eval()

        for name, loader, name_l, name_a in [
                ('tval', tval_loader, 'tvl', 'tva'),
                ('val', val_loader, 'vl', 'va')]:

            if (config.get('tval_dataset') is None) and name == 'tval':
                continue

            np.random.seed(0)
            for data, _ in tqdm(loader, desc=name, leave=False):
                x_shot, x_query = fs.split_shot_query(
                        data.cuda(), n_way, n_shot, n_query,
                        ep_per_batch=4)
                label = fs.make_nk_label(n_way, n_query,
                        ep_per_batch=4).cuda()

                with torch.no_grad():
                    logits = model(x_shot, x_query).view(-1, n_way)
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)
                
                aves[name_l].add(loss.item())
                aves[name_a].add(acc)

        _sig = int(_[-1])

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()
            trlog[k].append(aves[k])

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)
        utils.log('epoch {}, train {:.4f}|{:.4f}, tval {:.4f}|{:.4f}, '
                'val {:.4f}|{:.4f}, {} {}/{} (@{})'.format(
                epoch, aves['tl'], aves['ta'], aves['tvl'], aves['tva'],
                aves['vl'], aves['va'], t_epoch, t_used, t_estimate, _sig))

        writer.add_scalars('loss', {
            'train': aves['tl'],
            'tval': aves['tvl'],
            'val': aves['vl'],
        }, epoch)
        writer.add_scalars('acc', {
            'train': aves['ta'],
            'tval': aves['tva'],
            'val': aves['va'],
        }, epoch)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,

            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),

            'training': training,
        }
        torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(save_path, 'trlog.pth'))

        if (save_epoch is not None) and epoch % save_epoch == 0:
            torch.save(save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))

        writer.flush()
コード例 #23
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'pretrain-multi'
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    def make_dataset(name):
        dataset = make_md([name],
            'batch', split='train', image_size=126, batch_size=256)
        return dataset

    ds_names = ['ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd', \
            'quickdraw', 'fungi', 'vgg_flower']
    datasets = []
    for name in ds_names:
        datasets.append(make_dataset(name))
    iters = []
    for d in datasets:
        iters.append(d.make_one_shot_iterator().get_next())

    to_torch_labels = lambda a: torch.from_numpy(a).long()

    to_pil = transforms.ToPILImage()
    augmentation = transforms.Compose([
        transforms.Resize(146),
        transforms.RandomResizedCrop(128),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
    ])
    ########

    #### Model and Optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(
            model.parameters(),
            config['optimizer'], **config['optimizer_args'])

    ########
    
    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves_keys = ['tl', 'ta', 'vl', 'va']
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        n_batch = 915547 // 256
        with tf.Session() as sess:
            for i_batch in tqdm(range(n_batch)):
                if random.randint(0, 1) == 0:
                    ds_id = 0
                else:
                    ds_id = random.randint(1, len(datasets) - 1)

                next_element = iters[ds_id]
                e, cfr_id = sess.run(next_element)

                data_, label = e[0], to_torch_labels(e[1])
                data_ = ((data_ + 1.0) * 0.5 * 255).astype('uint8')
                data = torch.zeros(256, 3, 128, 128).float()
                for i in range(len(data_)):
                    x = data_[i]
                    x = to_pil(x)
                    x = augmentation(x)
                    data[i] = x

                data = data.cuda()
                label = label.cuda()

                logits = model(data, cfr_id=ds_id)
                loss = F.cross_entropy(logits, label)
                acc = utils.compute_acc(logits, label)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                aves['tl'].add(loss.item())
                aves['ta'].add(acc)

                logits = None; loss = None

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)

        if epoch <= max_epoch:
            epoch_str = str(epoch)
        else:
            epoch_str = 'ex'
        log_str = 'epoch {}, train {:.4f}|{:.4f}'.format(
                epoch_str, aves['tl'], aves['ta'])
        writer.add_scalars('loss', {'train': aves['tl']}, epoch)
        writer.add_scalars('acc', {'train': aves['ta']}, epoch)

        if epoch <= max_epoch:
            log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        else:
            log_str += ', {}'.format(t_epoch)
        utils.log(log_str)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,

            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),

            'training': training,
        }
        if epoch <= max_epoch:
            torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))

            if (save_epoch is not None) and epoch % save_epoch == 0:
                torch.save(save_obj, os.path.join(
                    save_path, 'epoch-{}.pth'.format(epoch)))

            if aves['va'] > max_va:
                max_va = aves['va']
                torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))
        else:
            torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth'))

        writer.flush()
コード例 #24
0
ファイル: train_meta.py プロジェクト: tce/Bongard-LOGO
def main(config):
    svname = args.name
    if svname is None:
        svname = 'meta_{}-{}shot'.format(config['train_dataset'],
                                         config['n_shot'])
        svname += '_' + config['model']
        if config['model_args'].get('encoder'):
            svname += '-' + config['model_args']['encoder']
        if config['model_args'].get('prog_synthesis'):
            svname += '-' + config['model_args']['prog_synthesis']
    svname += '-seed' + str(args.seed)
    if args.tag is not None:
        svname += '_' + args.tag

    save_path = os.path.join(args.save_dir, svname)
    utils.ensure_path(save_path, remove=False)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    logger = utils.Logger(file_name=os.path.join(save_path, "log_sdout.txt"),
                          file_mode="a+",
                          should_flush=True)

    #### Dataset ####

    n_way, n_shot = config['n_way'], config['n_shot']
    n_query = config['n_query']

    if config.get('n_train_way') is not None:
        n_train_way = config['n_train_way']
    else:
        n_train_way = n_way
    if config.get('n_train_shot') is not None:
        n_train_shot = config['n_train_shot']
    else:
        n_train_shot = n_shot
    if config.get('ep_per_batch') is not None:
        ep_per_batch = config['ep_per_batch']
    else:
        ep_per_batch = 1

    random_state = np.random.RandomState(args.seed)
    print('seed:', args.seed)

    # train
    train_dataset = datasets.make(config['train_dataset'],
                                  **config['train_dataset_args'])
    utils.log('train dataset: {} (x{})'.format(train_dataset[0][0].shape,
                                               len(train_dataset)))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(train_dataset, 'train_dataset', writer)
    train_sampler = BongardSampler(train_dataset.n_tasks,
                                   config['train_batches'], ep_per_batch,
                                   random_state.randint(2**31))
    train_loader = DataLoader(train_dataset,
                              batch_sampler=train_sampler,
                              num_workers=8,
                              pin_memory=True)

    # tvals
    tval_loaders = {}
    tval_name_ntasks_dict = {
        'tval': 2000,
        'tval_ff': 600,
        'tval_bd': 480,
        'tval_hd_comb': 400,
        'tval_hd_novel': 320
    }  # numbers depend on dataset
    for tval_type in tval_name_ntasks_dict.keys():
        if config.get('{}_dataset'.format(tval_type)):
            tval_dataset = datasets.make(
                config['{}_dataset'.format(tval_type)],
                **config['{}_dataset_args'.format(tval_type)])
            utils.log('{} dataset: {} (x{})'.format(tval_type,
                                                    tval_dataset[0][0].shape,
                                                    len(tval_dataset)))
            if config.get('visualize_datasets'):
                utils.visualize_dataset(tval_dataset, 'tval_ff_dataset',
                                        writer)
            tval_sampler = BongardSampler(
                tval_dataset.n_tasks,
                n_batch=tval_name_ntasks_dict[tval_type] // ep_per_batch,
                ep_per_batch=ep_per_batch,
                seed=random_state.randint(2**31))
            tval_loader = DataLoader(tval_dataset,
                                     batch_sampler=tval_sampler,
                                     num_workers=8,
                                     pin_memory=True)
            tval_loaders.update({tval_type: tval_loader})
        else:
            tval_loaders.update({tval_type: None})

    # val
    val_dataset = datasets.make(config['val_dataset'],
                                **config['val_dataset_args'])
    utils.log('val dataset: {} (x{})'.format(val_dataset[0][0].shape,
                                             len(val_dataset)))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(val_dataset, 'val_dataset', writer)
    val_sampler = BongardSampler(val_dataset.n_tasks,
                                 n_batch=900 // ep_per_batch,
                                 ep_per_batch=ep_per_batch,
                                 seed=random_state.randint(2**31))
    val_loader = DataLoader(val_dataset,
                            batch_sampler=val_sampler,
                            num_workers=8,
                            pin_memory=True)

    ########

    #### Model and optimizer ####

    if config.get('load'):
        print('loading pretrained model: ', config['load'])
        model = models.load(torch.load(config['load']))
    else:
        model = models.make(config['model'], **config['model_args'])

        if config.get('load_encoder'):
            print('loading pretrained encoder: ', config['load_encoder'])
            encoder = models.load(torch.load(config['load_encoder'])).encoder
            model.encoder.load_state_dict(encoder.state_dict())

        if config.get('load_prog_synthesis'):
            print('loading pretrained program synthesis model: ',
                  config['load_prog_synthesis'])
            prog_synthesis = models.load(
                torch.load(config['load_prog_synthesis']))
            model.prog_synthesis.load_state_dict(prog_synthesis.state_dict())

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(model.parameters(),
                                                   config['optimizer'],
                                                   **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    aves_keys = ['tl', 'ta', 'vl', 'va']
    tval_tuple_lst = []
    for k, v in tval_loaders.items():
        if v is not None:
            loss_key = 'tvl' + k.split('tval')[-1]
            acc_key = ' tva' + k.split('tval')[-1]
            aves_keys.append(loss_key)
            aves_keys.append(acc_key)
            tval_tuple_lst.append((k, v, loss_key, acc_key))

    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        if config.get('freeze_bn'):
            utils.freeze_bn(model)
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        for data, label in tqdm(train_loader, desc='train', leave=False):

            x_shot, x_query = fs.split_shot_query(data.cuda(),
                                                  n_train_way,
                                                  n_train_shot,
                                                  n_query,
                                                  ep_per_batch=ep_per_batch)
            label_query = fs.make_nk_label(n_train_way,
                                           n_query,
                                           ep_per_batch=ep_per_batch).cuda()

            if config['model'] == 'snail':  # only use one selected label_query
                query_dix = random_state.randint(n_train_way * n_query)
                label_query = label_query.view(ep_per_batch, -1)[:, query_dix]
                x_query = x_query[:, query_dix:query_dix + 1]

            if config['model'] == 'maml':  # need grad in maml
                model.zero_grad()

            logits = model(x_shot, x_query).view(-1, n_train_way)
            loss = F.cross_entropy(logits, label_query)
            acc = utils.compute_acc(logits, label_query)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None
            loss = None

        # eval
        model.eval()

        for name, loader, name_l, name_a in [('val', val_loader, 'vl', 'va')
                                             ] + tval_tuple_lst:

            if config.get('{}_dataset'.format(name)) is None:
                continue

            np.random.seed(0)
            for data, _ in tqdm(loader, desc=name, leave=False):
                x_shot, x_query = fs.split_shot_query(
                    data.cuda(),
                    n_way,
                    n_shot,
                    n_query,
                    ep_per_batch=ep_per_batch)
                label_query = fs.make_nk_label(
                    n_way, n_query, ep_per_batch=ep_per_batch).cuda()

                if config[
                        'model'] == 'snail':  # only use one randomly selected label_query
                    query_dix = random_state.randint(n_train_way)
                    label_query = label_query.view(ep_per_batch, -1)[:,
                                                                     query_dix]
                    x_query = x_query[:, query_dix:query_dix + 1]

                if config['model'] == 'maml':  # need grad in maml
                    model.zero_grad()
                    logits = model(x_shot, x_query, eval=True).view(-1, n_way)
                    loss = F.cross_entropy(logits, label_query)
                    acc = utils.compute_acc(logits, label_query)
                else:
                    with torch.no_grad():
                        logits = model(x_shot, x_query,
                                       eval=True).view(-1, n_way)
                        loss = F.cross_entropy(logits, label_query)
                        acc = utils.compute_acc(logits, label_query)

                aves[name_l].add(loss.item())
                aves[name_a].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()
            trlog[k].append(aves[k])

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)
        log_str = 'epoch {}, train {:.4f}|{:.4f}, val {:.4f}|{:.4f}'.format(
            epoch, aves['tl'], aves['ta'], aves['vl'], aves['va'])
        for tval_name, _, loss_key, acc_key in tval_tuple_lst:
            log_str += ', {} {:.4f}|{:.4f}'.format(tval_name, aves[loss_key],
                                                   aves[acc_key])
            writer.add_scalars('loss', {tval_name: aves[loss_key]}, epoch)
            writer.add_scalars('acc', {tval_name: aves[acc_key]}, epoch)
        log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        utils.log(log_str)

        writer.add_scalars('loss', {
            'train': aves['tl'],
            'val': aves['vl'],
        }, epoch)
        writer.add_scalars('acc', {
            'train': aves['ta'],
            'val': aves['va'],
        }, epoch)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,
            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),
            'training': training,
        }
        torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(save_path, 'trlog.pth'))

        if (save_epoch is not None) and epoch % save_epoch == 0:
            torch.save(save_obj,
                       os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))

        writer.flush()

    print('finished training!')
    logger.close()
コード例 #25
0
ファイル: test.py プロジェクト: zt706/liif
def eval_psnr(loader,
              model,
              data_norm=None,
              eval_type=None,
              eval_bsize=None,
              verbose=False):
    model.eval()

    if data_norm is None:
        data_norm = {
            'inp': {
                'sub': [0],
                'div': [1]
            },
            'gt': {
                'sub': [0],
                'div': [1]
            }
        }
    t = data_norm['inp']
    inp_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda()
    inp_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda()
    t = data_norm['gt']
    gt_sub = torch.FloatTensor(t['sub']).view(1, 1, -1).cuda()
    gt_div = torch.FloatTensor(t['div']).view(1, 1, -1).cuda()

    if eval_type is None:
        metric_fn = utils.calc_psnr
    elif eval_type.startswith('div2k'):
        scale = int(eval_type.split('-')[1])
        metric_fn = partial(utils.calc_psnr, dataset='div2k', scale=scale)
    elif eval_type.startswith('benchmark'):
        scale = int(eval_type.split('-')[1])
        metric_fn = partial(utils.calc_psnr, dataset='benchmark', scale=scale)
    else:
        raise NotImplementedError

    val_res = utils.Averager()

    pbar = tqdm(loader, leave=False, desc='val')
    for batch in pbar:
        for k, v in batch.items():
            batch[k] = v.cuda()

        inp = (batch['inp'] - inp_sub) / inp_div
        if eval_bsize is None:
            with torch.no_grad():
                pred = model(inp, batch['coord'], batch['cell'])
        else:
            pred = batched_predict(model, inp, batch['coord'], batch['cell'],
                                   eval_bsize)
        pred = pred * gt_div + gt_sub
        pred.clamp_(0, 1)

        if eval_type is not None:  # reshape for shaving-eval
            ih, iw = batch['inp'].shape[-2:]
            s = math.sqrt(batch['coord'].shape[1] / (ih * iw))
            shape = [batch['inp'].shape[0], round(ih * s), round(iw * s), 3]
            pred = pred.view(*shape) \
                .permute(0, 3, 1, 2).contiguous()
            batch['gt'] = batch['gt'].view(*shape) \
                .permute(0, 3, 1, 2).contiguous()

        res = metric_fn(pred, batch['gt'])
        val_res.add(res.item(), inp.shape[0])

        if verbose:
            pbar.set_description('val {:.4f}'.format(val_res.item()))

    return val_res.item()
コード例 #26
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'classifier_{}'.format(config['train_dataset'])
        svname += '_' + config['model_args']['encoder']
        clsfr = config['model_args']['classifier']
        if clsfr != 'linear-classifier':
            svname += '-' + clsfr
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    # train
    train_dataset = datasets.make(config['train_dataset'],
                                  **config['train_dataset_args'])
    augmentations = [
        transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomRotation(35),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.RandomRotation(35),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomRotation(35),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomRotation(35),
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    ]
    train_dataset.transform = augmentations[int(config['_a'])]
    print(train_dataset.transform)
    print("_a", config['_a'])
    input("Continue with these augmentations?")

    train_loader = DataLoader(train_dataset,
                              config['batch_size'],
                              shuffle=True,
                              num_workers=0,
                              pin_memory=True)
    utils.log('train dataset: {} (x{}), {}'.format(train_dataset[0][0].shape,
                                                   len(train_dataset),
                                                   train_dataset.n_classes))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(train_dataset, 'train_dataset', writer)

    # val
    if config.get('val_dataset'):
        eval_val = True
        val_dataset = datasets.make(config['val_dataset'],
                                    **config['val_dataset_args'])
        val_loader = DataLoader(val_dataset,
                                config['batch_size'],
                                num_workers=0,
                                pin_memory=True)
        utils.log('val dataset: {} (x{}), {}'.format(val_dataset[0][0].shape,
                                                     len(val_dataset),
                                                     val_dataset.n_classes))
        if config.get('visualize_datasets'):
            utils.visualize_dataset(val_dataset, 'val_dataset', writer)
    else:
        eval_val = False

    # few-shot eval
    if config.get('fs_dataset'):
        ef_epoch = config.get('eval_fs_epoch')
        if ef_epoch is None:
            ef_epoch = 5
        eval_fs = True

        fs_dataset = datasets.make(config['fs_dataset'],
                                   **config['fs_dataset_args'])
        utils.log('fs dataset: {} (x{}), {}'.format(fs_dataset[0][0].shape,
                                                    len(fs_dataset),
                                                    fs_dataset.n_classes))
        if config.get('visualize_datasets'):
            utils.visualize_dataset(fs_dataset, 'fs_dataset', writer)

        n_way = 5
        n_query = 15
        n_shots = [1, 5]
        fs_loaders = []
        for n_shot in n_shots:
            fs_sampler = CategoriesSampler(fs_dataset.label,
                                           200,
                                           n_way,
                                           n_shot + n_query,
                                           ep_per_batch=4)
            fs_loader = DataLoader(fs_dataset,
                                   batch_sampler=fs_sampler,
                                   num_workers=0,
                                   pin_memory=True)
            fs_loaders.append(fs_loader)
    else:
        eval_fs = False

    ########

    #### Model and Optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

    if eval_fs:
        fs_model = models.make('meta-baseline', encoder=None)
        fs_model.encoder = model.encoder

    if config.get('_parallel'):
        model = nn.DataParallel(model)
        if eval_fs:
            fs_model = nn.DataParallel(fs_model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(model.parameters(),
                                                   config['optimizer'],
                                                   **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    for epoch in range(1, max_epoch + 1 + 1):
        if epoch == max_epoch + 1:
            if not config.get('epoch_ex'):
                break
            train_dataset.transform = train_dataset.default_transform
            print(train_dataset.transform)
            train_loader = DataLoader(train_dataset,
                                      config['batch_size'],
                                      shuffle=True,
                                      num_workers=0,
                                      pin_memory=True)

        timer_epoch.s()
        aves_keys = ['tl', 'ta', 'vl', 'va']
        if eval_fs:
            for n_shot in n_shots:
                aves_keys += ['fsa-' + str(n_shot)]
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        for data, label in tqdm(train_loader, desc='train', leave=False):
            # for data, label in train_loader:
            data, label = data.cuda(), label.cuda()
            logits = model(data)
            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None
            loss = None

        # eval
        if eval_val:
            model.eval()
            for data, label in tqdm(val_loader, desc='val', leave=False):
                data, label = data.cuda(), label.cuda()
                with torch.no_grad():
                    logits = model(data)
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)

                aves['vl'].add(loss.item())
                aves['va'].add(acc)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            fs_model.eval()
            for i, n_shot in enumerate(n_shots):
                np.random.seed(0)
                for data, _ in tqdm(fs_loaders[i],
                                    desc='fs-' + str(n_shot),
                                    leave=False):
                    x_shot, x_query = fs.split_shot_query(data.cuda(),
                                                          n_way,
                                                          n_shot,
                                                          n_query,
                                                          ep_per_batch=4)
                    label = fs.make_nk_label(n_way, n_query,
                                             ep_per_batch=4).cuda()
                    with torch.no_grad():
                        logits = fs_model(x_shot, x_query).view(-1, n_way)
                        acc = utils.compute_acc(logits, label)
                    aves['fsa-' + str(n_shot)].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)

        if epoch <= max_epoch:
            epoch_str = str(epoch)
        else:
            epoch_str = 'ex'
        log_str = 'epoch {}, train {:.4f}|{:.4f}'.format(
            epoch_str, aves['tl'], aves['ta'])
        writer.add_scalars('loss', {'train': aves['tl']}, epoch)
        writer.add_scalars('acc', {'train': aves['ta']}, epoch)

        if eval_val:
            log_str += ', val {:.4f}|{:.4f}'.format(aves['vl'], aves['va'])
            writer.add_scalars('loss', {'val': aves['vl']}, epoch)
            writer.add_scalars('acc', {'val': aves['va']}, epoch)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            log_str += ', fs'
            for n_shot in n_shots:
                key = 'fsa-' + str(n_shot)
                log_str += ' {}: {:.4f}'.format(n_shot, aves[key])
                writer.add_scalars('acc', {key: aves[key]}, epoch)

        if epoch <= max_epoch:
            log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        else:
            log_str += ', {}'.format(t_epoch)
        utils.log(log_str)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,
            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),
            'training': training,
        }
        if epoch <= max_epoch:
            torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))

            if (save_epoch is not None) and epoch % save_epoch == 0:
                torch.save(
                    save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

            if aves['va'] > max_va:
                max_va = aves['va']
                torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))
        else:
            torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth'))

        writer.flush()
コード例 #27
0
ファイル: efifstr_torch.py プロジェクト: ximzzzzz/Food_CAMERA
def train(opt):
    model = Model.Basemodel(opt, device)
    
    # weight initialization
    for name, param, in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initializaed')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
                
        except Exception as e :
            if 'weight' in name:
                param.data.fill_(1)
            continue
            
    # load pretrained model
    if opt.saved_model != '':
        base_path = './models'
        print(f'looking for pretrained model from {os.path.join(base_path, opt.saved_model)}')
        
        try :
            model.load_state_dict(torch.load(os.path.join(base_path, opt.saved_model)))
            print('loading complete ')    
        except Exception as e:
            print(e)
            print('coud not find model')
            
    #data parallel for multi GPU
    model = torch.nn.DataParallel(model, device_ids=[0]).to(device)
    model.train() 
    
    # filter that only require gradient descent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p : p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Tranable params : ', sum(params_num))

    loss_avg = utils.Averager()
    loss_avg_glyph = utils.Averager()
    
    # optimizer
    
    optimizer = optim.Adadelta(filtered_parameters, lr= opt.lr, rho = opt.rho, eps = opt.eps)
#     optimizer = torch.optim.Adam(filtered_parameters, lr=0.0001)
#     optimizer = SWA(base_opt)
#     optimizer = torch.optim.AdamW(filtered_parameters)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', verbose=True, patience = 2, factor= 0.5 )
#     optimizer = adabound.AdaBound(filtered_parameters, lr=1e-3, final_lr=0.1)

    # opt log
    with open(f'./models/{opt.experiment_name}/opt.txt', 'a') as opt_file:
        opt_log = '---------------------Options-----------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log +=f'{str(k)} : {str(v)}\n'
        opt_log +='---------------------------------------------\n'
        opt_file.write(opt_log)
        
    #start training
    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    swa_count = 0
    
    for n_epoch, epoch in enumerate(range(opt.num_epoch)):
        for n_iter, data_point in enumerate(data_loader):
            
            image, labels = data_point 
            image = image.to(device)
            try:
                target, length = converter.encode(labels, batch_max_length = opt.max_length)
                batch_size = image.size(0)
            except Exception as e:
                print(f'{e}')
                continue

            logits, glyphs, embedding_ids = model(image, (target, length), is_train = True)
            
            recognition_loss = model.module.decoder.recognition_loss(logits.view(-1, opt.num_classes+2), target.view(-1))
            generation_loss = model.module.generator.glyph_loss(glyphs, target, length, embedding_ids, opt)
            
            cost = recognition_loss + generation_loss

            loss_avg.add(recognition_loss)
            loss_avg_glyph.add(generation_loss)
            
            model.zero_grad()
            cost.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) #gradient clipping with 5
            optimizer.step()
            
            #validation
            if (n_iter % opt.val_interval == 0) & (n_iter!=0)   :
#                 & (n_iter!=0)
                elapsed_time = time.time() - start_time
                with open(f'./models/{opt.experiment_name}/log_train.txt', 'a') as log:
                    model.eval()
                    with torch.no_grad():
                        valid_loss_recog, valid_loss_glyph, current_accuracy, current_norm_ED, preds, confidence_score, labels,                         infer_time, length_of_data = evaluate.validation_efifstr(model, valid_loader, converter, opt)
                    model.train()

                    present_time = time.localtime()
                    loss_log = f'[epoch : {n_epoch}/{opt.num_epoch}] [iter : {n_iter*opt.batch_size} / {int(len(data) * 0.998)}]\n'+                    f'Train recognition loss : {loss_avg.val():0.5f}, Glyph loss : {loss_avg_glyph.val():0.5f}\nValid recogntion loss : {valid_loss_recog:0.5f}, Glyph loss : {valid_loss_glyph:0.5f}, Elapsed time : {elapsed_time:0.5f}, Present time : {present_time[1]}/{present_time[2]}, {present_time[3]+9} : {present_time[4]}'
                    loss_avg.reset()
                    loss_avg_glyph.reset()

                    current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"current_norm_ED":17s}: {current_norm_ED:0.2f}'

                    #keep the best
                    if current_accuracy > best_accuracy:
                        best_accuracy = current_accuracy
                        torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/best_accuracy_{round(current_accuracy,2)}.pth')

                    if current_norm_ED > best_norm_ED:
                        best_norm_ED = current_norm_ED
                        torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/best_norm_ED.pth')

                    best_model_log = f'{"Best accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'
                    loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                    print(loss_model_log)
                    log.write(loss_model_log+'\n')

                    dashed_line = '-'*80
                    head = f'{"Ground Truth":25s} | {"Prediction" :25s}| Confidence Score & T/F'
                    predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'

                    random_idx  = np.random.choice(range(len(labels)), size= 5, replace=False)
                    for gt, pred, confidence in zip(list(np.asarray(labels)[random_idx]), list(np.asarray(preds)[random_idx]), list(np.asarray(confidence_score)[random_idx])):
                        gt = gt[: gt.find('[s]')]
                        pred = pred[: pred.find('[s]')]

                        predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                    predicted_result_log += f'{dashed_line}'
                    print(predicted_result_log)
                    log.write(predicted_result_log+'\n')

#                 # Stochastic weight averaging
#                 optimizer.update_swa()
#                 swa_count+=1
#                 if swa_count % 3 ==0:
#                     optimizer.swap_swa_sgd()
#                     torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/swa_{swa_count}.pth')

        if (n_epoch) % 5 ==0:
            torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/{n_epoch}.pth')