def main(config):
    svname = args.name
    if svname is None:
        svname = 'classifier_{}'.format(config['train_dataset'])
        svname += '_' + config['model_args']['encoder']
        clsfr = config['model_args']['classifier']
        if clsfr != 'linear-classifier':
            svname += '-' + clsfr
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    # train
    train_dataset = datasets.make(config['train_dataset'],
                                  **config['train_dataset_args'])
    augmentations = [
        transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomRotation(35),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.RandomRotation(35),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomRotation(35),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomRotation(35),
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    ]
    train_dataset.transform = augmentations[int(config['_a'])]
    print(train_dataset.transform)
    print("_a", config['_a'])
    input("Continue with these augmentations?")

    train_loader = DataLoader(train_dataset,
                              config['batch_size'],
                              shuffle=True,
                              num_workers=0,
                              pin_memory=True)
    utils.log('train dataset: {} (x{}), {}'.format(train_dataset[0][0].shape,
                                                   len(train_dataset),
                                                   train_dataset.n_classes))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(train_dataset, 'train_dataset', writer)

    # val
    if config.get('val_dataset'):
        eval_val = True
        val_dataset = datasets.make(config['val_dataset'],
                                    **config['val_dataset_args'])
        val_loader = DataLoader(val_dataset,
                                config['batch_size'],
                                num_workers=0,
                                pin_memory=True)
        utils.log('val dataset: {} (x{}), {}'.format(val_dataset[0][0].shape,
                                                     len(val_dataset),
                                                     val_dataset.n_classes))
        if config.get('visualize_datasets'):
            utils.visualize_dataset(val_dataset, 'val_dataset', writer)
    else:
        eval_val = False

    # few-shot eval
    if config.get('fs_dataset'):
        ef_epoch = config.get('eval_fs_epoch')
        if ef_epoch is None:
            ef_epoch = 5
        eval_fs = True

        fs_dataset = datasets.make(config['fs_dataset'],
                                   **config['fs_dataset_args'])
        utils.log('fs dataset: {} (x{}), {}'.format(fs_dataset[0][0].shape,
                                                    len(fs_dataset),
                                                    fs_dataset.n_classes))
        if config.get('visualize_datasets'):
            utils.visualize_dataset(fs_dataset, 'fs_dataset', writer)

        n_way = 5
        n_query = 15
        n_shots = [1, 5]
        fs_loaders = []
        for n_shot in n_shots:
            fs_sampler = CategoriesSampler(fs_dataset.label,
                                           200,
                                           n_way,
                                           n_shot + n_query,
                                           ep_per_batch=4)
            fs_loader = DataLoader(fs_dataset,
                                   batch_sampler=fs_sampler,
                                   num_workers=0,
                                   pin_memory=True)
            fs_loaders.append(fs_loader)
    else:
        eval_fs = False

    ########

    #### Model and Optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

    if eval_fs:
        fs_model = models.make('meta-baseline', encoder=None)
        fs_model.encoder = model.encoder

    if config.get('_parallel'):
        model = nn.DataParallel(model)
        if eval_fs:
            fs_model = nn.DataParallel(fs_model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(model.parameters(),
                                                   config['optimizer'],
                                                   **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    for epoch in range(1, max_epoch + 1 + 1):
        if epoch == max_epoch + 1:
            if not config.get('epoch_ex'):
                break
            train_dataset.transform = train_dataset.default_transform
            print(train_dataset.transform)
            train_loader = DataLoader(train_dataset,
                                      config['batch_size'],
                                      shuffle=True,
                                      num_workers=0,
                                      pin_memory=True)

        timer_epoch.s()
        aves_keys = ['tl', 'ta', 'vl', 'va']
        if eval_fs:
            for n_shot in n_shots:
                aves_keys += ['fsa-' + str(n_shot)]
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        for data, label in tqdm(train_loader, desc='train', leave=False):
            # for data, label in train_loader:
            data, label = data.cuda(), label.cuda()
            logits = model(data)
            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None
            loss = None

        # eval
        if eval_val:
            model.eval()
            for data, label in tqdm(val_loader, desc='val', leave=False):
                data, label = data.cuda(), label.cuda()
                with torch.no_grad():
                    logits = model(data)
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)

                aves['vl'].add(loss.item())
                aves['va'].add(acc)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            fs_model.eval()
            for i, n_shot in enumerate(n_shots):
                np.random.seed(0)
                for data, _ in tqdm(fs_loaders[i],
                                    desc='fs-' + str(n_shot),
                                    leave=False):
                    x_shot, x_query = fs.split_shot_query(data.cuda(),
                                                          n_way,
                                                          n_shot,
                                                          n_query,
                                                          ep_per_batch=4)
                    label = fs.make_nk_label(n_way, n_query,
                                             ep_per_batch=4).cuda()
                    with torch.no_grad():
                        logits = fs_model(x_shot, x_query).view(-1, n_way)
                        acc = utils.compute_acc(logits, label)
                    aves['fsa-' + str(n_shot)].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)

        if epoch <= max_epoch:
            epoch_str = str(epoch)
        else:
            epoch_str = 'ex'
        log_str = 'epoch {}, train {:.4f}|{:.4f}'.format(
            epoch_str, aves['tl'], aves['ta'])
        writer.add_scalars('loss', {'train': aves['tl']}, epoch)
        writer.add_scalars('acc', {'train': aves['ta']}, epoch)

        if eval_val:
            log_str += ', val {:.4f}|{:.4f}'.format(aves['vl'], aves['va'])
            writer.add_scalars('loss', {'val': aves['vl']}, epoch)
            writer.add_scalars('acc', {'val': aves['va']}, epoch)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            log_str += ', fs'
            for n_shot in n_shots:
                key = 'fsa-' + str(n_shot)
                log_str += ' {}: {:.4f}'.format(n_shot, aves[key])
                writer.add_scalars('acc', {key: aves[key]}, epoch)

        if epoch <= max_epoch:
            log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        else:
            log_str += ', {}'.format(t_epoch)
        utils.log(log_str)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,
            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),
            'training': training,
        }
        if epoch <= max_epoch:
            torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))

            if (save_epoch is not None) and epoch % save_epoch == 0:
                torch.save(
                    save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

            if aves['va'] > max_va:
                max_va = aves['va']
                torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))
        else:
            torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth'))

        writer.flush()
示例#2
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'meta_{}-{}shot'.format(config['train_dataset'],
                                         config['n_shot'])
        svname += '_' + config['model']
        if config['model_args'].get('encoder'):
            svname += '-' + config['model_args']['encoder']
        if config['model_args'].get('prog_synthesis'):
            svname += '-' + config['model_args']['prog_synthesis']
    svname += '-seed' + str(args.seed)
    if args.tag is not None:
        svname += '_' + args.tag

    save_path = os.path.join(args.save_dir, svname)
    utils.ensure_path(save_path, remove=False)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    logger = utils.Logger(file_name=os.path.join(save_path, "log_sdout.txt"),
                          file_mode="a+",
                          should_flush=True)

    #### Dataset ####

    n_way, n_shot = config['n_way'], config['n_shot']
    n_query = config['n_query']

    if config.get('n_train_way') is not None:
        n_train_way = config['n_train_way']
    else:
        n_train_way = n_way
    if config.get('n_train_shot') is not None:
        n_train_shot = config['n_train_shot']
    else:
        n_train_shot = n_shot
    if config.get('ep_per_batch') is not None:
        ep_per_batch = config['ep_per_batch']
    else:
        ep_per_batch = 1

    random_state = np.random.RandomState(args.seed)
    print('seed:', args.seed)

    # train
    train_dataset = datasets.make(config['train_dataset'],
                                  **config['train_dataset_args'])
    utils.log('train dataset: {} (x{})'.format(train_dataset[0][0].shape,
                                               len(train_dataset)))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(train_dataset, 'train_dataset', writer)
    train_sampler = BongardSampler(train_dataset.n_tasks,
                                   config['train_batches'], ep_per_batch,
                                   random_state.randint(2**31))
    train_loader = DataLoader(train_dataset,
                              batch_sampler=train_sampler,
                              num_workers=8,
                              pin_memory=True)

    # tvals
    tval_loaders = {}
    tval_name_ntasks_dict = {
        'tval': 2000,
        'tval_ff': 600,
        'tval_bd': 480,
        'tval_hd_comb': 400,
        'tval_hd_novel': 320
    }  # numbers depend on dataset
    for tval_type in tval_name_ntasks_dict.keys():
        if config.get('{}_dataset'.format(tval_type)):
            tval_dataset = datasets.make(
                config['{}_dataset'.format(tval_type)],
                **config['{}_dataset_args'.format(tval_type)])
            utils.log('{} dataset: {} (x{})'.format(tval_type,
                                                    tval_dataset[0][0].shape,
                                                    len(tval_dataset)))
            if config.get('visualize_datasets'):
                utils.visualize_dataset(tval_dataset, 'tval_ff_dataset',
                                        writer)
            tval_sampler = BongardSampler(
                tval_dataset.n_tasks,
                n_batch=tval_name_ntasks_dict[tval_type] // ep_per_batch,
                ep_per_batch=ep_per_batch,
                seed=random_state.randint(2**31))
            tval_loader = DataLoader(tval_dataset,
                                     batch_sampler=tval_sampler,
                                     num_workers=8,
                                     pin_memory=True)
            tval_loaders.update({tval_type: tval_loader})
        else:
            tval_loaders.update({tval_type: None})

    # val
    val_dataset = datasets.make(config['val_dataset'],
                                **config['val_dataset_args'])
    utils.log('val dataset: {} (x{})'.format(val_dataset[0][0].shape,
                                             len(val_dataset)))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(val_dataset, 'val_dataset', writer)
    val_sampler = BongardSampler(val_dataset.n_tasks,
                                 n_batch=900 // ep_per_batch,
                                 ep_per_batch=ep_per_batch,
                                 seed=random_state.randint(2**31))
    val_loader = DataLoader(val_dataset,
                            batch_sampler=val_sampler,
                            num_workers=8,
                            pin_memory=True)

    ########

    #### Model and optimizer ####

    if config.get('load'):
        print('loading pretrained model: ', config['load'])
        model = models.load(torch.load(config['load']))
    else:
        model = models.make(config['model'], **config['model_args'])

        if config.get('load_encoder'):
            print('loading pretrained encoder: ', config['load_encoder'])
            encoder = models.load(torch.load(config['load_encoder'])).encoder
            model.encoder.load_state_dict(encoder.state_dict())

        if config.get('load_prog_synthesis'):
            print('loading pretrained program synthesis model: ',
                  config['load_prog_synthesis'])
            prog_synthesis = models.load(
                torch.load(config['load_prog_synthesis']))
            model.prog_synthesis.load_state_dict(prog_synthesis.state_dict())

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(model.parameters(),
                                                   config['optimizer'],
                                                   **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    aves_keys = ['tl', 'ta', 'vl', 'va']
    tval_tuple_lst = []
    for k, v in tval_loaders.items():
        if v is not None:
            loss_key = 'tvl' + k.split('tval')[-1]
            acc_key = ' tva' + k.split('tval')[-1]
            aves_keys.append(loss_key)
            aves_keys.append(acc_key)
            tval_tuple_lst.append((k, v, loss_key, acc_key))

    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        if config.get('freeze_bn'):
            utils.freeze_bn(model)
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        for data, label in tqdm(train_loader, desc='train', leave=False):

            x_shot, x_query = fs.split_shot_query(data.cuda(),
                                                  n_train_way,
                                                  n_train_shot,
                                                  n_query,
                                                  ep_per_batch=ep_per_batch)
            label_query = fs.make_nk_label(n_train_way,
                                           n_query,
                                           ep_per_batch=ep_per_batch).cuda()

            if config['model'] == 'snail':  # only use one selected label_query
                query_dix = random_state.randint(n_train_way * n_query)
                label_query = label_query.view(ep_per_batch, -1)[:, query_dix]
                x_query = x_query[:, query_dix:query_dix + 1]

            if config['model'] == 'maml':  # need grad in maml
                model.zero_grad()

            logits = model(x_shot, x_query).view(-1, n_train_way)
            loss = F.cross_entropy(logits, label_query)
            acc = utils.compute_acc(logits, label_query)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None
            loss = None

        # eval
        model.eval()

        for name, loader, name_l, name_a in [('val', val_loader, 'vl', 'va')
                                             ] + tval_tuple_lst:

            if config.get('{}_dataset'.format(name)) is None:
                continue

            np.random.seed(0)
            for data, _ in tqdm(loader, desc=name, leave=False):
                x_shot, x_query = fs.split_shot_query(
                    data.cuda(),
                    n_way,
                    n_shot,
                    n_query,
                    ep_per_batch=ep_per_batch)
                label_query = fs.make_nk_label(
                    n_way, n_query, ep_per_batch=ep_per_batch).cuda()

                if config[
                        'model'] == 'snail':  # only use one randomly selected label_query
                    query_dix = random_state.randint(n_train_way)
                    label_query = label_query.view(ep_per_batch, -1)[:,
                                                                     query_dix]
                    x_query = x_query[:, query_dix:query_dix + 1]

                if config['model'] == 'maml':  # need grad in maml
                    model.zero_grad()
                    logits = model(x_shot, x_query, eval=True).view(-1, n_way)
                    loss = F.cross_entropy(logits, label_query)
                    acc = utils.compute_acc(logits, label_query)
                else:
                    with torch.no_grad():
                        logits = model(x_shot, x_query,
                                       eval=True).view(-1, n_way)
                        loss = F.cross_entropy(logits, label_query)
                        acc = utils.compute_acc(logits, label_query)

                aves[name_l].add(loss.item())
                aves[name_a].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()
            trlog[k].append(aves[k])

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)
        log_str = 'epoch {}, train {:.4f}|{:.4f}, val {:.4f}|{:.4f}'.format(
            epoch, aves['tl'], aves['ta'], aves['vl'], aves['va'])
        for tval_name, _, loss_key, acc_key in tval_tuple_lst:
            log_str += ', {} {:.4f}|{:.4f}'.format(tval_name, aves[loss_key],
                                                   aves[acc_key])
            writer.add_scalars('loss', {tval_name: aves[loss_key]}, epoch)
            writer.add_scalars('acc', {tval_name: aves[acc_key]}, epoch)
        log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        utils.log(log_str)

        writer.add_scalars('loss', {
            'train': aves['tl'],
            'val': aves['vl'],
        }, epoch)
        writer.add_scalars('acc', {
            'train': aves['ta'],
            'val': aves['va'],
        }, epoch)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,
            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),
            'training': training,
        }
        torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(save_path, 'trlog.pth'))

        if (save_epoch is not None) and epoch % save_epoch == 0:
            torch.save(save_obj,
                       os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))

        writer.flush()

    print('finished training!')
    logger.close()
