def train(hps: DictConfig) -> None: # This enables a ctr-C without triggering errors import signal signal.signal(signal.SIGINT, lambda x, y: sys.exit(0)) logger = logging.getLogger(__name__) cuda_available = torch.cuda.is_available() torch.manual_seed(hps.seed) device = "cuda" if cuda_available and hps.device == 'cuda' else "cpu" # Models local_channel = hps.get(hps.base_classifier).last_conv_channel classifier = get_model(model_name=hps.base_classifier, in_size=local_channel, out_size=hps.rep_size).to(hps.device) logger.info('Base classifier name: {}, # parameters: {}'.format( hps.base_classifier, cal_parameters(classifier))) sdim = SDIM(disc_classifier=classifier, mi_units=hps.mi_units, n_classes=hps.n_classes, margin=hps.margin, rep_size=hps.rep_size, local_channel=local_channel).to(hps.device) # logging the SDIM desc. for desc in sdim.desc(): logger.info(desc) train_loader = Loader('train', batch_size=hps.n_batch_train, device=device) if cuda_available and hps.n_gpu > 1: sdim = torch.nn.DataParallel(sdim, device_ids=list(range(hps.n_gpu))) optimizer = Adam(filter(lambda param: param.requires_grad is True, sdim.parameters()), lr=hps.lr) torch.manual_seed(hps.seed) np.random.seed(hps.seed) # Create log dir logdir = os.path.abspath(hps.log_dir) + "/" if not os.path.exists(logdir): os.mkdir(logdir) loss_optimal = 1e5 n_iters = 0 losses = AverageMeter('Loss') MIs = AverageMeter('MI') nlls = AverageMeter('NLL') margins = AverageMeter('Margin') top1 = AverageMeter('Acc@1') top5 = AverageMeter('Acc@5') for x, y in train_loader: n_iters += 1 if n_iters == hps.training_iters: break # backward optimizer.zero_grad() loss, mi_loss, nll_loss, ll_margin, log_lik = sdim(x, y) loss.mean().backward() optimizer.step() acc1, acc5 = accuracy(log_lik, y, topk=(1, 5)) losses.update(loss.item(), x.size(0)) top1.update(acc1, x.size(0)) top5.update(acc5, x.size(0)) MIs.update(mi_loss.item(), x.size(0)) nlls.update(nll_loss.item(), x.size(0)) margins.update(ll_margin.item(), x.size(0)) if n_iters % hps.log_interval == hps.log_interval - 1: logger.info( 'Train loss: {:.4f}, mi: {:.4f}, nll: {:.4f}, ll_margin: {:.4f}' .format(losses.avg, MIs.avg, nlls.avg, margins.avg)) logger.info('Train Acc@1: {:.3f}, Acc@5: {:.3f}'.format( top1.avg, top5.avg)) if losses.avg < loss_optimal: loss_optimal = losses.avg model_path = 'SDIM_{}.pth'.format(hps.base_classifier) if cuda_available and hps.n_gpu > 1: state = sdim.module.state_dict() else: state = sdim.state_dict() check_point = { 'model_state': state, 'train_acc_top1': top1.avg, 'train_acc_top5': top5.avg } torch.save(check_point, os.path.join(hps.log_dir, model_path)) losses.reset() MIs.reset() nlls.reset() margins.reset() top1.reset() top5.reset()
def inference(hps: DictConfig) -> None: # This enables a ctr-C without triggering errors import signal signal.signal(signal.SIGINT, lambda x, y: sys.exit(0)) logger = logging.getLogger(__name__) cuda_available = torch.cuda.is_available() torch.manual_seed(hps.seed) device = "cuda" if cuda_available and hps.device == 'cuda' else "cpu" # Models local_channel = hps.get(hps.base_classifier).last_conv_channel classifier = get_model(model_name=hps.base_classifier, in_size=local_channel, out_size=hps.rep_size).to(hps.device) logger.info('Base classifier name: {}, # parameters: {}'.format( hps.base_classifier, cal_parameters(classifier))) sdim = SDIM(disc_classifier=classifier, mi_units=hps.mi_units, n_classes=hps.n_classes, margin=hps.margin, rep_size=hps.rep_size, local_channel=local_channel).to(hps.device) model_path = 'SDIM_{}.pth'.format(hps.base_classifier) base_dir = '/userhome/cs/u3003679/generative-classification-with-rejection' path = os.path.join(base_dir, model_path) sdim.load_state_dict(torch.load(path)['model_state']) # logging the SDIM desc. for desc in sdim.desc(): logger.info(desc) eval_loader = Loader('eval', batch_size=hps.n_batch_test, device=device) if cuda_available and hps.n_gpu > 1: sdim = torch.nn.DataParallel(sdim, device_ids=list(range(hps.n_gpu))) torch.manual_seed(hps.seed) np.random.seed(hps.seed) n_iters = 0 top1 = AverageMeter('Acc@1') top5 = AverageMeter('Acc@5') sdim.eval() for x, y in eval_loader: n_iters += 1 if n_iters == len(eval_loader): break with torch.no_grad(): log_lik = sdim.infer(x) acc1, acc5 = accuracy(log_lik, y, topk=(1, 5)) top1.update(acc1, x.size(0)) top5.update(acc5, x.size(0)) logger.info('Test Acc@1: {:.3f}, Acc@5: {:.3f}'.format(top1.avg, top5.avg))