def forward(self, x_shot, x_query, **kwargs): ep_per_batch, n_way, n_shot = list(x_shot.shape[:-3]) query_shape = list(x_query.shape[:-3]) img_shape = list(x_shot.shape[-3:]) x_shot = x_shot.view(-1, *img_shape) x_query = x_query.view(-1, *img_shape) x_tot = self.encoder(torch.cat([x_shot, x_query], dim=0)) x_shot, x_query = x_tot[:len(x_shot)], x_tot[-len(x_query):] x_shot = x_shot.view(ep_per_batch, n_way * n_shot, -1) # [bs, n_way * n_shot, n_feat] x_query = x_query.view(*query_shape, -1) # [bs, n_way * n_query, n_feat] labels_support = make_nk_label(n_way, n_shot, ep_per_batch) labels_support = labels_support.cuda().view(ep_per_batch, -1) logits = self.cls_head(x_query, x_shot, labels_support, n_way, n_shot, normalize=self.normalize) return logits
def forward(self, x_shot, x_query, **kwargs): ep_per_batch, n_way, n_shot = x_shot.shape[:-3] assert n_shot == self.n_shot and n_way == self.n_way img_shape = x_shot.shape[-3:] if kwargs.get('eval') is None: kwargs['eval'] = False first_order = self.first_order or kwargs['eval'] x_shot = x_shot.view(-1, *img_shape) train_logit = self.model(x_shot) # [bs * n_way * n_shot, n_way] labels_support = make_nk_label(n_way, n_shot, ep_per_batch) # [bs * n_way * n_shot] inner_loss = F.cross_entropy(train_logit, labels_support.cuda()) # self.zero_grad() params = gradient_update_parameters(self.model, inner_loss, step_size=self.step_size, first_order=first_order) x_query = x_query.view(-1, *img_shape) test_logit = self.model(x_query, params=params) # [bs * n_way * n_query, n_way] return test_logit
def forward(self, x_shot, x_query, **kwargs): ep_per_batch, n_way, n_shot = list(x_shot.shape[:-3]) img_shape = list(x_shot.shape[-3:]) if kwargs.get('eval') is None: kwargs['eval'] = False if self.dynamic_k and not kwargs['eval']: # Dynamically zero out up to K-1 training batches. x_shot = x_shot.view(ep_per_batch, -1, *img_shape) k = np.random.randint(0, n_shot - 1) x_shot[:, :k * n_way] = 0. x_shot = x_shot.reshape(ep_per_batch, n_way, n_shot, *img_shape) x_shot = x_shot.view(-1, *img_shape) x_last = x_query.view(-1, *img_shape) x_tot = self.encoder(torch.cat( [x_shot, x_last], dim=0)) # [bs * (n_way * n_shot + 1), n_feat] x_tot = x_tot.view(ep_per_batch, n_way * n_shot + 1, -1) # [bs, n_way * n_shot + 1, n_feat] labels_support = make_nk_label(n_way, n_shot, ep_per_batch) # [bs * n_way * n_shot] labels_support = labels_support.cuda().unsqueeze( -1) # [bs * n_way * n_shot, 1] labels_support_onehot = torch.FloatTensor(labels_support.size(0), 2).cuda() labels_support_onehot.zero_() labels_support_onehot.scatter_(1, labels_support, 1) # [bs * n_way * n_shot, n_way] labels_support_onehot = labels_support_onehot.view( ep_per_batch, -1, n_way) labels_query_zero = torch.Tensor(np.zeros( (ep_per_batch, 1, n_way))).cuda() labels = torch.cat([labels_support_onehot, labels_query_zero], dim=1) # [bs, n_way * n_shot + 1, n_way] x = torch.cat((x_tot, labels), dim=-1) # [bs, n_way * n_shot + 1, n_feat + n_way] x = self.attention1(x) x = self.tc1(x) x = self.attention2(x) x = self.tc2(x) x = self.attention3(x) x = self.fc(x) # [bs, n_way * n_shot + 1, n_way] return x[:, -1, :] # [bs, n_way]
def main(config): svname = args.name if svname is None: svname = 'classifier_{}'.format(config['train_dataset']) svname += '_' + config['model_args']['encoder'] clsfr = config['model_args']['classifier'] if clsfr != 'linear-classifier': svname += '-' + clsfr if args.tag is not None: svname += '_' + args.tag save_path = os.path.join('./save', svname) utils.ensure_path(save_path) utils.set_log_path(save_path) writer = SummaryWriter(os.path.join(save_path, 'tensorboard')) yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w')) #### Dataset #### # train train_dataset = datasets.make(config['train_dataset'], **config['train_dataset_args']) augmentations = [ transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), transforms.Compose([ transforms.RandomResizedCrop(size=(80, 80), scale=(0.08, 1.0), ratio=(0.75, 1.3333)), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), transforms.Compose([ transforms.RandomRotation(35), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), transforms.Compose([ transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), transforms.Compose([ transforms.RandomResizedCrop(size=(80, 80), scale=(0.08, 1.0), ratio=(0.75, 1.3333)), transforms.RandomRotation(35), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), transforms.Compose([ transforms.RandomRotation(35), transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), transforms.Compose([ transforms.RandomResizedCrop(size=(80, 80), scale=(0.08, 1.0), ratio=(0.75, 1.3333)), transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), transforms.Compose([ transforms.RandomRotation(35), transforms.RandomResizedCrop(size=(80, 80), scale=(0.08, 1.0), ratio=(0.75, 1.3333)), transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) ] train_dataset.transform = augmentations[int(config['_a'])] print(train_dataset.transform) print("_a", config['_a']) input("Continue with these augmentations?") train_loader = DataLoader(train_dataset, config['batch_size'], shuffle=True, num_workers=0, pin_memory=True) utils.log('train dataset: {} (x{}), {}'.format(train_dataset[0][0].shape, len(train_dataset), train_dataset.n_classes)) if config.get('visualize_datasets'): utils.visualize_dataset(train_dataset, 'train_dataset', writer) # val if config.get('val_dataset'): eval_val = True val_dataset = datasets.make(config['val_dataset'], **config['val_dataset_args']) val_loader = DataLoader(val_dataset, config['batch_size'], num_workers=0, pin_memory=True) utils.log('val dataset: {} (x{}), {}'.format(val_dataset[0][0].shape, len(val_dataset), val_dataset.n_classes)) if config.get('visualize_datasets'): utils.visualize_dataset(val_dataset, 'val_dataset', writer) else: eval_val = False # few-shot eval if config.get('fs_dataset'): ef_epoch = config.get('eval_fs_epoch') if ef_epoch is None: ef_epoch = 5 eval_fs = True fs_dataset = datasets.make(config['fs_dataset'], **config['fs_dataset_args']) utils.log('fs dataset: {} (x{}), {}'.format(fs_dataset[0][0].shape, len(fs_dataset), fs_dataset.n_classes)) if config.get('visualize_datasets'): utils.visualize_dataset(fs_dataset, 'fs_dataset', writer) n_way = 5 n_query = 15 n_shots = [1, 5] fs_loaders = [] for n_shot in n_shots: fs_sampler = CategoriesSampler(fs_dataset.label, 200, n_way, n_shot + n_query, ep_per_batch=4) fs_loader = DataLoader(fs_dataset, batch_sampler=fs_sampler, num_workers=0, pin_memory=True) fs_loaders.append(fs_loader) else: eval_fs = False ######## #### Model and Optimizer #### if config.get('load'): model_sv = torch.load(config['load']) model = models.load(model_sv) else: model = models.make(config['model'], **config['model_args']) if eval_fs: fs_model = models.make('meta-baseline', encoder=None) fs_model.encoder = model.encoder if config.get('_parallel'): model = nn.DataParallel(model) if eval_fs: fs_model = nn.DataParallel(fs_model) utils.log('num params: {}'.format(utils.compute_n_params(model))) optimizer, lr_scheduler = utils.make_optimizer(model.parameters(), config['optimizer'], **config['optimizer_args']) ######## max_epoch = config['max_epoch'] save_epoch = config.get('save_epoch') max_va = 0. timer_used = utils.Timer() timer_epoch = utils.Timer() for epoch in range(1, max_epoch + 1 + 1): if epoch == max_epoch + 1: if not config.get('epoch_ex'): break train_dataset.transform = train_dataset.default_transform print(train_dataset.transform) train_loader = DataLoader(train_dataset, config['batch_size'], shuffle=True, num_workers=0, pin_memory=True) timer_epoch.s() aves_keys = ['tl', 'ta', 'vl', 'va'] if eval_fs: for n_shot in n_shots: aves_keys += ['fsa-' + str(n_shot)] aves = {k: utils.Averager() for k in aves_keys} # train model.train() writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) for data, label in tqdm(train_loader, desc='train', leave=False): # for data, label in train_loader: data, label = data.cuda(), label.cuda() logits = model(data) loss = F.cross_entropy(logits, label) acc = utils.compute_acc(logits, label) optimizer.zero_grad() loss.backward() optimizer.step() aves['tl'].add(loss.item()) aves['ta'].add(acc) logits = None loss = None # eval if eval_val: model.eval() for data, label in tqdm(val_loader, desc='val', leave=False): data, label = data.cuda(), label.cuda() with torch.no_grad(): logits = model(data) loss = F.cross_entropy(logits, label) acc = utils.compute_acc(logits, label) aves['vl'].add(loss.item()) aves['va'].add(acc) if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1): fs_model.eval() for i, n_shot in enumerate(n_shots): np.random.seed(0) for data, _ in tqdm(fs_loaders[i], desc='fs-' + str(n_shot), leave=False): x_shot, x_query = fs.split_shot_query(data.cuda(), n_way, n_shot, n_query, ep_per_batch=4) label = fs.make_nk_label(n_way, n_query, ep_per_batch=4).cuda() with torch.no_grad(): logits = fs_model(x_shot, x_query).view(-1, n_way) acc = utils.compute_acc(logits, label) aves['fsa-' + str(n_shot)].add(acc) # post if lr_scheduler is not None: lr_scheduler.step() for k, v in aves.items(): aves[k] = v.item() t_epoch = utils.time_str(timer_epoch.t()) t_used = utils.time_str(timer_used.t()) t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch) if epoch <= max_epoch: epoch_str = str(epoch) else: epoch_str = 'ex' log_str = 'epoch {}, train {:.4f}|{:.4f}'.format( epoch_str, aves['tl'], aves['ta']) writer.add_scalars('loss', {'train': aves['tl']}, epoch) writer.add_scalars('acc', {'train': aves['ta']}, epoch) if eval_val: log_str += ', val {:.4f}|{:.4f}'.format(aves['vl'], aves['va']) writer.add_scalars('loss', {'val': aves['vl']}, epoch) writer.add_scalars('acc', {'val': aves['va']}, epoch) if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1): log_str += ', fs' for n_shot in n_shots: key = 'fsa-' + str(n_shot) log_str += ' {}: {:.4f}'.format(n_shot, aves[key]) writer.add_scalars('acc', {key: aves[key]}, epoch) if epoch <= max_epoch: log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate) else: log_str += ', {}'.format(t_epoch) utils.log(log_str) if config.get('_parallel'): model_ = model.module else: model_ = model training = { 'epoch': epoch, 'optimizer': config['optimizer'], 'optimizer_args': config['optimizer_args'], 'optimizer_sd': optimizer.state_dict(), } save_obj = { 'file': __file__, 'config': config, 'model': config['model'], 'model_args': config['model_args'], 'model_sd': model_.state_dict(), 'training': training, } if epoch <= max_epoch: torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth')) if (save_epoch is not None) and epoch % save_epoch == 0: torch.save( save_obj, os.path.join(save_path, 'epoch-{}.pth'.format(epoch))) if aves['va'] > max_va: max_va = aves['va'] torch.save(save_obj, os.path.join(save_path, 'max-va.pth')) else: torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth')) writer.flush()
def main(config): svname = args.name if svname is None: svname = 'meta_{}-{}shot'.format(config['train_dataset'], config['n_shot']) svname += '_' + config['model'] if config['model_args'].get('encoder'): svname += '-' + config['model_args']['encoder'] if config['model_args'].get('prog_synthesis'): svname += '-' + config['model_args']['prog_synthesis'] svname += '-seed' + str(args.seed) if args.tag is not None: svname += '_' + args.tag save_path = os.path.join(args.save_dir, svname) utils.ensure_path(save_path, remove=False) utils.set_log_path(save_path) writer = SummaryWriter(os.path.join(save_path, 'tensorboard')) yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w')) logger = utils.Logger(file_name=os.path.join(save_path, "log_sdout.txt"), file_mode="a+", should_flush=True) #### Dataset #### n_way, n_shot = config['n_way'], config['n_shot'] n_query = config['n_query'] if config.get('n_train_way') is not None: n_train_way = config['n_train_way'] else: n_train_way = n_way if config.get('n_train_shot') is not None: n_train_shot = config['n_train_shot'] else: n_train_shot = n_shot if config.get('ep_per_batch') is not None: ep_per_batch = config['ep_per_batch'] else: ep_per_batch = 1 random_state = np.random.RandomState(args.seed) print('seed:', args.seed) # train train_dataset = datasets.make(config['train_dataset'], **config['train_dataset_args']) utils.log('train dataset: {} (x{})'.format(train_dataset[0][0].shape, len(train_dataset))) if config.get('visualize_datasets'): utils.visualize_dataset(train_dataset, 'train_dataset', writer) train_sampler = BongardSampler(train_dataset.n_tasks, config['train_batches'], ep_per_batch, random_state.randint(2**31)) train_loader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=8, pin_memory=True) # tvals tval_loaders = {} tval_name_ntasks_dict = { 'tval': 2000, 'tval_ff': 600, 'tval_bd': 480, 'tval_hd_comb': 400, 'tval_hd_novel': 320 } # numbers depend on dataset for tval_type in tval_name_ntasks_dict.keys(): if config.get('{}_dataset'.format(tval_type)): tval_dataset = datasets.make( config['{}_dataset'.format(tval_type)], **config['{}_dataset_args'.format(tval_type)]) utils.log('{} dataset: {} (x{})'.format(tval_type, tval_dataset[0][0].shape, len(tval_dataset))) if config.get('visualize_datasets'): utils.visualize_dataset(tval_dataset, 'tval_ff_dataset', writer) tval_sampler = BongardSampler( tval_dataset.n_tasks, n_batch=tval_name_ntasks_dict[tval_type] // ep_per_batch, ep_per_batch=ep_per_batch, seed=random_state.randint(2**31)) tval_loader = DataLoader(tval_dataset, batch_sampler=tval_sampler, num_workers=8, pin_memory=True) tval_loaders.update({tval_type: tval_loader}) else: tval_loaders.update({tval_type: None}) # val val_dataset = datasets.make(config['val_dataset'], **config['val_dataset_args']) utils.log('val dataset: {} (x{})'.format(val_dataset[0][0].shape, len(val_dataset))) if config.get('visualize_datasets'): utils.visualize_dataset(val_dataset, 'val_dataset', writer) val_sampler = BongardSampler(val_dataset.n_tasks, n_batch=900 // ep_per_batch, ep_per_batch=ep_per_batch, seed=random_state.randint(2**31)) val_loader = DataLoader(val_dataset, batch_sampler=val_sampler, num_workers=8, pin_memory=True) ######## #### Model and optimizer #### if config.get('load'): print('loading pretrained model: ', config['load']) model = models.load(torch.load(config['load'])) else: model = models.make(config['model'], **config['model_args']) if config.get('load_encoder'): print('loading pretrained encoder: ', config['load_encoder']) encoder = models.load(torch.load(config['load_encoder'])).encoder model.encoder.load_state_dict(encoder.state_dict()) if config.get('load_prog_synthesis'): print('loading pretrained program synthesis model: ', config['load_prog_synthesis']) prog_synthesis = models.load( torch.load(config['load_prog_synthesis'])) model.prog_synthesis.load_state_dict(prog_synthesis.state_dict()) if config.get('_parallel'): model = nn.DataParallel(model) utils.log('num params: {}'.format(utils.compute_n_params(model))) optimizer, lr_scheduler = utils.make_optimizer(model.parameters(), config['optimizer'], **config['optimizer_args']) ######## max_epoch = config['max_epoch'] save_epoch = config.get('save_epoch') max_va = 0. timer_used = utils.Timer() timer_epoch = utils.Timer() aves_keys = ['tl', 'ta', 'vl', 'va'] tval_tuple_lst = [] for k, v in tval_loaders.items(): if v is not None: loss_key = 'tvl' + k.split('tval')[-1] acc_key = ' tva' + k.split('tval')[-1] aves_keys.append(loss_key) aves_keys.append(acc_key) tval_tuple_lst.append((k, v, loss_key, acc_key)) trlog = dict() for k in aves_keys: trlog[k] = [] for epoch in range(1, max_epoch + 1): timer_epoch.s() aves = {k: utils.Averager() for k in aves_keys} # train model.train() if config.get('freeze_bn'): utils.freeze_bn(model) writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) for data, label in tqdm(train_loader, desc='train', leave=False): x_shot, x_query = fs.split_shot_query(data.cuda(), n_train_way, n_train_shot, n_query, ep_per_batch=ep_per_batch) label_query = fs.make_nk_label(n_train_way, n_query, ep_per_batch=ep_per_batch).cuda() if config['model'] == 'snail': # only use one selected label_query query_dix = random_state.randint(n_train_way * n_query) label_query = label_query.view(ep_per_batch, -1)[:, query_dix] x_query = x_query[:, query_dix:query_dix + 1] if config['model'] == 'maml': # need grad in maml model.zero_grad() logits = model(x_shot, x_query).view(-1, n_train_way) loss = F.cross_entropy(logits, label_query) acc = utils.compute_acc(logits, label_query) optimizer.zero_grad() loss.backward() optimizer.step() aves['tl'].add(loss.item()) aves['ta'].add(acc) logits = None loss = None # eval model.eval() for name, loader, name_l, name_a in [('val', val_loader, 'vl', 'va') ] + tval_tuple_lst: if config.get('{}_dataset'.format(name)) is None: continue np.random.seed(0) for data, _ in tqdm(loader, desc=name, leave=False): x_shot, x_query = fs.split_shot_query( data.cuda(), n_way, n_shot, n_query, ep_per_batch=ep_per_batch) label_query = fs.make_nk_label( n_way, n_query, ep_per_batch=ep_per_batch).cuda() if config[ 'model'] == 'snail': # only use one randomly selected label_query query_dix = random_state.randint(n_train_way) label_query = label_query.view(ep_per_batch, -1)[:, query_dix] x_query = x_query[:, query_dix:query_dix + 1] if config['model'] == 'maml': # need grad in maml model.zero_grad() logits = model(x_shot, x_query, eval=True).view(-1, n_way) loss = F.cross_entropy(logits, label_query) acc = utils.compute_acc(logits, label_query) else: with torch.no_grad(): logits = model(x_shot, x_query, eval=True).view(-1, n_way) loss = F.cross_entropy(logits, label_query) acc = utils.compute_acc(logits, label_query) aves[name_l].add(loss.item()) aves[name_a].add(acc) # post if lr_scheduler is not None: lr_scheduler.step() for k, v in aves.items(): aves[k] = v.item() trlog[k].append(aves[k]) t_epoch = utils.time_str(timer_epoch.t()) t_used = utils.time_str(timer_used.t()) t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch) log_str = 'epoch {}, train {:.4f}|{:.4f}, val {:.4f}|{:.4f}'.format( epoch, aves['tl'], aves['ta'], aves['vl'], aves['va']) for tval_name, _, loss_key, acc_key in tval_tuple_lst: log_str += ', {} {:.4f}|{:.4f}'.format(tval_name, aves[loss_key], aves[acc_key]) writer.add_scalars('loss', {tval_name: aves[loss_key]}, epoch) writer.add_scalars('acc', {tval_name: aves[acc_key]}, epoch) log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate) utils.log(log_str) writer.add_scalars('loss', { 'train': aves['tl'], 'val': aves['vl'], }, epoch) writer.add_scalars('acc', { 'train': aves['ta'], 'val': aves['va'], }, epoch) if config.get('_parallel'): model_ = model.module else: model_ = model training = { 'epoch': epoch, 'optimizer': config['optimizer'], 'optimizer_args': config['optimizer_args'], 'optimizer_sd': optimizer.state_dict(), } save_obj = { 'file': __file__, 'config': config, 'model': config['model'], 'model_args': config['model_args'], 'model_sd': model_.state_dict(), 'training': training, } torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth')) torch.save(trlog, os.path.join(save_path, 'trlog.pth')) if (save_epoch is not None) and epoch % save_epoch == 0: torch.save(save_obj, os.path.join(save_path, 'epoch-{}.pth'.format(epoch))) if aves['va'] > max_va: max_va = aves['va'] torch.save(save_obj, os.path.join(save_path, 'max-va.pth')) writer.flush() print('finished training!') logger.close()
def main(config): svname = args.name if svname is None: svname = 'meta_{}-{}shot'.format( config['train_dataset'], config['n_shot']) svname += '_' + config['model'] + '-' + config['model_args']['encoder'] if args.tag is not None: svname += '_' + args.tag save_path = os.path.join('./save', svname) utils.ensure_path(save_path) utils.set_log_path(save_path) writer = SummaryWriter(os.path.join(save_path, 'tensorboard')) yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w')) #### Dataset #### n_way, n_shot = config['n_way'], config['n_shot'] n_query = config['n_query'] if config.get('n_train_way') is not None: n_train_way = config['n_train_way'] else: n_train_way = n_way if config.get('n_train_shot') is not None: n_train_shot = config['n_train_shot'] else: n_train_shot = n_shot if config.get('ep_per_batch') is not None: ep_per_batch = config['ep_per_batch'] else: ep_per_batch = 1 # train train_dataset = datasets.make(config['train_dataset'], **config['train_dataset_args']) utils.log('train dataset: {} (x{}), {}'.format( train_dataset[0][0].shape, len(train_dataset), train_dataset.n_classes)) if config.get('visualize_datasets'): utils.visualize_dataset(train_dataset, 'train_dataset', writer) train_sampler = CategoriesSampler( train_dataset.label, config['train_batches'], n_train_way, n_train_shot + n_query, ep_per_batch=ep_per_batch) train_loader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=8, pin_memory=True) # tval if config.get('tval_dataset'): tval_dataset = datasets.make(config['tval_dataset'], **config['tval_dataset_args']) utils.log('tval dataset: {} (x{}), {}'.format( tval_dataset[0][0].shape, len(tval_dataset), tval_dataset.n_classes)) if config.get('visualize_datasets'): utils.visualize_dataset(tval_dataset, 'tval_dataset', writer) tval_sampler = CategoriesSampler( tval_dataset.label, 200, n_way, n_shot + n_query, ep_per_batch=4) tval_loader = DataLoader(tval_dataset, batch_sampler=tval_sampler, num_workers=8, pin_memory=True) else: tval_loader = None # val val_dataset = datasets.make(config['val_dataset'], **config['val_dataset_args']) utils.log('val dataset: {} (x{}), {}'.format( val_dataset[0][0].shape, len(val_dataset), val_dataset.n_classes)) if config.get('visualize_datasets'): utils.visualize_dataset(val_dataset, 'val_dataset', writer) val_sampler = CategoriesSampler( val_dataset.label, 200, n_way, n_shot + n_query, ep_per_batch=4) val_loader = DataLoader(val_dataset, batch_sampler=val_sampler, num_workers=8, pin_memory=True) ######## #### Model and optimizer #### if config.get('load'): model_sv = torch.load(config['load']) model = models.load(model_sv) else: model = models.make(config['model'], **config['model_args']) if config.get('load_encoder'): encoder = models.load(torch.load(config['load_encoder'])).encoder model.encoder.load_state_dict(encoder.state_dict()) if config.get('_parallel'): model = nn.DataParallel(model) utils.log('num params: {}'.format(utils.compute_n_params(model))) optimizer, lr_scheduler = utils.make_optimizer( model.parameters(), config['optimizer'], **config['optimizer_args']) ######## max_epoch = config['max_epoch'] save_epoch = config.get('save_epoch') max_va = 0. timer_used = utils.Timer() timer_epoch = utils.Timer() aves_keys = ['tl', 'ta', 'tvl', 'tva', 'vl', 'va'] trlog = dict() for k in aves_keys: trlog[k] = [] for epoch in range(1, max_epoch + 1): timer_epoch.s() aves = {k: utils.Averager() for k in aves_keys} # train model.train() if config.get('freeze_bn'): utils.freeze_bn(model) writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) np.random.seed(epoch) for data, _ in tqdm(train_loader, desc='train', leave=False): x_shot, x_query = fs.split_shot_query( data.cuda(), n_train_way, n_train_shot, n_query, ep_per_batch=ep_per_batch) label = fs.make_nk_label(n_train_way, n_query, ep_per_batch=ep_per_batch).cuda() logits = model(x_shot, x_query).view(-1, n_train_way) loss = F.cross_entropy(logits, label) acc = utils.compute_acc(logits, label) optimizer.zero_grad() loss.backward() optimizer.step() aves['tl'].add(loss.item()) aves['ta'].add(acc) logits = None; loss = None # eval model.eval() for name, loader, name_l, name_a in [ ('tval', tval_loader, 'tvl', 'tva'), ('val', val_loader, 'vl', 'va')]: if (config.get('tval_dataset') is None) and name == 'tval': continue np.random.seed(0) for data, _ in tqdm(loader, desc=name, leave=False): x_shot, x_query = fs.split_shot_query( data.cuda(), n_way, n_shot, n_query, ep_per_batch=4) label = fs.make_nk_label(n_way, n_query, ep_per_batch=4).cuda() with torch.no_grad(): logits = model(x_shot, x_query).view(-1, n_way) loss = F.cross_entropy(logits, label) acc = utils.compute_acc(logits, label) aves[name_l].add(loss.item()) aves[name_a].add(acc) _sig = int(_[-1]) # post if lr_scheduler is not None: lr_scheduler.step() for k, v in aves.items(): aves[k] = v.item() trlog[k].append(aves[k]) t_epoch = utils.time_str(timer_epoch.t()) t_used = utils.time_str(timer_used.t()) t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch) utils.log('epoch {}, train {:.4f}|{:.4f}, tval {:.4f}|{:.4f}, ' 'val {:.4f}|{:.4f}, {} {}/{} (@{})'.format( epoch, aves['tl'], aves['ta'], aves['tvl'], aves['tva'], aves['vl'], aves['va'], t_epoch, t_used, t_estimate, _sig)) writer.add_scalars('loss', { 'train': aves['tl'], 'tval': aves['tvl'], 'val': aves['vl'], }, epoch) writer.add_scalars('acc', { 'train': aves['ta'], 'tval': aves['tva'], 'val': aves['va'], }, epoch) if config.get('_parallel'): model_ = model.module else: model_ = model training = { 'epoch': epoch, 'optimizer': config['optimizer'], 'optimizer_args': config['optimizer_args'], 'optimizer_sd': optimizer.state_dict(), } save_obj = { 'file': __file__, 'config': config, 'model': config['model'], 'model_args': config['model_args'], 'model_sd': model_.state_dict(), 'training': training, } torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth')) torch.save(trlog, os.path.join(save_path, 'trlog.pth')) if (save_epoch is not None) and epoch % save_epoch == 0: torch.save(save_obj, os.path.join(save_path, 'epoch-{}.pth'.format(epoch))) if aves['va'] > max_va: max_va = aves['va'] torch.save(save_obj, os.path.join(save_path, 'max-va.pth')) writer.flush()
def main(config): # dataset dataset = datasets.make(config['dataset'], **config['dataset_args']) utils.log('dataset: {} (x{}), {}'.format(dataset[0][0].shape, len(dataset), dataset.n_classes)) if not args.sauc: n_way = 5 else: n_way = 2 n_shot, n_unlabel, n_query = args.shot, 30, 15 n_batch = 200 ep_per_batch = 4 batch_sampler = CategoriesSampler_Semi(dataset.label, n_batch, n_way, n_shot, n_unlabel, n_query, ep_per_batch=ep_per_batch) loader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=8, pin_memory=True) # model if config.get('load') is None: model = models.make('meta-baseline', encoder=None) else: model = models.load(torch.load(config['load'])) if config.get('load_encoder') is not None: encoder = models.load(torch.load(config['load_encoder'])).encoder model.encoder = encoder if config.get('_parallel'): model = nn.DataParallel(model) model.eval() utils.log('num params: {}'.format(utils.compute_n_params(model))) # testing aves_keys = ['vl', 'va'] aves = {k: utils.Averager() for k in aves_keys} test_epochs = args.test_epochs np.random.seed(0) va_lst = [] for epoch in range(1, test_epochs + 1): for data, _ in tqdm(loader, leave=False): x_shot, x_unlabel, x_query = fs.split_shot_query_semi( data.cuda(), n_way, n_shot, n_unlabel, n_query, ep_per_batch=ep_per_batch) with torch.no_grad(): if not args.sauc: logits = model(x_shot, x_unlabel, x_query).view(-1, n_way) label = fs.make_nk_label(n_way, n_query, ep_per_batch=ep_per_batch).cuda() loss = F.cross_entropy(logits, label) acc = utils.compute_acc(logits, label) aves['vl'].add(loss.item(), len(data)) aves['va'].add(acc, len(data)) va_lst.append(acc) else: x_shot = x_shot[:, 0, :, :, :, :].contiguous() shot_shape = x_shot.shape[:-3] img_shape = x_shot.shape[-3:] bs = shot_shape[0] p = model.encoder(x_shot.view(-1, *img_shape)).reshape( *shot_shape, -1).mean(dim=1, keepdim=True) q = model.encoder(x_query.view(-1, *img_shape)).view( bs, -1, p.shape[-1]) p = F.normalize(p, dim=-1) q = F.normalize(q, dim=-1) s = torch.bmm(q, p.transpose(2, 1)).view(bs, -1).cpu() for i in range(bs): k = s.shape[1] // 2 y_true = [1] * k + [0] * k acc = roc_auc_score(y_true, s[i]) aves['va'].add(acc, len(data)) va_lst.append(acc) print('test epoch {}: acc={:.2f} +- {:.2f} (%), loss={:.4f} (@{})'. format(epoch, aves['va'].item() * 100, mean_confidence_interval(va_lst) * 100, aves['vl'].item(), _[-1]))
def main(config): svname = config.get('sv_name') if args.tag is not None: svname += '_' + args.tag config['sv_name'] = svname save_path = os.path.join('./save', svname) utils.ensure_path(save_path) utils.set_log_path(save_path) utils.log(svname) writer = SummaryWriter(os.path.join(save_path, 'tensorboard')) yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w')) #### Dataset #### n_way, n_shot = config['n_way'], config['n_shot'] n_query = config['n_query'] n_pseudo = config['n_pseudo'] ep_per_batch = config['ep_per_batch'] if config.get('test_batches') is not None: test_batches = config['test_batches'] else: test_batches = config['train_batches'] for s in ['train', 'val', 'tval']: if config.get(f"{s}_dataset_args") is not None: config[f"{s}_dataset_args"]['data_dir'] = os.path.join(os.getcwd(), os.pardir, 'data_root') # train train_dataset = CustomDataset(config['train_dataset'], save_dir=config.get('load_encoder'), **config['train_dataset_args']) if config['train_dataset_args']['split'] == 'helper': with open(os.path.join(save_path, 'train_helper_cls.pkl'), 'wb') as f: pkl.dump(train_dataset.dataset_classes, f) train_sampler = EpisodicSampler(train_dataset, config['train_batches'], n_way, n_shot, n_query, n_pseudo, episodes_per_batch=ep_per_batch) train_loader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=4, pin_memory=True) # tval if config.get('tval_dataset'): tval_dataset = CustomDataset(config['tval_dataset'], **config['tval_dataset_args']) tval_sampler = EpisodicSampler(tval_dataset, test_batches, n_way, n_shot, n_query, n_pseudo, episodes_per_batch=ep_per_batch) tval_loader = DataLoader(tval_dataset, batch_sampler=tval_sampler, num_workers=4, pin_memory=True) else: tval_loader = None # val val_dataset = CustomDataset(config['val_dataset'], **config['val_dataset_args']) val_sampler = EpisodicSampler(val_dataset, test_batches, n_way, n_shot, n_query, n_pseudo, episodes_per_batch=ep_per_batch) val_loader = DataLoader(val_dataset, batch_sampler=val_sampler, num_workers=4, pin_memory=True) #### Model and optimizer #### if config.get('load'): model_sv = torch.load(config['load']) model = models.load(model_sv) else: model = models.make(config['model'], **config['model_args']) if config.get('load_encoder'): encoder = models.load(torch.load(config['load_encoder'])).encoder model.encoder.load_state_dict(encoder.state_dict()) if config.get('freeze_encoder'): for param in model.encoder.parameters(): param.requires_grad = False if config.get('_parallel'): model = nn.DataParallel(model) utils.log('num params: {}'.format(utils.compute_n_params(model))) optimizer, lr_scheduler = utils.make_optimizer( model.parameters(), config['optimizer'], **config['optimizer_args']) ######## max_epoch = config['max_epoch'] save_epoch = config.get('save_epoch') max_va = 0. timer_used = utils.Timer() timer_epoch = utils.Timer() aves_keys = ['tl', 'ta', 'tvl', 'tva', 'vl', 'va'] trlog = dict() for k in aves_keys: trlog[k] = [] for epoch in range(1, max_epoch + 1): timer_epoch.s() aves = {k: utils.Averager() for k in aves_keys} # train model.train() if config.get('freeze_bn'): utils.freeze_bn(model) writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) np.random.seed(epoch) for data in tqdm(train_loader, desc='train', leave=False): x_shot, x_query, x_pseudo = fs.split_shot_query( data.cuda(), n_way, n_shot, n_query, n_pseudo, ep_per_batch=ep_per_batch) label = fs.make_nk_label(n_way, n_query, ep_per_batch=ep_per_batch).cuda() logits = model(x_shot, x_query, x_pseudo) logits = logits.view(-1, n_way) loss = F.cross_entropy(logits, label) acc = utils.compute_acc(logits, label) optimizer.zero_grad() loss.backward() optimizer.step() aves['tl'].add(loss.item()) aves['ta'].add(acc) logits = None; loss = None # eval model.eval() for name, loader, name_l, name_a in [ ('tval', tval_loader, 'tvl', 'tva'), ('val', val_loader, 'vl', 'va')]: if (config.get('tval_dataset') is None) and name == 'tval': continue np.random.seed(0) for data in tqdm(loader, desc=name, leave=False): x_shot, x_query, x_pseudo = fs.split_shot_query( data.cuda(), n_way, n_shot, n_query, n_pseudo, ep_per_batch=ep_per_batch) label = fs.make_nk_label(n_way, n_query, ep_per_batch=ep_per_batch).cuda() with torch.no_grad(): logits = model(x_shot, x_query, x_pseudo) logits = logits.view(-1, n_way) loss = F.cross_entropy(logits, label) acc = utils.compute_acc(logits, label) aves[name_l].add(loss.item()) aves[name_a].add(acc) # post if lr_scheduler is not None: lr_scheduler.step() for k, v in aves.items(): aves[k] = v.item() trlog[k].append(aves[k]) t_epoch = utils.time_str(timer_epoch.t()) t_used = utils.time_str(timer_used.t()) t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch) utils.log('epoch {}, train {:.4f}|{:.4f}, tval {:.4f}|{:.4f}, ' 'val {:.4f}|{:.4f}, {} {}/{}'.format( epoch, aves['tl'], aves['ta'], aves['tvl'], aves['tva'], aves['vl'], aves['va'], t_epoch, t_used, t_estimate)) writer.add_scalars('loss', { 'train': aves['tl'], 'tval': aves['tvl'], 'val': aves['vl'], }, epoch) writer.add_scalars('acc', { 'train': aves['ta'], 'tval': aves['tva'], 'val': aves['va'], }, epoch) if config.get('_parallel'): model_ = model.module else: model_ = model training = { 'epoch': epoch, 'optimizer': config['optimizer'], 'optimizer_args': config['optimizer_args'], 'optimizer_sd': optimizer.state_dict(), } save_obj = { 'file': __file__, 'config': config, 'model': config['model'], 'model_args': config['model_args'], 'model_sd': model_.state_dict(), 'training': training, } torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth')) torch.save(trlog, os.path.join(save_path, 'trlog.pth')) if (save_epoch is not None) and epoch % save_epoch == 0: torch.save(save_obj, os.path.join(save_path, 'epoch-{}.pth'.format(epoch))) if aves['va'] > max_va: max_va = aves['va'] torch.save(save_obj, os.path.join(save_path, 'max-va.pth')) writer.flush()
def main(config): svname = args.name if svname is None: svname = 'moco_{}'.format(config['train_dataset']) svname += '_' + config['model_args']['encoder'] out_dim = config['model_args']['encoder_args']['out_dim'] svname += '-out_dim' + str(out_dim) svname += '-seed' + str(args.seed) if args.tag is not None: svname += '_' + args.tag save_path = os.path.join(args.save_dir, svname) utils.ensure_path(save_path, remove=False) utils.set_log_path(save_path) writer = SummaryWriter(os.path.join(save_path, 'tensorboard')) yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w')) random_state = np.random.RandomState(args.seed) print('seed:', args.seed) logger = utils.Logger(file_name=os.path.join(save_path, "log_sdout.txt"), file_mode="a+", should_flush=True) #### Dataset #### # train train_dataset = datasets.make(config['train_dataset'], **config['train_dataset_args']) train_loader = DataLoader(train_dataset, config['batch_size'], shuffle=True, num_workers=8, pin_memory=True, drop_last=True) utils.log('train dataset: {} (x{})'.format(train_dataset[0][0][0].shape, len(train_dataset))) if config.get('visualize_datasets'): utils.visualize_dataset(train_dataset, 'train_dataset', writer) # val if config.get('val_dataset'): eval_val = True val_dataset = datasets.make(config['val_dataset'], **config['val_dataset_args']) val_loader = DataLoader(val_dataset, config['batch_size'], num_workers=8, pin_memory=True, drop_last=True) utils.log('val dataset: {} (x{})'.format(val_dataset[0][0][0].shape, len(val_dataset))) if config.get('visualize_datasets'): utils.visualize_dataset(val_dataset, 'val_dataset', writer) else: eval_val = False # few-shot eval if config.get('eval_fs'): ef_epoch = config.get('eval_fs_epoch') if ef_epoch is None: ef_epoch = 5 eval_fs = True n_way = 2 n_query = 1 n_shot = 6 if config.get('ep_per_batch') is not None: ep_per_batch = config['ep_per_batch'] else: ep_per_batch = 1 # tvals fs_loaders = {} tval_name_ntasks_dict = { 'tval': 2000, 'tval_ff': 600, 'tval_bd': 480, 'tval_hd_comb': 400, 'tval_hd_novel': 320 } # numbers depend on dataset for tval_type in tval_name_ntasks_dict.keys(): if config.get('{}_dataset'.format(tval_type)): tval_dataset = datasets.make( config['{}_dataset'.format(tval_type)], **config['{}_dataset_args'.format(tval_type)]) utils.log('{} dataset: {} (x{})'.format( tval_type, tval_dataset[0][0][0].shape, len(tval_dataset))) if config.get('visualize_datasets'): utils.visualize_dataset(tval_dataset, 'tval_ff_dataset', writer) tval_sampler = BongardSampler( tval_dataset.n_tasks, n_batch=tval_name_ntasks_dict[tval_type] // ep_per_batch, ep_per_batch=ep_per_batch, seed=random_state.randint(2**31)) tval_loader = DataLoader(tval_dataset, batch_sampler=tval_sampler, num_workers=8, pin_memory=True) fs_loaders.update({tval_type: tval_loader}) else: fs_loaders.update({tval_type: None}) else: eval_fs = False ######## #### Model and Optimizer #### if config.get('load'): model_sv = torch.load(config['load']) model = models.load(model_sv) else: model = models.make(config['model'], **config['model_args']) if eval_fs: fs_model = models.make('meta-baseline', encoder=None) fs_model.encoder = model.encoder if config.get('_parallel'): model = nn.DataParallel(model) if eval_fs: fs_model = nn.DataParallel(fs_model) utils.log('num params: {}'.format(utils.compute_n_params(model))) optimizer, lr_scheduler = utils.make_optimizer(model.parameters(), config['optimizer'], **config['optimizer_args']) ######## max_epoch = config['max_epoch'] save_epoch = config.get('save_epoch') max_va = 0. timer_used = utils.Timer() timer_epoch = utils.Timer() for epoch in range(1, max_epoch + 1 + 1): timer_epoch.s() aves_keys = ['tl', 'ta', 'vl', 'va', 'tvl', 'tva'] if eval_fs: for k, v in fs_loaders.items(): if v is not None: aves_keys += ['fsa' + k.split('tval')[-1]] aves = {ave_k: utils.Averager() for ave_k in aves_keys} # train model.train() writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) for data, _ in tqdm(train_loader, desc='train', leave=False): logits, label = model(im_q=data[0].cuda(), im_k=data[1].cuda()) loss = F.cross_entropy(logits, label) acc = utils.compute_acc(logits, label) optimizer.zero_grad() loss.backward() optimizer.step() aves['tl'].add(loss.item()) aves['ta'].add(acc) logits = None loss = None # val if eval_val: model.eval() for data, _ in tqdm(val_loader, desc='val', leave=False): with torch.no_grad(): logits, label = model(im_q=data[0].cuda(), im_k=data[1].cuda()) loss = F.cross_entropy(logits, label) acc = utils.compute_acc(logits, label) aves['vl'].add(loss.item()) aves['va'].add(acc) if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1): fs_model.eval() for k, v in fs_loaders.items(): if v is not None: ave_key = 'fsa' + k.split('tval')[-1] np.random.seed(0) for data, _ in tqdm(v, desc=ave_key, leave=False): x_shot, x_query = fs.split_shot_query( data[0].cuda(), n_way, n_shot, n_query, ep_per_batch=ep_per_batch) label_query = fs.make_nk_label( n_way, n_query, ep_per_batch=ep_per_batch).cuda() with torch.no_grad(): logits = fs_model(x_shot, x_query).view(-1, n_way) acc = utils.compute_acc(logits, label_query) aves[ave_key].add(acc) # post if lr_scheduler is not None: lr_scheduler.step() for k, v in aves.items(): aves[k] = v.item() t_epoch = utils.time_str(timer_epoch.t()) t_used = utils.time_str(timer_used.t()) t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch) if epoch <= max_epoch: epoch_str = str(epoch) else: epoch_str = 'ex' log_str = 'epoch {}, train {:.4f}|{:.4f}'.format( epoch_str, aves['tl'], aves['ta']) writer.add_scalars('loss', {'train': aves['tl']}, epoch) writer.add_scalars('acc', {'train': aves['ta']}, epoch) if eval_val: log_str += ', val {:.4f}|{:.4f}, tval {:.4f}|{:.4f}'.format( aves['vl'], aves['va'], aves['tvl'], aves['tva']) writer.add_scalars('loss', {'val': aves['vl']}, epoch) writer.add_scalars('loss', {'tval': aves['tvl']}, epoch) writer.add_scalars('acc', {'val': aves['va']}, epoch) writer.add_scalars('acc', {'tval': aves['tva']}, epoch) if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1): log_str += ', fs' for ave_key in aves_keys: if 'fsa' in ave_key: log_str += ' {}: {:.4f}'.format(ave_key, aves[ave_key]) writer.add_scalars('acc', {ave_key: aves[ave_key]}, epoch) if epoch <= max_epoch: log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate) else: log_str += ', {}'.format(t_epoch) utils.log(log_str) if config.get('_parallel'): model_ = model.module else: model_ = model training = { 'epoch': epoch, 'optimizer': config['optimizer'], 'optimizer_args': config['optimizer_args'], 'optimizer_sd': optimizer.state_dict(), } save_obj = { 'file': __file__, 'config': config, 'model': config['model'], 'model_args': config['model_args'], 'model_sd': model_.state_dict(), 'training': training, } if epoch <= max_epoch: torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth')) if (save_epoch is not None) and epoch % save_epoch == 0: torch.save( save_obj, os.path.join(save_path, 'epoch-{}.pth'.format(epoch))) if aves['va'] > max_va: max_va = aves['va'] torch.save(save_obj, os.path.join(save_path, 'max-va.pth')) else: torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth')) writer.flush() print('finished training!') logger.close()
def main(config): config['dataset_args']['data_dir'] = os.path.join(os.getcwd(), os.pardir, 'data_root') dataset = CustomDataset(name=config['dataset'], **config['dataset_args']) n_way = 5 n_shot = config['n_shot'] n_query = config.get('n_query') if config.get( 'n_query') is not None else 15 n_pseudo = config['n_pseudo'] if config.get('n_pseudo') is not None else 15 n_batch = config['train_batches'] if config.get( 'train_batches') is not None else 200 ep_per_batch = config['ep_per_batch'] if config.get( 'ep_per_batch') is not None else 4 batch_sampler = EpisodicSampler(dataset, n_batch, n_way, n_shot, n_query, n_pseudo, episodes_per_batch=ep_per_batch) loader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=4, pin_memory=True) model_sv = torch.load(config['load']) model = models.load(model_sv) if config.get('fs_dataset'): fs_model = models.make('meta-baseline', encoder=None) fs_model.encoder = model.encoder model = fs_model if config.get('_parallel'): model = nn.DataParallel(model) model.eval() # testing aves_keys = ['vl', 'va'] aves = {k: utils.Averager() for k in aves_keys} test_epochs = args.test_epochs np.random.seed(0) va_lst = [] for epoch in range(1, test_epochs + 1): for data in tqdm(loader, desc=f"eval: {epoch}", leave=False): x_shot, x_query, x_pseudo = fs.split_shot_query( data.cuda(), n_way, n_shot, n_query, n_pseudo, ep_per_batch=ep_per_batch) with torch.no_grad(): logits = model(x_shot, x_query, x_pseudo) logits = logits.view(-1, n_way) label = fs.make_nk_label(n_way, n_query, ep_per_batch=ep_per_batch).cuda() loss = F.cross_entropy(logits, label) acc = utils.compute_acc(logits, label) aves['vl'].add(loss.item(), len(data)) aves['va'].add(acc, len(data)) va_lst.append(acc) utils.log( 'test epoch {}: acc={:.2f} +- {:.2f} (%), loss={:.4f}'.format( epoch, aves['va'].item() * 100, mean_confidence_interval(va_lst) * 100, aves['vl'].item()), filename='test_log.txt')
def main(config): svname = args.name if svname is None: svname = f"classifier-{config['train_dataset']}-{config['model_args']['encoder']}" clsfr = config['model_args']['classifier'] if clsfr != 'linear-classifier': svname += '-' + clsfr svname += '-aux' + str(args.aux_level) if args.topk is not None: svname += f"-top{args.topk}" if args.tag is not None: svname += '_' + args.tag save_path = os.path.join('./save', svname) utils.ensure_path(save_path) utils.set_log_path(save_path) writer = SummaryWriter(os.path.join(save_path, 'tensorboard')) yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w')) #### Dataset #### for s in ['train', 'val', 'tval', 'fs', 'fs_val']: if config.get(f"{s}_dataset_args") is not None: config[f"{s}_dataset_args"]['data_dir'] = os.path.join( os.getcwd(), os.pardir, 'data_root') # train train_dataset = TrainDataset(name=config['train_dataset'], **config['train_dataset_args']) train_loader = DataLoader(train_dataset, config['batch_size'], shuffle=True, num_workers=16, pin_memory=True, drop_last=True) with open(os.path.join(save_path, 'training_classes.pkl'), 'wb') as f: pkl.dump(train_dataset.separated_training_classes, f) # val if config.get('val_dataset'): eval_val = True val_dataset = TrainDataset(config['val_dataset'], **config['val_dataset_args']) val_loader = DataLoader(val_dataset, config['batch_size'], num_workers=16, pin_memory=True, drop_last=True) else: eval_val = False # few-shot eval fs_loaders = {'fs_dataset': list(), 'fs_val_dataset': list()} for key in fs_loaders.keys(): if config.get(key): ef_epoch = config.get('eval_fs_epoch') if ef_epoch is None: ef_epoch = 5 eval_fs = True fs_dataset = CustomDataset(config[key], **config[key + '_args']) n_way = config['n_way'] if config.get('n_way') else 5 n_query = config['n_query'] if config.get('n_query') else 15 if config.get('n_pseudo') is not None: n_pseudo = config['n_pseudo'] else: n_pseudo = 15 n_batches = config['n_batches'] if config.get('n_batches') else 200 ep_per_batch = config['ep_per_batch'] if config.get( 'ep_per_batch') else 4 n_shots = [1, 5] for n_shot in n_shots: fs_sampler = EpisodicSampler(fs_dataset, n_batches, n_way, n_shot, n_query, n_pseudo, episodes_per_batch=ep_per_batch) fs_loader = DataLoader(fs_dataset, batch_sampler=fs_sampler, num_workers=16, pin_memory=True) fs_loaders[key].append(fs_loader) else: eval_fs = False eval_fs = False for key in fs_loaders.keys(): if config.get(key): eval_fs = True #### Model and Optimizer #### config['model_args']['classifier_args'][ 'n_classes'] = train_dataset.n_classes model = models.make(config['model'], **config['model_args']) if eval_fs: fs_model = models.make('meta-baseline', encoder=None) fs_model.encoder = model.encoder if config.get('_parallel'): model = nn.DataParallel(model) if eval_fs: fs_model = nn.DataParallel(fs_model) utils.log('num params: {}'.format(utils.compute_n_params(model))) optimizer, lr_scheduler = utils.make_optimizer(model.parameters(), config['optimizer'], **config['optimizer_args']) ######## max_epoch = config['max_epoch'] save_epoch = config.get('save_epoch') max_va = 0. timer_used = utils.Timer() timer_epoch = utils.Timer() for epoch in range(1, max_epoch + 1 + 1): timer_epoch.s() aves_keys = ['tl', 'ta', 'vl', 'va'] if eval_fs: for n_shot in n_shots: aves_keys += ['fsa-' + str(n_shot)] if config.get('fs_val_dataset'): aves_keys += ['fsav-' + str(n_shot)] aves = {k: utils.Averager() for k in aves_keys} # train model.train() writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) for data, label in tqdm(train_loader, desc='train', leave=False): data, label = data.cuda(), label.cuda() logits = model(data) loss = F.cross_entropy(logits, label) acc = utils.compute_acc(logits, label) optimizer.zero_grad() loss.backward() optimizer.step() aves['tl'].add(loss.item()) aves['ta'].add(acc) logits = None loss = None # eval if eval_val: model.eval() for data, label in tqdm(val_loader, desc='val', leave=False): data, label = data.cuda(), label.cuda() with torch.no_grad(): logits = model(data) loss = F.cross_entropy(logits, label) acc = utils.compute_acc(logits, label) aves['vl'].add(loss.item()) aves['va'].add(acc) if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1): fs_model.eval() for key in fs_loaders.keys(): if len(fs_loaders[key]) == 0: continue tag = 'v' if key == 'fs_val_dataset' else '' for i, n_shot in enumerate(n_shots): np.random.seed(0) for data in tqdm(fs_loaders[key][i], desc='fs' + tag + '-' + str(n_shot), leave=False): x_shot, x_query, x_pseudo = fs.split_shot_query( data.cuda(), n_way, n_shot, n_query, pseudo=n_pseudo, ep_per_batch=ep_per_batch) label = fs.make_nk_label( n_way, n_query, ep_per_batch=ep_per_batch).cuda() with torch.no_grad(): logits = fs_model(x_shot, x_query, x_pseudo) logits = logits.view(-1, n_way) acc = utils.compute_acc(logits, label) aves['fsa' + tag + '-' + str(n_shot)].add(acc) # post if lr_scheduler is not None: lr_scheduler.step() for k, v in aves.items(): aves[k] = v.item() t_epoch = utils.time_str(timer_epoch.t()) t_used = utils.time_str(timer_used.t()) t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch) if epoch <= max_epoch: epoch_str = str(epoch) else: epoch_str = 'ex' log_str = 'epoch {}, train {:.4f}|{:.4f}'.format( epoch_str, aves['tl'], aves['ta']) writer.add_scalars('loss', {'train': aves['tl']}, epoch) writer.add_scalars('acc', {'train': aves['ta']}, epoch) if eval_val: log_str += ', val {:.4f}|{:.4f}'.format(aves['vl'], aves['va']) writer.add_scalars('loss', {'val': aves['vl']}, epoch) writer.add_scalars('acc', {'val': aves['va']}, epoch) if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1): for key in fs_loaders.keys(): if len(fs_loaders[key]) == 0: continue tag = 'v' if key == 'fs_val_dataset' else '' log_str += ', fs' + tag for n_shot in n_shots: key = 'fsa' + tag + '-' + str(n_shot) log_str += ' {}: {:.4f}'.format(n_shot, aves[key]) writer.add_scalars('acc', {key: aves[key]}, epoch) if epoch <= max_epoch: log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate) else: log_str += ', {}'.format(t_epoch) utils.log(log_str) if config.get('_parallel'): model_ = model.module else: model_ = model training = { 'epoch': epoch, 'optimizer': config['optimizer'], 'optimizer_args': config['optimizer_args'], 'optimizer_sd': optimizer.state_dict(), } save_obj = { 'file': __file__, 'config': config, 'model': config['model'], 'model_args': config['model_args'], 'model_sd': model_.state_dict(), 'training': training, } if epoch <= max_epoch: torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth')) if (save_epoch is not None) and epoch % save_epoch == 0: torch.save( save_obj, os.path.join(save_path, 'epoch-{}.pth'.format(epoch))) if aves['va'] > max_va: max_va = aves['va'] torch.save(save_obj, os.path.join(save_path, 'max-va.pth')) else: torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth')) writer.flush()