示例#3
0
def train(config):
    #### set the save and log path ####
    # svname = args.name
    # if svname is None:
    #     svname = config['train_dataset_type'] + '_' + config['model']
    #     # svname += '_' + config['model_args']['encoder']
    #     # if config['model_args']['classifier'] != 'linear-classifier':
    #     #     svname += '-' + config['model_args']['classifier']
    # if args.tag is not None:
    #     svname += '_' + args.tag
    # save_path = os.path.join('./save', svname)
    save_path = config['save_path']
    utils.set_save_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(config['save_path'], 'tensorboard'))
    yaml.dump(config, open(os.path.join(config['save_path'], 'classifier_config.yaml'), 'w'))

    #### make datasets ####
    # train
    train_folder = config['dataset_path'] + config['train_dataset_type'] + "/training/frames"
    test_folder = config['dataset_path'] + config['train_dataset_type'] + "/testing/frames"

    # Loading dataset
    train_dataset_args = config['train_dataset_args']
    test_dataset_args = config['test_dataset_args']
    train_dataset = VadDataset(train_folder, transforms.Compose([
        transforms.ToTensor(),
    ]), resize_height=train_dataset_args['h'], resize_width=train_dataset_args['w'], time_step=train_dataset_args['t_length'] - 1)

    test_dataset = VadDataset(test_folder, transforms.Compose([
        transforms.ToTensor(),
    ]), resize_height=test_dataset_args['h'], resize_width=test_dataset_args['w'], time_step=test_dataset_args['t_length'] - 1)

    train_dataloader = DataLoader(train_dataset, batch_size=train_dataset_args['batch_size'],
                                  shuffle=True, num_workers=train_dataset_args['num_workers'], drop_last=True)
    test_dataloader = DataLoader(test_dataset, batch_size=test_dataset_args['batch_size'],
                                 shuffle=False, num_workers=test_dataset_args['num_workers'], drop_last=False)

    # for test---- prepare labels
    labels = scipy.io.loadmat(config['label_path'])
    if config['test_dataset_type'] == 'shanghai':
        labels = np.expand_dims(labels, 0)
    videos = OrderedDict()
    videos_list = sorted(glob.glob(os.path.join(test_folder, '*')))
    labels_list = []
    label_length = 0
    psnr_list = {}
    for video in sorted(videos_list):
        video_name = video.split('/')[-1]
        videos[video_name] = {}
        videos[video_name]['path'] = video
        videos[video_name]['frame'] = glob.glob(os.path.join(video, '*.jpg'))
        videos[video_name]['frame'].sort()
        videos[video_name]['length'] = len(videos[video_name]['frame'])
        labels_list = np.append(labels_list, labels[video_name][0][4:])
        label_length += videos[video_name]['length']
        psnr_list[video_name] = []

    # Model setting
    if config['generator'] == 'cycle_generator_convlstm':
        ngf = 64
        netG = 'resnet_6blocks'
        norm = 'instance'
        no_dropout = False
        init_type = 'normal'
        init_gain = 0.02
        gpu_ids = []
        model = define_G(train_dataset_args['c'], train_dataset_args['c'],
                             ngf, netG, norm, not no_dropout, init_type, init_gain, gpu_ids)
    elif config['generator'] == 'unet':
        # TODO
        model = UNet(n_channels=train_dataset_args['c']*(train_dataset_args['t_length']-1),
                         layer_nums=num_unet_layers, output_channel=train_dataset_args['c'])
    else:
        raise Exception('The generator is not implemented')

    # optimizer setting
    params_encoder = list(model.parameters())
    params_decoder = list(model.parameters())
    params = params_encoder + params_decoder
    optimizer, lr_scheduler = utils.make_optimizer(
        params, config['optimizer'], config['optimizer_args'])

    # set loss, different range with the source version, should change
    lam_int = 1.0 * 2
    lam_gd = 1.0 * 2
    # TODO here we use no flow loss
    # lam_op = 0  # 2.0
    # op_loss = Flow_Loss()
    
    adversarial_loss = Adversarial_Loss()
    # # TODO if use adv
    # lam_adv = 0.05
    # discriminate_loss = Discriminate_Loss()
    alpha = 1
    l_num = 2
    gd_loss = Gradient_Loss(alpha, train_dataset_args['c'])    
    int_loss = Intensity_Loss(l_num)

    # parallel if muti-gpus
    if torch.cuda.is_available():
        model.cuda()
    if config.get('_parallel'):
        model = nn.DataParallel(model)

    # Training
    utils.log('Start train')
    max_frame_AUC, max_roi_AUC = 0,0
    base_channel_num  = train_dataset_args['c'] * (train_dataset_args['t_length'] - 1)
    save_epoch = 5 if config['save_epoch'] is None else config['save_epoch']
    for epoch in range(config['epochs']):

        model.train()
        for j, imgs in enumerate(tqdm(train_dataloader, desc='train', leave=False)):
            imgs = imgs.cuda()
            # input = imgs[:, :-1, ].view(imgs.shape[0], -1, imgs.shape[-2], imgs.shape[-1])
            input = imgs[:, :-1, ]
            target = imgs[:, -1, ]
            outputs = model(input)
            optimizer.zero_grad()

            g_int_loss = int_loss(outputs, target)
            g_gd_loss = gd_loss(outputs, target)
            loss = lam_gd * g_gd_loss + lam_int * g_int_loss

            loss.backward()
            optimizer.step()
        lr_scheduler.step()

        utils.log('----------------------------------------')
        utils.log('Epoch:' + str(epoch + 1))
        utils.log('----------------------------------------')
        utils.log('Loss: Reconstruction {:.6f}'.format(loss.item()))

        # Testing
        utils.log('Evaluation of ' + config['test_dataset_type'])   

        # Save the model
        if epoch % save_epoch == 0 or epoch == config['epochs'] - 1:
            if not os.path.exists(save_path):
                os.makedirs(save_path) 
            if not os.path.exists(os.path.join(save_path, "models")):
                os.makedirs(os.path.join(save_path, "models")) 
            # TODO
            frame_AUC, roi_AUC = evaluate(test_dataloader, model, labels_list, videos, int_loss, config['test_dataset_type'], test_bboxes=config['test_bboxes'],
                frame_height = train_dataset_args['h'], frame_width=train_dataset_args['w'], 
                is_visual=False, mask_labels_path = config['mask_labels_path'], save_path = os.path.join(save_path, "./final"), labels_dict=labels) 
            
            torch.save(model.state_dict(), os.path.join(save_path, 'models/model-epoch-{}.pth'.format(epoch)))
        else:
            frame_AUC, roi_AUC = evaluate(test_dataloader, model, labels_list, videos, int_loss, config['test_dataset_type'], test_bboxes=config['test_bboxes'],
                frame_height = train_dataset_args['h'], frame_width=train_dataset_args['w']) 

        utils.log('The result of ' + config['test_dataset_type'])
        utils.log("AUC: {}%, roi AUC: {}%".format(frame_AUC*100, roi_AUC*100))

        if frame_AUC > max_frame_AUC:
            max_frame_AUC = frame_AUC
            # TODO
            torch.save(model.state_dict(), os.path.join(save_path, 'models/max-frame_auc-model.pth'))
            evaluate(test_dataloader, model, labels_list, videos, int_loss, config['test_dataset_type'], test_bboxes=config['test_bboxes'],
                frame_height = train_dataset_args['h'], frame_width=train_dataset_args['w'], 
                is_visual=True, mask_labels_path = config['mask_labels_path'], save_path = os.path.join(save_path, "./frame_best"), labels_dict=labels) 
        if roi_AUC > max_roi_AUC:
            max_roi_AUC = roi_AUC
            torch.save(model.state_dict(), os.path.join(save_path, 'models/max-roi_auc-model.pth'))
            evaluate(test_dataloader, model, labels_list, videos, int_loss, config['test_dataset_type'], test_bboxes=config['test_bboxes'],
                frame_height = train_dataset_args['h'], frame_width=train_dataset_args['w'], 
                is_visual=True, mask_labels_path = config['mask_labels_path'], save_path = os.path.join(save_path, "./roi_best"), labels_dict=labels) 

        utils.log('----------------------------------------')

    utils.log('Training is finished')
    utils.log('max_frame_AUC: {}, max_roi_AUC: {}'.format(max_frame_AUC, max_roi_AUC))
