def set_model(args, n_data): # set the model # if args.model == 'alexnet': # model = MyAlexNetCMC(args.feat_dim) # elif args.model.startswith('resnet'): # model = MyResNetsCMC(args.model) n_views_per_sample = 2 if True or args.model == 'alexnet_multiview': model = MyMultiviewAlexNetCMC(args.feat_dim) else: raise ValueError('model not supported yet {}'.format(args.model)) # contrast = NCEAverage(args.feat_dim, n_data, args.nce_k, args.nce_t, args.nce_m, args.softmax) contrast = NCEAverageMultiview(args.feat_dim, n_data, n_views_per_sample, args.nce_k, args.nce_t, args.nce_m, args.softmax) criterion_v1 = NCESoftmaxLoss() if args.softmax else NCECriterion(n_data) criterion_v2 = NCESoftmaxLoss() if args.softmax else NCECriterion(n_data) if torch.cuda.is_available(): model = model.cuda() contrast = contrast.cuda() criterion_v2 = criterion_v2.cuda() criterion_v1 = criterion_v1.cuda() cudnn.benchmark = True return model, contrast, criterion_v2, criterion_v1
def set_model(args, n_data): # set the model if args.model == 'alexnet': model = MyAlexNetCMC(args.feat_dim) elif args.model.startswith('resnet'): model = MyResNetsCMC(args.model) else: raise ValueError('model not supported yet {}'.format(args.model)) contrast = NCEAverage(args.feat_dim, n_data, args.nce_k, args.nce_t, args.nce_m, args.softmax) criterion_l = NCESoftmaxLoss() if args.softmax else NCECriterion(n_data) criterion_ab = NCESoftmaxLoss() if args.softmax else NCECriterion(n_data) if torch.cuda.is_available(): model = model.cuda() contrast = contrast.cuda() criterion_ab = criterion_ab.cuda() criterion_l = criterion_l.cuda() cudnn.benchmark = True return model, contrast, criterion_ab, criterion_l
def set_model(args, n_data): # set the model if args.model == 'alexnet': if args.view == 'Lab': model = alexnet(in_channel=(1, 2), feat_dim=args.feat_dim) elif args.view == 'Rot': model = alexnet(in_channel=(3, 3), feat_dim=args.feat_dim) elif args.view == 'LabRot': model = alexnet(in_channel=(1, 2), feat_dim=args.feat_dim) else: raise NotImplemented('view not implemented {}'.format(args.view)) elif args.model.startswith('resnet'): model = ResNetV2(args.model) else: raise ValueError('model not supported yet {}'.format(args.model)) contrast = NCEAverage(args.feat_dim, n_data, args.nce_k, args.nce_t, args.nce_m) criterion_l = NCECriterion(n_data) criterion_ab = NCECriterion(n_data) if torch.cuda.is_available(): if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model).cuda() else: model = model.cuda() contrast = contrast.cuda() criterion_ab = criterion_ab.cuda() criterion_l = criterion_l.cuda() cudnn.benchmark = True return model, contrast, criterion_ab, criterion_l
def set_model(n_data): contrast = NCEAverage(128, n_data, 4096, 0.1, 0.5) criterion_l = NCECriterion(n_data) criterion_ab = NCECriterion(n_data) if torch.cuda.is_available(): contrast = contrast.cuda() criterion_ab = criterion_ab.cuda() criterion_l = criterion_l.cuda() cudnn.benchmark = True return contrast, criterion_ab, criterion_l
def set_model(args, n_data): # set the model if args.model == 'alexnet': model = alexnet(args.feat_dim) else: raise ValueError('model not supported yet {}'.format(args.model)) contrast = NCEAverage(args.feat_dim, n_data, args.nce_k, args.nce_t, args.nce_m) criterion_l = NCECriterion(n_data) criterion_ab = NCECriterion(n_data) if torch.cuda.is_available(): model = model.cuda() contrast = contrast.cuda() criterion_ab = criterion_ab.cuda() criterion_l = criterion_l.cuda() cudnn.benchmark = True return model, contrast, criterion_ab, criterion_l
def main(): args = parse_option() if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) # set the data loader data_folder = os.path.join(args.data_folder, 'train') image_size = 224 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] normalize = transforms.Normalize(mean=mean, std=std) if args.aug == 'NULL': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) elif args.aug == 'CJ': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomGrayscale(p=0.2), transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) else: raise NotImplemented('augmentation not supported: {}'.format(args.aug)) train_dataset = ImageFolderInstance(data_folder, transform=train_transform, two_crop=args.moco) print(len(train_dataset)) train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) # create model and optimizer n_data = len(train_dataset) if args.model == 'resnet50': model = InsResNet50() if args.moco: model_ema = InsResNet50() elif args.model == 'resnet50x2': model = InsResNet50(width=2) if args.moco: model_ema = InsResNet50(width=2) elif args.model == 'resnet50x4': model = InsResNet50(width=4) if args.moco: model_ema = InsResNet50(width=4) else: raise NotImplementedError('model not supported {}'.format(args.model)) # copy weights from `model' to `model_ema' if args.moco: moment_update(model, model_ema, 0) # set the contrast memory and criterion if args.moco: contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t, args.softmax).cuda(args.gpu) else: contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t, args.nce_m, args.softmax).cuda(args.gpu) criterion = NCESoftmaxLoss() if args.softmax else NCECriterion(n_data) criterion = criterion.cuda(args.gpu) model = model.cuda() if args.moco: model_ema = model_ema.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) cudnn.benchmark = True if args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) if args.moco: optimizer_ema = torch.optim.SGD(model_ema.parameters(), lr=0, momentum=0, weight_decay=0) model_ema, optimizer_ema = amp.initialize(model_ema, optimizer_ema, opt_level=args.opt_level) # optionally resume from a checkpoint args.start_epoch = 1 if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') # checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) contrast.load_state_dict(checkpoint['contrast']) if args.moco: model_ema.load_state_dict(checkpoint['model_ema']) if args.amp and checkpoint['opt'].amp: print('==> resuming amp state_dict') amp.load_state_dict(checkpoint['amp']) print("=> loaded successfully '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) del checkpoint torch.cuda.empty_cache() else: print("=> no checkpoint found at '{}'".format(args.resume)) # tensorboard logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2) # routine for epoch in range(args.start_epoch, args.epochs + 1): adjust_learning_rate(epoch, args, optimizer) print("==> training...") time1 = time.time() if args.moco: loss, prob = train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, args) else: loss, prob = train_ins(epoch, train_loader, model, contrast, criterion, optimizer, args) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) # tensorboard logger logger.log_value('ins_loss', loss, epoch) logger.log_value('ins_prob', prob, epoch) logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) # save model if epoch % args.save_freq == 0: print('==> Saving...') state = { 'opt': args, 'model': model.state_dict(), 'contrast': contrast.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } if args.moco: state['model_ema'] = model_ema.state_dict() if args.amp: state['amp'] = amp.state_dict() save_file = os.path.join( args.model_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # help release GPU memory del state # saving the model print('==> Saving...') state = { 'opt': args, 'model': model.state_dict(), 'contrast': contrast.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } if args.moco: state['model_ema'] = model_ema.state_dict() if args.amp: state['amp'] = amp.state_dict() save_file = os.path.join(args.model_folder, 'current.pth') torch.save(state, save_file) if epoch % args.save_freq == 0: save_file = os.path.join( args.model_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # help release GPU memory del state torch.cuda.empty_cache()
def main(): args = parse_option() if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) # create model and optimizer # == dataset config== """ CUDA_VISIBLE_DEVICES=0,1 python train_temporal_dis.py \ --batch_size 16 --num_workers 8 --nce_k 3569 --softmax --moco """ # args.dataset = 'hmdb51' # args.train_list = '../datasets/lists/hmdb51/hmdb51_rgb_train_split_1.txt' # args.val_list = '../datasets/lists/hmdb51/hmdb51_rgb_val_split_1.txt' """ CUDA_VISIBLE_DEVICES=1 python train_temporal_dis.py \ --batch_size 16 --num_workers 8 --nce_k 9536 --softmax --moco """ # args.print_freq = 100 # args.dataset = 'ucf101' # args.train_list = '../datasets/lists/ucf101/ucf101_rgb_train_split_1.txt' # args.val_list = '../datasets/lists/ucf101/ucf101_rgb_val_split_1.txt' # args.print_freq = 1000 # args.dataset = 'kinetics' # args.train_list = '../datasets/lists/kinetics-400/ssd_kinetics_video_trainlist.txt' # args.val_list = '../datasets/lists/kinetics-400/ssd_kinetics_video_vallist.txt' args.dropout = 0.5 args.clips = 1 args.data_length = 16 args.stride = 4 args.spatial_size = 224 args.root = "" args.mode = 'rgb' args.eval_indict = 'loss' args.pt_loss = 'TemporalDis' args.workers = 4 # args.arch = 'i3d' # 'r2p1d' num_class, data_length, image_tmpl = data_config(args) train_transforms, test_transforms, eval_transforms = augmentation_config( args) train_loader, val_loader, eval_loader, train_samples, val_samples, eval_samples = data_loader_init( args, data_length, image_tmpl, train_transforms, test_transforms, eval_transforms) n_data = len(train_loader) if args.arch == 'i3d': model = I3D(num_classes=101, modality=args.mode, dropout_prob=args.dropout, with_classifier=False) model_ema = I3D(num_classes=101, modality=args.mode, dropout_prob=args.dropout, with_classifier=False) elif args.arch == 'r2p1d': model = R2Plus1DNet((1, 1, 1, 1), num_classes=num_class, with_classifier=False) model_ema = R2Plus1DNet((1, 1, 1, 1), num_classes=num_class, with_classifier=False) elif args.arch == 'r3d': from model.r3d import resnet18 model = resnet18(num_classes=num_class, with_classifier=False) model_ema = resnet18(num_classes=num_class, with_classifier=False) else: Exception("Not implemene error!") model = torch.nn.DataParallel(model) model_ema = torch.nn.DataParallel(model_ema) # random initialization model.apply(weights_init) model_ema.apply(weights_init) # copy weights from `model' to `model_ema' moment_update(model, model_ema, 0) contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t, args.softmax).cuda(args.gpu) # contrast2 = MemoryMoCo(128, n_data, args.nce_k, args.nce_t, args.softmax).cuda(args.gpu) criterion = NCESoftmaxLoss() if args.softmax else NCECriterion(n_data) criterion = criterion.cuda(args.gpu) cls_criterion = nn.CrossEntropyLoss().cuda() model = model.cuda() if args.moco: model_ema = model_ema.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) cudnn.benchmark = True if args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) if args.moco: optimizer_ema = torch.optim.SGD(model_ema.parameters(), lr=0, momentum=0, weight_decay=0) model_ema, optimizer_ema = amp.initialize(model_ema, optimizer_ema, opt_level=args.opt_level) # optionally resume from a checkpoint args.start_epoch = 1 if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') # checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) contrast.load_state_dict(checkpoint['contrast']) if args.moco: model_ema.load_state_dict(checkpoint['model_ema']) if args.amp and checkpoint['opt'].amp: print('==> resuming amp state_dict') amp.load_state_dict(checkpoint['amp']) print("=> loaded successfully '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) del checkpoint torch.cuda.empty_cache() else: print("=> no checkpoint found at '{}'".format(args.resume)) # tensorboard logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2) logger2 = tb_logger.Logger(logdir=args.tb_folder2, flush_secs=2) #==================================== our data augmentation method================================= pos_aug = GenPositive() neg_aug = GenNegative() # routine for epoch in range(args.start_epoch, args.epochs + 1): adjust_learning_rate(epoch, args, optimizer) print("==> training...") time1 = time.time() loss, prob = train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, args, pos_aug, neg_aug) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) saving(logger, loss, epoch, optimizer, args, model, contrast, prob, model_ema, 'TemporalDis')
def main(): args = parse_option() if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) # set the data loader data_folder = os.path.join(args.data_folder, 'train') val_folder = os.path.join(args.data_folder, 'val') crop_padding = 32 image_size = 224 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] normalize = transforms.Normalize(mean=mean, std=std) if args.aug == 'NULL' and args.dataset == 'imagenet': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) elif args.aug == 'CJ': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomGrayscale(p=0.2), transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) # elif args.aug == 'NULL' and args.dataset == 'cifar': # train_transform = transforms.Compose([ # transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), # transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), # transforms.RandomGrayscale(p=0.2), # transforms.RandomHorizontalFlip(p=0.5), # transforms.ToTensor(), # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # ]) # # test_transform = transforms.Compose([ # transforms.ToTensor(), # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # ]) elif args.aug == 'simple' and args.dataset == 'imagenet': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomHorizontalFlip(), get_color_distortion(1.0), transforms.ToTensor(), normalize, ]) # TODO: Currently follow CMC test_transform = transforms.Compose([ transforms.Resize(image_size + crop_padding), transforms.CenterCrop(image_size), transforms.ToTensor(), normalize, ]) elif args.aug == 'simple' and args.dataset == 'cifar': train_transform = transforms.Compose([ transforms.RandomResizedCrop(size=32), transforms.RandomHorizontalFlip(p=0.5), get_color_distortion(0.5), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) else: raise NotImplemented('augmentation not supported: {}'.format(args.aug)) # Get Datasets if args.dataset == "imagenet": train_dataset = ImageFolderInstance(data_folder, transform=train_transform, two_crop=args.moco) print(len(train_dataset)) train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) test_dataset = datasets.ImageFolder(val_folder, transforms=test_transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=args.num_workers, pin_memory=True) elif args.dataset == 'cifar': # cifar-10 dataset if args.contrastive_model == 'simclr': train_dataset = CIFAR10Instance_double(root='./data', train=True, download=True, transform=train_transform, double=True) else: train_dataset = CIFAR10Instance(root='./data', train=True, download=True, transform=train_transform) train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler, drop_last=True) test_dataset = CIFAR10Instance(root='./data', train=False, download=True, transform=test_transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=args.num_workers) # create model and optimizer n_data = len(train_dataset) if args.model == 'resnet50': model = InsResNet50() if args.contrastive_model == 'moco': model_ema = InsResNet50() elif args.model == 'resnet50x2': model = InsResNet50(width=2) if args.contrastive_model == 'moco': model_ema = InsResNet50(width=2) elif args.model == 'resnet50x4': model = InsResNet50(width=4) if args.contrastive_model == 'moco': model_ema = InsResNet50(width=4) elif args.model == 'resnet50_cifar': model = InsResNet50_cifar() if args.contrastive_model == 'moco': model_ema = InsResNet50_cifar() else: raise NotImplementedError('model not supported {}'.format(args.model)) # copy weights from `model' to `model_ema' if args.contrastive_model == 'moco': moment_update(model, model_ema, 0) # set the contrast memory and criterion if args.contrastive_model == 'moco': contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t, args.softmax).cuda(args.gpu) elif args.contrastive_model == 'simclr': contrast = None else: contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t, args.nce_m, args.softmax).cuda(args.gpu) if args.softmax: criterion = NCESoftmaxLoss() elif args.contrastive_model == 'simclr': criterion = BatchCriterion(1, args.nce_t, args.batch_size) else: criterion = NCECriterion(n_data) criterion = criterion.cuda(args.gpu) model = model.cuda() if args.contrastive_model == 'moco': model_ema = model_ema.cuda() # Exclude BN and bias if needed weight_decay = args.weight_decay if weight_decay and args.filter_weight_decay: parameters = add_weight_decay(model, weight_decay, args.filter_weight_decay) weight_decay = 0. else: parameters = model.parameters() optimizer = torch.optim.SGD(parameters, lr=args.learning_rate, momentum=args.momentum, weight_decay=weight_decay) cudnn.benchmark = True if args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) if args.contrastive_model == 'moco': optimizer_ema = torch.optim.SGD(model_ema.parameters(), lr=0, momentum=0, weight_decay=0) model_ema, optimizer_ema = amp.initialize(model_ema, optimizer_ema, opt_level=args.opt_level) if args.LARS: optimizer = LARS(optimizer=optimizer, eps=1e-8, trust_coef=0.001) # optionally resume from a checkpoint args.start_epoch = 0 if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') # checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if contrast: contrast.load_state_dict(checkpoint['contrast']) if args.contrastive_model == 'moco': model_ema.load_state_dict(checkpoint['model_ema']) if args.amp and checkpoint['opt'].amp: print('==> resuming amp state_dict') amp.load_state_dict(checkpoint['amp']) print("=> loaded successfully '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) del checkpoint torch.cuda.empty_cache() else: print("=> no checkpoint found at '{}'".format(args.resume)) # tensorboard logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2) # routine for epoch in range(args.start_epoch, args.epochs + 1): print("==> training...") time1 = time.time() if args.contrastive_model == 'moco': loss, prob = train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, args) elif args.contrastive_model == 'simclr': print("Train using simclr") loss, prob = train_simclr(epoch, train_loader, model, criterion, optimizer, args) else: print("Train using InsDis") loss, prob = train_ins(epoch, train_loader, model, contrast, criterion, optimizer, args) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) # tensorboard logger logger.log_value('ins_loss', loss, epoch) logger.log_value('ins_prob', prob, epoch) logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) test_epoch = 2 if epoch % test_epoch == 0: model.eval() if args.contrastive_model == 'moco': model_ema.eval() print('----------Evaluation---------') start = time.time() if args.dataset == 'cifar': acc = kNN(epoch, model, train_loader, test_loader, 200, args.nce_t, n_data, low_dim=128, memory_bank=None) print("Evaluation Time: '{}'s".format(time.time() - start)) # writer.add_scalar('nn_acc', acc, epoch) logger.log_value('Test accuracy', acc, epoch) # print('accuracy: {}% \t (best acc: {}%)'.format(acc, best_acc)) print('[Epoch]: {}'.format(epoch)) print('accuracy: {}%)'.format(acc)) # test_log_file.flush() # save model if epoch % args.save_freq == 0: print('==> Saving...') state = { 'opt': args, 'model': model.state_dict(), # 'contrast': contrast.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } if args.contrastive_model == 'moco': state['model_ema'] = model_ema.state_dict() if args.amp: state['amp'] = amp.state_dict() save_file = os.path.join( args.model_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # help release GPU memory del state # saving the model print('==> Saving...') state = { 'opt': args, 'model': model.state_dict(), # 'contrast': contrast.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } if args.contrastive_model == 'moco': state['model_ema'] = model_ema.state_dict() if args.amp: state['amp'] = amp.state_dict() save_file = os.path.join(args.model_folder, 'current.pth') torch.save(state, save_file) if epoch % args.save_freq == 0: save_file = os.path.join( args.model_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # help release GPU memory del state torch.cuda.empty_cache()