def main(): # os.system('shutdown -c') # cancel previous shutdown command if write_log: utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) args_file_path = os.path.join(args.save, 'args.yaml') with open(args_file_path, 'w') as f: yaml.dump(vars(args), f, default_flow_style=False) if args.distributed: if write_log: logger.info('Distributed initializing process group') torch.cuda.set_device(args.local_rank) distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size(), rank=env_rank()) assert (dist_utils.env_world_size() == distributed.get_world_size()) if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size())) device = torch.device("cuda:%d" % torch.cuda.current_device() if torch.cuda.is_available() else "cpu") else: device = torch.cuda.current_device() # # import pdb; pdb.set_trace() cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) # load dataset train_loader, test_loader, data_shape = get_dataset(args) trainlog = os.path.join(args.save, 'training.csv') testlog = os.path.join(args.save, 'test.csv') traincolumns = ['itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm'] testcolumns = ['wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost'] # build model regularization_fns, regularization_coeffs = create_regularization_fns(args) model = create_model(args, data_shape, regularization_fns).cuda() if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) traincolumns = append_regularization_keys_header(traincolumns, regularization_fns) if not args.resume and write_log: with open(trainlog, 'w') as f: csvlogger = csv.DictWriter(f, traincolumns) csvlogger.writeheader() with open(testlog, 'w') as f: csvlogger = csv.DictWriter(f, testcolumns) csvlogger.writeheader() set_cnf_options(args, model) if write_log: logger.info(model) if write_log: logger.info("Number of trainable parameters: {}".format(count_parameters(model))) if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader))) if write_log: logger.info('Iters per test: {}'.format(len(test_loader))) # optimizer if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, nesterov=False) # restore parameters # import pdb; pdb.set_trace() if args.resume is not None: # import pdb; pdb.set_trace() print('resume from checkpoint') checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model.load_state_dict(checkpt["state_dict"]) if "optim_state_dict" in checkpt.keys(): optimizer.load_state_dict(checkpt["optim_state_dict"]) # Manually move optimizer state to device. for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = cvt(v) # For visualization. if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape)) if write_log: time_meter = utils.RunningAverageMeter(0.97) bpd_meter = utils.RunningAverageMeter(0.97) loss_meter = utils.RunningAverageMeter(0.97) steps_meter = utils.RunningAverageMeter(0.97) grad_meter = utils.RunningAverageMeter(0.97) tt_meter = utils.RunningAverageMeter(0.97) if not args.resume: best_loss = float("inf") itr = 0 wall_clock = 0. begin_epoch = 1 chkdir = args.save ''' elif args.resume and args.validate: chkdir = os.path.dirname(args.resume) wall_clock = 0 itr = 0 best_loss = 0.0 begin_epoch = 0 ''' else: chkdir = os.path.dirname(args.resume) filename = os.path.join(chkdir, 'test.csv') print(filename) tedf = pd.read_csv(os.path.join(chkdir, 'test.csv')) trdf = pd.read_csv(os.path.join(chkdir, 'training.csv')) # import pdb; pdb.set_trace() wall_clock = trdf['wall'].to_numpy()[-1] itr = trdf['itr'].to_numpy()[-1] best_loss = tedf['bpd'].min() begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1) # not exactly correct if args.distributed: if write_log: logger.info('Syncing machines before training') dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) for epoch in range(begin_epoch, begin_epoch + 1): # compute test loss print('Evaluating') model.eval() if args.local_rank == 0: utils.makedirs(args.save) # import pdb; pdb.set_trace() if hasattr(model, 'module'): _state = model.module.state_dict() else: _state = model.state_dict() torch.save({ "args": args, "state_dict": _state, # model.module.state_dict() if torch.cuda.is_available() else model.state_dict(), "optim_state_dict": optimizer.state_dict(), "fixed_z": fixed_z.cpu() }, os.path.join(args.save, "checkpt_%d.pth" % epoch)) # save real and generate with different temperatures fig_num = 64 if True: # args.save_real: for i, (x, y) in enumerate(test_loader): if i < 100: pass elif i == 100: real = x.size(0) else: break if x.shape[0] > fig_num: x = x[:fig_num, ...] # import pdb; pdb.set_trace() fig_filename = os.path.join(chkdir, "real.jpg") save_image(x.float() / 255.0, fig_filename, nrow=8) if True: # args.generate: print('\nGenerating images... ') fixed_z = cvt(torch.randn(fig_num, *data_shape)) nb = int(np.ceil(np.sqrt(float(fixed_z.size(0))))) for t in [ 1.0, 0.99, 0.98, 0.97,0.96,0.95,0.93,0.92,0.90,0.85,0.8,0.75,0.7,0.65,0.6]: # visualize samples and density fig_filename = os.path.join(chkdir, "generated-T%g.jpg" % t) utils.makedirs(os.path.dirname(fig_filename)) generated_samples = model(t * fixed_z, reverse=True) x = unshift(generated_samples[0].view(-1, *data_shape), 8) save_image(x, fig_filename, nrow=nb)
def main(): #os.system('shutdown -c') # cancel previous shutdown command if write_log: utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) args_file_path = os.path.join(args.save, 'args.yaml') with open(args_file_path, 'w') as f: yaml.dump(vars(args), f, default_flow_style=False) if args.distributed: if write_log: logger.info('Distributed initializing process group') torch.cuda.set_device(args.local_rank) distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size(), rank=env_rank()) assert (dist_utils.env_world_size() == distributed.get_world_size()) if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size())) # get deivce # device = torch.device("cuda:%d"%torch.cuda.current_device() if torch.cuda.is_available() else "cpu") device = "cpu" cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) # load dataset train_loader, test_loader, data_shape = get_dataset(args) trainlog = os.path.join(args.save, 'training.csv') testlog = os.path.join(args.save, 'test.csv') traincolumns = [ 'itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm' ] testcolumns = [ 'wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost' ] # build model regularization_fns, regularization_coeffs = create_regularization_fns(args) model = create_model(args, data_shape, regularization_fns) # model = model.cuda() if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) traincolumns = append_regularization_keys_header(traincolumns, regularization_fns) if not args.resume and write_log: with open(trainlog, 'w') as f: csvlogger = csv.DictWriter(f, traincolumns) csvlogger.writeheader() with open(testlog, 'w') as f: csvlogger = csv.DictWriter(f, testcolumns) csvlogger.writeheader() set_cnf_options(args, model) if write_log: logger.info(model) if write_log: logger.info("Number of trainable parameters: {}".format( count_parameters(model))) if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader))) if write_log: logger.info('Iters per test: {}'.format(len(test_loader))) # optimizer if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, nesterov=False) # restore parameters if args.resume is not None: checkpt = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model.load_state_dict(checkpt["state_dict"]) if "optim_state_dict" in checkpt.keys(): optimizer.load_state_dict(checkpt["optim_state_dict"]) # Manually move optimizer state to device. for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = cvt(v) # For visualization. if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape)) if write_log: time_meter = utils.RunningAverageMeter(0.97) bpd_meter = utils.RunningAverageMeter(0.97) loss_meter = utils.RunningAverageMeter(0.97) steps_meter = utils.RunningAverageMeter(0.97) grad_meter = utils.RunningAverageMeter(0.97) tt_meter = utils.RunningAverageMeter(0.97) if not args.resume: best_loss = float("inf") itr = 0 wall_clock = 0. begin_epoch = 1 else: chkdir = os.path.dirname(args.resume) tedf = pd.read_csv(os.path.join(chkdir, 'test.csv')) trdf = pd.read_csv(os.path.join(chkdir, 'training.csv')) wall_clock = trdf['wall'].to_numpy()[-1] itr = trdf['itr'].to_numpy()[-1] best_loss = tedf['bpd'].min() begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1) # not exactly correct if args.distributed: if write_log: logger.info('Syncing machines before training') dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) for epoch in range(begin_epoch, args.num_epochs + 1): if not args.validate: model.train() with open(trainlog, 'a') as f: if write_log: csvlogger = csv.DictWriter(f, traincolumns) for _, (x, y) in enumerate(train_loader): start = time.time() update_lr(optimizer, itr) optimizer.zero_grad() # cast data and move to device x = add_noise(cvt(x), nbits=args.nbits) #x = x.clamp_(min=0, max=1) # compute loss bpd, (x, z), reg_states = compute_bits_per_dim(x, model) if np.isnan(bpd.data.item()): raise ValueError('model returned nan during training') elif np.isinf(bpd.data.item()): raise ValueError('model returned inf during training') loss = bpd if regularization_coeffs: reg_loss = sum(reg_state * coeff for reg_state, coeff in zip( reg_states, regularization_coeffs) if coeff != 0) loss = loss + reg_loss total_time = count_total_time(model) loss.backward() nfe_opt = count_nfe(model) if write_log: steps_meter.update(nfe_opt) grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.max_grad_norm) optimizer.step() itr_time = time.time() - start wall_clock += itr_time batch_size = x.size(0) metrics = torch.tensor([ 1., batch_size, loss.item(), bpd.item(), nfe_opt, grad_norm, *reg_states ]).float() rv = tuple(torch.tensor(0.) for r in reg_states) total_gpus, batch_total, r_loss, r_bpd, r_nfe, r_grad_norm, *rv = dist_utils.sum_tensor( metrics).cpu().numpy() if write_log: time_meter.update(itr_time) bpd_meter.update(r_bpd / total_gpus) loss_meter.update(r_loss / total_gpus) grad_meter.update(r_grad_norm / total_gpus) tt_meter.update(total_time) fmt = '{:.4f}' logdict = { 'itr': itr, 'wall': fmt.format(wall_clock), 'itr_time': fmt.format(itr_time), 'loss': fmt.format(r_loss / total_gpus), 'bpd': fmt.format(r_bpd / total_gpus), 'total_time': fmt.format(total_time), 'fe': r_nfe / total_gpus, 'grad_norm': fmt.format(r_grad_norm / total_gpus), } if regularization_coeffs: rv = tuple(v_ / total_gpus for v_ in rv) logdict = append_regularization_csv_dict( logdict, regularization_fns, rv) csvlogger.writerow(logdict) if itr % args.log_freq == 0: log_message = ( "Itr {:06d} | Wall {:.3e}({:.2f}) | " "Time/Itr {:.2f}({:.2f}) | BPD {:.2f}({:.2f}) | " "Loss {:.2f}({:.2f}) | " "FE {:.0f}({:.0f}) | Grad Norm {:.3e}({:.3e}) | " "TT {:.2f}({:.2f})".format( itr, wall_clock, wall_clock / (itr + 1), time_meter.val, time_meter.avg, bpd_meter.val, bpd_meter.avg, loss_meter.val, loss_meter.avg, steps_meter.val, steps_meter.avg, grad_meter.val, grad_meter.avg, tt_meter.val, tt_meter.avg)) if regularization_coeffs: log_message = append_regularization_to_log( log_message, regularization_fns, rv) logger.info(log_message) itr += 1 # compute test loss model.eval() if args.local_rank == 0: utils.makedirs(args.save) torch.save( { "args": args, "state_dict": model.module.state_dict() if torch.cuda.is_available() else model.state_dict(), "optim_state_dict": optimizer.state_dict(), "fixed_z": fixed_z.cpu() }, os.path.join(args.save, "checkpt.pth")) if epoch % args.val_freq == 0 or args.validate: with open(testlog, 'a') as f: if write_log: csvlogger = csv.DictWriter(f, testcolumns) with torch.no_grad(): start = time.time() if write_log: logger.info("validating...") lossmean = 0. meandist = 0. steps = 0 tt = 0. for i, (x, y) in enumerate(test_loader): sh = x.shape x = shift(cvt(x), nbits=args.nbits) loss, (x, z), _ = compute_bits_per_dim(x, model) dist = (x.view(x.size(0), -1) - z).pow(2).mean(dim=-1).mean() meandist = i / (i + 1) * dist + meandist / (i + 1) lossmean = i / (i + 1) * lossmean + loss / (i + 1) tt = i / (i + 1) * tt + count_total_time(model) / (i + 1) steps = i / (i + 1) * steps + count_nfe(model) / (i + 1) loss = lossmean.item() metrics = torch.tensor([1., loss, meandist, steps]).float() total_gpus, r_bpd, r_mdist, r_steps = dist_utils.sum_tensor( metrics).cpu().numpy() eval_time = time.time() - start if write_log: fmt = '{:.4f}' logdict = { 'epoch': epoch, 'eval_time': fmt.format(eval_time), 'bpd': fmt.format(r_bpd / total_gpus), 'wall': fmt.format(wall_clock), 'total_time': fmt.format(tt), 'transport_cost': fmt.format(r_mdist / total_gpus), 'fe': '{:.2f}'.format(r_steps / total_gpus) } csvlogger.writerow(logdict) logger.info( "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}, Steps {:.4f}, TT {:.2f}, Transport Cost {:.2e}" .format(epoch, eval_time, r_bpd / total_gpus, r_steps / total_gpus, tt, r_mdist / total_gpus)) loss = r_bpd / total_gpus if loss < best_loss and args.local_rank == 0: best_loss = loss shutil.copyfile(os.path.join(args.save, "checkpt.pth"), os.path.join(args.save, "best.pth")) # visualize samples and density if write_log: with torch.no_grad(): fig_filename = os.path.join(args.save, "figs", "{:04d}.jpg".format(epoch)) utils.makedirs(os.path.dirname(fig_filename)) generated_samples, _, _ = model(fixed_z, reverse=True) generated_samples = generated_samples.view(-1, *data_shape) nb = int(np.ceil(np.sqrt(float(fixed_z.size(0))))) save_image(unshift(generated_samples, nbits=args.nbits), fig_filename, nrow=nb) if args.validate: break
def main(): os.system('shutdown -c') # cancel previous shutdown command log.console(args) tb.log('sizes/world', dist_utils.env_world_size()) # need to index validation directory before we start counting the time dataloader.sort_ar(args.data + '/validation') if args.distributed: log.console('Distributed initializing process group') torch.cuda.set_device(args.local_rank) dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size()) assert (dist_utils.env_world_size() == dist.get_world_size()) log.console("Distributed: success (%d/%d)" % (args.local_rank, dist.get_world_size())) log.console("Loading model") model = resnet.resnet50(bn0=args.init_bn0).cuda() if args.fp16: model = network_to_half(model) if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) best_top5 = 93 # only save models over 93%. Otherwise it stops to save every time global model_params, master_params if args.fp16: model_params, master_params = prep_param_lists(model) else: model_params = master_params = model.parameters() optim_params = experimental_utils.bnwd_optim_params( model, model_params, master_params) if args.no_bn_wd else master_params # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD( optim_params, 0, momentum=args.momentum, weight_decay=args.weight_decay ) # start with 0 lr. Scheduler will change this later if args.resume: checkpoint = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch'] best_top5 = checkpoint['best_top5'] optimizer.load_state_dict(checkpoint['optimizer']) # save script so we can reproduce from logs shutil.copy2(os.path.realpath(__file__), f'{args.logdir}') log.console( "Creating data loaders (this could take up to 10 minutes if volume needs to be warmed up)" ) phases = eval(args.phases) dm = DataManager([copy.deepcopy(p) for p in phases if 'bs' in p]) scheduler = Scheduler(optimizer, [copy.deepcopy(p) for p in phases if 'lr' in p]) start_time = datetime.now() # Loading start to after everything is loaded if args.evaluate: return validate(dm.val_dl, model, criterion, 0, start_time) if args.distributed: log.console('Syncing machines before training') dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) log.event("~~epoch\thours\ttop1\ttop5\n") for epoch in range(args.start_epoch, scheduler.tot_epochs): dm.set_epoch(epoch) train(dm.trn_dl, model, criterion, optimizer, scheduler, epoch) top1, top5 = validate(dm.val_dl, model, criterion, epoch, start_time) time_diff = (datetime.now() - start_time).total_seconds() / 3600.0 log.event(f'~~{epoch}\t{time_diff:.5f}\t\t{top1:.3f}\t\t{top5:.3f}\n') is_best = top5 > best_top5 best_top5 = max(top5, best_top5) if args.local_rank == 0: if is_best: save_checkpoint(epoch, model, best_top5, optimizer, is_best=True, filename='model_best.pth.tar') phase = dm.get_phase(epoch) if phase: save_checkpoint( epoch, model, best_top5, optimizer, filename=f'sz{phase["bs"]}_checkpoint.path.tar')
def main(): # os.system('sudo shutdown -c') # cancel previous shutdown command log.console(args) tb.log('sizes/world', dist_utils.env_world_size()) print(args.data) assert os.path.exists(args.data) # need to index validation directory before we start counting the time dataloader.sort_ar(args.data + '/val') if args.distributed: log.console('Distributed initializing process group') torch.cuda.set_device(args.local_rank) dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size()) assert (dist_utils.env_world_size() == dist.get_world_size()) # todo(y): use global_rank instead of local_rank here log.console("Distributed: success (%d/%d)" % (args.local_rank, dist.get_world_size())) log.console("Loading model") #from mobilenetv3 import MobileNetV3 #model = MobileNetV3(mode='small', num_classes=1000).cuda() if args.network == 'resnet50': model.resnet.resnet50(bn0=args.init_bn0).cuda() elif args.network == 'resnet50friendlyv1': model = resnet.resnet50friendly(bn0=args.init_bn0, hybrid=True).cuda() elif args.network == 'resnet50friendlyv2': model = resnet.resnet50friendly2(bn0=args.init_bn0, hybrid=True).cuda() elif args.network == 'resnet50friendlyv3': model = resnet.resnet50friendly3(bn0=args.init_bn0, hybrid=True).cuda() elif args.network == 'resnet50friendlyv4': model = resnet.resnet50friendly4(bn0=args.init_bn0, hybrid=True).cuda() #import resnet_friendly #model = resnet_friendly.ResNet50Friendly().cuda() #model = torchvision.models.mobilenet_v2(pretrained=False).cuda() if args.fp16: model = network_to_half(model) if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) best_top5 = 93 # only save models over 93%. Otherwise it stops to save every time global model_params, master_params if args.fp16: model_params, master_params = prep_param_lists(model) else: model_params = master_params = model.parameters() optim_params = experimental_utils.bnwd_optim_params( model, model_params, master_params) if args.no_bn_wd else master_params # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD( optim_params, 0, momentum=args.momentum, weight_decay=args.weight_decay ) # start with 0 lr. Scheduler will change this later if args.resume: checkpoint = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch'] current_phase = checkpoint['current_phase'] best_top5 = checkpoint['best_top5'] optimizer.load_state_dict(checkpoint['optimizer']) # save script so we can reproduce from logs shutil.copy2(os.path.realpath(__file__), f'{args.logdir}') log.console( "Creating data loaders (this could take up to 10 minutes if volume needs to be warmed up)" ) # phases = util.text_unpickle(args.phases) lr = 0.9 scale_224 = 224 / 512 scale_288 = 128 / 512 one_machine = [ { 'ep': 0, 'sz': 128, 'bs': 512, 'trndir': '' }, # Will this work? -- No idea! Should we try with mv2 baseline? ??? { 'ep': (0, 5), 'lr': (lr, lr * 2) }, # lr warmup is better with --init-bn0 { 'ep': 5, 'lr': lr }, { 'ep': 14, 'sz': 224, 'bs': 224, 'lr': lr * scale_224 }, { 'ep': 16, 'lr': lr / 10 * scale_224 }, { 'ep': 32, 'lr': lr / 100 * scale_224 }, { 'ep': 37, 'lr': lr / 100 * scale_224 }, { 'ep': 39, 'sz': 288, 'bs': 128, 'min_scale': 0.5, 'rect_val': True, 'lr': lr / 100 * scale_288 }, { 'ep': (40, 44), 'lr': lr / 1000 * scale_288 }, #{'ep': (36, 40), 'lr': lr / 1000 * scale_288}, { 'ep': (45, 48), 'lr': lr / 10000 * scale_288 }, { 'ep': (49, 52), 'sz': 288, 'bs': 224, 'lr': lr / 10000 * scale_224 } #{'ep': (46, 50), 'sz': 320, 'bs': 64, 'lr': lr / 10000 * scale_320} ] phases = util.text_pickle(one_machine) #Ok? Unpickle? phases = util.text_unpickle(phases) dm = DataManager([copy.deepcopy(p) for p in phases if 'bs' in p]) scheduler = Scheduler(optimizer, [copy.deepcopy(p) for p in phases if 'lr' in p]) start_time = datetime.now() # Loading start to after everything is loaded if args.evaluate: return validate(dm.val_dl, model, criterion, 0, start_time) if args.distributed: log.console('Syncing machines before training') dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) log.event("~~epoch\thours\ttop1\ttop5\n") for epoch in range(args.start_epoch, scheduler.tot_epochs): print(" The start epoch:", args.start_epoch) dm.set_epoch(epoch) train(dm.trn_dl, model, criterion, optimizer, scheduler, epoch) top1, top5 = validate(dm.val_dl, model, criterion, epoch, start_time) time_diff = (datetime.now() - start_time).total_seconds() / 3600.0 log.event(f'~~{epoch}\t{time_diff:.5f}\t\t{top1:.3f}\t\t{top5:.3f}\n') is_best = top5 > best_top5 best_top5 = max(top5, best_top5) phase_save = dm.get_phase(epoch) if args.local_rank == 0: if is_best: save_checkpoint(phase_save, epoch, model, best_top5, optimizer, is_best=True, filename='model_best_' + args.network + args.name + '.pth.tar') else: save_checkpoint(phase_save, epoch, model, top5, optimizer, is_best=False, filename='model_epoch_latest_' + args.network + args.name + '.pth.tar') phase = dm.get_phase(epoch) if phase: save_checkpoint( phase_save, epoch, model, best_top5, optimizer, filename=f'sz{phase["bs"]}_checkpoint.path.tar')