def train(config):
    #### set the save and log path ####
    save_path = config['save_path']
    utils.set_save_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(config['save_path'], 'tensorboard'))
    yaml.dump(config, open(os.path.join(config['save_path'], 'classifier_config.yaml'), 'w'))

    device = torch.device('cuda:' + args.gpu)

    #### make datasets ####
    # train
    train_folder = config['dataset_path'] + config['train_dataset_type'] + "/training/frames"
    test_folder = config['dataset_path'] + config['train_dataset_type'] + "/testing/frames"

    # Loading dataset
    train_dataset_args = config['train_dataset_args']
    test_dataset_args = config['test_dataset_args']

    train_dataset = VadDataset(args,video_folder= train_folder, bbox_folder = config['train_bboxes_path'], flow_folder=config['train_flow_path'],
                            transform=transforms.Compose([transforms.ToTensor()]),
                            resize_height=train_dataset_args['h'], resize_width=train_dataset_args['w'],
                               dataset=config['train_dataset_type'], time_step=train_dataset_args['t_length'] - 1,
                               device=device)

    test_dataset = VadDataset(args,video_folder= test_folder, bbox_folder = config['test_bboxes_path'], flow_folder=config['test_flow_path'],
                            transform=transforms.Compose([transforms.ToTensor()]),
                            resize_height=train_dataset_args['h'], resize_width=train_dataset_args['w'],
                               dataset=config['train_dataset_type'], time_step=train_dataset_args['t_length'] - 1,
                               device=device)


    train_dataloader = DataLoader(train_dataset, batch_size=train_dataset_args['batch_size'],
                                  shuffle=True, num_workers=train_dataset_args['num_workers'], drop_last=True)
    test_dataloader = DataLoader(test_dataset, batch_size=test_dataset_args['batch_size'],
                                 shuffle=False, num_workers=test_dataset_args['num_workers'], drop_last=False)

    # for test---- prepare labels
    labels = np.load('./data/frame_labels_' + config['test_dataset_type'] + '.npy')
    if config['test_dataset_type'] == 'shanghai':
        labels = np.expand_dims(labels, 0)
    videos = OrderedDict()
    videos_list = sorted(glob.glob(os.path.join(test_folder, '*')))
    labels_list = []
    label_length = 0
    psnr_list = {}
    for video in sorted(videos_list):
        video_name = video.split('/')[-1]
        videos[video_name] = {}
        videos[video_name]['path'] = video
        videos[video_name]['frame'] = glob.glob(os.path.join(video, '*.jpg'))
        videos[video_name]['frame'].sort()
        videos[video_name]['length'] = len(videos[video_name]['frame'])
        labels_list = np.append(labels_list, labels[0][4 + label_length:videos[video_name]['length'] + label_length])
        label_length += videos[video_name]['length']
        psnr_list[video_name] = []

    # Model setting
    num_unet_layers = 4
    discriminator_num_filters = [128, 256, 512, 512]

    # for gradient loss
    alpha = 1
    # for int loss
    l_num = 2
    pretrain = False

    if config['generator'] == 'cycle_generator_convlstm':
        ngf = 64
        netG = 'resnet_6blocks'
        norm = 'instance'
        no_dropout = False
        init_type = 'normal'
        init_gain = 0.02
        gpu_ids = []
        model = define_G(train_dataset_args['c'], train_dataset_args['c'],
                             ngf, netG, norm, not no_dropout, init_type, init_gain, gpu_ids)
    elif config['generator'] == 'unet':
        # generator = UNet(n_channels=train_dataset_args['c']*(train_dataset_args['t_length']-1),
        #                  layer_nums=num_unet_layers, output_channel=train_dataset_args['c'])
        model = PreAE(train_dataset_args['c'], train_dataset_args['t_length'], **config['model_args'])
    else:
        raise Exception('The generator is not implemented')

    # generator = torch.load('save/avenue_cycle_generator_convlstm_flownet2_0103/generator-epoch-199.pth')
    if config['use_D']:
        discriminator=PixelDiscriminator(train_dataset_args['c'],discriminator_num_filters,use_norm=False)
        optimizer_D = torch.optim.Adam(discriminator.parameters(),lr=0.00002)

    # optimizer setting
    params_encoder = list(model.parameters())
    params_decoder = list(model.parameters())
    params = params_encoder + params_decoder
    optimizer_G, lr_scheduler = utils.make_optimizer(
        params, config['optimizer'], config['optimizer_args'])    


    # set loss, different range with the source version, should change
    lam_int = 1.0 * 2
    lam_gd = 1.0 * 2
    # TODO here we use no flow loss
    # lam_op = 0  # 2.0
    # op_loss = Flow_Loss()
    
    adversarial_loss = Adversarial_Loss()
    # TODO if use adv
    lam_adv = 0.05
    discriminate_loss = Discriminate_Loss()
    alpha = 1
    l_num = 2
    gd_loss = Gradient_Loss(alpha, train_dataset_args['c'])    
    int_loss = Intensity_Loss(l_num)
    object_loss = ObjectLoss(device, l_num)

    # parallel if muti-gpus
    if torch.cuda.is_available():
        model.cuda()
        if config['use_D']:
            discriminator.cuda()
    if config.get('_parallel'):
        model = nn.DataParallel(model)
        if config['use_D']:
            discriminator = nn.DataParallel(discriminator)
    # Training
    utils.log('Start train')
    max_frame_AUC, max_roi_AUC = 0,0
    base_channel_num  = train_dataset_args['c'] * (train_dataset_args['t_length'] - 1)
    save_epoch = 5 if config['save_epoch'] is None else config['save_epoch']
    for epoch in range(config['epochs']):

        model.train()
        for j, (imgs, bbox, flow) in enumerate(tqdm(train_dataloader, desc='train', leave=False)):
            imgs = imgs.cuda()
            flow = flow.cuda()
            # input = imgs[:, :-1, ].view(imgs.shape[0], -1, imgs.shape[-2], imgs.shape[-1])
            input = imgs[:, :-1, ]
            target = imgs[:, -1, ]
            outputs = model(input)

            if config['use_D']:
                g_adv_loss = adversarial_loss(discriminator(outputs))
            else:
                g_adv_loss = 0 

            g_object_loss = object_loss(outputs, target, flow, bbox)
            # g_int_loss = int_loss(outputs, target)
            g_gd_loss = gd_loss(outputs, target)
            g_loss = lam_adv * g_adv_loss + lam_gd * g_gd_loss + lam_int * g_object_loss

            optimizer_G.zero_grad()
            g_loss.backward()

            optimizer_G.step()

            train_psnr = utils.psnr_error(outputs,target)

            # ----------- update optim_D -------
            if config['use_D']:
                optimizer_D.zero_grad()
                d_loss = discriminate_loss(discriminator(target), discriminator(outputs.detach()))
                d_loss.backward()
                optimizer_D.step()
        lr_scheduler.step()

        utils.log('----------------------------------------')
        utils.log('Epoch:' + str(epoch + 1))
        utils.log('----------------------------------------')
        utils.log('Loss: Reconstruction {:.6f}'.format(g_loss.item()))

        # Testing
        utils.log('Evaluation of ' + config['test_dataset_type'])   


        # Save the model
        if epoch % save_epoch == 0 or epoch == config['epochs'] - 1:
            if not os.path.exists(save_path):
                os.makedirs(save_path) 
            if not os.path.exists(os.path.join(save_path, "models")):
                os.makedirs(os.path.join(save_path, "models")) 
            # TODO 
            frame_AUC = ObjectLoss_evaluate(test_dataloader, model, labels_list, videos, dataset=config['test_dataset_type'],device = device,
                frame_height = train_dataset_args['h'], frame_width=train_dataset_args['w'],
                is_visual=False, mask_labels_path = config['mask_labels_path'], save_path = os.path.join(save_path, "./final"), labels_dict=labels) 
            
            torch.save(model.state_dict(), os.path.join(save_path, 'models/model-epoch-{}.pth'.format(epoch)))
            if config['use_D']:
                torch.save(discriminator.state_dict(), os.path.join(save_path, 'models/discrominator-epoch-{}.pth'.format(epoch)))
        else:
            frame_AUC = ObjectLoss_evaluate(test_dataloader, model, labels_list, videos, dataset=config['test_dataset_type'],device=device,
                frame_height = train_dataset_args['h'], frame_width=train_dataset_args['w']) 

        utils.log('The result of ' + config['test_dataset_type'])
        utils.log("AUC: {}%".format(frame_AUC*100))

        if frame_AUC > max_frame_AUC:
            max_frame_AUC = frame_AUC
            # TODO
            torch.save(model.state_dict(), os.path.join(save_path, 'models/max-frame_auc-model.pth'))
            if config['use_D']:
                torch.save(discriminator.state_dict(), os.path.join(save_path, 'models/discrominator-epoch-{}.pth'.format(epoch)))
            # evaluate(test_dataloader, model, labels_list, videos, int_loss, config['test_dataset_type'], test_bboxes=config['test_bboxes'],
            #     frame_height = train_dataset_args['h'], frame_width=train_dataset_args['w'], 
            #     is_visual=True, mask_labels_path = config['mask_labels_path'], save_path = os.path.join(save_path, "./frame_best"), labels_dict=labels) 
        
        utils.log('----------------------------------------')

    utils.log('Training is finished')
    utils.log('max_frame_AUC: {}'.format(max_frame_AUC))
示例#5
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'meta_{}-{}shot'.format(
                config['train_dataset'], config['n_shot'])
        svname += '_' + config['model'] + '-' + config['model_args']['encoder']
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    n_way, n_shot = config['n_way'], config['n_shot']
    n_query = config['n_query']

    if config.get('n_train_way') is not None:
        n_train_way = config['n_train_way']
    else:
        n_train_way = n_way
    if config.get('n_train_shot') is not None:
        n_train_shot = config['n_train_shot']
    else:
        n_train_shot = n_shot
    if config.get('ep_per_batch') is not None:
        ep_per_batch = config['ep_per_batch']
    else:
        ep_per_batch = 1

    # train
    train_dataset = datasets.make(config['train_dataset'],
                                  **config['train_dataset_args'])
    utils.log('train dataset: {} (x{}), {}'.format(
            train_dataset[0][0].shape, len(train_dataset),
            train_dataset.n_classes))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(train_dataset, 'train_dataset', writer)
    train_sampler = CategoriesSampler(
            train_dataset.label, config['train_batches'],
            n_train_way, n_train_shot + n_query,
            ep_per_batch=ep_per_batch)
    train_loader = DataLoader(train_dataset, batch_sampler=train_sampler,
                              num_workers=8, pin_memory=True)

    # tval
    if config.get('tval_dataset'):
        tval_dataset = datasets.make(config['tval_dataset'],
                                     **config['tval_dataset_args'])
        utils.log('tval dataset: {} (x{}), {}'.format(
                tval_dataset[0][0].shape, len(tval_dataset),
                tval_dataset.n_classes))
        if config.get('visualize_datasets'):
            utils.visualize_dataset(tval_dataset, 'tval_dataset', writer)
        tval_sampler = CategoriesSampler(
                tval_dataset.label, 200,
                n_way, n_shot + n_query,
                ep_per_batch=4)
        tval_loader = DataLoader(tval_dataset, batch_sampler=tval_sampler,
                                 num_workers=8, pin_memory=True)
    else:
        tval_loader = None

    # val
    val_dataset = datasets.make(config['val_dataset'],
                                **config['val_dataset_args'])
    utils.log('val dataset: {} (x{}), {}'.format(
            val_dataset[0][0].shape, len(val_dataset),
            val_dataset.n_classes))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(val_dataset, 'val_dataset', writer)
    val_sampler = CategoriesSampler(
            val_dataset.label, 200,
            n_way, n_shot + n_query,
            ep_per_batch=4)
    val_loader = DataLoader(val_dataset, batch_sampler=val_sampler,
                            num_workers=8, pin_memory=True)

    ########

    #### Model and optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

        if config.get('load_encoder'):
            encoder = models.load(torch.load(config['load_encoder'])).encoder
            model.encoder.load_state_dict(encoder.state_dict())

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(
            model.parameters(),
            config['optimizer'], **config['optimizer_args'])

    ########
    
    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    aves_keys = ['tl', 'ta', 'tvl', 'tva', 'vl', 'va']
    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        if config.get('freeze_bn'):
            utils.freeze_bn(model) 
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        np.random.seed(epoch)
        for data, _ in tqdm(train_loader, desc='train', leave=False):
            x_shot, x_query = fs.split_shot_query(
                    data.cuda(), n_train_way, n_train_shot, n_query,
                    ep_per_batch=ep_per_batch)
            label = fs.make_nk_label(n_train_way, n_query,
                    ep_per_batch=ep_per_batch).cuda()

            logits = model(x_shot, x_query).view(-1, n_train_way)
            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None; loss = None 

        # eval
        model.eval()

        for name, loader, name_l, name_a in [
                ('tval', tval_loader, 'tvl', 'tva'),
                ('val', val_loader, 'vl', 'va')]:

            if (config.get('tval_dataset') is None) and name == 'tval':
                continue

            np.random.seed(0)
            for data, _ in tqdm(loader, desc=name, leave=False):
                x_shot, x_query = fs.split_shot_query(
                        data.cuda(), n_way, n_shot, n_query,
                        ep_per_batch=4)
                label = fs.make_nk_label(n_way, n_query,
                        ep_per_batch=4).cuda()

                with torch.no_grad():
                    logits = model(x_shot, x_query).view(-1, n_way)
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)
                
                aves[name_l].add(loss.item())
                aves[name_a].add(acc)

        _sig = int(_[-1])

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()
            trlog[k].append(aves[k])

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)
        utils.log('epoch {}, train {:.4f}|{:.4f}, tval {:.4f}|{:.4f}, '
                'val {:.4f}|{:.4f}, {} {}/{} (@{})'.format(
                epoch, aves['tl'], aves['ta'], aves['tvl'], aves['tva'],
                aves['vl'], aves['va'], t_epoch, t_used, t_estimate, _sig))

        writer.add_scalars('loss', {
            'train': aves['tl'],
            'tval': aves['tvl'],
            'val': aves['vl'],
        }, epoch)
        writer.add_scalars('acc', {
            'train': aves['ta'],
            'tval': aves['tva'],
            'val': aves['va'],
        }, epoch)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,

            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),

            'training': training,
        }
        torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(save_path, 'trlog.pth'))

        if (save_epoch is not None) and epoch % save_epoch == 0:
            torch.save(save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))

        writer.flush()
