import torch from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F from args import get_opt import h5py import os opt = get_opt() std = 0.01 def normal_init(m, mean, std): if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): m.weight.data.normal_(mean, std) m.bias.data.zero_() class dataset_mnist: def __init__(self): filename = 'data.hdf5' file = os.path.join(os.getcwd(), filename) try: self.data = h5py.File(file, 'r+') except: raise IOError( 'Dataset not found. Please make sure the dataset was downloaded.' ) print("Reading Done: %s", file)
def main(): opt = get_opt() tb_logger.configure(opt.logger_name, flush_secs=5, opt=opt) logfname = os.path.join(opt.logger_name, 'log.txt') logging.basicConfig(filename=logfname, format='%(asctime)s %(message)s', level=logging.INFO) logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(str(opt.d)) torch.manual_seed(opt.seed) if opt.cuda: # TODO: remove deterministic torch.backends.cudnn.deterministic = True torch.cuda.manual_seed(opt.seed) np.random.seed(opt.seed) # helps with wide-resnet by reducing memory and time 2x cudnn.benchmark = True train_loader, test_loader, train_test_loader = get_loaders(opt) if opt.epoch_iters == 0: opt.epoch_iters = int( np.ceil(1. * len(train_loader.dataset) / opt.batch_size)) opt.maxiter = opt.epoch_iters * opt.epochs if opt.g_epoch: opt.gvar_start *= opt.epoch_iters opt.g_optim_start = (opt.g_optim_start * opt.epoch_iters) + 1 model = models.init_model(opt) optimizer = OptimizerFactory(model, train_loader, tb_logger, opt) epoch = 0 save_checkpoint = utils.SaveCheckpoint() # optionally resume from a checkpoint if not opt.noresume: model_path = os.path.join(opt.logger_name, opt.ckpt_name) if os.path.isfile(model_path): print("=> loading checkpoint '{}'".format(model_path)) checkpoint = torch.load(model_path) best_prec1 = checkpoint['best_prec1'] optimizer.gvar.load_state_dict(checkpoint['gvar']) optimizer.niters = checkpoint['niters'] epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) save_checkpoint.best_prec1 = best_prec1 print("=> loaded checkpoint '{}' (epoch {}, best_prec {})".format( model_path, epoch, best_prec1)) else: print("=> no checkpoint found at '{}'".format(model_path)) if opt.niters > 0: max_iters = opt.niters else: max_iters = opt.epochs * opt.epoch_iters if opt.untrain_steps > 0: untrain(model, optimizer.gvar, opt) while optimizer.niters < max_iters: optimizer.epoch = epoch utils.adjust_lr(optimizer, opt) ecode = train(tb_logger, epoch, train_loader, model, optimizer, opt, test_loader, save_checkpoint, train_test_loader) if ecode == -1: break epoch += 1 tb_logger.save_log()