def evaluate(model: TextCNN, data_path): print('evaluate ...') train_loader, test_loader, vocab = get_dataloader(data_path=data_path, bs=32, seq_len=50) train_y_true, train_y_pred = [], [] test_y_true, test_y_pred = [], [] model.eval() with torch.no_grad(): for batch in tqdm(test_loader): inputs, targets = batch.text, batch.label output = model(inputs) pred = torch.max(output.data, dim=1)[1].cpu().numpy().tolist() test_y_pred.extend(pred) test_y_true.extend(targets.data) for batch in tqdm(train_loader): inputs, targets = batch.text, batch.label output = model(inputs) pred = torch.max(output.data, dim=1)[1].cpu().numpy().tolist() train_y_pred.extend(pred) train_y_true.extend(targets.data) model.train() test_acc = metrics.accuracy_score(test_y_true, test_y_pred) test_f1 = metrics.f1_score(test_y_true, test_y_pred, average='macro') train_acc = metrics.accuracy_score(train_y_true, train_y_pred) train_f1 = metrics.f1_score(train_y_true, train_y_pred, average='macro') print(f'Train Accuracy: {train_acc}, F1-Score: {train_f1}') print(f'Test Accuracy: {test_acc}, F1-Score: {test_f1}') return train_acc, train_f1, test_acc, test_f1
def train(data_path): train_loader, test_loader, vocab = get_dataloader(data_path=data_path, bs=32, seq_len=50) model = TextCNN(ModelConfig()) print(model) config = TrainConfig() optimizer = optim.Adam(model.parameters(), lr=config.lr) criterion = nn.CrossEntropyLoss(ignore_index=1) # Ignoring <PAD> Token model.train() gs = 0 for epoch in tqdm(range(config.num_epochs)): for idx, batch in tqdm(enumerate(train_loader)): gs += 1 inputs, targets = batch.text, batch.label optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() if gs % 500 == 0: writer.add_scalar('train/loss', loss.item(), gs) print(f'{gs} loss : {loss.item()}') train_acc, train_f1, test_acc, test_f1 = evaluate(model, './rsc/data/') writer.add_scalar('train/acc', train_acc, epoch) writer.add_scalar('train/f1', train_f1, epoch) writer.add_scalar('test/acc', test_acc, epoch) writer.add_scalar('test/f1', test_f1, epoch)
def new_model(is_generator): ''' Creates a new model instance and initializes the respective fields in params. Keyword arguments: > params (dict) -- current state variable Returns: N/A ''' params = param_factory(is_generator=is_generator) print('You are initializing a new', bolded('generator') if is_generator else bolded('classifier') + '.') model_list = constants.GENERATORS if is_generator else constants.CLASSIFIERS # Name params['run_name'] = input('Please type the current model run name -> ') # Architecture. Slightly hacky - allows constants.py to enforce # which models are generators vs. classifiers. model_string = train_utils.input_from_list(model_list, 'model') if model_string == 'Classifier_A': params['model'] = models.Classifier_A() elif model_string == 'Classifier_B': params['model'] = models.Classifier_B() elif model_string == 'Classifier_C': params['model'] = models.Classifier_C() elif model_string == 'Classifier_D': params['model'] = models.Classifier_D() elif model_string == 'Classifier_E': params['model'] = models.Classifier_E() elif model_string == 'VANILLA_VAE': params['model'] = models.VAE() elif model_string == 'DEFENSE_VAE': params['model'] = models.Defense_VAE() else: raise Exception(model_string, 'does not exist as a model (yet)!') # Kaiming initialization for weights models.initialize_model(params['model']) # Setup other state variables for state_var in constants.SETUP_STATE_VARS: train_utils.store_user_choice(params, state_var) print() # Grabs dataloaders. TODO: Prompt for val split/randomize val indices params['train_dataloader'], params['val_dataloader'], params[ 'test_dataloader'] = get_dataloader(dataset_name=params['dataset'], batch_sz=params['batch_size'], num_threads=params['num_threads']) # Saves an initial copy if not os.path.isdir('models/' + model_type(params) + '/' + params['run_name'] + '/'): os.makedirs('models/' + model_type(params) + '/' + params['run_name'] + '/') train_utils.save_checkpoint(params, 0) return params
def test(network): test_loader = get_dataloader(cfg.DATASET.NAME, cfg.DATASET.PATH, 0, None, smoothing=cfg.DATASET.TARGET_SMOOTHING, normalize=cfg.DATASET.NORMALIZE, test=True) interface = Trainer(network, None, None, test_loader) interface.validate() accuracy = interface.val_accuracy[-1] logger.info(f'TEST Accuracy: {accuracy:.4f}')
def train(network): logger.info('Loading dataset...') train_loader, val_loader = get_dataloader(cfg.DATASET.NAME, cfg.DATASET.PATH, cfg.TRAINING.HOLDOUT, cfg.TRAINING.BATCH_SIZE, smoothing=cfg.DATASET.TARGET_SMOOTHING, normalize=cfg.DATASET.NORMALIZE) logger.info('Creating optimizer...') optimizer = Optimizer(network, cfg.TRAINING.LOSS, cfg.TRAINING.LR, cfg.TRAINING.LR_SCHEDULE, cfg.TRAINING.MOMENTUM, cfg.TRAINING.WEIGHT_DECAY) train_interface = Trainer(network, optimizer, train_loader, val_loader) train_interface.train(cfg.TRAINING.EPOCHS)
def main(): args = read_args() gan = DCGAN(device=args.device) if args.train_or_sample == "train": data_loader = get_dataloader(args.data, data_dir=args.data_dir, batch_size=args.batch_size) if args.resume_path != "": gan.load_model(dict_path=args.resume_path) gan.train(data_loader, epochs=100, log_dir=args.log_dir) elif args.train_or_sample == "sample": gan.load_model(args.model_path) save_n_samples(gan, args.samples_dir, args.n_samples)
def train(args): _make_results_dir() dataloader = get_dataloader(args.batch_size) testpoint = torch.Tensor(dataloader.dataset[0]) if not args.no_cuda: testpoint = testpoint.cuda() model = DenseVAE(args.nb_latents) model.train() if not args.no_cuda: model.cuda() optimizer = optim.Adagrad(model.parameters(), lr=args.eta) runloss, runkld = None, np.array([]) start_time = time.time() for epoch_nb in range(1, args.epochs + 1): for batch_idx, data in enumerate(dataloader): if not args.no_cuda: data = data.cuda() recon_batch, mu, logvar = model(data) kld, loss = _loss_function(recon_batch, data, mu, logvar, args.beta) # param update optimizer.zero_grad() loss.backward() optimizer.step() loss /= len(data) runloss = loss if not runloss else runloss*0.99 + loss*0.01 runkld = np.zeros(args.nb_latents) if not len(runkld) else runkld*0.99 + kld.data.cpu().numpy()*0.01 if not batch_idx % args.log_interval: print("Epoch {}, batch: {}/{} ({:.2f} s), loss: {:.2f}, kl: [{}]".format( epoch_nb, batch_idx, len(dataloader), time.time() - start_time, runloss, ", ".join("{:.2f}".format(kl) for kl in runkld))) start_time = time.time() if not batch_idx % args.save_interval: _traverse_latents(model, testpoint, args.nb_latents, epoch_nb, batch_idx) model.train()
def main(): """ """ args = get_args() np.random.seed(args.seed) torch.random.manual_seed(args.seed) gan = DCGAN() gan.load_model(dict_path=args.gan_model) vae = VAE() vae.load_model(dict_path=args.vae_model) # ---------------------------------------------------------------------------------- # first save some random samples from both the models and also from original dataset samples_dir = os.path.join(args.out_dir, "visual_samples/") os.makedirs(samples_dir, exist_ok=True) # # draw 3 8X8 grid of images from each of 3 sources for i in range(1, 4): # original svhn dataset samples svhn_data_loader = get_dataloader("svhn_train", batch_size=64) orig_imgs, _ = next(iter(svhn_data_loader)) save_image((orig_imgs * 0.5 + 0.5), samples_dir + f"orig_image_grid{i}.png") # gan samples gan_imgs = gan.sample(num_images=64) save_image(gan_imgs, samples_dir + f"gan_image_grid{i}.png") # gan samples vae_imgs = vae.sample(num_images=64) save_image(vae_imgs, samples_dir + f"vae_image_grid{i}.png") # ---------------------------------------------------------------------------------- # # next we want to see if the model has learned a disentangled representation in thelatent space disentg_dir = os.path.join(args.out_dir, "disentangled_repr/") os.makedirs(disentg_dir, exist_ok=True) imgs_per_row = 12 eps = 15 noise = torch.randn(imgs_per_row, 100) for tag, model in [("gan", gan), ("vae", vae)]: imgs_orig = model.sample(noise=noise) imgs_list = [imgs_orig, torch.zeros(imgs_per_row, 3, 32, 32)] interesting_dims = [14, 46, 51] if tag == "gan" else [12, 18, 70] # for i in tqdm(range(100)): for i in interesting_dims: noise_perturbed = noise.clone() noise_perturbed[:, i] += eps imgs_list.append(model.sample(noise=noise_perturbed)) imgs_joined = torch.cat(imgs_list, dim=0) save_image( imgs_joined, disentg_dir + f"{tag}_disentang_3dims_seed{args.seed}_eps{eps}.png", nrow=imgs_per_row, ) # ---------------------------------------------------------------------------------- # Compare between interpolations in the data space and in the latent space interpolations_dir = os.path.join(args.out_dir, "interpolations/") os.makedirs(interpolations_dir, exist_ok=True) z = torch.randn(2, 100) # two noises which will be interpolated alpha = torch.linspace(0.0, 1.0, 11) # .unsqueeze(1) # unsqueeze for mat-mul z_interpolations = torch.ger(alpha, z[0]) + torch.ger((1 - alpha), z[1]) alpha = alpha.view(-1, 1, 1, 1) # so as to broadcast across 3-dimensional images for tag, model in [("gan", gan), ("vae", vae)]: x = model.sample(noise=z) imgs_x_interpolations = alpha * x[0] + (1 - alpha) * x[1] imgs_z_interpolations = model.sample(noise=z_interpolations) imgs_joined = torch.cat([imgs_x_interpolations, imgs_z_interpolations], dim=0) save_image( imgs_joined, interpolations_dir + f"{tag}_interpolations_s{args.seed}.png", nrow=11, )
def main(args): # load data with open(args.data, "rb") as f: data = pickle.load(f) # get dataloader train_dataloader = get_dataloader(data, 'train', bsz=args.bsz, freq_dom=args.freq_dom) val_dataloader = get_dataloader(data, 'test', bsz=200, freq_dom=args.freq_dom) # build model if args.mdl == 'maxent': model = SimpleClassifier() elif args.mdl == 'conv': model = ConvClassifier() elif args.mdl == 'conv2d': model = Conv2dClassifier() if args.gpu: model = model.cuda() crit = nn.CrossEntropyLoss() # optimizer = optim.Adam(model.parameters(), lr=1e-6, betas=(0.9, 0.98), eps=1e-9) optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9) # train losses = 0 step = 0 val_acc = [] for epoch in range(args.max_epoch): for i, batch in enumerate(iter(train_dataloader)): x, y = batch[:, :-1], batch[:, -1].long() if args.mdl == "conv2d": x = x.view(-1, 129, 158) optimizer.zero_grad() out = model(x) loss = crit(out, y) loss.backward() optimizer.step() losses += loss.item() if step % args.log_every == 0: print('-epoch {:3} step {:5} train loss {:5}'.format( epoch, i, loss.item())) if step % args.eval_every == 0: model.eval() true_y, pred_y = [], [] for i, b in enumerate(val_dataloader): x, y = b[:, :-1], b[:, -1].long() if args.mdl == "conv2d": x = x.view(-1, 129, 158) true_y.append(y) out = model(x) out = out.argmax(dim=1) pred_y.append(out) true_y, pred_y = torch.cat(true_y), torch.cat(pred_y) acc = (true_y == pred_y).sum() / float(true_y.shape[0]) print('-epoch {:3} step {:5} eval acc {:5}'.format( epoch, i, acc)) val_acc.append(acc) model.train() if len(val_acc) > args.early_stop: val_acc.pop(0) if acc < min(val_acc): break step += 1 torch.save({ 'model': model.state_dict(), 'optim': optimizer.state_dict() }, os.path.join(args.ckpt_dir, f"model_{args.mdl}.pt")) print('acc:', acc) print('confusion matrix', confusion_matrix(true_y.tolist(), pred_y.tolist()))
def train_poly(args, hparams): torch.manual_seed(42) torch.cuda.manual_seed(42) np.random.seed(42) print('CHECK HERE train poly ONLY') train_dataloader = get_dataloader(hparams.use_output_mask, hparams.train_file, hparams.train_label, hparams, hparams.poly_batch_size, hparams.poly_max_length, shuffle=True) val_dataloader = get_dataloader(hparams.use_output_mask, hparams.val_file, hparams.val_label, hparams, hparams.poly_batch_size, hparams.poly_max_length, shuffle=True) # test_dataloader = get_dataloader(args.use_output_mask, args.test_file, args.test_label, # args.class2idx, args.merge_cedict, args.poly_batch_size, # args.max_length, shuffle=True) with codecs.open(hparams.class2idx, 'r', 'utf-8') as usernames: class2idx = json.load(usernames) print("num classes: {}".format(len(class2idx))) num_classes = len(class2idx) model = G2PTransformerMask(num_classes, hparams) device = torch.cuda.current_device() ## 查看当前使用的gpu序号 model = model.to(device) ## 将模型加载到指定设备上 for name, param in model.named_parameters(): # frozen syntax module if name.split('.')[0] != 'tree_shared_linear' and name.split('.')[0] != 'structure_cnn_poly' \ and name.split('.')[0] != 'linear_pre' and name.split('.')[0] != 'poly_phoneme_classifier' \ and name.split('.')[0] != 'linear_aft': param.requires_grad = False training_parameters_list = [ p for p in model.parameters() if p.requires_grad ] optimizer = torch.optim.Adam(training_parameters_list, lr=hparams.poly_lr) criterion = nn.NLLLoss() # mask_criterion = Mask_Softmax() mask_criterion = Gumbel_Softmax() model_dir = "./save/poly_only_syntax_frozen" if not os.path.exists(model_dir): os.makedirs(model_dir) best_acc = 0 for epoch in range(hparams.poly_epochs): model.train() for idx, batch in enumerate(train_dataloader, start=1): # print('CEHCK batch:', batch) # if idx > 200: # break batch = tuple(t.to(device) for t in batch) if hparams.use_output_mask: input_ids, poly_ids, labels, output_mask = batch mask = torch.sign(input_ids) inputs = { "input_ids": input_ids, "poly_ids": poly_ids, "attention_mask": mask } else: input_ids, poly_ids, labels = batch mask = torch.sign( input_ids ) ## torch.sign(input,out=None) 符号函数,返回一个新张量,包含输入input张量每个元素的正负(大于0的元素对应1,小于0的元素对应-1,0还是0) inputs = { "input_ids": input_ids, "poly_ids": poly_ids, "attention_mask": mask } # inputs = {"input_ids": input_ids, # "poly_ids": poly_ids, # "attention_mask": mask} logits, _ = model(**inputs) batch_size = logits.size(0) logits = logits[torch.arange(batch_size), poly_ids] # logits = mask_criterion(logits, output_mask, True) logits = mask_criterion(logits, output_mask) loss = criterion(logits, labels) loss.backward() # nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() model.zero_grad() if idx % 100 == 0: ## %取余 print("loss : {:.4f}".format(loss.item())) all_preds = [] all_mask_preds = [] all_labels = [] model.eval() for batch in tqdm(val_dataloader, total=len(val_dataloader)): batch = tuple(t.to(device) for t in batch) # input_ids, poly_ids, labels = batch # mask = torch.sign(input_ids) # inputs = {"input_ids": input_ids, # "poly_ids": poly_ids, # "attention_mask": mask} if hparams.use_output_mask: input_ids, poly_ids, labels, output_mask = batch mask = torch.sign(input_ids) inputs = { "input_ids": input_ids, "poly_ids": poly_ids, "attention_mask": mask } else: input_ids, poly_ids, labels = batch mask = torch.sign(input_ids) inputs = { "input_ids": input_ids, "poly_ids": poly_ids, "attention_mask": mask } with torch.no_grad(): logits, _ = model(**inputs) batch_size = logits.size(0) logits = logits[torch.arange(batch_size), poly_ids] # logits = logits.exp() # output_mask_false = 1.0 - output_mask # logits = logits - output_mask_false # logits = mask_criterion(logits, output_mask, True) logits = mask_criterion(logits, output_mask) preds = torch.argmax(logits, dim=1).cpu().numpy() mask_preds = masked_augmax(logits, output_mask, dim=1).cpu().numpy() if not (preds == mask_preds).all(): print('CHECK preds:', preds) print('CHECK mask_preds:', mask_preds) print('CHECK labels:', labels) print('CHECK output_mask:', np.where(output_mask.cpu().numpy() == 1.0)) all_preds.append(preds) all_mask_preds.append(mask_preds) all_labels.append(labels.cpu().numpy()) preds = np.concatenate(all_preds, axis=0) mask_preds = np.concatenate(all_mask_preds, axis=0) labels = np.concatenate(all_labels, axis=0) # print('CHECK preds:', preds) # print('CHECK mask_preds:', mask_preds) # print('CHECK labels:', labels) val_acc = accuracy_score(labels, preds) mask_val_acc = accuracy_score(labels, mask_preds) pred_diff_acc = accuracy_score(preds, mask_preds) print( "epoch :{}, acc: {:.2f}, mask acc: {:.2f}, pred_diff_acc: {:.2f}". format(epoch, val_acc * 100, mask_val_acc * 100, pred_diff_acc * 100)) if val_acc > best_acc: best_acc = val_acc state_dict = model.state_dict() save_file = os.path.join(model_dir, "{:.2f}_model.pt".format(val_acc * 100)) torch.save(state_dict, save_file)
def main(args): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_idx) os.environ["CUDA_DEVICE"] = str(args.gpu_idx) np.random.seed(0) torch.cuda.manual_seed(0) torch.cuda.set_device(args.gpu_idx) pretrained = os.path.join(args.net_dir, '{}_{}.pth'.format(args.net_type, args.dataset)) # set the out-of-distribution data out_dist_list = [ 'imagenet_crop', 'imagenet_resize', 'lsun_crop', 'lsun_resize', 'isun' ] if args.dataset == 'cifar10': out_dist_list = ['cifar100', 'svhn'] + out_dist_list input_stds = (0.2470, 0.2435, 0.2616) in_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), input_stds) ]) elif args.dataset == 'cifar100': out_dist_list = ['cifar10', 'svhn'] + out_dist_list input_stds = (0.2673, 0.2564, 0.2762) in_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4865, 0.4409), input_stds) ]) elif args.dataset == 'svhn': out_dist_list = ['cifar10', 'cifar100'] + out_dist_list input_stds = (1.0, 1.0, 1.0) in_transform = transforms.Compose([transforms.ToTensor()]) # load model print('load model: ' + args.net_type) model = torch.load(pretrained, map_location="cuda:" + str(args.gpu_idx)) model.cuda() model.eval() # load dataset print('load target data: ' + args.dataset) train_loader = data_utils.get_dataloader(args.dataset, args.data_root, 'train', in_transform, args.batch_size) test_loader = data_utils.get_dataloader(args.dataset, args.data_root, 'test', in_transform, args.batch_size) # fit detector print('fit detector') OOD_Detector = detector_dict[args.detector_type] main_detector = OOD_Detector( model, args.num_classes, ood_tuning=args.ood_tuning, net_type='' if args.naive_layer else args.net_type, normalizer=input_stds if args.detector_type in ['odin', 'mahalanobis'] else None, ) if args.detector_type == 'malcom' and args.ood_tuning: args.detector_type = 'malcom++' main_detector.fit(train_loader) # get scores print('get scores') results = [] if not args.ood_tuning: in_scores = detectors.detect_utils.get_scores(main_detector, test_loader) for _, out_dist in enumerate(out_dist_list): print('\t...out-of-distribution: ' + out_dist) out_test_loader = data_utils.get_dataloader(out_dist, args.data_root, 'test', in_transform, args.batch_size) if args.ood_tuning: main_detector.tune_parameters(test_loader, out_test_loader, num_samples=1000) in_scores = detectors.detect_utils.get_scores( main_detector, test_loader) out_scores = detectors.detect_utils.get_scores( main_detector, out_test_loader) else: out_scores = detectors.detect_utils.get_scores( main_detector, out_test_loader) test_results = callog.metric(-in_scores[1000:], -out_scores[1000:]) results.append(test_results) mtypes = ['', 'TNR', 'AUROC', 'DTACC', 'AUIN', 'AUOUT'] print('=' * 78) print('{} detector (with {} trained on {} '.format(args.detector_type, args.net_type, args.dataset), end='') print('w/o using ood samples): ' if not args.ood_tuning else 'with using ood samples): ') for mtype in mtypes: print(' {mtype:^12s}'.format(mtype=mtype), end='') for count_out, result in enumerate(results): print('\n {:12}'.format(out_dist_list[count_out][:10]), end='') print(' {val:^12.2f}'.format(val=100. * result['TNR']), end='') print(' {val:^12.2f}'.format(val=100. * result['AUROC']), end='') print(' {val:^12.2f}'.format(val=100. * result['DTACC']), end='') print(' {val:^12.2f}'.format(val=100. * result['AUIN']), end='') print(' {val:^12.2f}'.format(val=100. * result['AUOUT']), end='') print('') print('=' * 78)
def main(args): if not os.path.exists(args.output_dir): os.mkdir(args.output_dir) os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_idx) os.environ["CUDA_DEVICE"] = str(args.gpu_idx) np.random.seed(0) torch.cuda.manual_seed(0) torch.cuda.set_device(args.gpu_idx) out_file = os.path.join(args.output_dir, '{}_{}.pth'.format(args.net_type, args.dataset)) # set the transformations for training tfs_for_augmentation = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ] if args.dataset == 'cifar10': train_transform = transforms.Compose(tfs_for_augmentation + [ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4821, 0.4465), (0.2470, 0.2435, 0.2616)), ]) elif args.dataset == 'cifar100': train_transform = transforms.Compose(tfs_for_augmentation + [ transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)), ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)), ]) elif args.dataset == 'svhn': train_transform = transforms.Compose([transforms.ToTensor()]) test_transform = transforms.Compose([transforms.ToTensor()]) # load model if args.net_type == 'densenet': if args.dataset == 'svhn': model = densenet.DenseNet3(100, args.num_classes, growth_rate=12, dropRate=0.2) else: model = densenet.DenseNet3(100, args.num_classes, growth_rate=12) elif args.net_type == 'resnet': model = resnet.ResNet34(num_c=args.num_classes) elif args.net_type == 'vanilla': model = vanilla.VanillaCNN(args.num_classes) model.cuda() print('load model: ' + args.net_type) # load dataset print('load target data: ' + args.dataset) if args.dataset == 'svhn': train_loader, valid_loader = data_utils.get_dataloader( args.dataset, args.data_root, 'train', train_transform, args.batch_size, valid_transform=test_transform) else: train_loader = data_utils.get_dataloader(args.dataset, args.data_root, 'train', train_transform, args.batch_size) test_loader = data_utils.get_dataloader(args.dataset, args.data_root, 'test', test_transform, args.batch_size) # define objective and optimizer criterion = nn.CrossEntropyLoss() if args.net_type == 'densenet' or args.net_type == 'vanilla': weight_decay = 1e-4 milestones = [150, 225] gamma = 0.1 elif args.net_type == 'resnet': weight_decay = 5e-4 milestones = [60, 120, 160] gamma = 0.2 if args.dataset == 'svhn' or args.net_type == 'vanilla': milestones = [20, 30] optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=weight_decay) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=gamma) # train best_loss = np.inf iter_cnt = 0 for epoch in range(args.num_epochs): model.train() total, total_loss, total_step = 0, 0, 0 for _, (data, labels) in enumerate(train_loader): data = data.cuda() labels = labels.cuda() total += data.size(0) outputs = model(data) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() * data.size(0) iter_cnt += 1 total_step += 1 if args.dataset == 'svhn' and iter_cnt >= 200: valid_loss, _ = evaluation(model, valid_loader, criterion) test_loss, acc = evaluation(model, test_loader, criterion) print( 'Epoch [{:03d}/{:03d}], step [{}/{}] train loss : {:.4f}, valid loss : {:.4f}, test loss : {:.4f}, test acc : {:.2f} %' .format(epoch + 1, args.num_epochs, total_step, len(train_loader), total_loss / total, valid_loss, test_loss, 100 * acc)) if valid_loss < best_loss: best_loss = valid_loss torch.save(model, out_file) iter_cnt = 0 model.train() if args.dataset != 'svhn': test_loss, acc = evaluation(model, test_loader, criterion) print( '[{:03d}/{:03d}] train loss : {:.4f}, test loss : {:.4f}, test acc : {:.2f} %' .format(epoch + 1, args.num_epochs, total_loss / total, test_loss, 100 * acc)) torch.save(model, out_file) scheduler.step()