示例#6
0
def main(config):
    svname = config.get('sv_name')
    if args.tag is not None:
        svname += '_' + args.tag
    config['sv_name'] = svname
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    utils.log(svname)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))
    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    n_way, n_shot = config['n_way'], config['n_shot']
    n_query = config['n_query']
    n_pseudo = config['n_pseudo']
    ep_per_batch = config['ep_per_batch']

    if config.get('test_batches') is not None:
        test_batches = config['test_batches']
    else:
        test_batches = config['train_batches']

    for s in ['train', 'val', 'tval']:
        if config.get(f"{s}_dataset_args") is not None:
            config[f"{s}_dataset_args"]['data_dir'] = os.path.join(os.getcwd(), os.pardir, 'data_root')

    # train
    train_dataset = CustomDataset(config['train_dataset'], save_dir=config.get('load_encoder'),
                                  **config['train_dataset_args'])

    if config['train_dataset_args']['split'] == 'helper':
        with open(os.path.join(save_path, 'train_helper_cls.pkl'), 'wb') as f:
            pkl.dump(train_dataset.dataset_classes, f)

    train_sampler = EpisodicSampler(train_dataset, config['train_batches'], n_way, n_shot, n_query,
                                    n_pseudo, episodes_per_batch=ep_per_batch)
    train_loader = DataLoader(train_dataset, batch_sampler=train_sampler,
                                  num_workers=4, pin_memory=True)

    # tval
    if config.get('tval_dataset'):
        tval_dataset = CustomDataset(config['tval_dataset'],
                                     **config['tval_dataset_args'])

        tval_sampler = EpisodicSampler(tval_dataset, test_batches, n_way, n_shot, n_query,
                                       n_pseudo, episodes_per_batch=ep_per_batch)
        tval_loader = DataLoader(tval_dataset, batch_sampler=tval_sampler,
                                 num_workers=4, pin_memory=True)
    else:
        tval_loader = None

    # val
    val_dataset = CustomDataset(config['val_dataset'],
                                **config['val_dataset_args'])
    val_sampler = EpisodicSampler(val_dataset, test_batches, n_way, n_shot, n_query,
                                  n_pseudo, episodes_per_batch=ep_per_batch)
    val_loader = DataLoader(val_dataset, batch_sampler=val_sampler,
                            num_workers=4, pin_memory=True)


    #### Model and optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])
        if config.get('load_encoder'):
            encoder = models.load(torch.load(config['load_encoder'])).encoder
            model.encoder.load_state_dict(encoder.state_dict())
            if config.get('freeze_encoder'):
                for param in model.encoder.parameters():
                    param.requires_grad = False

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(
        model.parameters(),
        config['optimizer'], **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    aves_keys = ['tl', 'ta', 'tvl', 'tva', 'vl', 'va']
    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        if config.get('freeze_bn'):
            utils.freeze_bn(model)
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
        np.random.seed(epoch)

        for data in tqdm(train_loader, desc='train', leave=False):
            x_shot, x_query, x_pseudo = fs.split_shot_query(
                data.cuda(), n_way, n_shot, n_query, n_pseudo,
                ep_per_batch=ep_per_batch)
            label = fs.make_nk_label(n_way, n_query,
                                     ep_per_batch=ep_per_batch).cuda()

            logits = model(x_shot, x_query, x_pseudo)
            logits = logits.view(-1, n_way)
            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None; loss = None

            # eval
        model.eval()
        for name, loader, name_l, name_a in [
            ('tval', tval_loader, 'tvl', 'tva'),
            ('val', val_loader, 'vl', 'va')]:

            if (config.get('tval_dataset') is None) and name == 'tval':
                continue

            np.random.seed(0)
            for data in tqdm(loader, desc=name, leave=False):
                x_shot, x_query, x_pseudo = fs.split_shot_query(
                    data.cuda(), n_way, n_shot, n_query, n_pseudo,
                    ep_per_batch=ep_per_batch)
                label = fs.make_nk_label(n_way, n_query,
                                         ep_per_batch=ep_per_batch).cuda()

                with torch.no_grad():
                    logits = model(x_shot, x_query, x_pseudo)
                    logits = logits.view(-1, n_way)
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)

                aves[name_l].add(loss.item())
                aves[name_a].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()
            trlog[k].append(aves[k])

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)
        utils.log('epoch {}, train {:.4f}|{:.4f}, tval {:.4f}|{:.4f}, '
                  'val {:.4f}|{:.4f}, {} {}/{}'.format(
            epoch, aves['tl'], aves['ta'], aves['tvl'], aves['tva'],
            aves['vl'], aves['va'], t_epoch, t_used, t_estimate))

        writer.add_scalars('loss', {
            'train': aves['tl'],
            'tval': aves['tvl'],
            'val': aves['vl'],
        }, epoch)
        writer.add_scalars('acc', {
            'train': aves['ta'],
            'tval': aves['tva'],
            'val': aves['va'],
        }, epoch)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,

            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),

            'training': training,
        }
        torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(save_path, 'trlog.pth'))

        if (save_epoch is not None) and epoch % save_epoch == 0:
            torch.save(save_obj,
                       os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))

        writer.flush()
