def main(): train_transform = trn.Compose([ trn.RandomHorizontalFlip(), trn.RandomCrop(32, padding=4), trn.ToTensor(), trn.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) test_transform = trn.Compose([ trn.ToTensor(), trn.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) if args.dataset == 'cifar10': print("Using CIFAR 10") train_data_in = dset.CIFAR10('/data/sauravkadavath/cifar10-dataset', train=True, transform=train_transform) test_data = dset.CIFAR10('/data/sauravkadavath/cifar10-dataset', train=False, transform=test_transform) num_classes = 10 else: print("Using CIFAR100") train_data_in = dset.CIFAR100('/data/sauravkadavath/cifar10-dataset', train=True, transform=train_transform) test_data = dset.CIFAR100('/data/sauravkadavath/cifar10-dataset', train=False, transform=test_transform) num_classes = 100 train_loader_in = torch.utils.data.DataLoader(train_data_in, batch_size=args.batch_size, shuffle=True, num_workers=args.prefetch, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.prefetch, pin_memory=True) net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate) net.cuda() optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'], weight_decay=state['decay'], nesterov=True) def cosine_annealing(step, total_steps, lr_max, lr_min): return lr_min + (lr_max - lr_min) * 0.5 * ( 1 + np.cos(step / total_steps * np.pi)) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: cosine_annealing( step, args.epochs * len(train_loader_in), 1, # since lr_lambda computes multiplicative factor 1e-6 / args.learning_rate)) # Make save directory if not os.path.exists(args.save): os.makedirs(args.save) if not os.path.isdir(args.save): raise Exception('%s is not a dir' % args.save) print('Beginning Training\n') with open(os.path.join(args.save, "training_log.csv"), 'w') as f: f.write() # Main loop for epoch in range(0, args.epochs): state['epoch'] = epoch begin_epoch = time.time() train(net, state, train_loader_in, optimizer, lr_scheduler) test(net, state, test_loader) # Save model torch.save( net.state_dict(), os.path.join( args.save, '{0}_{1}_layers_{2}_widenfactor_{3}_transform_epoch_{4}.pt'. format(args.dataset, args.model, str(args.layers), str(args.widen_factor), str(epoch)))) # Let us not waste space and delete the previous model prev_path = os.path.join( args.save, '{0}_{1}_layers_{2}_widenfactor_{3}_transform_epoch_{4}.pt'.format( args.dataset, args.model, str(args.layers), str(args.widen_factor), str(epoch - 1))) if os.path.exists(prev_path): os.remove(prev_path) # Show results print( 'Epoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} | Test Error {4:.2f}' .format((epoch + 1), int(time.time() - begin_epoch), state['train_loss'], state['test_loss'], 100 - 100. * state['test_accuracy']))
print('Beginning Training\n') # Main loop best_test_accuracy = 0 for epoch in range(start_epoch, opt.epochs + 1): state['epoch'] = epoch begin_epoch = time.time() train() test() # Save model if epoch > 10 and epoch % 10 == 0: torch.save( net.state_dict(), os.path.join( opt.save, opt.dataset + opt.model + '_epoch_' + str(epoch) + '.pt')) if state['test_accuracy'] > best_test_accuracy: best_test_accuracy = state['test_accuracy'] torch.save( net.state_dict(), os.path.join(opt.save, opt.dataset + opt.model + '_epoch_best.pt')) # Show results with open( os.path.join( opt.save, "log_" + opt.dataset + opt.model + '_training_results.csv'),
def main(index, args): if xm.is_master_ordinal(): print(state) # Acquires the (unique) Cloud TPU core corresponding to this process's index xla_device = xm.xla_device() if args.dataset == 'cifar10': train_data = dset.CIFAR10('~/cifarpy/', train=True, download=True, transform=train_transform) test_data = dset.CIFAR10('~/cifarpy/', train=False, download=True, transform=test_transform) num_classes = 10 else: train_data = dset.CIFAR100('~/cifarpy/', train=True, download=True, transform=train_transform) test_data = dset.CIFAR100('~/cifarpy/', train=False, download=True, transform=test_transform) num_classes = 100 train_sampler = torch.utils.data.distributed.DistributedSampler( train_data, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_data, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, num_workers=args.prefetch, drop_last=True, sampler=train_sampler) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_bs, num_workers=args.prefetch, drop_last=True, sampler=test_sampler) # Create model if args.model == 'wrn': net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate).train().to(xla_device) else: raise NotImplementedError() start_epoch = 0 optimizer = torch.optim.SGD(net.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.decay, nesterov=True) def cosine_annealing(step, total_steps, lr_max, lr_min): return lr_min + (lr_max - lr_min) * 0.5 * ( 1 + np.cos(step / total_steps * np.pi)) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: cosine_annealing( step, args.epochs * len(train_loader), 1, # since lr_lambda computes multiplicative factor 1e-6 / args.learning_rate)) print('Beginning Training') # Main loop for epoch in range(start_epoch, args.epochs): state['epoch'] = epoch begin_epoch = time.time() # Spawn a bunch of processes, one for each TPU core. train_loss = train(train_loader, net, optimizer, scheduler, xla_device, args) # Calculate test loss test_results = test(test_loader, net, xla_device, args) # Save model. Does sync between all processes xm.save( net.state_dict(), os.path.join( args.save, args.dataset + args.model + '_baseline_epoch_' + str(epoch) + '.pt')) # Record stuff all_train_losses = xm.rendezvous("calc_train_loss", payload=str(train_loss)) all_test_results = xm.rendezvous("calc_test_results", payload=str(test_results)) all_test_results = parse_test_results(all_test_results) if xm.is_master_ordinal(): all_train_losses = [float(L) for L in all_train_losses] train_loss = sum(all_train_losses) / float(len(all_train_losses)) state['train_loss'] = train_loss test_loss = sum([r[0] for r in all_test_results]) / sum( [r[2] for r in all_test_results]) test_acc = sum([r[1] for r in all_test_results]) / sum( [r[2] for r in all_test_results]) state['test_loss'] = test_loss state['test_accuracy'] = test_acc # Let us not waste space and delete the previous model prev_path = os.path.join( args.save, args.dataset + args.model + '_baseline_epoch_' + str(epoch - 1) + '.pt') if os.path.exists(prev_path): os.remove(prev_path) # Show results with open( os.path.join( args.save, args.dataset + args.model + '_baseline_training_results.csv'), 'a') as f: f.write('%03d,%05d,%0.6f,%0.5f,%0.2f\n' % ( (epoch + 1), time.time() - begin_epoch, state['train_loss'], state['test_loss'], 100 - 100. * state['test_accuracy'], )) print( 'Epoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} | Test Error {4:.2f}' .format((epoch + 1), int(time.time() - begin_epoch), state['train_loss'], state['test_loss'], 100 - 100. * state['test_accuracy'])) writer.add_scalar("test_loss", state["test_loss"], epoch + 1) writer.add_scalar("test_accuracy", state["test_accuracy"], epoch + 1) # Wait for master to finish Disk I/O above print("Finished with one epoch") xm.rendezvous("epoch_finish")