train=False), batch_size=args.batch_size, shuffle=True) data_loaders['sup'] = train_loader_sup data_loaders['unsup'] = train_loader_unsup # how often would a supervised batch be encountered during inference periodic_interval_batches = 5 # number of unsupervised examples sup_num = len(data_loaders['sup']) * args.batch_size unsup_num = len(data_loaders['unsup']) * args.batch_size # setup the VAE model = DIVA(args).to(device) # setup the optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr) # init val_total_loss = [] val_class_err_d = [] val_class_err_y = [] best_loss = 1000. best_y_acc = 0. early_stopping_counter = 1 max_early_stopping = 100
torch.backends.cudnn.benchmark = False np.random.seed(args.seed) # Load supervised training train_loader = data_utils.DataLoader(MnistRotated(args.list_train_domains, args.list_test_domain, args.num_supervised, args.seed, './../dataset/', train=True), batch_size=args.batch_size, shuffle=True, **kwargs) # setup the VAE model = DIVA(args).to(device) # setup the optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr) best_loss = 1000. best_y_acc = 0. early_stopping_counter = 1 max_early_stopping = 100 # training loop print('\nStart training:', args) for epoch in range(1, args.epochs + 1): model.beta_d = min( [args.beta_d, args.beta_d * (epoch * 1.) / args.warmup])