示例#7
0
def train(config):
    #### set the save and log path ####
    svname = args.name
    if svname is None:
        svname = config['train_dataset_type'] + '_' + config[
            'generator'] + '_' + config['flow_model']

    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.set_save_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))
    yaml.dump(config,
              open(os.path.join(save_path, 'classifier_config.yaml'), 'w'))

    #### make datasets ####
    # train
    train_folder = config['dataset_path'] + config[
        'train_dataset_type'] + "/training/frames"
    test_folder = config['dataset_path'] + config[
        'train_dataset_type'] + "/testing/frames"

    # Loading dataset
    train_dataset_args = config['train_dataset_args']
    test_dataset_args = config['test_dataset_args']
    train_dataset = VadDataset(train_folder,
                               transforms.Compose([
                                   transforms.ToTensor(),
                               ]),
                               resize_height=train_dataset_args['h'],
                               resize_width=train_dataset_args['w'],
                               time_step=train_dataset_args['t_length'] - 1)

    test_dataset = VadDataset(test_folder,
                              transforms.Compose([
                                  transforms.ToTensor(),
                              ]),
                              resize_height=test_dataset_args['h'],
                              resize_width=test_dataset_args['w'],
                              time_step=test_dataset_args['t_length'] - 1)

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=train_dataset_args['batch_size'],
        shuffle=True,
        num_workers=train_dataset_args['num_workers'],
        drop_last=True)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=test_dataset_args['batch_size'],
                                 shuffle=False,
                                 num_workers=test_dataset_args['num_workers'],
                                 drop_last=False)

    # for test---- prepare labels
    labels = np.load('./data/frame_labels_' + config['test_dataset_type'] +
                     '.npy')
    if config['test_dataset_type'] == 'shanghai':
        labels = np.expand_dims(labels, 0)
    videos = OrderedDict()
    videos_list = sorted(glob.glob(os.path.join(test_folder, '*')))
    labels_list = []
    label_length = 0
    psnr_list = {}
    for video in sorted(videos_list):
        # video_name = video.split('/')[-1]

        # windows
        video_name = os.path.split(video)[-1]
        videos[video_name] = {}
        videos[video_name]['path'] = video
        videos[video_name]['frame'] = glob.glob(os.path.join(video, '*.jpg'))
        videos[video_name]['frame'].sort()
        videos[video_name]['length'] = len(videos[video_name]['frame'])
        labels_list = np.append(
            labels_list,
            labels[0][4 + label_length:videos[video_name]['length'] +
                      label_length])
        label_length += videos[video_name]['length']
        psnr_list[video_name] = []

    # Model setting
    num_unet_layers = 4
    discriminator_num_filters = [128, 256, 512, 512]

    # for gradient loss
    alpha = 1
    # for int loss
    l_num = 2
    pretrain = False

    if config['generator'] == 'cycle_generator_convlstm':
        ngf = 64
        netG = 'resnet_6blocks'
        norm = 'instance'
        no_dropout = False
        init_type = 'normal'
        init_gain = 0.02
        gpu_ids = []
        generator = define_G(train_dataset_args['c'], train_dataset_args['c'],
                             ngf, netG, norm, not no_dropout, init_type,
                             init_gain, gpu_ids)
    elif config['generator'] == 'unet':
        # generator = UNet(n_channels=train_dataset_args['c']*(train_dataset_args['t_length']-1),
        #                  layer_nums=num_unet_layers, output_channel=train_dataset_args['c'])
        model = PreAE(train_dataset_args['c'], train_dataset_args['t_length'],
                      **config['model_args'])
    else:
        raise Exception('The generator is not implemented')

    # generator = torch.load('save/avenue_cycle_generator_convlstm_flownet2_0103/generator-epoch-199.pth')

    discriminator = PixelDiscriminator(train_dataset_args['c'],
                                       discriminator_num_filters,
                                       use_norm=False)
    # discriminator = torch.load('save/avenue_cycle_generator_convlstm_flownet2_0103/discriminator-epoch-199.pth')

    # if not pretrain:
    #     generator.apply(weights_init_normal)
    #     discriminator.apply(weights_init_normal)

    # if use flownet
    # if config['flow_model'] == 'flownet2':
    #     flownet2SD_model_path = 'flownet2/FlowNet2_checkpoint.pth.tar'
    #     flow_network = FlowNet2(args).eval()
    #     flow_network.load_state_dict(torch.load(flownet2SD_model_path)['state_dict'])
    # elif config['flow_model'] == 'liteflownet':
    #     lite_flow_model_path = 'liteFlownet/network-sintel.pytorch'
    #     flow_network = Network().eval()
    #     flow_network.load_state_dict(torch.load(lite_flow_model_path))

    # different range with the source version, should change
    lam_int = 1.0 * 2
    lam_gd = 1.0 * 2
    # here we use no flow loss
    lam_op = 0  # 2.0
    lam_adv = 0.05
    adversarial_loss = Adversarial_Loss()
    discriminate_loss = Discriminate_Loss()
    gd_loss = Gradient_Loss(alpha, train_dataset_args['c'])
    op_loss = Flow_Loss()
    int_loss = Intensity_Loss(l_num)
    step = 0

    utils.log('initializing the model with Generator-Unet {} layers,'
              'PixelDiscriminator with filters {} '.format(
                  num_unet_layers, discriminator_num_filters))

    g_lr = 0.0002
    d_lr = 0.00002
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=g_lr)
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=d_lr)

    # # optimizer setting
    # params_encoder = list(generator.encoder.parameters())
    # params_decoder = list(generator.decoder.parameters())
    # params = params_encoder + params_decoder
    # optimizer, lr_scheduler = utils.make_optimizer(
    #     params, config['optimizer'], config['optimizer_args'])
    #
    # loss_func_mse = nn.MSELoss(reduction='none')

    # parallel if muti-gpus
    if torch.cuda.is_available():
        generator.cuda()
        discriminator.cuda()
        # # if use flownet
        # flow_network.cuda()
        adversarial_loss.cuda()
        discriminate_loss.cuda()
        gd_loss.cuda()
        op_loss.cuda()
        int_loss.cuda()

    if config.get('_parallel'):
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        # if use flownet
        # flow_network = nn.DataParallel(flow_network)
        adversarial_loss = nn.DataParallel(adversarial_loss)
        discriminate_loss = nn.DataParallel(discriminate_loss)
        gd_loss = nn.DataParallel(gd_loss)
        op_loss = nn.DataParallel(op_loss)
        int_loss = nn.DataParallel(int_loss)

    # Training
    utils.log('Start train')
    max_accuracy = 0
    base_channel_num = train_dataset_args['c'] * (
        train_dataset_args['t_length'] - 1)
    save_epoch = 5 if config['save_epoch'] is None else config['save_epoch']
    for epoch in range(config['epochs']):

        generator.train()
        for j, imgs in enumerate(
                tqdm(train_dataloader, desc='train', leave=False)):
            imgs = imgs.cuda()
            input = imgs[:, :-1, ]
            input_last = input[:, -1, ]
            target = imgs[:, -1, ]
            # input = input.view(input.shape[0], -1, input.shape[-2],input.shape[-1])

            # only for debug
            # input0=imgs[:, 0,]
            # input1 = imgs[:, 1, ]
            # gt_flow_esti_tensor = torch.cat([input0, input1], 1)
            # flow_gt = batch_estimate(gt_flow_esti_tensor, flow_network)[0]
            # objectOutput = open('./out_train.flo', 'wb')
            # np.array([80, 73, 69, 72], np.uint8).tofile(objectOutput)
            # np.array([flow_gt.size(2), flow_gt.size(1)], np.int32).tofile(objectOutput)
            # np.array(flow_gt.detach().cpu().numpy().transpose(1, 2, 0), np.float32).tofile(objectOutput)
            # objectOutput.close()
            # break

            # ------- update optim_G --------------
            outputs = generator(input)
            # pred_flow_tensor = torch.cat([input_last, outputs], 1)
            # gt_flow_tensor = torch.cat([input_last, target], 1)
            # flow_pred = batch_estimate(pred_flow_tensor, flow_network)
            # flow_gt = batch_estimate(gt_flow_tensor, flow_network)

            # if you want to use flownet2SD, comment out the part in front

            # #### if use flownet ####
            # pred_flow_esti_tensor = torch.cat([input_last.view(-1,3,1,input.shape[-2],input.shape[-1]),
            #                                    outputs.view(-1,3,1,input.shape[-2],input.shape[-1])], 2)
            # gt_flow_esti_tensor = torch.cat([input_last.view(-1,3,1,input.shape[-2],input.shape[-1]),
            #                                  target.view(-1,3,1,input.shape[-2],input.shape[-1])], 2)

            # flow_gt=flow_network(gt_flow_esti_tensor*255.0)
            # flow_pred=flow_network(pred_flow_esti_tensor*255.0)
            ##############################
            # g_op_loss = op_loss(flow_pred, flow_gt) ## flow loss
            g_op_loss = 0
            g_adv_loss = adversarial_loss(discriminator(outputs))

            g_int_loss = int_loss(outputs, target)
            g_gd_loss = gd_loss(outputs, target)
            g_loss = lam_adv * g_adv_loss + lam_gd * g_gd_loss + lam_op * g_op_loss + lam_int * g_int_loss

            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

            train_psnr = utils.psnr_error(outputs, target)

            # ----------- update optim_D -------
            optimizer_D.zero_grad()
            d_loss = discriminate_loss(discriminator(target),
                                       discriminator(outputs.detach()))
            d_loss.backward()
            optimizer_D.step()
            # break
        # lr_scheduler.step()

        utils.log('----------------------------------------')
        utils.log('Epoch:' + str(epoch + 1))
        utils.log('----------------------------------------')
        utils.log("g_loss: {} d_loss {}".format(g_loss, d_loss))
        utils.log('\t gd_loss {}, op_loss {}, int_loss {} ,'.format(
            g_gd_loss, g_op_loss, g_int_loss))
        utils.log('\t train psnr{}'.format(train_psnr))

        # Testing
        utils.log('Evaluation of ' + config['test_dataset_type'])
        for video in sorted(videos_list):
            # video_name = video.split('/')[-1]
            video_name = os.path.split(video)[-1]
            psnr_list[video_name] = []

        generator.eval()
        video_num = 0
        # label_length += videos[videos_list[video_num].split('/')[-1]]['length']
        label_length = videos[os.path.split(
            videos_list[video_num])[1]]['length']
        for k, imgs in enumerate(
                tqdm(test_dataloader, desc='test', leave=False)):
            if k == label_length - 4 * (video_num + 1):
                video_num += 1
                label_length += videos[os.path.split(
                    videos_list[video_num])[1]]['length']
            imgs = imgs.cuda()
            input = imgs[:, :-1, ]
            target = imgs[:, -1, ]
            # input = input.view(input.shape[0], -1, input.shape[-2], input.shape[-1])

            outputs = generator(input)
            mse_imgs = int_loss((outputs + 1) / 2, (target + 1) / 2).item()
            # psnr_list[videos_list[video_num].split('/')[-1]].append(utils.psnr(mse_imgs))
            psnr_list[os.path.split(videos_list[video_num])[1]].append(
                utils.psnr(mse_imgs))

        # Measuring the abnormality score and the AUC
        anomaly_score_total_list = []
        for video in sorted(videos_list):
            # video_name = video.split('/')[-1]
            video_name = os.path.split(video)[1]
            anomaly_score_total_list += utils.anomaly_score_list(
                psnr_list[video_name])

        anomaly_score_total_list = np.asarray(anomaly_score_total_list)
        accuracy = utils.AUC(anomaly_score_total_list,
                             np.expand_dims(1 - labels_list, 0))

        utils.log('The result of ' + config['test_dataset_type'])
        utils.log('AUC: ' + str(accuracy * 100) + '%')

        # Save the model
        if epoch % save_epoch == 0 or epoch == config['epochs'] - 1:
            # torch.save(model, os.path.join(
            #     save_path, 'model-epoch-{}.pth'.format(epoch)))

            torch.save(
                generator,
                os.path.join(save_path,
                             'generator-epoch-{}.pth'.format(epoch)))
            torch.save(
                discriminator,
                os.path.join(save_path,
                             'discriminator-epoch-{}.pth'.format(epoch)))

        if accuracy > max_accuracy:
            torch.save(generator, os.path.join(save_path, 'generator-max'))
            torch.save(discriminator,
                       os.path.join(save_path, 'discriminator-max'))

        utils.log('----------------------------------------')

    utils.log('Training is finished')
示例#8
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'meta'
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    if args.dataset == 'all':
        train_lst = ['ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd',
                'quickdraw', 'fungi', 'vgg_flower']
        eval_lst = ['ilsvrc_2012']
    else:
        train_lst = [args.dataset]
        eval_lst = [args.dataset]

    if config.get('no_train') == True:
        train_iter = None
    else:
        trainset = make_md(train_lst, 'episodic', split='train', image_size=126)
        train_iter = trainset.make_one_shot_iterator().get_next()

    if config.get('no_val') == True:
        val_iter = None
    else:
        valset = make_md(eval_lst, 'episodic', split='val', image_size=126)
        val_iter = valset.make_one_shot_iterator().get_next()

    testset = make_md(eval_lst, 'episodic', split='test', image_size=126)
    test_iter = testset.make_one_shot_iterator().get_next()

    sess = tf.Session()

    ########

    #### Model and optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

        if config.get('load_encoder'):
            encoder = models.load(torch.load(config['load_encoder'])).encoder
            model.encoder.load_state_dict(encoder.state_dict())

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(
            model.parameters(),
            config['optimizer'], **config['optimizer_args'])

    ########
    
    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    aves_keys = ['tl', 'ta', 'tvl', 'tva', 'vl', 'va']
    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    def process_data(e):
        e = list(e[0])
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(146),
            transforms.CenterCrop(128),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
        ])
        for ii in [0, 3]:
            e[ii] = ((e[ii] + 1.0) * 0.5 * 255).astype('uint8')
            tmp = torch.zeros(len(e[ii]), 3, 128, 128).float()
            for i in range(len(e[ii])):
                tmp[i] = transform(e[ii][i])
            e[ii] = tmp.cuda()

        e[1] = torch.from_numpy(e[1]).long().cuda()
        e[4] = torch.from_numpy(e[4]).long().cuda()

        return e

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        if config.get('freeze_bn'):
            utils.freeze_bn(model) 
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        if config.get('no_train') == True:
            pass
        else:
            for i_ep in tqdm(range(config['n_train'])):

                e = process_data(sess.run(train_iter))
                loss, acc = model(e[0], e[1], e[3], e[4])

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                aves['tl'].add(loss.item())
                aves['ta'].add(acc)

                loss = None 

        # eval
        model.eval()

        for name, ds_iter, name_l, name_a in [
                ('tval', val_iter, 'tvl', 'tva'),
                ('val', test_iter, 'vl', 'va')]:
            if config.get('no_val') == True and name == 'tval':
                continue

            for i_ep in tqdm(range(config['n_eval'])):

                e = process_data(sess.run(ds_iter))

                with torch.no_grad():
                    loss, acc = model(e[0], e[1], e[3], e[4])
                
                aves[name_l].add(loss.item())
                aves[name_a].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()
            trlog[k].append(aves[k])

        _sig = 0

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)
        utils.log('epoch {}, train {:.4f}|{:.4f}, tval {:.4f}|{:.4f}, '
                'val {:.4f}|{:.4f}, {} {}/{} (@{})'.format(
                epoch, aves['tl'], aves['ta'], aves['tvl'], aves['tva'],
                aves['vl'], aves['va'], t_epoch, t_used, t_estimate, _sig))

        writer.add_scalars('loss', {
            'train': aves['tl'],
            'tval': aves['tvl'],
            'val': aves['vl'],
        }, epoch)
        writer.add_scalars('acc', {
            'train': aves['ta'],
            'tval': aves['tva'],
            'val': aves['va'],
        }, epoch)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,

            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),

            'training': training,
        }
        torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(save_path, 'trlog.pth'))

        if (save_epoch is not None) and epoch % save_epoch == 0:
            torch.save(save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))

        writer.flush()
