def main(): if args.gan_path is None: both = False else: both = True if args.validation: train_loader = torch.utils.data.DataLoader(AttributeDataset( args.data_dir, args.dataset, features_path=args.gan_path, mode='train', both=both, normalize=args.normalize, sentences=args.sentences), batch_size=args.batch_size, shuffle=True) val_seen_loader = torch.utils.data.DataLoader( AttributeDataset(args.data_dir, args.dataset, features_path=args.gan_path, mode='val_seen', generalized=True, normalize=args.normalize, sentences=args.sentences), batch_size=args.batch_size, shuffle=False) val_unseen_loader = torch.utils.data.DataLoader( AttributeDataset(args.data_dir, args.dataset, features_path=args.gan_path, mode='val_unseen', generalized=True, normalize=args.normalize, sentences=args.sentences), batch_size=args.batch_size, shuffle=False) else: trainval_loader = torch.utils.data.DataLoader( AttributeDataset(args.data_dir, args.dataset, features_path=args.gan_path, mode='trainval', both=both, normalize=args.normalize, sentences=args.sentences), batch_size=args.batch_size, shuffle=True) test_seen_loader = torch.utils.data.DataLoader(AttributeDataset( args.data_dir, args.dataset, features_path=args.gan_path, mode='test_seen', generalized=True, normalize=args.normalize, sentences=args.sentences), batch_size=args.batch_size, shuffle=False) test_unseen_loader = torch.utils.data.DataLoader( AttributeDataset(args.data_dir, args.dataset, features_path=args.gan_path, mode='test_unseen', generalized=True, normalize=args.normalize, sentences=args.sentences), batch_size=args.batch_size, shuffle=False) # instanciate the models if args.mlp: mlp = MLP(args.dim_input, [args.nhidden * 2], args.nhidden) else: mlp = LinearProjection(args.dim_input, args.nhidden) embed = LinearProjection(args.nhidden, args.dim_embed) if args.sentences: cam_key = 'sentences' else: cam_key = 'emb' if args.validation: cam = torch.from_numpy(train_loader.dataset.data[cam_key].T) else: cam = torch.from_numpy(trainval_loader.dataset.data[cam_key].T) proxies = ProxyNet(args.n_classes, args.dim_embed, proxies=cam) model = Base(mlp, embed, proxies) criterion = ProxyLoss(temperature=args.temp) if args.cuda: mlp.cuda() embed.cuda() model.cuda() proxies.cuda() parameters_set = [] layers = [] for c in mlp.children(): if isinstance(c, nn.Linear) or isinstance(c, nn.ModuleList): layers.extend(list(c.parameters())) for c in embed.children(): if isinstance(c, nn.Linear): layers.extend(list(c.parameters())) parameters_set.append({'params': layers, 'lr': args.lr}) optimizer = optim.SGD(parameters_set, lr=args.lr, momentum=0.9, nesterov=True, weight_decay=5e-5) n_parameters = sum([p.data.nelement() for p in model.parameters()]) print(' + Number of params: {}'.format(n_parameters)) scheduler = CosineAnnealingLR(optimizer, args.epochs) best_acc = 0 print('Random results:') if args.validation: validate(val_seen_loader, val_unseen_loader, model, criterion) else: validate(test_seen_loader, test_unseen_loader, model, criterion) for epoch in range(args.start_epoch, args.epochs + 1): # update learning rate if args.lr_decay: scheduler.step() # train for one epoch if args.validation: train(train_loader, model, criterion, optimizer, epoch) validate(val_seen_loader, val_unseen_loader, model, criterion) else: train(trainval_loader, model, criterion, optimizer, epoch) validate(test_seen_loader, test_unseen_loader, model, criterion) # saving save_checkpoint({'epoch': epoch, 'state_dict': model.state_dict()}) print('\nFinal evaluation on last epoch model:') validate(test_seen_loader, test_unseen_loader, model, criterion)