示例#9
0
def main(config, args):
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False
    wandb_auth()
    try:
        __IPYTHON__
        wandb.init(project="NAS", group=f"maml")
    except:
        wandb.init(project="NAS", group=f"maml", config=config)

    ckpt_name = args.name
    if ckpt_name is None:
        ckpt_name = config['encoder']
        ckpt_name += '_' + config['dataset'].replace('meta-', '')
        ckpt_name += '_{}_way_{}_shot'.format(config['train']['n_way'],
                                              config['train']['n_shot'])
    if args.tag is not None:
        ckpt_name += '_' + args.tag

    ckpt_path = os.path.join('./save', ckpt_name)
    utils.ensure_path(ckpt_path)
    utils.set_log_path(ckpt_path)
    writer = SummaryWriter(os.path.join(ckpt_path, 'tensorboard'))
    yaml.dump(config, open(os.path.join(ckpt_path, 'config.yaml'), 'w'))

    ##### Dataset #####
    # meta-train
    train_set = datasets.make(config['dataset'], **config['train'])
    utils.log('meta-train set: {} (x{}), {}'.format(train_set[0][0].shape,
                                                    len(train_set),
                                                    train_set.n_classes))

    # meta-val
    eval_val = False
    if config.get('val'):
        eval_val = True
        val_set = datasets.make(config['dataset'], **config['val'])
        utils.log('meta-val set: {} (x{}), {}'.format(val_set[0][0].shape,
                                                      len(val_set),
                                                      val_set.n_classes))
        val_loader = DataLoader(val_set,
                                config['val']['n_episode'],
                                collate_fn=datasets.collate_fn,
                                num_workers=1,
                                pin_memory=True)

    # if args.split == "traintrain" and config.get('val'): # TODO I dont think this is what they meant by train-train :D
    #   train_set = torch.utils.data.ConcatDataset([train_set, val_set])
    train_loader = DataLoader(train_set,
                              config['train']['n_episode'],
                              collate_fn=datasets.collate_fn,
                              num_workers=1,
                              pin_memory=True)

    ##### Model and Optimizer #####

    inner_args = utils.config_inner_args(config.get('inner_args'))
    if config.get('load') or (args.load is True and
                              os.path.exists(ckpt_path + '/epoch-last.pth')):
        if config.get('load') is None:
            config['load'] = ckpt_path + '/epoch-last.pth'
        ckpt = torch.load(config['load'])
        config['encoder'] = ckpt['encoder']
        config['encoder_args'] = ckpt['encoder_args']
        config['classifier'] = ckpt['classifier']
        config['classifier_args'] = ckpt['classifier_args']
        model = models.load(ckpt,
                            load_clf=(not inner_args['reset_classifier']))
        optimizer, lr_scheduler = optimizers.load(ckpt, model.parameters())
        start_epoch = ckpt['training']['epoch'] + 1
        max_va = ckpt['training']['max_va']
    else:
        config['encoder_args'] = config.get('encoder_args') or dict()
        config['classifier_args'] = config.get('classifier_args') or dict()
        config['encoder_args']['bn_args']['n_episode'] = config['train'][
            'n_episode']
        config['classifier_args']['n_way'] = config['train']['n_way']
        model = models.make(config['encoder'], config['encoder_args'],
                            config['classifier'], config['classifier_args'])
        optimizer, lr_scheduler = optimizers.make(config['optimizer'],
                                                  model.parameters(),
                                                  **config['optimizer_args'])
        start_epoch = 1
        max_va = 0.

    if args.efficient:
        model.go_efficient()

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))
    timer_elapsed, timer_epoch = utils.Timer(), utils.Timer()

    ##### Training and evaluation #####

    # 'tl': meta-train loss
    # 'ta': meta-train accuracy
    # 'vl': meta-val loss
    # 'va': meta-val accuracy
    aves_keys = ['tl', 'ta', 'vl', 'va']
    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    for epoch in tqdm(range(start_epoch, config['epoch'] + 1),
                      desc="Iterating over epochs"):
        timer_epoch.start()
        aves = {k: utils.AverageMeter() for k in aves_keys}

        # meta-train
        model.train()
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
        np.random.seed(epoch)

        all_sotls = 0
        all_sovls = 0
        for data_idx, data in enumerate(
                tqdm(train_loader, desc='meta-train', leave=False)):
            x_shot, x_query, y_shot, y_query = data
            x_shot, y_shot = x_shot.cuda(), y_shot.cuda()
            x_query, y_query = x_query.cuda(), y_query.cuda()

            if inner_args['reset_classifier']:
                if config.get('_parallel'):
                    model.module.reset_classifier()
                else:
                    model.reset_classifier()

            if args.split == "traintrain":
                x_query = x_shot
                y_query = y_shot

            logits, sotl, all_losses = model(x_shot,
                                             x_query,
                                             y_shot,
                                             inner_args,
                                             meta_train=True)
            # print("HAHHA", data_idx, all_losses)
            # sotl = sum([l[-1] for l in all_losses])
            # for l in all_losses[:-1]:
            #   for i in range(len(l)-1):
            #     l[i] = l[i].detach()

            logits = logits.flatten(0, 1)
            labels = y_query.flatten()

            all_sotls += sotl

            pred = torch.argmax(logits, dim=-1)
            acc = utils.compute_acc(pred, labels)
            loss = F.cross_entropy(logits, labels)

            # all_sovls += loss # TODO I think this causes blowup because it creates new tensors that never get discarded and it maintains the computational graph after?
            if args.split == "trainval" or (
                    args.split == "sovl"
                    and not data_idx % args.sotl_freq == 0):

                aves['tl'].update(loss.item(), 1)
                aves['ta'].update(acc, 1)

                optimizer.zero_grad()
                loss.backward()
                for param in optimizer.param_groups[0]['params']:
                    nn.utils.clip_grad_value_(param, 10)
                optimizer.step()
            elif args.split == "traintrain":

                aves['tl'].update(loss.item(), 1)
                aves['ta'].update(acc, 1)

                # sotl = sum(sotl) + loss
                optimizer.zero_grad()
                # sotl.backward()
                loss.backward()
                for param in optimizer.param_groups[0]['params']:
                    nn.utils.clip_grad_value_(param, 10)
                optimizer.step()

            elif args.split == "sotl" and data_idx % args.sotl_freq == 0:
                # TODO doesnt work whatsoever

                aves['tl'].update(loss.item(), 1)
                aves['ta'].update(acc, 1)
                optimizer.zero_grad()
                all_sotls.backward()
                for param in optimizer.param_groups[0]['params']:
                    nn.utils.clip_grad_value_(param, 10)
                optimizer.step()
                all_sotls = 0  # detach
            elif args.split == "sovl" and data_idx % args.sotl_freq == 0:
                # TODO doesnt work whatsoever

                aves['tl'].update(loss.item(), 1)
                aves['ta'].update(acc, 1)
                optimizer.zero_grad()
                all_sovls.backward()
                for param in optimizer.param_groups[0]['params']:
                    nn.utils.clip_grad_value_(param, 10)
                optimizer.step()
                all_sovls = 0  # detach

        # meta-val
        if eval_val:
            model.eval()
            np.random.seed(0)

            for data in tqdm(val_loader, desc='meta-val', leave=False):
                x_shot, x_query, y_shot, y_query = data
                x_shot, y_shot = x_shot.cuda(), y_shot.cuda()
                x_query, y_query = x_query.cuda(), y_query.cuda()

                if inner_args['reset_classifier']:
                    if config.get('_parallel'):
                        model.module.reset_classifier()
                    else:
                        model.reset_classifier()

                logits, sotl, all_losses = model(x_shot,
                                                 x_query,
                                                 y_shot,
                                                 inner_args,
                                                 meta_train=False)
                logits = logits.flatten(0, 1)
                labels = y_query.flatten()

                pred = torch.argmax(logits, dim=-1)
                acc = utils.compute_acc(pred, labels)
                loss = F.cross_entropy(logits, labels)
                aves['vl'].update(loss.item(), 1)
                aves['va'].update(acc, 1)

        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, avg in aves.items():
            aves[k] = avg.item()
            trlog[k].append(aves[k])

        t_epoch = utils.time_str(timer_epoch.end())
        t_elapsed = utils.time_str(timer_elapsed.end())
        t_estimate = utils.time_str(timer_elapsed.end() /
                                    (epoch - start_epoch + 1) *
                                    (config['epoch'] - start_epoch + 1))

        # formats output
        log_str = 'epoch {}, meta-train {:.4f}|{:.4f}'.format(
            str(epoch), aves['tl'], aves['ta'])
        writer.add_scalars('loss', {'meta-train': aves['tl']}, epoch)
        writer.add_scalars('acc', {'meta-train': aves['ta']}, epoch)

        if eval_val:
            log_str += ', meta-val {:.4f}|{:.4f}'.format(
                aves['vl'], aves['va'])
            writer.add_scalars('loss', {'meta-val': aves['vl']}, epoch)
            writer.add_scalars('acc', {'meta-val': aves['va']}, epoch)

        wandb.log({
            "train_loss": aves['tl'],
            "train_acc": aves['ta'],
            "val_loss": aves['vl'],
            "val_acc": aves['va']
        })
        log_str += ', {} {}/{}'.format(t_epoch, t_elapsed, t_estimate)
        utils.log(log_str)

        # saves model and meta-data
        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch':
            epoch,
            'max_va':
            max(max_va, aves['va']),
            'optimizer':
            config['optimizer'],
            'optimizer_args':
            config['optimizer_args'],
            'optimizer_state_dict':
            optimizer.state_dict(),
            'lr_scheduler_state_dict':
            lr_scheduler.state_dict() if lr_scheduler is not None else None,
        }
        ckpt = {
            'file': __file__,
            'config': config,
            'encoder': config['encoder'],
            'encoder_args': config['encoder_args'],
            'encoder_state_dict': model_.encoder.state_dict(),
            'classifier': config['classifier'],
            'classifier_args': config['classifier_args'],
            'classifier_state_dict': model_.classifier.state_dict(),
            'training': training,
        }

        # 'epoch-last.pth': saved at the latest epoch
        # 'max-va.pth': saved when validation accuracy is at its maximum
        torch.save(ckpt, os.path.join(ckpt_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(ckpt_path, 'trlog.pth'))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(ckpt, os.path.join(ckpt_path, 'max-va.pth'))

        writer.flush()
示例#10
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'pretrain-multi'
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    def make_dataset(name):
        dataset = make_md([name],
            'batch', split='train', image_size=126, batch_size=256)
        return dataset

    ds_names = ['ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd', \
            'quickdraw', 'fungi', 'vgg_flower']
    datasets = []
    for name in ds_names:
        datasets.append(make_dataset(name))
    iters = []
    for d in datasets:
        iters.append(d.make_one_shot_iterator().get_next())

    to_torch_labels = lambda a: torch.from_numpy(a).long()

    to_pil = transforms.ToPILImage()
    augmentation = transforms.Compose([
        transforms.Resize(146),
        transforms.RandomResizedCrop(128),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
    ])
    ########

    #### Model and Optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(
            model.parameters(),
            config['optimizer'], **config['optimizer_args'])

    ########
    
    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves_keys = ['tl', 'ta', 'vl', 'va']
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        n_batch = 915547 // 256
        with tf.Session() as sess:
            for i_batch in tqdm(range(n_batch)):
                if random.randint(0, 1) == 0:
                    ds_id = 0
                else:
                    ds_id = random.randint(1, len(datasets) - 1)

                next_element = iters[ds_id]
                e, cfr_id = sess.run(next_element)

                data_, label = e[0], to_torch_labels(e[1])
                data_ = ((data_ + 1.0) * 0.5 * 255).astype('uint8')
                data = torch.zeros(256, 3, 128, 128).float()
                for i in range(len(data_)):
                    x = data_[i]
                    x = to_pil(x)
                    x = augmentation(x)
                    data[i] = x

                data = data.cuda()
                label = label.cuda()

                logits = model(data, cfr_id=ds_id)
                loss = F.cross_entropy(logits, label)
                acc = utils.compute_acc(logits, label)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                aves['tl'].add(loss.item())
                aves['ta'].add(acc)

                logits = None; loss = None

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)

        if epoch <= max_epoch:
            epoch_str = str(epoch)
        else:
            epoch_str = 'ex'
        log_str = 'epoch {}, train {:.4f}|{:.4f}'.format(
                epoch_str, aves['tl'], aves['ta'])
        writer.add_scalars('loss', {'train': aves['tl']}, epoch)
        writer.add_scalars('acc', {'train': aves['ta']}, epoch)

        if epoch <= max_epoch:
            log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        else:
            log_str += ', {}'.format(t_epoch)
        utils.log(log_str)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,

            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),

            'training': training,
        }
        if epoch <= max_epoch:
            torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))

            if (save_epoch is not None) and epoch % save_epoch == 0:
                torch.save(save_obj, os.path.join(
                    save_path, 'epoch-{}.pth'.format(epoch)))

            if aves['va'] > max_va:
                max_va = aves['va']
                torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))
        else:
            torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth'))

        writer.flush()
示例#11
0
def main(config):
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    ckpt_name = args.name
    if ckpt_name is None:
        ckpt_name = config['encoder']
        ckpt_name += '_' + config['dataset'].replace('meta-', '')
        ckpt_name += '_{}_way_{}_shot'.format(config['train']['n_way'],
                                              config['train']['n_shot'])
    if args.tag is not None:
        ckpt_name += '_' + args.tag

    ckpt_path = os.path.join('./save', ckpt_name)
    utils.ensure_path(ckpt_path)
    utils.set_log_path(ckpt_path)
    writer = SummaryWriter(os.path.join(ckpt_path, 'tensorboard'))
    yaml.dump(config, open(os.path.join(ckpt_path, 'config.yaml'), 'w'))

    ##### Dataset #####

    # meta-train
    train_set = datasets.make(config['dataset'], **config['train'])
    utils.log('meta-train set: {} (x{}), {}'.format(train_set[0][0].shape,
                                                    len(train_set),
                                                    train_set.n_classes))
    train_loader = DataLoader(train_set,
                              config['train']['n_episode'],
                              collate_fn=datasets.collate_fn,
                              num_workers=1,
                              pin_memory=True)

    # meta-val
    eval_val = False
    if config.get('val'):
        eval_val = True
        val_set = datasets.make(config['dataset'], **config['val'])
        utils.log('meta-val set: {} (x{}), {}'.format(val_set[0][0].shape,
                                                      len(val_set),
                                                      val_set.n_classes))
        val_loader = DataLoader(val_set,
                                config['val']['n_episode'],
                                collate_fn=datasets.collate_fn,
                                num_workers=1,
                                pin_memory=True)

    ##### Model and Optimizer #####

    inner_args = utils.config_inner_args(config.get('inner_args'))
    if config.get('load'):
        ckpt = torch.load(config['load'])
        config['encoder'] = ckpt['encoder']
        config['encoder_args'] = ckpt['encoder_args']
        config['classifier'] = ckpt['classifier']
        config['classifier_args'] = ckpt['classifier_args']
        model = models.load(ckpt,
                            load_clf=(not inner_args['reset_classifier']))
        optimizer, lr_scheduler = optimizers.load(ckpt, model.parameters())
        start_epoch = ckpt['training']['epoch'] + 1
        max_va = ckpt['training']['max_va']
    else:
        config['encoder_args'] = config.get('encoder_args') or dict()
        config['classifier_args'] = config.get('classifier_args') or dict()
        config['encoder_args']['bn_args']['n_episode'] = config['train'][
            'n_episode']
        config['classifier_args']['n_way'] = config['train']['n_way']
        model = models.make(config['encoder'], config['encoder_args'],
                            config['classifier'], config['classifier_args'])
        optimizer, lr_scheduler = optimizers.make(config['optimizer'],
                                                  model.parameters(),
                                                  **config['optimizer_args'])
        start_epoch = 1
        max_va = 0.

    if args.efficient:
        model.go_efficient()

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))
    timer_elapsed, timer_epoch = utils.Timer(), utils.Timer()

    ##### Training and evaluation #####

    # 'tl': meta-train loss
    # 'ta': meta-train accuracy
    # 'vl': meta-val loss
    # 'va': meta-val accuracy
    aves_keys = ['tl', 'ta', 'vl', 'va']
    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    for epoch in range(start_epoch, config['epoch'] + 1):
        timer_epoch.start()
        aves = {k: utils.AverageMeter() for k in aves_keys}

        # meta-train
        model.train()
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
        np.random.seed(epoch)

        for data in tqdm(train_loader, desc='meta-train', leave=False):
            x_shot, x_query, y_shot, y_query = data
            x_shot, y_shot = x_shot.cuda(), y_shot.cuda()
            x_query, y_query = x_query.cuda(), y_query.cuda()

            if inner_args['reset_classifier']:
                if config.get('_parallel'):
                    model.module.reset_classifier()
                else:
                    model.reset_classifier()

            logits = model(x_shot,
                           x_query,
                           y_shot,
                           inner_args,
                           meta_train=True)
            logits = logits.flatten(0, 1)
            labels = y_query.flatten()

            pred = torch.argmax(logits, dim=-1)
            acc = utils.compute_acc(pred, labels)
            loss = F.cross_entropy(logits, labels)
            aves['tl'].update(loss.item(), 1)
            aves['ta'].update(acc, 1)

            optimizer.zero_grad()
            loss.backward()
            for param in optimizer.param_groups[0]['params']:
                nn.utils.clip_grad_value_(param, 10)
            optimizer.step()

        # meta-val
        if eval_val:
            model.eval()
            np.random.seed(0)

            for data in tqdm(val_loader, desc='meta-val', leave=False):
                x_shot, x_query, y_shot, y_query = data
                x_shot, y_shot = x_shot.cuda(), y_shot.cuda()
                x_query, y_query = x_query.cuda(), y_query.cuda()

                if inner_args['reset_classifier']:
                    if config.get('_parallel'):
                        model.module.reset_classifier()
                    else:
                        model.reset_classifier()

                logits = model(x_shot,
                               x_query,
                               y_shot,
                               inner_args,
                               meta_train=False)
                logits = logits.flatten(0, 1)
                labels = y_query.flatten()

                pred = torch.argmax(logits, dim=-1)
                acc = utils.compute_acc(pred, labels)
                loss = F.cross_entropy(logits, labels)
                aves['vl'].update(loss.item(), 1)
                aves['va'].update(acc, 1)

        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, avg in aves.items():
            aves[k] = avg.item()
            trlog[k].append(aves[k])

        t_epoch = utils.time_str(timer_epoch.end())
        t_elapsed = utils.time_str(timer_elapsed.end())
        t_estimate = utils.time_str(timer_elapsed.end() /
                                    (epoch - start_epoch + 1) *
                                    (config['epoch'] - start_epoch + 1))

        # formats output
        log_str = 'epoch {}, meta-train {:.4f}|{:.4f}'.format(
            str(epoch), aves['tl'], aves['ta'])
        writer.add_scalars('loss', {'meta-train': aves['tl']}, epoch)
        writer.add_scalars('acc', {'meta-train': aves['ta']}, epoch)

        if eval_val:
            log_str += ', meta-val {:.4f}|{:.4f}'.format(
                aves['vl'], aves['va'])
            writer.add_scalars('loss', {'meta-val': aves['vl']}, epoch)
            writer.add_scalars('acc', {'meta-val': aves['va']}, epoch)

        log_str += ', {} {}/{}'.format(t_epoch, t_elapsed, t_estimate)
        utils.log(log_str)

        # saves model and meta-data
        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch':
            epoch,
            'max_va':
            max(max_va, aves['va']),
            'optimizer':
            config['optimizer'],
            'optimizer_args':
            config['optimizer_args'],
            'optimizer_state_dict':
            optimizer.state_dict(),
            'lr_scheduler_state_dict':
            lr_scheduler.state_dict() if lr_scheduler is not None else None,
        }
        ckpt = {
            'file': __file__,
            'config': config,
            'encoder': config['encoder'],
            'encoder_args': config['encoder_args'],
            'encoder_state_dict': model_.encoder.state_dict(),
            'classifier': config['classifier'],
            'classifier_args': config['classifier_args'],
            'classifier_state_dict': model_.classifier.state_dict(),
            'training': training,
        }

        # 'epoch-last.pth': saved at the latest epoch
        # 'max-va.pth': saved when validation accuracy is at its maximum
        torch.save(ckpt, os.path.join(ckpt_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(ckpt_path, 'trlog.pth'))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(ckpt, os.path.join(ckpt_path, 'max-va.pth'))

        writer.flush()
示例#12
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'moco_{}'.format(config['train_dataset'])
        svname += '_' + config['model_args']['encoder']
        out_dim = config['model_args']['encoder_args']['out_dim']
        svname += '-out_dim' + str(out_dim)
    svname += '-seed' + str(args.seed)
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join(args.save_dir, svname)
    utils.ensure_path(save_path, remove=False)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    random_state = np.random.RandomState(args.seed)
    print('seed:', args.seed)

    logger = utils.Logger(file_name=os.path.join(save_path, "log_sdout.txt"),
                          file_mode="a+",
                          should_flush=True)

    #### Dataset ####

    # train
    train_dataset = datasets.make(config['train_dataset'],
                                  **config['train_dataset_args'])
    train_loader = DataLoader(train_dataset,
                              config['batch_size'],
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True,
                              drop_last=True)
    utils.log('train dataset: {} (x{})'.format(train_dataset[0][0][0].shape,
                                               len(train_dataset)))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(train_dataset, 'train_dataset', writer)

    # val
    if config.get('val_dataset'):
        eval_val = True
        val_dataset = datasets.make(config['val_dataset'],
                                    **config['val_dataset_args'])
        val_loader = DataLoader(val_dataset,
                                config['batch_size'],
                                num_workers=8,
                                pin_memory=True,
                                drop_last=True)
        utils.log('val dataset: {} (x{})'.format(val_dataset[0][0][0].shape,
                                                 len(val_dataset)))
        if config.get('visualize_datasets'):
            utils.visualize_dataset(val_dataset, 'val_dataset', writer)
    else:
        eval_val = False

    # few-shot eval
    if config.get('eval_fs'):
        ef_epoch = config.get('eval_fs_epoch')
        if ef_epoch is None:
            ef_epoch = 5
        eval_fs = True
        n_way = 2
        n_query = 1
        n_shot = 6

        if config.get('ep_per_batch') is not None:
            ep_per_batch = config['ep_per_batch']
        else:
            ep_per_batch = 1

        # tvals
        fs_loaders = {}
        tval_name_ntasks_dict = {
            'tval': 2000,
            'tval_ff': 600,
            'tval_bd': 480,
            'tval_hd_comb': 400,
            'tval_hd_novel': 320
        }  # numbers depend on dataset
        for tval_type in tval_name_ntasks_dict.keys():
            if config.get('{}_dataset'.format(tval_type)):
                tval_dataset = datasets.make(
                    config['{}_dataset'.format(tval_type)],
                    **config['{}_dataset_args'.format(tval_type)])
                utils.log('{} dataset: {} (x{})'.format(
                    tval_type, tval_dataset[0][0][0].shape, len(tval_dataset)))
                if config.get('visualize_datasets'):
                    utils.visualize_dataset(tval_dataset, 'tval_ff_dataset',
                                            writer)
                tval_sampler = BongardSampler(
                    tval_dataset.n_tasks,
                    n_batch=tval_name_ntasks_dict[tval_type] // ep_per_batch,
                    ep_per_batch=ep_per_batch,
                    seed=random_state.randint(2**31))
                tval_loader = DataLoader(tval_dataset,
                                         batch_sampler=tval_sampler,
                                         num_workers=8,
                                         pin_memory=True)
                fs_loaders.update({tval_type: tval_loader})
            else:
                fs_loaders.update({tval_type: None})

    else:
        eval_fs = False

    ########

    #### Model and Optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

    if eval_fs:
        fs_model = models.make('meta-baseline', encoder=None)
        fs_model.encoder = model.encoder

    if config.get('_parallel'):
        model = nn.DataParallel(model)
        if eval_fs:
            fs_model = nn.DataParallel(fs_model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(model.parameters(),
                                                   config['optimizer'],
                                                   **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    for epoch in range(1, max_epoch + 1 + 1):

        timer_epoch.s()
        aves_keys = ['tl', 'ta', 'vl', 'va', 'tvl', 'tva']
        if eval_fs:
            for k, v in fs_loaders.items():
                if v is not None:
                    aves_keys += ['fsa' + k.split('tval')[-1]]
        aves = {ave_k: utils.Averager() for ave_k in aves_keys}

        # train
        model.train()
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        for data, _ in tqdm(train_loader, desc='train', leave=False):
            logits, label = model(im_q=data[0].cuda(), im_k=data[1].cuda())

            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None
            loss = None

        # val
        if eval_val:
            model.eval()
            for data, _ in tqdm(val_loader, desc='val', leave=False):
                with torch.no_grad():
                    logits, label = model(im_q=data[0].cuda(),
                                          im_k=data[1].cuda())
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)

                aves['vl'].add(loss.item())
                aves['va'].add(acc)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            fs_model.eval()
            for k, v in fs_loaders.items():
                if v is not None:
                    ave_key = 'fsa' + k.split('tval')[-1]
                    np.random.seed(0)
                    for data, _ in tqdm(v, desc=ave_key, leave=False):
                        x_shot, x_query = fs.split_shot_query(
                            data[0].cuda(),
                            n_way,
                            n_shot,
                            n_query,
                            ep_per_batch=ep_per_batch)
                        label_query = fs.make_nk_label(
                            n_way, n_query, ep_per_batch=ep_per_batch).cuda()
                        with torch.no_grad():
                            logits = fs_model(x_shot, x_query).view(-1, n_way)
                            acc = utils.compute_acc(logits, label_query)
                        aves[ave_key].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)

        if epoch <= max_epoch:
            epoch_str = str(epoch)
        else:
            epoch_str = 'ex'
        log_str = 'epoch {}, train {:.4f}|{:.4f}'.format(
            epoch_str, aves['tl'], aves['ta'])
        writer.add_scalars('loss', {'train': aves['tl']}, epoch)
        writer.add_scalars('acc', {'train': aves['ta']}, epoch)

        if eval_val:
            log_str += ', val {:.4f}|{:.4f}, tval {:.4f}|{:.4f}'.format(
                aves['vl'], aves['va'], aves['tvl'], aves['tva'])
            writer.add_scalars('loss', {'val': aves['vl']}, epoch)
            writer.add_scalars('loss', {'tval': aves['tvl']}, epoch)
            writer.add_scalars('acc', {'val': aves['va']}, epoch)
            writer.add_scalars('acc', {'tval': aves['tva']}, epoch)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            log_str += ', fs'
            for ave_key in aves_keys:
                if 'fsa' in ave_key:
                    log_str += ' {}: {:.4f}'.format(ave_key, aves[ave_key])
                    writer.add_scalars('acc', {ave_key: aves[ave_key]}, epoch)

        if epoch <= max_epoch:
            log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        else:
            log_str += ', {}'.format(t_epoch)
        utils.log(log_str)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,
            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),
            'training': training,
        }
        if epoch <= max_epoch:
            torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))

            if (save_epoch is not None) and epoch % save_epoch == 0:
                torch.save(
                    save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

            if aves['va'] > max_va:
                max_va = aves['va']
                torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))
        else:
            torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth'))

        writer.flush()

    print('finished training!')
    logger.close()
示例#13
0
def train(config):
    #### set the save and log path ####
    svname = args.name
    if svname is None:
        svname = 'classifier_{}'.format(config['train_dataset'])
        svname += '_' + config['model_args']['encoder']
        clsfr = config['model_args']['classifier']
        if clsfr != 'linear-classifier':
            svname += '-' + clsfr
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.set_save_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))
    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### make datasets ####
    # train
    dataloader_train = get_standard_loader(config['dataset_path'],
                                           config['train_dataset'],
                                           **config['train_dataset_args'])
    utils.log('train dataset: {}'.format(config['train_dataset']))
    # val
    if config.get('val_dataset'):
        eval_val = True
        dataloader_val = get_standard_loader(config['dataset_path'],
                                             config['val_dataset'],
                                             **config['val_dataset_args'])
    else:
        eval_val = False
    # few-shot test
    if config.get('test_dataset'):
        test_epoch = config.get('test_epoch')
        if test_epoch is None:
            ef_epoch = 5
        test_fs = True

        n_way = 5
        n_query = 15
        n_shots = [1, 5]
        dataloaders_test = []
        for n_shot in n_shots:
            dataloader_test = get_meta_loader(config['dataset_path'],
                                              config['test_dataset'],
                                              ways=n_way,
                                              shots=n_shot,
                                              query_shots=n_query,
                                              **config['test_dataset_args'])
            dataloaders_test.append(dataloader_test)
    else:
        test_fs = False

    #### Model and Optimizer ####
    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

    if test_fs:
        fs_model = models.make('meta-baseline', encoder=None)
        fs_model.encoder = model.encoder

    if config.get('_parallel'):
        model = nn.DataParallel(model)
        if test_fs:
            fs_model = nn.DataParallel(fs_model)

    optimizer, lr_scheduler = utils.make_optimizer(model.parameters(),
                                                   config['optimizer'],
                                                   **config['optimizer_args'])

    #### train and test ####

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.

    for epoch in range(1, max_epoch + 1):

        aves_keys = ['tl', 'ta', 'vl', 'va']
        if test_fs:
            for n_shot in n_shots:
                aves_keys += ['fsa-' + str(n_shot)]

        # train
        model.train()
        for data, label in tqdm(dataloader_train, desc='train', leave=False):
            data, label = data.cuda(), label.cuda()
            logits = model(data)
            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            logits = None
            loss = None

        # eval
        if eval_val:
            model.eval()
            for data, label in tqdm(dataloader_val, desc='val', leave=False):
                data, label = data.cuda(), label.cuda()
                with torch.no_grad():
                    logits = model(data)
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)

        if test_fs and (epoch % ef_epoch == 0 or epoch == max_epoch):
            fs_model.eval()
            for i, n_shot in enumerate(n_shots):
                np.random.seed(0)
                for data in tqdm(dataloaders_test[i],
                                 desc='fs-' + str(n_shot),
                                 leave=False):
                    train_inputs, train_targets = data["train"]
                    train_inputs = train_inputs.cuda()
                    train_targets = train_targets.cuda()
                    query_inputs, query_targets = data["test"]

                    with torch.no_grad():
                        logits = fs_model(
                            train_inputs.view(train_inputs.shape[0], n_way,
                                              n_shot,
                                              *train_inputs.shape[-3:]),
                            query_inputs).view(-1, n_way)
                        acc = utils.compute_acc(logits, label)
        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()
示例#14
0
def main(config):
    svname = args.name
    if svname is None:
        svname = f"classifier-{config['train_dataset']}-{config['model_args']['encoder']}"
        clsfr = config['model_args']['classifier']
        if clsfr != 'linear-classifier':
            svname += '-' + clsfr

    svname += '-aux' + str(args.aux_level)

    if args.topk is not None:
        svname += f"-top{args.topk}"

    if args.tag is not None:
        svname += '_' + args.tag

    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    for s in ['train', 'val', 'tval', 'fs', 'fs_val']:
        if config.get(f"{s}_dataset_args") is not None:
            config[f"{s}_dataset_args"]['data_dir'] = os.path.join(
                os.getcwd(), os.pardir, 'data_root')

    # train
    train_dataset = TrainDataset(name=config['train_dataset'],
                                 **config['train_dataset_args'])
    train_loader = DataLoader(train_dataset,
                              config['batch_size'],
                              shuffle=True,
                              num_workers=16,
                              pin_memory=True,
                              drop_last=True)

    with open(os.path.join(save_path, 'training_classes.pkl'), 'wb') as f:
        pkl.dump(train_dataset.separated_training_classes, f)

    # val
    if config.get('val_dataset'):
        eval_val = True
        val_dataset = TrainDataset(config['val_dataset'],
                                   **config['val_dataset_args'])
        val_loader = DataLoader(val_dataset,
                                config['batch_size'],
                                num_workers=16,
                                pin_memory=True,
                                drop_last=True)
    else:
        eval_val = False

    # few-shot eval
    fs_loaders = {'fs_dataset': list(), 'fs_val_dataset': list()}
    for key in fs_loaders.keys():
        if config.get(key):
            ef_epoch = config.get('eval_fs_epoch')
            if ef_epoch is None:
                ef_epoch = 5
            eval_fs = True

            fs_dataset = CustomDataset(config[key], **config[key + '_args'])

            n_way = config['n_way'] if config.get('n_way') else 5
            n_query = config['n_query'] if config.get('n_query') else 15
            if config.get('n_pseudo') is not None:
                n_pseudo = config['n_pseudo']
            else:
                n_pseudo = 15
            n_batches = config['n_batches'] if config.get('n_batches') else 200
            ep_per_batch = config['ep_per_batch'] if config.get(
                'ep_per_batch') else 4
            n_shots = [1, 5]
            for n_shot in n_shots:
                fs_sampler = EpisodicSampler(fs_dataset,
                                             n_batches,
                                             n_way,
                                             n_shot,
                                             n_query,
                                             n_pseudo,
                                             episodes_per_batch=ep_per_batch)
                fs_loader = DataLoader(fs_dataset,
                                       batch_sampler=fs_sampler,
                                       num_workers=16,
                                       pin_memory=True)
                fs_loaders[key].append(fs_loader)
        else:
            eval_fs = False

    eval_fs = False
    for key in fs_loaders.keys():
        if config.get(key):
            eval_fs = True

    #### Model and Optimizer ####

    config['model_args']['classifier_args'][
        'n_classes'] = train_dataset.n_classes
    model = models.make(config['model'], **config['model_args'])

    if eval_fs:
        fs_model = models.make('meta-baseline', encoder=None)
        fs_model.encoder = model.encoder

    if config.get('_parallel'):
        model = nn.DataParallel(model)
        if eval_fs:
            fs_model = nn.DataParallel(fs_model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(model.parameters(),
                                                   config['optimizer'],
                                                   **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    for epoch in range(1, max_epoch + 1 + 1):
        timer_epoch.s()
        aves_keys = ['tl', 'ta', 'vl', 'va']
        if eval_fs:
            for n_shot in n_shots:
                aves_keys += ['fsa-' + str(n_shot)]
                if config.get('fs_val_dataset'):
                    aves_keys += ['fsav-' + str(n_shot)]
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        for data, label in tqdm(train_loader, desc='train', leave=False):
            data, label = data.cuda(), label.cuda()
            logits = model(data)
            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None
            loss = None

        # eval
        if eval_val:
            model.eval()
            for data, label in tqdm(val_loader, desc='val', leave=False):
                data, label = data.cuda(), label.cuda()
                with torch.no_grad():
                    logits = model(data)
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)

                aves['vl'].add(loss.item())
                aves['va'].add(acc)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            fs_model.eval()
            for key in fs_loaders.keys():
                if len(fs_loaders[key]) == 0:
                    continue
                tag = 'v' if key == 'fs_val_dataset' else ''
                for i, n_shot in enumerate(n_shots):
                    np.random.seed(0)
                    for data in tqdm(fs_loaders[key][i],
                                     desc='fs' + tag + '-' + str(n_shot),
                                     leave=False):
                        x_shot, x_query, x_pseudo = fs.split_shot_query(
                            data.cuda(),
                            n_way,
                            n_shot,
                            n_query,
                            pseudo=n_pseudo,
                            ep_per_batch=ep_per_batch)
                        label = fs.make_nk_label(
                            n_way, n_query, ep_per_batch=ep_per_batch).cuda()
                        with torch.no_grad():
                            logits = fs_model(x_shot, x_query, x_pseudo)
                            logits = logits.view(-1, n_way)
                            acc = utils.compute_acc(logits, label)
                        aves['fsa' + tag + '-' + str(n_shot)].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)

        if epoch <= max_epoch:
            epoch_str = str(epoch)
        else:
            epoch_str = 'ex'
        log_str = 'epoch {}, train {:.4f}|{:.4f}'.format(
            epoch_str, aves['tl'], aves['ta'])
        writer.add_scalars('loss', {'train': aves['tl']}, epoch)
        writer.add_scalars('acc', {'train': aves['ta']}, epoch)

        if eval_val:
            log_str += ', val {:.4f}|{:.4f}'.format(aves['vl'], aves['va'])
            writer.add_scalars('loss', {'val': aves['vl']}, epoch)
            writer.add_scalars('acc', {'val': aves['va']}, epoch)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            for key in fs_loaders.keys():
                if len(fs_loaders[key]) == 0:
                    continue
                tag = 'v' if key == 'fs_val_dataset' else ''
                log_str += ', fs' + tag
                for n_shot in n_shots:
                    key = 'fsa' + tag + '-' + str(n_shot)
                    log_str += ' {}: {:.4f}'.format(n_shot, aves[key])
                    writer.add_scalars('acc', {key: aves[key]}, epoch)

        if epoch <= max_epoch:
            log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        else:
            log_str += ', {}'.format(t_epoch)
        utils.log(log_str)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,
            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),
            'training': training,
        }
        if epoch <= max_epoch:
            torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))

            if (save_epoch is not None) and epoch % save_epoch == 0:
                torch.save(
                    save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

            if aves['va'] > max_va:
                max_va = aves['va']
                torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))
        else:
            torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth'))

        writer.flush()