class LRFinder(): def __init__(self, optim, num_iter, low_lr=1e-7, high_lr=10): self.optim = optim self.num_iter = num_iter self.lambda_func = lambda batch: get_running_factor( low_lr, high_lr, self.num_iter, batch) self.scheduler = LambdaLR(optim, lr_lambda=self.lambda_func) self.learning_rates = [] self.losses = [] def find(self, model, input, loss_fn): smooth_loss = None best_loss = np.Inf stop_training = False num_epochs = int(np.floor(self.num_iter / len(input))) for epoch in list(range(num_epochs)): if stop_training: break model.train() self.optim.zero_grad() for i, batch in enumerate(input): print("Batch {} has learning rate {}".format( i, self.scheduler.get_lr())) images, masks = batch x = Variable(images) y = Variable(masks) pred = model.forward(x) loss = loss_fn(pred, y) if smooth_loss is not None: smooth_loss = ewma(smooth_loss, loss.item()) else: smooth_loss = loss.item() if smooth_loss < best_loss: best_loss = smooth_loss if smooth_loss > 4 * best_loss or np.isnan(smooth_loss): stop_training = True break self.learning_rates.append(self.scheduler.get_lr()) self.losses.append(loss.item()) loss.backward() self.optim.step() self.scheduler.step() self.optim.zero_grad() self.learning_rates = np.array(self.learning_rates).flatten() self.losses = np.array(self.losses) def get_learning_rates(self): return self.learning_rates def get_losses(self): return self.losses
def train(classifier, train_loader, test_loader, args): optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) best_train_loss = np.inf scheduler = LambdaLR( optimizer, lr_lambda=lambda step: get_lr( # pylint: disable=g-long-lambda step, args.epochs * len(train_loader), 1, # lr_lambda computes multiplicative factor 1e-6 / args.learning_rate)) for epoch in range(1, args.epochs + 1): train_loss, train_acc = run_epoch(classifier, train_loader, args, optimizer=optimizer, scheduler=scheduler) lr = scheduler.get_lr()[0] logger.info( 'Epoch: {}, lr: {:.4f}, training loss: {:.4f}, acc: {:.4f}.'. format(epoch, lr, train_loss, train_acc)) test_loss, test_acc = run_epoch(classifier, test_loader, args) logger.info("Test loss: {:.4f}, acc: {:.4f}".format( test_loss, test_acc)) if train_loss < best_train_loss: best_train_loss = train_loss save_name = 'resnet18_wd{}.pth'.format(args.weight_decay) state = classifier.state_dict() torch.save(state, save_name) logger.info( "==> New optimal training loss & saving checkpoint ...")
class Trainer(object): ''' An object that encapsulates model training ''' def __init__(self, config, model, dataloader, device): self.model = model self.config = config self.device = device self.stopped_early = False self.dataloader = dataloader self.validation_dataloader = dataloader self.last_checkpoint_time = time.time() if 'cuda' in device.type: self.model = nn.DataParallel(model.cuda()) if self.config.optimizer == "adam": self.optimizer = optim.Adam(model.parameters(), config.base_lr, betas=(0.9, 0.98), eps=1e-9) # self.optimizer = optim.Adam(model.parameters(), 1e-7, betas=(0.9, 0.98), eps=1e-9) if config.lr_scheduler == 'warmup': self.lr_scheduler = LambdaLR( self.optimizer, WarmupLRSchedule(config.warmup_steps)) elif config.lr_scheduler == 'warmup2': self.lr_scheduler = LambdaLR( self.optimizer, WarmupLRSchedule2(config.warmup_steps)) elif config.lr_scheduler == 'linear': self.lr_scheduler = LambdaLR( self.optimizer, LinearLRSchedule(config.base_lr, config.final_lr, config.max_steps)) elif config.lr_scheduler == 'exponential': self.lr_scheduler = ExponentialLR(self.optimizer, config.lr_decay) else: raise ValueError('Unknown learning rate scheduler!') elif self.config.optimizer == "sgd": print("using optimizer: SGD") self.optimizer = optim.SGD(model.parameters(), lr=config.base_lr, momentum=0.9) self.lr_scheduler = LambdaLR(self.optimizer, DummyLRSchedule(config.base_lr)) elif self.config.optimizer == "adam-fixed": print("using optimizer: adam with fixed learning rate") self.optimizer = optim.Adam(model.parameters(), config.base_lr, betas=(0.9, 0.98), eps=1e-9) self.lr_scheduler = LambdaLR(self.optimizer, DummyLRSchedule(config.base_lr)) else: raise ValueError('Unknown optimizer!') # Initialize the metrics metrics_path = os.path.join(self.config.checkpoint_directory, 'train_metrics.pt') self.metric_store = metrics.MetricStore(metrics_path) self.metric_store.add(metrics.Metric('oom', metrics.format_int, 't')) self.metric_store.add( metrics.Metric('nll', metrics.format_float, max_history=1000)) self.metric_store.add( metrics.Metric('lr', metrics.format_scientific, 'g', max_history=1)) self.metric_store.add( metrics.Metric('num_tok', metrics.format_int, 'a', max_history=1000)) # self.metric_store.add(metrics.Metric('time_per_batch', metrics.format_float, 'g', max_history=100000)) # self.metric_store.add(metrics.Metric('time_total', metrics.format_float, 'g', max_history=1)) if self.config.early_stopping: self.metric_store.add( metrics.Metric('vnll', metrics.format_float, 'g')) self.modules = { 'model': model, 'optimizer': self.optimizer, 'lr_scheduler': self.lr_scheduler } @property def dataset(self): ''' Get the dataset ''' return self.dataloader.dataset def train_epoch(self, epoch, experiment, verbose=0): ''' Run one training epoch ''' oom = self.metric_store['oom'] learning_rate = self.metric_store['lr'] num_tokens = self.metric_store['num_tok'] neg_log_likelihood = self.metric_store['nll'] def try_optimize(i, last=False): # optimize if: # 1) last and remainder # 2) not last and not remainder remainder = bool(i % self.config.accumulate_steps) if not last ^ remainder: next_lr = self.optimize() learning_rate.update(next_lr) experiment.log_metric('learning_rate', next_lr) return True return False def get_description(): description = f'Train #{epoch}' if verbose > 0: description += f' {self.metric_store}' if verbose > 1: description += f' [{profile.mem_stat_string(["allocated"])}]' return description batches = tqdm( self.dataloader, unit='batch', dynamic_ncols=True, desc=get_description(), file=sys.stdout # needed to make tqdm_wrap_stdout work ) with tqdm_wrap_stdout(): i = 1 nll_per_update = 0. length_per_update = 0 num_tokens_per_update = 0 for i, batch in enumerate(batches, 1): try: nll, length = self.calculate_gradient(batch) did_optimize = try_optimize(i) # record the effective number of tokens num_tokens_per_update += int(sum(batch['input_lens'])) num_tokens_per_update += int(sum(batch['target_lens'])) if length: # record length and nll nll_per_update += nll length_per_update += length if did_optimize: # advance the experiment step experiment.set_step(experiment.curr_step + 1) num_tokens.update(num_tokens_per_update) neg_log_likelihood.update(nll_per_update / length_per_update) experiment.log_metric('num_tokens', num_tokens_per_update) experiment.log_metric('nll', neg_log_likelihood.last_value) # experiment.log_metric('max_memory_alloc', torch.cuda.max_memory_allocated()//1024//1024) # experiment.log_metric('max_memory_cache', torch.cuda.max_memory_cached()//1024//1024) nll_per_update = 0. length_per_update = 0 num_tokens_per_update = 0 except RuntimeError as rte: if 'out of memory' in str(rte): torch.cuda.empty_cache() oom.update(1) experiment.log_metric('oom', oom.total) #exit(-1) else: batches.close() raise rte if self.should_checkpoint(): new_best = False if self.config.early_stopping: with tqdm_unwrap_stdout(): new_best = self.evaluate(experiment, epoch, verbose) self.checkpoint(epoch, experiment.curr_step, new_best) batches.set_description_str(get_description()) if self.is_done(experiment, epoch): batches.close() break try_optimize(i, last=True) def should_checkpoint(self): ''' Function which determines if a new checkpoint should be saved ''' return time.time( ) - self.last_checkpoint_time > self.config.checkpoint_interval def checkpoint(self, epoch, step, best=False): ''' Save a checkpoint ''' checkpoint_path = checkpoint( epoch, step, self.modules, self.config.checkpoint_directory, max_checkpoints=self.config.max_checkpoints) if best: dirname = os.path.dirname(checkpoint_path) basename = os.path.basename(checkpoint_path) best_checkpoint_path = os.path.join(dirname, f'best_{basename}') shutil.copy2(checkpoint_path, best_checkpoint_path) self.metric_store.save() self.last_checkpoint_time = time.time() def evaluate(self, experiment, epoch, verbose=0): ''' Evaluate the current model and determine if it is a new best ''' model = self.modules['model'] evaluator = Evaluator(args.ArgGroup(None), model, self.validation_dataloader, self.device) vnll = evaluator(epoch, experiment, verbose) metric = self.metric_store['vnll'] full_history = metric.values metric.update(vnll) self.metric_store.save() return all(vnll < nll for nll in full_history[:-1]) def is_done(self, experiment, epoch): ''' Has training completed ''' if self.config.max_steps and experiment.curr_step >= self.config.max_steps: return True if self.config.max_epochs and epoch >= self.config.max_epochs: return True if self.config.early_stopping: history = self.metric_store['vnll'].values[ -self.config.early_stopping - 1:] if len(history) == self.config.early_stopping + 1: self.stopped_early = all(history[-1] > nll for nll in history[:-1]) return self.stopped_early return False def optimize(self): ''' Calculate an optimization step ''' self.optimizer.step() self.optimizer.zero_grad() self.lr_scheduler.step() return self.lr_scheduler.get_lr()[0] def calculate_gradient(self, batch): ''' Runs one step of optimization ''' # run the data through the model self.model.train() loss, nll = self.model(batch) # nn.DataParallel wants to gather rather than doing a reduce_add, so the output here # will be a tensor of values that must be summed nll = nll.sum() loss = loss.sum() # calculate gradients then run an optimization step loss.backward() # need to use .item() which converts to Python scalar # because as a Tensor it accumulates gradients return nll.item(), torch.sum(batch['target_lens']).item() def __call__(self, start_epoch, experiment, verbose=0): ''' Execute training ''' with ExitStack() as stack: stack.enter_context(chunked_scattering()) stack.enter_context(experiment.train()) if start_epoch > 0 or experiment.curr_step > 0: # TODO: Hacky approach to decide if the metric store should be loaded. Revisit later self.metric_store = self.metric_store.load() epoch = start_epoch experiment.log_current_epoch(epoch) while not self.is_done(experiment, epoch): experiment.log_current_epoch(epoch) self.train_epoch(epoch, experiment, verbose) experiment.log_epoch_end(epoch) epoch += 1 if self.stopped_early: print('Stopping early!') else: new_best = False if self.config.early_stopping: new_best = self.evaluate(experiment, epoch, verbose) self.checkpoint(epoch, experiment.curr_step, new_best)
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = T.Compose( [T.Resize(args.resize_size), T.ToTensor(), normalize]) val_transform = T.Compose( [T.Resize(args.resize_size), T.ToTensor(), normalize]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) num_factors = train_source_dataset.num_factors backbone = models.__dict__[args.arch](pretrained=True) bottleneck_dim = args.bottleneck_dim if args.normalization == 'IN': backbone = convert_model(backbone) bottleneck = nn.Sequential( nn.Conv2d(backbone.out_features, bottleneck_dim, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(bottleneck_dim), nn.ReLU(), ) head = nn.Sequential( nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(bottleneck_dim), nn.ReLU(), nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(bottleneck_dim), nn.ReLU(), nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(bottleneck_dim, num_factors), nn.Sigmoid()) for layer in head: if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): nn.init.normal_(layer.weight, 0, 0.01) nn.init.constant_(layer.bias, 0) adv_head = nn.Sequential( nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(bottleneck_dim), nn.ReLU(), nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(bottleneck_dim), nn.ReLU(), nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(bottleneck_dim, num_factors), nn.Sigmoid()) for layer in adv_head: if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): nn.init.normal_(layer.weight, 0, 0.01) nn.init.constant_(layer.bias, 0) regressor = ImageRegressor(backbone, num_factors, bottleneck=bottleneck, head=head, adv_head=adv_head, bottleneck_dim=bottleneck_dim, width=bottleneck_dim) else: regressor = ImageRegressor(backbone, num_factors, bottleneck_dim=bottleneck_dim, width=bottleneck_dim) regressor = regressor.to(device) print(regressor) mdd = MarginDisparityDiscrepancy(args.margin).to(device) # define optimizer and lr scheduler optimizer = SGD(regressor.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR( optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x))**(-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') regressor.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) # extract features from both domains feature_extractor = nn.Sequential(regressor.backbone, regressor.bottleneck, regressor.head[:-2]).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': mae = validate(val_loader, regressor, args, train_source_dataset.factors, device) print(mae) return # start training best_mae = 100000. for epoch in range(args.epochs): # train for one epoch print("lr", lr_scheduler.get_lr()) train(train_source_iter, train_target_iter, regressor, mdd, optimizer, lr_scheduler, epoch, args) # evaluate on validation set mae = validate(val_loader, regressor, args, train_source_dataset.factors, device) # remember best mae and save checkpoint torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest')) if mae < best_mae: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_mae = min(mae, best_mae) print("mean MAE {:6.3f} best MAE {:6.3f}".format(mae, best_mae)) print("best_mae = {:6.3f}".format(best_mae)) logger.close()
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform( args.train_resizing, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device) # define loss function domain_adv = DomainAdversarialLoss().to(device) gl = WarmStartGradientLayer(alpha=1., lo=0., hi=1., max_iters=1000, auto_step=True) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) optimizer_d = SGD(domain_discri.get_parameters(), args.lr_d, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR( optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x))**(-args.lr_decay)) lr_scheduler_d = LambdaLR( optimizer_d, lambda x: args.lr_d * (1. + args.lr_gamma * float(x))**(-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): print("lr classifier:", lr_scheduler.get_lr()) print("lr discriminator:", lr_scheduler_d.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_discri, domain_adv, gl, optimizer, lr_scheduler, optimizer_d, lr_scheduler_d, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close()
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_transform = T.Compose([ T.RandomRotation(args.rotation), T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale), T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25), T.GaussianBlur(), T.ToTensor(), normalize ]) val_transform = T.Compose( [T.Resize(args.image_size), T.ToTensor(), normalize]) image_size = (args.image_size, args.image_size) heatmap_size = (args.heatmap_size, args.heatmap_size) source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform, image_size=image_size, heatmap_size=heatmap_size) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_source_dataset = source_dataset(root=args.source_root, split='test', transforms=val_transform, image_size=image_size, heatmap_size=heatmap_size) val_source_loader = DataLoader(val_source_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform, image_size=image_size, heatmap_size=heatmap_size) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_target_dataset = target_dataset(root=args.target_root, split='test', transforms=val_transform, image_size=image_size, heatmap_size=heatmap_size) val_target_loader = DataLoader(val_target_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True) print("Source train:", len(train_source_loader)) print("Target train:", len(train_target_loader)) print("Source test:", len(val_source_loader)) print("Target test:", len(val_target_loader)) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model backbone = models.__dict__[args.arch](pretrained=True) upsampling = Upsampling(backbone.out_features) num_keypoints = train_source_dataset.num_keypoints model = RegDAPoseResNet(backbone, upsampling, 256, num_keypoints, num_head_layers=args.num_head_layers, finetune=True).to(device) # define loss function criterion = JointsKLLoss() pseudo_label_generator = PseudoLabelGenerator(num_keypoints, args.heatmap_size, args.heatmap_size) regression_disparity = RegressionDisparity(pseudo_label_generator, JointsKLLoss(epsilon=1e-7)) # define optimizer and lr scheduler optimizer_f = SGD([ { 'params': backbone.parameters(), 'lr': 0.1 }, { 'params': upsampling.parameters(), 'lr': 0.1 }, ], lr=0.1, momentum=args.momentum, weight_decay=args.wd, nesterov=True) optimizer_h = SGD(model.head.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True) optimizer_h_adv = SGD(model.head_adv.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_decay_function = lambda x: args.lr * (1. + args.lr_gamma * float(x))**( -args.lr_decay) lr_scheduler_f = LambdaLR(optimizer_f, lr_decay_function) lr_scheduler_h = LambdaLR(optimizer_h, lr_decay_function) lr_scheduler_h_adv = LambdaLR(optimizer_h_adv, lr_decay_function) start_epoch = 0 if args.resume is None: if args.pretrain is None: # first pretrain the backbone and upsampling print("Pretraining the model on source domain.") args.pretrain = logger.get_checkpoint_path('pretrain') pretrained_model = PoseResNet(backbone, upsampling, 256, num_keypoints, True).to(device) optimizer = SGD(pretrained_model.get_parameters(lr=args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor) best_acc = 0 for epoch in range(args.pretrain_epochs): lr_scheduler.step() print(lr_scheduler.get_lr()) pretrain(train_source_iter, pretrained_model, criterion, optimizer, epoch, args) source_val_acc = validate(val_source_loader, pretrained_model, criterion, None, args) # remember best acc and save checkpoint if source_val_acc['all'] > best_acc: best_acc = source_val_acc['all'] torch.save({'model': pretrained_model.state_dict()}, args.pretrain) print("Source: {} best: {}".format(source_val_acc['all'], best_acc)) # load from the pretrained checkpoint pretrained_dict = torch.load(args.pretrain, map_location='cpu')['model'] model_dict = model.state_dict() # remove keys from pretrained dict that doesn't appear in model dict pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model.load_state_dict(pretrained_dict, strict=False) else: # optionally resume from a checkpoint checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model']) optimizer_f.load_state_dict(checkpoint['optimizer_f']) optimizer_h.load_state_dict(checkpoint['optimizer_h']) optimizer_h_adv.load_state_dict(checkpoint['optimizer_h_adv']) lr_scheduler_f.load_state_dict(checkpoint['lr_scheduler_f']) lr_scheduler_h.load_state_dict(checkpoint['lr_scheduler_h']) lr_scheduler_h_adv.load_state_dict(checkpoint['lr_scheduler_h_adv']) start_epoch = checkpoint['epoch'] + 1 # define visualization function tensor_to_image = Compose([ Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ToPILImage() ]) def visualize(image, keypoint2d, name, heatmaps=None): """ Args: image (tensor): image in shape 3 x H x W keypoint2d (tensor): keypoints in shape K x 2 name: name of the saving image """ train_source_dataset.visualize( tensor_to_image(image), keypoint2d, logger.get_image_path("{}.jpg".format(name))) if args.phase == 'test': # evaluate on validation set source_val_acc = validate(val_source_loader, model, criterion, None, args) target_val_acc = validate(val_target_loader, model, criterion, visualize, args) print("Source: {:4.3f} Target: {:4.3f}".format(source_val_acc['all'], target_val_acc['all'])) for name, acc in target_val_acc.items(): print("{}: {:4.3f}".format(name, acc)) return # start training best_acc = 0 print("Start regression domain adaptation.") for epoch in range(start_epoch, args.epochs): logger.set_epoch(epoch) print(lr_scheduler_f.get_lr(), lr_scheduler_h.get_lr(), lr_scheduler_h_adv.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, model, criterion, regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch, visualize if args.debug else None, args) # evaluate on validation set source_val_acc = validate(val_source_loader, model, criterion, None, args) target_val_acc = validate(val_target_loader, model, criterion, visualize if args.debug else None, args) # remember best acc and save checkpoint torch.save( { 'model': model.state_dict(), 'optimizer_f': optimizer_f.state_dict(), 'optimizer_h': optimizer_h.state_dict(), 'optimizer_h_adv': optimizer_h_adv.state_dict(), 'lr_scheduler_f': lr_scheduler_f.state_dict(), 'lr_scheduler_h': lr_scheduler_h.state_dict(), 'lr_scheduler_h_adv': lr_scheduler_h_adv.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch)) if target_val_acc['all'] > best_acc: shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best')) best_acc = target_val_acc['all'] print("Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}".format( source_val_acc['all'], target_val_acc['all'], best_acc)) for name, acc in target_val_acc.items(): print("{}: {:4.3f}".format(name, acc)) logger.close()
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code if args.num_channels == 3: mode = 'RGB' mean = std = [0.5, 0.5, 0.5] else: mode = 'L' mean = std = [ 0.5, ] normalize = T.Normalize(mean=mean, std=std) train_transform = T.Compose([ ResizeImage(args.image_size), # T.RandomRotation(10), # TODO need results T.ToTensor(), normalize ]) val_transform = T.Compose( [ResizeImage(args.image_size), T.ToTensor(), normalize]) source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset(root=args.source_root, mode=mode, download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset(root=args.target_root, mode=mode, download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = target_dataset(root=args.target_root, mode=mode, split='test', download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) arch = models.__dict__[args.arch]() bottleneck = nn.Sequential( nn.Flatten(), nn.Linear(arch.bottleneck_dim, arch.bottleneck_dim), nn.BatchNorm1d(arch.bottleneck_dim), nn.ReLU(), nn.Dropout(0.5)) head = arch.head() adv_head = arch.head() classifier = GeneralModule(arch.backbone(), arch.num_classes, bottleneck, head, adv_head, finetune=False) mdd = MarginDisparityDiscrepancy(args.margin).to(device) # define optimizer and lr scheduler optimizer = Adam(classifier.get_parameters(), args.lr, betas=args.betas, weight_decay=args.wd) lr_scheduler = LambdaLR( optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x))**(-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = torch.nn.Sequential( classifier.backbone, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device, 10) target_feature = collect_feature(val_loader, feature_extractor, device, 10) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = validate(val_loader, classifier, args) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, classifier, mdd, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = validate(val_loader, classifier, args) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) logger.close()
def train(model, optimizer, train_data, val_data, params, metric=accuracy_score, criterion=nn.CrossEntropyLoss(), variable_created_by_model=True): mean_train_loss = [] mean_val_loss = [] mean_train_metric = [] mean_val_metric = [] scheduler = LambdaLR( optimizer, lr_lambda=lambda epoch: 0.5**(epoch // params["lr_ep_step"])) for epoch in range(params["epochs"]): start_time = time.time() scheduler.step() print("current lr = {}".format(scheduler.get_lr()[0])) train_loss, train_preds, train_targets = train_one_epoch( model, optimizer, train_data, params, criterion, variable_created_by_model) val_loss, val_preds, val_targets = validate(model, val_data, params, criterion, variable_created_by_model) # print the results for this epoch: mean_train_loss.append(np.mean(train_loss)) mean_val_loss.append(np.mean(val_loss)) mean_train_metric.append(metric(train_targets, train_preds)) mean_val_metric.append(metric(val_targets, val_preds)) clear_output(True) plt.figure(figsize=(10, 5)) plt.subplot(121) plt.plot(mean_train_loss) plt.plot(mean_val_loss) plt.subplot(122) plt.plot(mean_train_metric) plt.plot(mean_val_metric) plt.gca().set_ylim([0, 1]) plt.show() print("Epoch {} of {} took {:.3f}s".format(epoch + 1, params["epochs"], time.time() - start_time)) print(" training loss (in-iteration): \t{:.6f}".format( mean_train_loss[-1])) print(" validation loss: \t\t\t{:.6f}".format(mean_val_loss[-1])) print(" training metric: \t\t\t{:.2f}".format(mean_train_metric[-1])) print(" validation metric: \t\t\t{:.2f}".format(mean_val_metric[-1])) # if mean_train_loss[-1] < epsilon: # break return mean_train_loss, mean_val_loss, mean_train_metric, mean_val_metric # ? def cross_val_trains
def train(model, tokenizer, train_data, valid_data, args): model.train() train_dataset = TextDataset(train_data) train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=args.train_batch_size, num_workers=args.num_workers, collate_fn=lambda x: collate_fn_bert( x, tokenizer, args.max_seq_length)) valid_dataset = TextDataset(valid_data) valid_dataloader = DataLoader(valid_dataset, sampler=SequentialSampler(valid_dataset), batch_size=args.eval_batch_size, num_workers=args.num_workers, collate_fn=lambda x: collate_fn_bert( x, tokenizer, args.max_seq_length)) valid_noisy = [x['noisy'] for x in valid_data] valid_clean = [x['clean'] for x in valid_data] epochs = (args.max_steps - 1) // len(train_dataloader) + 1 # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, # betas=eval(args.adam_betas), eps=args.eps, # weight_decay=args.weight_decay) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) lr_lambda = lambda x: x / args.num_warmup_steps if x <= args.num_warmup_steps else ( x / args.num_warmup_steps)**-0.5 scheduler = LambdaLR(optimizer, lr_lambda) step = 0 best_val_gleu = -float("inf") meter = Meter() for epoch in range(1, epochs + 1): for batch in train_dataloader: step += 1 batch = tuple(t.to(args.device) for t in batch) noise_input_ids, clean_input_ids, noise_mask, clean_mask = batch #print("noise shape: {}, clean shape: {}".format(noise_input_ids.shape, clean_input_ids.shape)) outputs = model(noise_input_ids, labels=clean_input_ids, attention_mask=noise_mask) loss = outputs[0] predict_score = outputs[1] bsz = clean_input_ids.size(0) items = [loss.data.item(), bsz, clean_mask.sum().item()] #print("items: ", items) meter.add(*items) loss.backward() if args.max_grad_norm > 0: nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() model.zero_grad() scheduler.step() if step % args.log_interval == 0: lr = scheduler.get_lr()[0] loss_sent, loss_token = meter.average() logger.info( f' [{step:5d}] lr {lr:.6f} | {meter.print_str(True)}') nsml.report(step=step, scope=locals(), summary=True, train__lr=lr, train__loss_sent=loss_sent, train__token_ppl=math.exp(loss_token)) meter.init() if step % args.eval_interval == 0: start_eval = time.time() (val_loss, val_loss_token), valid_str = evaluate_kcBert( model, valid_dataloader, args) prediction = correct_kcBert(model, tokenizer, valid_noisy, args, length_limit=0.1) val_em = em(prediction, valid_clean) cnt = 0 for noisy, pred, clean in zip(valid_noisy, prediction, valid_clean): print(f'[{noisy}], [{pred}], [{clean}]') # 10개만 출력하기 cnt += 1 if cnt == 20: break # print("len of prediction: {}, len of valid_clean: {}", len(prediction), len(valid_clean)) val_gleu = gleu(prediction, valid_clean) logger.info('-' * 89) logger.info( f' [{step:6d}] valid | {valid_str} | em {val_em:5.2f} | gleu {val_gleu:5.2f}' ) logger.info('-' * 89) nsml.report(step=step, scope=locals(), summary=True, valid__loss_sent=val_loss, valid__token_ppl=math.exp(val_loss_token), valid__em=val_em, valid__gleu=val_gleu) if val_gleu > best_val_gleu: best_val_gleu = val_gleu nsml.save("best") meter.start += time.time() - start_eval if step >= args.max_steps: break if step >= args.max_steps: break
class PPOAgent(BaseAgent): actor: nn.Module critic: nn.Module same_body: float = False def __post_init__(self): move_to([self.actor, self.critic], device=cfg.alg.device) if cfg.alg.vf_loss_type == 'mse': self.val_loss_criterion = nn.MSELoss().to(cfg.alg.device) elif cfg.alg.vf_loss_type == 'smoothl1': self.val_loss_criterion = nn.SmoothL1Loss().to(cfg.alg.device) else: raise TypeError( f'Unknown value loss type: {cfg.alg.vf_loss_type}!') all_params = list(self.actor.parameters()) + list( self.critic.parameters()) # keep unique elements only. The following code works for python >=3.7 # for earlier version of python, u need to use OrderedDict self.all_params = dict.fromkeys(all_params).keys() if (cfg.alg.linear_decay_lr or cfg.alg.linear_decay_clip_range) and \ cfg.alg.max_steps > cfg.alg.max_decay_steps: logger.warning( 'max_steps should not be greater than max_decay_steps.') cfg.alg.max_decay_steps = int(cfg.alg.max_steps * 1.5) logger.warning( f'Resetting max_decay_steps to {cfg.alg.max_decay_steps}!') total_epochs = int( np.ceil(cfg.alg.max_decay_steps / (cfg.alg.num_envs * cfg.alg.episode_steps))) if cfg.alg.linear_decay_clip_range: self.clip_range_decay_rate = cfg.alg.clip_range / float( total_epochs) p_lr_lambda = partial(linear_decay_percent, total_epochs=total_epochs) optim_args = dict(lr=cfg.alg.policy_lr, weight_decay=cfg.alg.weight_decay) if not cfg.alg.sgd: optim_args['amsgrad'] = cfg.alg.use_amsgrad optim_func = optim.Adam else: optim_args['nesterov'] = True if cfg.alg.momentum > 0 else False optim_args['momentum'] = cfg.alg.momentum optim_func = optim.SGD if self.same_body: optim_args['params'] = self.all_params else: optim_args['params'] = [{ 'params': self.actor.parameters(), 'lr': cfg.alg.policy_lr }, { 'params': self.critic.parameters(), 'lr': cfg.alg.value_lr }] self.optimizer = optim_func(**optim_args) if self.same_body: self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=[p_lr_lambda]) else: v_lr_lambda = partial(linear_decay_percent, total_epochs=total_epochs) self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=[p_lr_lambda, v_lr_lambda]) @torch.no_grad() def get_action(self, ob, sample=True, *args, **kwargs): self.eval_mode() if type(ob) is dict: t_ob = { key: torch_float(ob[key], device=cfg.alg.device) for key in ob } else: t_ob = torch_float(ob, device=cfg.alg.device) act_dist, val = self.get_act_val(t_ob) action = action_from_dist(act_dist, sample=sample) log_prob = action_log_prob(action, act_dist) entropy = action_entropy(act_dist, log_prob) action_info = dict(log_prob=torch_to_np(log_prob), entropy=torch_to_np(entropy), val=torch_to_np(val)) return torch_to_np(action), action_info def get_act_val(self, ob, *args, **kwargs): if type(ob) is dict: ob = { key: torch_float(ob[key], device=cfg.alg.device) for key in ob } else: ob = torch_float(ob, device=cfg.alg.device) act_dist, body_out = self.actor(ob) if self.same_body: val, body_out = self.critic(body_x=body_out) else: val, body_out = self.critic(x=ob) val = val.squeeze(-1) return act_dist, val @torch.no_grad() def get_val(self, ob, *args, **kwargs): self.eval_mode() if type(ob) is dict: ob = { key: torch_float(ob[key], device=cfg.alg.device) for key in ob } else: ob = torch_float(ob, device=cfg.alg.device) val, body_out = self.critic(x=ob) val = val.squeeze(-1) return val def optimize(self, data, *args, **kwargs): pre_res = self.optim_preprocess(data) processed_data = pre_res processed_data['entropy'] = torch.mean(processed_data['entropy']) loss_res = self.cal_loss(**processed_data) loss, pg_loss, vf_loss, ratio = loss_res self.optimizer.zero_grad() loss.backward() grad_norm = clip_grad(self.all_params, cfg.alg.max_grad_norm) self.optimizer.step() with torch.no_grad(): approx_kl = 0.5 * torch.mean( torch.pow( processed_data['old_log_prob'] - processed_data['log_prob'], 2)) clip_frac = np.mean( np.abs(torch_to_np(ratio) - 1.0) > cfg.alg.clip_range) optim_info = dict(pg_loss=pg_loss.item(), vf_loss=vf_loss.item(), total_loss=loss.item(), entropy=processed_data['entropy'].item(), approx_kl=approx_kl.item(), clip_frac=clip_frac) optim_info['grad_norm'] = grad_norm return optim_info def optim_preprocess(self, data): self.train_mode() for key, val in data.items(): data[key] = torch_float(val, device=cfg.alg.device) ob = data['ob'] state = data['state'] action = data['action'] ret = data['ret'] adv = data['adv'] old_log_prob = data['log_prob'] old_val = data['val'] act_dist, val = self.get_act_val({"ob": ob, "state": state}) log_prob = action_log_prob(action, act_dist) entropy = action_entropy(act_dist, log_prob) if not all([x.ndim == 1 for x in [val, entropy, log_prob]]): raise ValueError('val, entropy, log_prob should be 1-dim!') processed_data = dict(val=val, old_val=old_val, ret=ret, log_prob=log_prob, old_log_prob=old_log_prob, adv=adv, entropy=entropy) return processed_data def cal_loss(self, val, old_val, ret, log_prob, old_log_prob, adv, entropy): vf_loss = self.cal_val_loss(val=val, old_val=old_val, ret=ret) ratio = torch.exp(log_prob - old_log_prob) surr1 = adv * ratio surr2 = adv * torch.clamp(ratio, 1 - cfg.alg.clip_range, 1 + cfg.alg.clip_range) pg_loss = -torch.mean(torch.min(surr1, surr2)) loss = pg_loss - entropy * cfg.alg.ent_coef + \ vf_loss * cfg.alg.vf_coef return loss, pg_loss, vf_loss, ratio def cal_val_loss(self, val, old_val, ret): if cfg.alg.clip_vf_loss: clipped_val = old_val + torch.clamp( val - old_val, -cfg.alg.clip_range, cfg.alg.clip_range) vf_loss1 = torch.pow(val - ret, 2) vf_loss2 = torch.pow(clipped_val - ret, 2) vf_loss = 0.5 * torch.mean(torch.max(vf_loss1, vf_loss2)) else: # val = torch.squeeze(val) vf_loss = 0.5 * self.val_loss_criterion(val, ret) return vf_loss def train_mode(self): self.actor.train() self.critic.train() def eval_mode(self): self.actor.eval() self.critic.eval() def decay_lr(self): self.lr_scheduler.step() def get_lr(self): cur_lr = self.lr_scheduler.get_lr() lrs = {'policy_lr': cur_lr[0]} if len(cur_lr) > 1: lrs['value_lr'] = cur_lr[1] return lrs def decay_clip_range(self): cfg.alg.clip_range -= self.clip_range_decay_rate def save_model(self, is_best=False, step=None): self.save_env(cfg.alg.model_dir) data_to_save = { 'step': step, 'actor_state_dict': self.actor.state_dict(), 'critic_state_dict': self.critic.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), 'lr_scheduler_state_dict': self.lr_scheduler.state_dict() } if cfg.alg.linear_decay_clip_range: data_to_save['clip_range'] = cfg.alg.clip_range data_to_save['clip_range_decay_rate'] = self.clip_range_decay_rate save_model(data_to_save, cfg.alg, is_best=is_best, step=step) def load_model(self, step=None, pretrain_model=None): self.load_env(cfg.alg.model_dir) ckpt_data = load_ckpt_data(cfg.alg, step=step, pretrain_model=pretrain_model) load_state_dict(self.actor, ckpt_data['actor_state_dict']) load_state_dict(self.critic, ckpt_data['critic_state_dict']) if pretrain_model is not None: return self.optimizer.load_state_dict(ckpt_data['optim_state_dict']) self.lr_scheduler.load_state_dict(ckpt_data['lr_scheduler_state_dict']) if cfg.alg.linear_decay_clip_range: self.clip_range_decay_rate = ckpt_data['clip_range_decay_rate'] cfg.alg.clip_range = ckpt_data['clip_range'] return ckpt_data['step'] def print_param_grad_status(self): logger.info('Requires Grad?') logger.info('================== Actor ================== ') for name, param in self.actor.named_parameters(): print(f'{name}: {param.requires_grad}') logger.info('================== Critic ================== ') for name, param in self.critic.named_parameters(): print(f'{name}: {param.requires_grad}')
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = T.Compose([ ResizeImage(256), T.RandomCrop(224), T.RandomHorizontalFlip(), T.ColorJitter(brightness=0.7, contrast=0.7, saturation=0.7, hue=0.5), T.RandomGrayscale(), T.ToTensor(), normalize ]) val_transform = T.Compose( [ResizeImage(256), T.CenterCrop(224), T.ToTensor(), normalize]) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform, MultipleApply([train_transform, val_transform])) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) # define optimizer and lr scheduler optimizer = Adam(classifier.get_parameters(), args.lr) lr_scheduler = LambdaLR( optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x))**(-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return if args.pretrain is None: # first pretrain the classifier wish source data print("Pretraining the model on source domain.") args.pretrain = logger.get_checkpoint_path('pretrain') pretrain_model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) pretrain_optimizer = Adam(pretrain_model.get_parameters(), args.pretrain_lr) pretrain_lr_scheduler = LambdaLR( pretrain_optimizer, lambda x: args.pretrain_lr * (1. + args.lr_gamma * float(x))**(-args.lr_decay)) # start pretraining for epoch in range(args.pretrain_epochs): # pretrain for one epoch utils.pretrain(train_source_iter, pretrain_model, pretrain_optimizer, pretrain_lr_scheduler, epoch, args, device) # validate to show pretrain process utils.validate(val_loader, pretrain_model, args, device) torch.save(pretrain_model.state_dict(), args.pretrain) print("Pretraining process is done.") checkpoint = torch.load(args.pretrain, map_location='cpu') classifier.load_state_dict(checkpoint) teacher = EmaTeacher(classifier, alpha=args.alpha) consistent_loss = L2ConsistencyLoss().to(device) class_balance_loss = ClassBalanceLoss(num_classes).to(device) # start training best_acc1 = 0. for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, classifier, teacher, consistent_loss, class_balance_loss, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close()
def train(self) -> None: r"""Main method for training PPO. Returns: None """ logger.info(f"config: {self.config}") random.seed(self.config.SEED) np.random.seed(self.config.SEED) torch.manual_seed(self.config.SEED) # add_signal_handlers() self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME), workers_ignore_signals=True) ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) if ppo_cfg.use_external_memory: memory_dim = self.actor_critic.net.memory_dim else: memory_dim = None rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size, ppo_cfg.use_external_memory, ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size + ppo_cfg.num_steps, ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size, memory_dim, ) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations) if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.update(batch, None) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1), reward=torch.zeros(self.envs.num_envs, 1), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state( model_dir=self.config.MODEL_DIR) if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optimizer_state"]) lr_scheduler.load_state_dict( interrupted_state["lr_scheduler_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set(): requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state(dict( state_dict=self.agent.state_dict(), optimizer_state=self.agent.optimizer.state_dict(), lr_scheduler_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ), model_dir=self.config.MODEL_DIR) requeue_job() return for step in range(ppo_cfg.num_steps): delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step( rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps delta_pth_time, value_loss, action_loss, dist_entropy = self._update_agent( ppo_cfg, rollouts) pth_time += delta_pth_time deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("Metrics/reward", deltas["reward"] / deltas["count"], count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: # writer.add_scalars("metrics", metrics, count_steps) for metric, value in metrics.items(): writer.add_scalar(f"Metrics/{metric}", value, count_steps) writer.add_scalar("Policy/value_loss", value_loss, count_steps) writer.add_scalar("Policy/policy_loss", action_loss, count_steps) writer.add_scalar("Policy/entropy_loss", dist_entropy, count_steps) writer.add_scalar('Policy/learning_rate', lr_scheduler.get_lr()[0], count_steps) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / (time.time() - t_start))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join("{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth") count_checkpoints += 1 self.envs.close()
def train(data_path: str, data_directory: str, generate_vocabularies: bool, input_vocab_path: str, target_vocab_path: str, embedding_dimension: int, num_encoder_layers: int, encoder_dropout_p: float, encoder_bidirectional: bool, training_batch_size: int, test_batch_size: int, max_decoding_steps: int, num_decoder_layers: int, decoder_dropout_p: float, cnn_kernel_size: int, cnn_dropout_p: float, cnn_hidden_num_channels: int, simple_situation_representation: bool, decoder_hidden_size: int, encoder_hidden_size: int, learning_rate: float, adam_beta_1: float, adam_beta_2: float, lr_decay: float, lr_decay_steps: int, resume_from_file: str, max_training_iterations: int, output_directory: str, print_every: int, evaluate_every: int, conditional_attention: bool, auxiliary_task: bool, weight_target_loss: float, attention_type: str, max_training_examples=None, seed=42, **kwargs): device = torch.device(type='cuda') if use_cuda else torch.device( type='cpu') cfg = locals().copy() torch.manual_seed(seed) logger.info("Loading Training set...") training_set = GroundedScanDataset( data_path, data_directory, split="train", input_vocabulary_file=input_vocab_path, target_vocabulary_file=target_vocab_path, generate_vocabulary=generate_vocabularies) training_set.read_dataset( max_examples=max_training_examples, simple_situation_representation=simple_situation_representation) logger.info("Done Loading Training set.") logger.info(" Loaded {} training examples.".format( training_set.num_examples)) logger.info(" Input vocabulary size training set: {}".format( training_set.input_vocabulary_size)) logger.info(" Most common input words: {}".format( training_set.input_vocabulary.most_common(5))) logger.info(" Output vocabulary size training set: {}".format( training_set.target_vocabulary_size)) logger.info(" Most common target words: {}".format( training_set.target_vocabulary.most_common(5))) if generate_vocabularies: training_set.save_vocabularies(input_vocab_path, target_vocab_path) logger.info( "Saved vocabularies to {} for input and {} for target.".format( input_vocab_path, target_vocab_path)) logger.info("Loading Test set...") test_set = GroundedScanDataset( data_path, data_directory, split="test", # TODO: use dev set here input_vocabulary_file=input_vocab_path, target_vocabulary_file=target_vocab_path, generate_vocabulary=False) test_set.read_dataset( max_examples=None, simple_situation_representation=simple_situation_representation) # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set. test_set.shuffle_data() logger.info("Done Loading Test set.") model = Model(input_vocabulary_size=training_set.input_vocabulary_size, target_vocabulary_size=training_set.target_vocabulary_size, num_cnn_channels=training_set.image_channels, input_padding_idx=training_set.input_vocabulary.pad_idx, target_pad_idx=training_set.target_vocabulary.pad_idx, target_eos_idx=training_set.target_vocabulary.eos_idx, **cfg) model = model.cuda() if use_cuda else model log_parameters(model) trainable_parameters = [ parameter for parameter in model.parameters() if parameter.requires_grad ] optimizer = torch.optim.Adam(trainable_parameters, lr=learning_rate, betas=(adam_beta_1, adam_beta_2)) scheduler = LambdaLR(optimizer, lr_lambda=lambda t: lr_decay**(t / lr_decay_steps)) # Load model and vocabularies if resuming. start_iteration = 1 best_iteration = 1 best_accuracy = 0 best_exact_match = 0 best_loss = float('inf') if resume_from_file: assert os.path.isfile( resume_from_file), "No checkpoint found at {}".format( resume_from_file) logger.info( "Loading checkpoint from file at '{}'".format(resume_from_file)) optimizer_state_dict = model.load_model(resume_from_file) optimizer.load_state_dict(optimizer_state_dict) start_iteration = model.trained_iterations logger.info("Loaded checkpoint '{}' (iter {})".format( resume_from_file, start_iteration)) logger.info("Training starts..") training_iteration = start_iteration while training_iteration < max_training_iterations: # Shuffle the dataset and loop over it. training_set.shuffle_data() for (input_batch, input_lengths, _, situation_batch, _, target_batch, target_lengths, agent_positions, target_positions) in training_set.get_data_iterator( batch_size=training_batch_size): is_best = False model.train() # Forward pass. target_scores, target_position_scores = model( commands_input=input_batch, commands_lengths=input_lengths, situations_input=situation_batch, target_batch=target_batch, target_lengths=target_lengths) loss = model.get_loss(target_scores, target_batch) if auxiliary_task: target_loss = model.get_auxiliary_loss(target_position_scores, target_positions) else: target_loss = 0 loss += weight_target_loss * target_loss # Backward pass and update model parameters. loss.backward() optimizer.step() scheduler.step() optimizer.zero_grad() model.update_state(is_best=is_best) # Print current metrics. if training_iteration % print_every == 0: accuracy, exact_match = model.get_metrics( target_scores, target_batch) if auxiliary_task: auxiliary_accuracy_target = model.get_auxiliary_accuracy( target_position_scores, target_positions) else: auxiliary_accuracy_target = 0. learning_rate = scheduler.get_lr()[0] logger.info( "Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f," " aux. accuracy target pos %5.2f" % (training_iteration, loss, accuracy, exact_match, learning_rate, auxiliary_accuracy_target)) # Evaluate on test set. if training_iteration % evaluate_every == 0: with torch.no_grad(): model.eval() logger.info("Evaluating..") accuracy, exact_match, target_accuracy = evaluate( test_set.get_data_iterator(batch_size=1), model=model, max_decoding_steps=max_decoding_steps, pad_idx=test_set.target_vocabulary.pad_idx, sos_idx=test_set.target_vocabulary.sos_idx, eos_idx=test_set.target_vocabulary.eos_idx, max_examples_to_evaluate=kwargs["max_testing_examples"] ) logger.info( " Evaluation Accuracy: %5.2f Exact Match: %5.2f " " Target Accuracy: %5.2f" % (accuracy, exact_match, target_accuracy)) if exact_match > best_exact_match: is_best = True best_accuracy = accuracy best_exact_match = exact_match model.update_state(accuracy=accuracy, exact_match=exact_match, is_best=is_best) file_name = "checkpoint.pth.tar".format( str(training_iteration)) if is_best: model.save_checkpoint( file_name=file_name, is_best=is_best, optimizer_state_dict=optimizer.state_dict()) training_iteration += 1 if training_iteration > max_training_iterations: break logger.info("Finished training.")
def train(args, data_loader): device = torch.device('cuda:0') torch.cuda.set_device(device) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) stream = torch.cuda.Stream(device) def to_device(args): x1, x2 = args with torch.cuda.stream(stream): x1 = x1.to(device, non_blocking=True) x2 = x2.to(device, non_blocking=True) return x1, x2 train_data = DataQueue(data_loader, max_queue_size=30, nb_worker=10) train_loader = Prefetcher(train_data, postprocess=to_device, buffer_size=1, stream=stream) model_cfg = ConfigParser() model_cfg.read(args.conf) max_grad_value = model_cfg.getfloat(args.model_type, "max_grad_value") max_grad_norm = model_cfg.getfloat(args.model_type, "max_grad_norm") beta = args.beta model_q = build_model(args.conf, model_type=args.model_type, write_back=True) model_k = build_model(args.conf, model_type=args.model_type, write_back=True) momentum_update(model_q, model_k, 1 - beta) embedding_size = model_cfg.getint(args.model_type, "embedding_size") memory = MemoryMoCo(embedding_size, args.mem_queue_size) if torch.cuda.device_count() > 1: model_q = nn.DataParallel(model_q, dim=0) model_k = nn.DataParallel(model_k, dim=0) model_q.to(device) model_k.to(device) memory.to(device) num_epochs = args.num_epochs warmup_lr = args.warmup_lr initial_lr = args.initial_lr final_lr = args.final_lr lr_anneal = (final_lr / initial_lr)**(1. / num_epochs) lr_decline = (initial_lr - final_lr) / num_epochs # Set up learning rate scheduler def get_learning_rate(epoch): """Compute learning rate of given epoch. Users can design different strategy to alter learning rate, Please make sure global variables like final_lr、 lr_decline、initial_lr are assigned before. """ if args.linear_decay: this_lr = max(final_lr, initial_lr - lr_decline * epoch) else: this_lr = max(final_lr, initial_lr * lr_anneal**epoch) if epoch == 0 and warmup_lr > 0: # use warmup lr instead this_lr = warmup_lr return this_lr momentum = 0.9 nesterov = False weight_decay = 1e-5 # Optimizer optimizer = torch.optim.SGD(model_q.parameters(), lr=initial_lr, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay) lr_lambda = lambda epoch: get_learning_rate(epoch) / initial_lr lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) model_dir = args.model_dir log_dir = os.path.join(model_dir, f'log/dist{torch.cuda.device_count()}') writer = SummaryWriter(log_dir) # Train the Model checkpoint_period = args.checkpoint_period for epoch in range(args.start_epoch, num_epochs): if args.exit_epoch > 0 and args.exit_epoch == epoch: break ckpt_filename = os.path.join(model_dir, f'checkpoint_e{epoch-1:03d}.pkl') if epoch == args.start_epoch and os.path.isfile(ckpt_filename): ckpt = load_checkpoint(model_q, ckpt_filename, map_location='cpu') if args.start_epoch > 0: if isinstance(ckpt, dict) and 'optimizer' in ckpt: optimizer.load_state_dict(ckpt['optimizer']) print(f'load optimizer states from {ckpt_filename}') if isinstance( ckpt, dict ) and 'meta' in ckpt and 'lr_state' in ckpt['meta']: lr_state = ckpt['meta'].get('lr_state') lr_scheduler.load_state_dict(lr_state) print(f'load scheduler states from {ckpt_filename}') print(f'load model from {ckpt_filename}') if epoch == args.start_epoch: nb_samples = epoch * args.frames_per_epoch writer.add_scalar('train/lr', lr_scheduler.get_lr()[0], nb_samples) total_loss = 0 acc_samples = 0 total_processed = 0 steps = 0 target_samples = checkpoint_period while total_processed < args.frames_per_epoch: inputs, dis_inputs = train_loader.get() inputs = inputs.to(device) dis_inputs = dis_inputs.to(device) b, t = inputs.size()[:2] # Forward + Backward + Optimize optimizer.zero_grad() # zero the gradient buffer with torch.no_grad(): # Shuffle BN shf_ids, rev_ids = get_shuffle_ids(b, device) dis_inputs = dis_inputs[shf_ids] key = model_k(dis_inputs)[rev_ids].detach() query = model_q(inputs) loss = memory(query, key) loss.backward() if max_grad_value > 0: cur_max_value = max_grad_value clip_grad_value_(model_q.parameters(), clip_value=cur_max_value) if max_grad_norm > 0: cur_max_norm = max_grad_norm norm = clip_grad_norm_(model_q.parameters(), max_norm=cur_max_norm) if norm > cur_max_norm: print( "grad norm {0:.2f} exceeds {1:.2f}, clip to {1:.2f}.". format(norm, cur_max_norm)) optimizer.step() momentum_update(model_q, model_k, beta) memory.update(key) loss_val = loss.item() del loss, key, query, inputs, dis_inputs total_processed += b * t steps += 1 nb_samples += b * t writer.add_scalar('train/loss', loss_val, nb_samples) total_loss += loss_val * b acc_samples += b if steps % 10 == 0: print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f' % (epoch + 1, num_epochs, total_processed, args.frames_per_epoch, total_loss / acc_samples)) total_loss = 0 acc_samples = 0 if checkpoint_period > 0 and total_processed >= target_samples: target_samples += checkpoint_period ckpt_filename = os.path.join( model_dir, 'checkpoint_s{:06d}M.pkl'.format(nb_samples // 1000000)) meta = {} meta['lr_state'] = lr_scheduler.state_dict() save_checkpoint(model_q, ckpt_filename, optimizer=optimizer, meta=meta) lr_scheduler.step() ckpt_filename = os.path.join( model_dir, 'checkpoint_s{:06d}M.pkl'.format(nb_samples // 1000000)) meta = {} meta['lr_state'] = lr_scheduler.state_dict() save_checkpoint(model_q, ckpt_filename, optimizer=optimizer, meta=meta) ckpt_linkname = os.path.join(model_dir, 'checkpoint_e{:03d}.pkl'.format(epoch)) cmd = "ln -sf ./checkpoint_s{:06d}M.pkl {}" cmd = cmd.format(nb_samples // 1000000, ckpt_linkname) subprocess.call(cmd, shell=True)
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = T.Compose([T.Resize(128), T.ToTensor(), normalize]) val_transform = T.Compose([T.Resize(128), T.ToTensor(), normalize]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) num_factors = train_source_dataset.num_factors regressor = Regressor(backbone=backbone, num_factors=num_factors).to(device) # define optimizer and lr scheduler optimizer = SGD(regressor.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR( optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x))**(-args.lr_decay)) if args.phase == 'test': regressor.load_state_dict( torch.load(logger.get_checkpoint_path('best'))) mae = validate(val_loader, regressor, args, train_source_dataset.factors) print(mae) return # start training best_mae = 100000. for epoch in range(args.epochs): # train for one epoch print("lr", lr_scheduler.get_lr()) train(train_source_iter, train_target_iter, regressor, optimizer, lr_scheduler, epoch, args) # evaluate on validation set mae = validate(val_loader, regressor, args, train_source_dataset.factors) # remember best mae and save checkpoint torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest')) if mae < best_mae: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_mae = min(mae, best_mae) print("mean MAE {:6.3f} best MAE {:6.3f}".format(mae, best_mae)) print("best_mae = {:6.3f}".format(best_mae)) logger.close()
class Trainer(object): """ Trainer encapsulates all the logic necessary for training the Recurrent Attention Model. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args ---- - config: object containing command line arguments. - data_loader: data iterator """ self.config = config # glimpse network params self.patch_size = config.patch_size self.glimpse_scale = config.glimpse_scale self.num_patches = config.num_patches self.loc_hidden = config.loc_hidden self.glimpse_hidden = config.glimpse_hidden # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] self.num_train = len(self.train_loader.sampler.indices) self.num_valid = len(self.valid_loader.sampler.indices) else: self.test_loader = data_loader self.num_test = len(self.test_loader.dataset) self.num_classes = 10 self.num_channels = 1 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr # misc params self.no_tqdm = config.no_tqdm self.use_gpu = config.use_gpu self.best = config.best self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0. self.counter = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.use_tensorboard = config.use_tensorboard self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses, config.patch_size, config.patch_size, config.glimpse_scale) if config.uncertainty == True: self.model_name += '_uncertainty_1' else: self.model_name += '_uncertainty_0' if config.intrinsic == True: self.model_name += '_intrinsic_1' else: self.model_name += '_intrinsic_0' self.plot_dir = './plots/' + self.model_name + '/' if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir) # configure tensorboard logging if self.use_tensorboard: tensorboard_dir = self.logs_dir + self.model_name print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) # build RAM model self.model = RecurrentAttention(self.patch_size, self.num_patches, self.glimpse_scale, self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std, self.hidden_size, self.num_classes, self.config) if self.use_gpu: self.model.cuda() self.dtypeFloat = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor) self.dtypeLong = (torch.cuda.LongTensor if self.use_gpu else torch.LongTensor) print('[*] Number of model parameters: {:,}'.format( sum([p.data.nelement() for p in self.model.parameters()]))) # # initialize optimizer and scheduler self.optimizer = optim.Adam( self.model.parameters(), lr=self.config.init_lr, ) lambda_of_lr = lambda epoch: 0.95**epoch self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambda_of_lr) # self.scheduler = StepLR(self.optimizer,step_size=20,gamma=0.1) # self.scheduler = ReduceLROnPlateau( # self.optimizer, 'min', patience=self.lr_patience # ) def reset(self): """ Initialize the hidden state of the core network and the location vector. This is called once every time a new minibatch `x` is introduced. """ dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor) h_t = torch.zeros(self.batch_size, self.hidden_size) h_t = Variable(h_t).type(dtype) l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1) l_t = Variable(l_t).type(dtype) return h_t, l_t def train(self): """ Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) print( "\n[*] Train on {} samples, validate on {} samples, learn rate {}". format(self.num_train, self.num_valid, self.scheduler.get_lr())) for epoch in range(self.start_epoch, self.epochs): print('\nEpoch: {}/{} . lr: {:.4e} '.format( epoch + 1, self.epochs, self.scheduler.get_lr()[0])) # train for 1 epoch train_loss, train_acc = self.train_one_epoch(epoch) # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) self.scheduler.step() is_best = valid_acc > self.best_valid_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "- val loss: {:.3f} - val acc: {:.3f}" if is_best: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 print(msg.format(train_loss, train_acc, valid_loss, valid_acc)) # check for improvement if not is_best: self.counter += 1 if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.save_checkpoint( { 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': self.best_valid_acc, }, is_best) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() tic = time.time() with tqdm(total=self.num_train, disable=self.no_tqdm) as pbar: for i, (x, y) in enumerate(self.train_loader): if self.config.use_translate: x = translate_function(x, original_dataset=x) if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) plot = False if (epoch % self.plot_freq == 0) and (i == 0): plot = True # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # save images imgs = [] imgs.append(x[0:9]) # extract the glimpses locs = [] log_pi = [] baselines = [] all_log_probas = [] # the prediction at each glimpse step uncertainities = [ ] # the self-uncertainty at each glimpse step uncertainities_baseline = [ ] # the self-uncertainty at each glimpse step, but this baseline is only used for the loss of training self-uncertainty, which only involves the error network. # by default it needs to run `self.num_glimpse` times num_glimpses_taken = [ self.num_glimpses - 1 for _ in range(self.batch_size) ] for t in range(self.num_glimpses): # forward pass through model h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model( x, l_t, h_t, last=True) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) all_log_probas.append(log_probas) uncertainities.append(diff_uncertainty) uncertainities_baseline.append(diff_uncertainty_baseline) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # if self.config.uncertainty == True: if self.config.uncertainty == True: uncertainities = torch.stack(uncertainities).transpose( 1, 0) uncertainities_baseline = torch.stack( uncertainities_baseline).transpose(1, 0) all_log_probas = torch.stack(all_log_probas).transpose(1, 0) # calculate reward num_glimpses_taken_indices = torch.LongTensor( num_glimpses_taken).type(self.dtypeLong) log_probas = torch.cat([ torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(all_log_probas, num_glimpses_taken_indices) ]).squeeze() predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules num_glimpses_taken = Variable( torch.LongTensor(num_glimpses_taken), requires_grad=False).type(self.dtypeLong) # the mask is used to take only the result of the last glimpse mask = _sequence_mask(sequence_length=num_glimpses_taken, max_len=self.num_glimpses) loss_action = F.nll_loss(log_probas, y, reduction='none') loss_action = torch.mean(loss_action) loss_baseline = F.mse_loss(baselines, R, reduction='none') loss_baseline = torch.mean(loss_baseline * mask) # loss_baseline = torch.mean( loss_baseline ) # compute reinforce loss # summed over timesteps and averaged across batch adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward * mask, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce if self.config.uncertainty == True: y_real_value = F.one_hot( y, self.num_classes).float().detach() diff_ = Variable(torch.abs( y_real_value.unsqueeze(1).expand( -1, self.num_glimpses, -1).data - torch.exp(all_log_probas).data), requires_grad=False) # loss_self_uncertaintiy_baseline = F.mse_loss(uncertainities_baseline, diff_) loss_self_uncertaintiy_baseline = F.mse_loss( uncertainities_baseline, diff_, reduction='none').mean() loss_self_uncertaintiy_baseline = torch.mean( loss_self_uncertaintiy_baseline) loss += loss_self_uncertaintiy_baseline if self.config.intrinsic == True: # the intrinsic sparsity belief reg = self.config.lambda_intrinsic intrinsic_term = torch.sum(-(1.0 / self.num_classes) * log_probas) loss_intrinsic = reg * intrinsic_term loss += loss_intrinsic if self.config.uncertainty == True: # the second reinforce loss: minimizing the uncertainty reg = self.config.lambda_uncertainty loss_self_uncertaintiy_minimizing = reg * torch.sum( uncertainities) loss += loss_self_uncertaintiy_minimizing # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.data, list(x.size())[0]) accs.update(acc.data, list(x.size())[0]) # compute gradients and update SGD self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) if self.no_tqdm is not True: pbar.set_description( ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.data, acc.data))) pbar.update(self.batch_size) # dump the glimpses and locs if plot: if self.use_gpu: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] else: imgs = [g.data.numpy().squeeze() for g in imgs] locs = [l.data.numpy() for l in locs] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i log_value('train_loss', losses.avg, iteration) log_value('train_acc', accs.avg, iteration) return losses.avg, accs.avg def validate(self, epoch, M=1): """ Evaluate the model on the validation set. """ losses = AverageMeter() accs = AverageMeter() for i, (x, y) in enumerate(self.valid_loader): if self.config.use_translate: x = translate_function(x, original_dataset=x) if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) # duplicate M times x = x.repeat(M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses locs = [] log_pi = [] baselines = [] all_log_probas = [] uncertainities = [] uncertainities_baseline = [] # by default it needs to run `self.num_glimpse` times num_glimpses_taken = [ self.num_glimpses - 1 for _ in range(self.batch_size) ] for t in range(self.num_glimpses): # forward pass through model h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model( x, l_t, h_t, last=True) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) all_log_probas.append(log_probas) uncertainities.append(diff_uncertainty) uncertainities_baseline.append(diff_uncertainty_baseline) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) if self.config.uncertainty == True: uncertainities = torch.stack(uncertainities).transpose(1, 0) uncertainities_baseline = torch.stack( uncertainities_baseline).transpose(1, 0) all_log_probas = torch.stack(all_log_probas).transpose(1, 0) # calculate reward num_glimpses_taken_indices = torch.LongTensor( num_glimpses_taken).type(self.dtypeLong) log_probas = torch.cat([ torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(all_log_probas, num_glimpses_taken_indices) ]).squeeze() # average the `self.M` times of prediction log_probas = log_probas.view(M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(M, self.num_glimpses) # compute losses for differentiable modules num_glimpses_taken = Variable(torch.LongTensor(num_glimpses_taken), requires_grad=False).type( self.dtypeLong) mask = _sequence_mask(sequence_length=num_glimpses_taken, max_len=self.num_glimpses) loss_action = F.nll_loss(log_probas, y, reduction='none') loss_action = torch.mean(loss_action) loss_baseline = F.mse_loss(baselines, R, reduction='none') loss_baseline = torch.mean(loss_baseline * mask) adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward * mask, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce if self.config.uncertainty == True: y_real_value = F.one_hot(y, self.num_classes).float().detach() diff_ = Variable(torch.abs( y_real_value.unsqueeze(1).expand(-1, self.num_glimpses, -1).data - torch.exp(all_log_probas).data), requires_grad=False) loss_self_uncertaintiy_baseline = F.mse_loss( uncertainities_baseline, diff_, reduction='none').mean() loss_self_uncertaintiy_baseline = torch.mean( loss_self_uncertaintiy_baseline) loss += loss_self_uncertaintiy_baseline if self.config.intrinsic == True: # the intrinsic sparsity belief reg = self.config.lambda_intrinsic loss_intrinsic = reg * torch.sum( -(1.0 / self.num_classes) * log_probas) loss += loss_intrinsic if self.config.uncertainty == True: # the second reinforce loss: minimizing the uncertainty reg = self.config.lambda_uncertainty loss_self_uncertaintiy_minimizing = reg * torch.sum( uncertainities) loss += loss_self_uncertaintiy_minimizing # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.data, list(x.size())[0]) accs.update(acc.data, list(x.size())[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i log_value('valid_loss', losses.avg, iteration) log_value('valid_acc', accs.avg, iteration) return losses.avg, accs.avg def test(self): """ Test the model on the held-out test data. This function should only be called at the very end once the model has finished training. """ correct = 0 # load the best checkpoint self.load_checkpoint(best=self.best) self.num_test = len(self.test_loader.sampler) all_num_glimpses_taken = [] for i, (x, y) in enumerate(self.test_loader): torch.manual_seed(self.config.random_seed) if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses locs = [] log_pi = [] baselines = [] all_log_probas = [] uncertainities = [] # by default it needs to run `self.num_glimpse` times num_glimpses_taken = [ self.config.num_glimpses - 1 for _ in range(self.batch_size) ] for t in range(self.config.num_glimpses): # forward pass through model h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model( x, l_t, h_t, last=True) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) all_log_probas.append(log_probas) uncertainities.append(diff_uncertainty) if self.config.dynamic == True: # determine if it has achieve a threshold uncertainty probs_data = torch.exp(log_probas).data.tolist() diff_uncertainty_data = diff_uncertainty.data.tolist() for instance_idx, (prediction, uncertainty) in enumerate( zip(probs_data, diff_uncertainty_data)): a_star_idx = max(enumerate(prediction), key=lambda x: x[1])[0] a_prime_idx = max( [(idx, pred + self.config.exploration_rate * uncertainty[idx]) for idx, pred in enumerate(prediction) if idx != a_star_idx], key=lambda x: x[1])[0] a_star_lower_bound = prediction[ a_star_idx] - self.config.exploration_rate * uncertainty[ a_star_idx] a_prime_upper_bound = prediction[ a_prime_idx] - self.config.exploration_rate * uncertainty[ a_prime_idx] if a_star_lower_bound >= a_prime_upper_bound: num_glimpses_taken[instance_idx] = t if all([ num < self.config.num_glimpses - 1 for num in num_glimpses_taken ]): # print(num_glimpses_taken) break # print('strange! end now!:',t) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) if self.config.uncertainty == True or self.config.dynamic == True: uncertainities = torch.stack(uncertainities).transpose(1, 0) all_log_probas = torch.stack(all_log_probas).transpose(1, 0) all_num_glimpses_taken.extend(num_glimpses_taken) # calculate reward num_glimpses_taken_indices = torch.LongTensor( num_glimpses_taken).type(self.dtypeLong) log_probas = torch.cat([ torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(all_log_probas, num_glimpses_taken_indices) ]).squeeze() # average the `self.M` times of prediction log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).cpu().sum() perc = (100. * correct) / (self.num_test) error = 100 - perc print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format( correct, self.num_test, perc, error)) if self.config.dynamic == True: print('use dynamic') avg_num_glimpses_taken = sum(all_num_glimpses_taken) / len( all_num_glimpses_taken) + 1 return (avg_num_glimpses_taken, 1.0 * correct.tolist() / self.num_test) return 1.0 * correct.tolist() / self.num_test # return perc.tolist() def test_for_all( self, range_all=100, ): """ Test the model on the held-out test data. This is used to run the model under different number of glimpses """ correct = [] for _ in range(range_all): correct.append(0) # load the best checkpoint self.load_checkpoint(best=self.best) self.num_test = len(self.test_loader.sampler) all_num_glimpses_taken = [] for i, (x, y) in enumerate(tqdm(self.test_loader)): torch.manual_seed(self.config.random_seed) if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses locs = [] log_pi = [] baselines = [] all_log_probas = [] uncertainities = [] # by default it needs to run `self.num_glimpse` times num_glimpses_taken = [ range_all - 1 for _ in range(self.batch_size) ] for t in range(self.config.num_glimpses): # forward pass through model h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model( x, l_t, h_t, last=True) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) all_log_probas.append(log_probas) uncertainities.append(diff_uncertainty) if self.config.dynamic == True: # determine if it has achieve a threshold uncertainty probs_data = torch.exp(log_probas).data.tolist() diff_uncertainty_data = diff_uncertainty.data.tolist() for instance_idx, (prediction, uncertainty) in enumerate( zip(probs_data, diff_uncertainty_data)): a_star_idx = max(enumerate(prediction), key=lambda x: x[1])[0] a_prime_idx = max( [(idx, pred + self.config.exploration_rate * uncertainty[idx]) for idx, pred in enumerate(prediction) if idx != a_star_idx], key=lambda x: x[1])[0] a_star_lower_bound = prediction[ a_star_idx] - self.config.exploration_rate * uncertainty[ a_star_idx] a_prime_upper_bound = prediction[ a_prime_idx] - self.config.exploration_rate * uncertainty[ a_prime_idx] if a_star_lower_bound >= a_prime_upper_bound: num_glimpses_taken[instance_idx] = t if all([ num < self.config.num_glimpses - 1 for num in num_glimpses_taken ]): # print(num_glimpses_taken) break # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) if self.config.uncertainty == True or self.config.dynamic == True: uncertainities = torch.stack(uncertainities).transpose(1, 0) all_log_probas = torch.stack(all_log_probas).transpose(1, 0) all_num_glimpses_taken.extend(num_glimpses_taken) # calculate reward for num in range(range_all): num_glimpses_taken = [num for _ in range(self.batch_size)] num_glimpses_taken_indices = torch.LongTensor( num_glimpses_taken).type(self.dtypeLong) # log_probas = torch.cat([ torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(all_log_probas, num_glimpses_taken_indices) ]).squeeze() log_probas = all_log_probas[:, num] # print(all_log_probas.size(),log_probas.size()) # average the `self.M` times of prediction log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct[num] += pred.eq(y.data.view_as(pred)).cpu().sum() return [1.0 * cor.tolist() / self.num_test for cor in correct] # return 1.0 * correct.tolist() / self.num_test def save_checkpoint(self, state, is_best): """ Save a copy of the model so that it can be loaded at a future date. This function is used when the model is being evaluated on the test data. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ # print("[*] Saving model to {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + '_model_best.pth.tar' shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): """ Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Params ------ - best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' if best: filename = self.model_name + '_model_best.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt['epoch'] self.best_valid_acc = ckpt['best_valid_acc'] self.model.load_state_dict(ckpt['model_state']) self.optimizer.load_state_dict(ckpt['optim_state']) if best: print("[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt['epoch'], ckpt['best_valid_acc'])) else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt['epoch']))
def train(mode='train', train_path='train.conllx', model='dozat', dataset='conllx', dev_path='dev.conllx', test_path='test.conllx', ud=True, output_dir='output', emb_dim=0, char_emb_dim=0, char_model=None, tagger=None, batch_size=5000, n_iters=10, dropout_p=0.33, num_layers=1, print_every=1, eval_every=100, bi=True, var_drop=False, upos_pred=False, lr=0.001, adam_beta1=0.9, adam_beta2=0.999, weight_decay=0., plateau=False, resume=False, lr_decay=1.0, lr_decay_steps=5000, clip=5., momentum=0, optimizer='adam', glove=True, seed=42, dim=0, window_size=0, num_filters=0, **kwargs): device = torch.device(type='cuda') if use_cuda else torch.device( type='cpu') if not os.path.exists(output_dir): os.makedirs(output_dir) cfg = locals().copy() torch.manual_seed(seed) np.random.seed(seed) # load data component if dataset == "conllx": dataset_obj = ConllXDataset fields = get_data_fields() _upos = None ud = False elif dataset == "conllu": dataset_obj = ConllUDataset fields = get_data_fields_conllu() _upos = fields['upos'][-1] ud = True else: raise NotImplementedError() _form = fields['form'][-1] _pos = fields['pos'][-1] _chars = fields['chars'][-1] train_dataset = dataset_obj(train_path, fields) dev_dataset = dataset_obj(dev_path, fields) test_dataset = dataset_obj(test_path, fields) logger.info("Loaded %d train examples" % len(train_dataset)) logger.info("Number of train tokens: %d" % train_dataset.n_tokens) logger.info("Loaded %d dev examples" % len(dev_dataset)) logger.info("Number of train tokens: %d" % dev_dataset.n_tokens) logger.info("Loaded %d test examples" % len(test_dataset)) logger.info("Number of train tokens: %d" % test_dataset.n_tokens) form_vocab_path = os.path.join(output_dir, 'vocab.form.pth.tar') pos_vocab_path = os.path.join(output_dir, 'vocab.pos.pth.tar') char_vocab_path = os.path.join(output_dir, 'vocab.char.pth.tar') if not resume: # build vocabularies # words have a min frequency of 2 to be included; others become <unk> # words without a Glove vector are initialized ~ N(0, 0.5) mimicking Glove # Note: this requires the latest torchtext development version from Github. # - git clone https://github.com/pytorch/text.git torchtext # - cd torchtext # - python setup.py build # - python setup.py install def unk_init(x): # return 0.01 * torch.randn(x) return torch.zeros(x) if glove: logger.info("Using Glove vectors") glove_vectors = GloVe(name='6B', dim=100) _form.build_vocab(train_dataset, min_freq=2, unk_init=unk_init, vectors=glove_vectors) n_unks = 0 unk_set = set() # for now, set UNK words manually # (torchtext does not seem to support it yet) for i, token in enumerate(_form.vocab.itos): if token not in glove_vectors.stoi: n_unks += 1 unk_set.add(token) _form.vocab.vectors[i] = unk_init(emb_dim) # print(n_unks, unk_set) else: _form.build_vocab(train_dataset, min_freq=2) _pos.build_vocab(train_dataset) if ud: _upos.build_vocab(train_dataset) _chars.build_vocab(train_dataset) # save vocabularies torch.save(_form.vocab, form_vocab_path) torch.save(_pos.vocab, pos_vocab_path) torch.save(_chars.vocab, char_vocab_path) else: # load vocabularies _form.vocab = torch.load(form_vocab_path) _pos.vocab = torch.load(pos_vocab_path) _chars.vocab = torch.load(char_vocab_path) print("First 10 vocabulary entries, words: ", " ".join(_form.vocab.itos[:10])) print("First 10 vocabulary entries, pos tags: ", " ".join(_pos.vocab.itos[:10])) print("First 10 vocabulary entries, chars: ", " ".join(_chars.vocab.itos[:10])) if upos_pred: print("First 10 vocabulary entries, upos tags: ", " ".join(_upos.vocab.itos[:10])) n_words = len(_form.vocab) n_tags = len(_pos.vocab) if upos_pred: n_utags = len(_upos.vocab) else: n_utags = 0 n_chars = len(_chars.vocab) def batch_size_fn(new, count, sofar): return len(new.form) + 1 + sofar # iterators train_iter = Iterator(train_dataset, batch_size, train=True, sort_within_batch=True, batch_size_fn=batch_size_fn, device=device) dev_iter = Iterator(dev_dataset, 32, train=False, sort_within_batch=True, device=device) test_iter = Iterator(test_dataset, 32, train=False, sort_within_batch=True, device=device) # if n_iters or eval_every are negative, we set them to that many # number of epochs iters_per_epoch = (len(train_dataset) // batch_size) + 1 if eval_every < 0: logger.info("Setting eval_every to %d epoch(s) = %d iters" % (-1 * eval_every, -1 * eval_every * iters_per_epoch)) eval_every = iters_per_epoch * eval_every if n_iters < 0: logger.info("Setting n_iters to %d epoch(s) = %d iters" % (-1 * n_iters, -1 * n_iters * iters_per_epoch)) n_iters = -1 * n_iters * iters_per_epoch # load up the model if upos_pred: upos_vocab = _upos.vocab else: upos_vocab = None model = Tagger(n_words=n_words, n_tags=n_tags, n_utags=n_utags, n_chars=n_chars, form_vocab=_form.vocab, char_vocab=_chars.vocab, pos_vocab=_pos.vocab, upos_vocab=upos_vocab, **cfg) # set word vectors if glove: _form.vocab.vectors = _form.vocab.vectors / torch.std( _form.vocab.vectors) # print(torch.std(_form.vocab.vectors)) model.encoder.embedding.weight.data.copy_(_form.vocab.vectors) model.encoder.embedding.weight.requires_grad = True model = model.cuda() if use_cuda else model start_iter = 1 best_iter = 0 best_pos_acc = -1. test_pos_acc = -1. # optimizer and learning rate scheduler trainable_parameters = [p for p in model.parameters() if p.requires_grad] if optimizer == 'sgd': optimizer = torch.optim.SGD(trainable_parameters, lr=lr, momentum=momentum) else: optimizer = torch.optim.Adam(trainable_parameters, lr=lr, betas=(adam_beta1, adam_beta2)) # learning rate schedulers if not plateau: scheduler = LambdaLR( optimizer, lr_lambda=lambda t: lr_decay**(t / lr_decay_steps)) else: scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.75, patience=5, min_lr=1e-4) # load model and vocabularies if resuming if resume: if os.path.isfile(resume): print("=> loading checkpoint '{}'".format(resume)) checkpoint = torch.load(resume) start_iter = checkpoint['iter_i'] best_pos_acc = checkpoint['best_pos_acc'] test_pos_acc = checkpoint['test_pos_acc'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (iter {})".format( resume, checkpoint['iter_i'])) else: print("=> no checkpoint found at '{}'".format(resume)) print_parameters(model) # print some stuff just for fun logger.info("Most common words: %s" % _form.vocab.freqs.most_common(20)) logger.info("Word vocab size: %s" % n_words) logger.info("Most common XPOS-tags: %s" % _pos.vocab.freqs.most_common()) logger.info("POS vocab size: %s" % n_tags) # logger.info("Most common chars: %s" % _chars.nesting_field.vocab.freqs.most_common()) logger.info("Chars vocab size: %s" % n_chars) print("First training example:") print_example(train_dataset[0]) print("First dev example:") print_example(dev_dataset[0]) print("First test example:") print_example(test_dataset[0]) logger.info("Training starts..") upos_var, morph_var = None, None for iter_i in range(start_iter, n_iters + 1): if not ud: epoch_done = (train_dataset.n_tokens // batch_size) else: epoch_done = (train_dataset.n_tokens // batch_size) # if not plateau and iter_i % epoch_done == 0: # TODO: fix # scheduler.step() scheduler.step() model.train() batch = next(iter(train_iter)) form_var, lengths = batch.form pos_var, pos_lengths = batch.pos if upos_pred: upos_var, _ = batch.upos else: upos_var = None char_var, sentence_lengths, word_lengths = batch.chars lengths = lengths.view(-1).tolist() result = model(form_var=form_var, char_var=char_var, pos_var=pos_var, lengths=lengths, word_lengths=word_lengths, pos_lengths=pos_lengths) if upos_pred: targets = dict(pos=batch.pos, upos=batch.upos) else: targets = dict(pos=batch.pos, upos=None) all_losses = model.get_loss(scores=result, targets=targets) loss = all_losses['loss'] loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), clip) optimizer.step() optimizer.zero_grad() if iter_i % print_every == 0: # get scores for this batch upos_predictions = [] if model.tagger == "linear" or model.tagger == "mlp": if model.upos_pred: upos_predictions = result['output']['upos'].max(2)[1] pos_predictions = result['output']['xpos'].max(2)[1] else: pos_predictions = result['output']['xpos'].max(2)[1] else: pos_predictions = result['sequence'] predictions = dict(pos=pos_predictions, upos=upos_predictions) if model.upos_pred: targets = dict(pos=batch.pos, upos=batch.upos) else: targets = dict(pos=batch.pos, upos=None) pos_acc, upos_acc = model.get_accuracy(predictions=predictions, targets=targets) if not plateau: lr = scheduler.get_lr()[0] else: lr = [group['lr'] for group in optimizer.param_groups][0] fmt = "Iter %08d loss %8.4f pos-acc %5.2f upos-acc %5.2f lr %.5f" logger.info(fmt % (iter_i, loss, pos_acc, upos_acc, lr)) if iter_i % eval_every == 0: # parse dev set and save to file for official evaluation dev_out_path = 'dev.iter%08d.conll' % iter_i dev_out_path = os.path.join(output_dir, dev_out_path) predict_and_save(dataset=dev_dataset, model=model, dataset_path=dev_path, out_path=dev_out_path) _dev_pos_acc, _dev_upos_acc = get_pos_acc(dev_path, dev_out_path, ud) logger.info("Evaluation dev Iter %08d " "pos-acc %5.2f upos-acc %5.2f" % (iter_i, _dev_pos_acc, _dev_upos_acc)) # parse test set and save to file for official evaluation test_out_path = 'test.iter%08d.conll' % iter_i test_out_path = os.path.join(output_dir, test_out_path) predict_and_save(dataset=test_dataset, model=model, dataset_path=test_path, out_path=test_out_path) _test_pos_acc, _test_upos_acc = get_pos_acc( test_path, test_out_path, ud) logger.info("Evaluation test Iter %08d " "pos-acc %5.2f upos-acc %5.2f" % (iter_i, _test_pos_acc, _test_upos_acc)) if plateau: scheduler.step(_dev_pos_acc) if _dev_pos_acc > best_pos_acc: best_iter = iter_i best_pos_acc = _dev_pos_acc test_pos_acc = _test_pos_acc is_best = True else: is_best = False save_checkpoint( output_dir, { 'iter_i': iter_i, 'state_dict': model.state_dict(), 'best_iter': best_iter, 'test_pos_acc': test_pos_acc, 'optimizer': optimizer.state_dict(), }, False) logger.info("Done Training") logger.info( "Best model Iter %08d Dev POS-acc %12.4f Test POS-acc %12.4f " % (best_iter, best_pos_acc, test_pos_acc))
class FNN: def __init__(self): ## Device configuration self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def set_data(self, features, targets, D, denom_sq): self.features_np = features self.targets_np = targets self.D_np = D self.inv_denom_sq = denom_sq**-1 def train(self, config): ## Internal config self.config = {} self.config['num_epochs'] = 5000 self.config['n_hidden'] = 2 self.config['hidden_size'] = 40 self.config['batch_size'] = 10 self.config['lr'] = 1e-2 self.config['regularization'] = 1e-10 # Overwrite internal config values given in the external config if config: for key in config.keys(): self.config[key] = config[key] # Assume we're using ray.tune at first self.tuning = True ## Model self.config['input_size'] = self.features_np['train'].shape[1] self.config['output_size'] = self.targets_np['train'].shape[1] self.model = Model(self.config).to(self.device) ## Data loaders self.batch_size = self.config['batch_size'] self.train_loader = data_loader.create_loader( self.features_np['train'], self.targets_np['train'], self.batch_size, True) self.validate_loader = data_loader.create_loader( self.features_np['validate'], self.targets_np['validate'], self.features_np['validate'].shape[0], # use all test samples False) # don't shuffle ## Hyperparameters self.num_epochs = self.config['num_epochs'] self.learning_rate = self.config['lr'] ## Loss and optimizer self.criterion = self.eps_reg_sq self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, eps=1e-8, weight_decay=self.config['regularization']) lambdaLR = lambda epoch: 1 / (1 + 0.005*epoch) self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambdaLR) self.train_start() def train_start(self): ## Train early_stop = False self.D = torch.from_numpy(self.D_np).float().to(self.device) for epoch in range(self.num_epochs): for i, (features, targets) in enumerate(self.train_loader): self.model.train() self.optimizer.zero_grad() # Move tensors to the configured device features = features.to(self.device) targets = targets.to(self.device) # Forward pass outputs = self.model(features) loss = self.criterion(outputs, targets) ** 0.5 if torch.isnan(loss): print('Something went nan, stopping') early_stop = True break # break out of this batch # Backward and optimize loss.backward() self.optimizer.step() if early_stop: break # break out of this epoch self.scheduler.step() if epoch%10==0 or epoch==self.num_epochs-1: validate_loss = self.get_loss(self.validate_loader) train_loss = self.get_loss(self.train_loader) print('eps_reg: Epoch [{}/{}], LR: {:.2e}, Train loss: {:.2e}, Validate loss: {:.2e}' .format(epoch+1, self.num_epochs, self.scheduler.get_lr()[0], train_loss.item()**0.5, validate_loss.item()**0.5)) if self.tuning: try: tune.track.log(mean_loss = validate_loss.item(), episodes_this_iter = 10) except: self.tuning = False return self def eps_reg_sq(self, outputs, targets): return torch.sum((self.D*(targets - outputs)) ** 2) * self.inv_denom_sq / targets.shape[0] def get_loss(self, loader): with torch.no_grad(): self.model.eval() loss = 0.0 for features, targets in loader: features = features.to(self.device) targets = targets.to(self.device) outputs = self.model(features) loss += self.criterion(outputs, targets) return loss/len(loader) def evaluate(self, features): with torch.no_grad(): self.model.eval() output = self.model(torch.tensor(features).float()) u_rb = output.numpy() return u_rb def save(self, model_dir, component): try: path_config = os.path.join(tune.track.trial_dir(),'config') path_state_dict = os.path.join(tune.track.trial_dir(),'state_dict') except: # not tuning path_config = os.path.join(model_dir, 'FNN', component,'config') path_state_dict = os.path.join(model_dir, 'FNN', component,'state_dict') with open(path_config, 'wb+') as f: pickle.dump(self.config, f) torch.save(self.model.state_dict(), path_state_dict) def load(self, model_dir, component): ''' Find and loads the best model from ray.tune analysis results. ''' try: path_analysis = os.path.join(model_dir,'FNN',component) analysis = tune.Analysis(path_analysis) df_temp = analysis.dataframe() idx = df_temp['mean_loss'].idxmin() logdir = df_temp.loc[idx]['logdir'] path_config = os.path.join(logdir,'config') path_state_dict = os.path.join(logdir,'state_dict') except: # no tuning records path_config = os.path.join(model_dir, 'FNN', component,'config') path_state_dict = os.path.join(model_dir, 'FNN', component,'state_dict') with open(path_config, 'rb') as f: config = pickle.load(f) self.model = Model(config).to(self.device) state_dict = torch.load(path_state_dict, map_location=torch.device('cpu')) self.model.load_state_dict(state_dict)
def train(train_data_path: str, val_data_paths: dict, use_cuda: bool, model_name: str, is_baseline: bool, resume_from_file=None): logger.info("Loading Training set...") logger.info(model_name) train_iter, train_input_vocab, train_target_vocab = dataloader( train_data_path, batch_size=cfg.TRAIN.BATCH_SIZE, use_cuda=use_cuda) val_iters = {} for split_name, path in val_data_paths.items(): val_iters[split_name], _, _ = dataloader( path, batch_size=cfg.VAL_BATCH_SIZE, use_cuda=use_cuda, input_vocab=train_input_vocab, target_vocab=train_target_vocab) pad_idx, sos_idx, eos_idx = train_target_vocab.stoi['<pad>'], train_target_vocab.stoi['<sos>'], \ train_target_vocab.stoi['<eos>'] train_input_vocab_size, train_target_vocab_size = len( train_input_vocab.itos), len(train_target_vocab.itos) ''' Input (command) [0]: batch_size x max_cmd_len [1]: batch_size x 0 (len for each cmd) Situation: batch_size x grid x grid x feat_size Target (action) [0]: batch_size x max_action_len [1]: batch_size x 0 (len for each action sequence) max_cmd_len = 6, max_action_len = 16 ''' logger.info("Done Loading Training set.") # if generate_vocabularies: # training_set.save_vocabularies(input_vocab_path, target_vocab_path) # logger.info("Saved vocabularies to {} for input and {} for target.".format(input_vocab_path, target_vocab_path)) logger.info("Loading Dev. set...") # val_input_vocab_size, val_target_vocab_size = train_input_vocab_size, train_target_vocab_size # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set. # val_set.shuffle_data() logger.info("Done Loading Dev. set.") model = GSCAN_model(pad_idx, eos_idx, train_input_vocab_size, train_target_vocab_size, is_baseline=is_baseline, output_directory=os.path.join(os.getcwd(), cfg.OUTPUT_DIRECTORY, model_name)) model = model.cuda() if use_cuda else model log_parameters(model) trainable_parameters = [ parameter for parameter in model.parameters() if parameter.requires_grad ] optimizer = torch.optim.Adam(trainable_parameters, lr=cfg.TRAIN.SOLVER.LR, betas=(cfg.TRAIN.SOLVER.ADAM_BETA1, cfg.TRAIN.SOLVER.ADAM_BETA2)) scheduler = LambdaLR(optimizer, lr_lambda=lambda t: cfg.TRAIN.SOLVER.LR_DECAY** (t / cfg.TRAIN.SOLVER.LR_DECAY_STEP)) start_iteration = 1 best_exact_match = 0 if resume_from_file: assert os.path.isfile( resume_from_file), "No checkpoint found at {}".format( resume_from_file) logger.info( "Loading checkpoint from file at '{}'".format(resume_from_file)) optimizer_state_dict = model.load_model(resume_from_file) optimizer.load_state_dict(optimizer_state_dict) start_iteration = model.trained_iterations logger.info("Loaded checkpoint '{}' (iter {})".format( resume_from_file, start_iteration)) logger.info("Training starts..") training_iteration = start_iteration while training_iteration < cfg.TRAIN.MAX_EPOCH: # iterations here actually means "epoch" # Shuffle the dataset and loop over it. # training_set.shuffle_data() num_batch = 0 for x in train_iter: is_best = False model.train() target_scores, target_position_scores = model( x.input, x.situation, x.target) loss = model.get_loss(target_scores, x.target[0]) target_loss = 0 if cfg.AUXILIARY_TASK: target_loss = model.get_auxiliary_loss(target_position_scores, x.target) loss += cfg.TRAIN.WEIGHT_TARGET_LOSS * target_loss # Backward pass and update model parameters. loss.backward() optimizer.step() scheduler.step() optimizer.zero_grad() model.update_state(is_best=is_best) # Print current metrics. if num_batch % cfg.PRINT_EVERY == 0: accuracy, exact_match = model.get_metrics( target_scores, x.target[0]) if cfg.AUXILIARY_TASK: auxiliary_accuracy_target = model.get_auxiliary_accuracy( target_position_scores, x.target) else: auxiliary_accuracy_target = 0. learning_rate = scheduler.get_lr()[0] logger.info( "Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f," " aux. accuracy target pos %5.2f" % (training_iteration, loss, accuracy, exact_match, learning_rate, auxiliary_accuracy_target)) num_batch += 1 if training_iteration % cfg.EVALUATE_EVERY == 0: with torch.no_grad(): model.eval() logger.info("Evaluating..") test_exact_match = 0 test_accuracy = 0 try: for split_name, val_iter in val_iters.items(): accuracy, exact_match, target_accuracy = evaluate( val_iter, model=model, max_decoding_steps=30, pad_idx=pad_idx, sos_idx=sos_idx, eos_idx=eos_idx, max_examples_to_evaluate=None) if split_name == 'dev': test_exact_match = exact_match test_accuracy = accuracy logger.info(" %s Accuracy: %5.2f Exact Match: %5.2f " " Target Accuracy: %5.2f " % (split_name, accuracy, exact_match, target_accuracy)) except: print("Exception!") if test_exact_match > best_exact_match: is_best = True best_accuracy = test_accuracy best_exact_match = test_exact_match model.update_state(accuracy=test_accuracy, exact_match=test_exact_match, is_best=is_best) file_name = model_name + "checkpoint.{}th.tar".format( str(training_iteration)) # file_name = os.path.join(os.getcwd(), cfg.OUTPUT_DIRECTORY, model_name, file_name) if is_best: logger.info("saving best model...") model.save_checkpoint( file_name=file_name, is_best=is_best, optimizer_state_dict=optimizer.state_dict()) if training_iteration % cfg.SAVE_EVERY == 0: logger.info("forcing to save model every several epochs...") file_name = model_name + " checkpoint_force.{}th.tar".format( str(training_iteration)) # file_name = os.path.join(os.getcwd(), cfg.OUTPUT_DIRECTORY, model_name, file_name) model.save_checkpoint(file_name=file_name, is_best=False, optimizer_state_dict=optimizer.state_dict()) training_iteration += 1 # warning: iteratin represents epochs here logger.info("Finished training.")
def train( data_path: str, data_directory: str, generate_vocabularies: bool, input_vocab_path: str, target_vocab_path: str, embedding_dimension: int, num_encoder_layers: int, encoder_dropout_p: float, encoder_bidirectional: bool, training_batch_size: int, test_batch_size: int, max_decoding_steps: int, num_decoder_layers: int, decoder_dropout_p: float, cnn_kernel_size: int, cnn_dropout_p: float, cnn_hidden_num_channels: int, simple_situation_representation: bool, decoder_hidden_size: int, encoder_hidden_size: int, learning_rate: float, adam_beta_1: float, adam_beta_2: float, lr_decay: float, lr_decay_steps: int, resume_from_file: str, max_training_iterations: int, output_directory: str, print_every: int, evaluate_every: int, conditional_attention: bool, auxiliary_task: bool, weight_target_loss: float, attention_type: str, k: int, max_training_examples, max_testing_examples, # SeqGAN params begin pretrain_gen_path, pretrain_gen_epochs, pretrain_disc_path, pretrain_disc_epochs, rollout_trails, rollout_update_rate, disc_emb_dim, disc_hid_dim, load_tensors_from_path, # SeqGAN params end seed=42, **kwargs): device = torch.device("cpu") cfg = locals().copy() torch.manual_seed(seed) logger.info("Loading Training set...") training_set = GroundedScanDataset( data_path, data_directory, split="train", input_vocabulary_file=input_vocab_path, target_vocabulary_file=target_vocab_path, generate_vocabulary=generate_vocabularies, k=k) training_set.read_dataset( max_examples=max_training_examples, simple_situation_representation=simple_situation_representation, load_tensors_from_path=load_tensors_from_path ) # set this to False if no pickle file available logger.info("Done Loading Training set.") logger.info(" Loaded {} training examples.".format( training_set.num_examples)) logger.info(" Input vocabulary size training set: {}".format( training_set.input_vocabulary_size)) logger.info(" Most common input words: {}".format( training_set.input_vocabulary.most_common(5))) logger.info(" Output vocabulary size training set: {}".format( training_set.target_vocabulary_size)) logger.info(" Most common target words: {}".format( training_set.target_vocabulary.most_common(5))) if generate_vocabularies: training_set.save_vocabularies(input_vocab_path, target_vocab_path) logger.info( "Saved vocabularies to {} for input and {} for target.".format( input_vocab_path, target_vocab_path)) # logger.info("Loading Dev. set...") # test_set = GroundedScanDataset(data_path, data_directory, split="dev", # input_vocabulary_file=input_vocab_path, # target_vocabulary_file=target_vocab_path, generate_vocabulary=False, k=0) # test_set.read_dataset(max_examples=max_testing_examples, # simple_situation_representation=simple_situation_representation) # # # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set. # test_set.shuffle_data() # logger.info("Done Loading Dev. set.") generator = Model( input_vocabulary_size=training_set.input_vocabulary_size, target_vocabulary_size=training_set.target_vocabulary_size, num_cnn_channels=training_set.image_channels, input_padding_idx=training_set.input_vocabulary.pad_idx, target_pad_idx=training_set.target_vocabulary.pad_idx, target_eos_idx=training_set.target_vocabulary.eos_idx, **cfg) total_vocabulary = set( list(training_set.input_vocabulary._word_to_idx.keys()) + list(training_set.target_vocabulary._word_to_idx.keys())) total_vocabulary_size = len(total_vocabulary) discriminator = Discriminator(embedding_dim=disc_emb_dim, hidden_dim=disc_hid_dim, vocab_size=total_vocabulary_size, max_seq_len=max_decoding_steps) generator = generator.cuda() if use_cuda else generator discriminator = discriminator.cuda() if use_cuda else discriminator rollout = Rollout(generator, rollout_update_rate) log_parameters(generator) trainable_parameters = [ parameter for parameter in generator.parameters() if parameter.requires_grad ] optimizer = torch.optim.Adam(trainable_parameters, lr=learning_rate, betas=(adam_beta_1, adam_beta_2)) scheduler = LambdaLR(optimizer, lr_lambda=lambda t: lr_decay**(t / lr_decay_steps)) # Load model and vocabularies if resuming. start_iteration = 1 best_iteration = 1 best_accuracy = 0 best_exact_match = 0 best_loss = float('inf') if resume_from_file: assert os.path.isfile( resume_from_file), "No checkpoint found at {}".format( resume_from_file) logger.info( "Loading checkpoint from file at '{}'".format(resume_from_file)) optimizer_state_dict = generator.load_model(resume_from_file) optimizer.load_state_dict(optimizer_state_dict) start_iteration = generator.trained_iterations logger.info("Loaded checkpoint '{}' (iter {})".format( resume_from_file, start_iteration)) if pretrain_gen_path is None: print('Pretraining generator with MLE...') pre_train_generator(training_set, training_batch_size, generator, seed, pretrain_gen_epochs, name='pretrained_generator') else: print('Load pretrained generator weights') generator_weights = torch.load(pretrain_gen_path) generator.load_state_dict(generator_weights) if pretrain_disc_path is None: print('Pretraining Discriminator....') train_discriminator(training_set, discriminator, training_batch_size, generator, seed, pretrain_disc_epochs, name="pretrained_discriminator") else: print('Loading Discriminator....') discriminator_weights = torch.load(pretrain_disc_path) discriminator.load_state_dict(discriminator_weights) logger.info("Training starts..") training_iteration = start_iteration torch.autograd.set_detect_anomaly(True) while training_iteration < max_training_iterations: # Shuffle the dataset and loop over it. training_set.shuffle_data() for (input_batch, input_lengths, _, situation_batch, _, target_batch, target_lengths, agent_positions, target_positions) in \ training_set.get_data_iterator(batch_size=training_batch_size): is_best = False generator.train() # Forward pass. samples = generator.sample( batch_size=training_batch_size, max_seq_len=max(target_lengths).astype(int), commands_input=input_batch, commands_lengths=input_lengths, situations_input=situation_batch, target_batch=target_batch, sos_idx=training_set.input_vocabulary.sos_idx, eos_idx=training_set.input_vocabulary.eos_idx) rewards = rollout.get_reward(samples, rollout_trails, input_batch, input_lengths, situation_batch, target_batch, training_set.input_vocabulary.sos_idx, training_set.input_vocabulary.eos_idx, discriminator) assert samples.shape == rewards.shape # calculate rewards rewards = torch.exp(rewards).contiguous().view((-1, )) if use_cuda: rewards = rewards.cuda() # get generator scores for sequence target_scores = generator.get_normalized_logits( commands_input=input_batch, commands_lengths=input_lengths, situations_input=situation_batch, samples=samples, sample_lengths=target_lengths, sos_idx=training_set.input_vocabulary.sos_idx) del samples # calculate loss on the generated sequence given the rewards loss = generator.get_gan_loss(target_scores, target_batch, rewards) del rewards # Backward pass and update model parameters. loss.backward() optimizer.step() scheduler.step(training_iteration) optimizer.zero_grad() generator.update_state(is_best=is_best) # Print current metrics. if training_iteration % print_every == 0: # accuracy, exact_match = generator.get_metrics(target_scores, target_batch) learning_rate = scheduler.get_lr()[0] logger.info("Iteration %08d, loss %8.4f, learning_rate %.5f," % (training_iteration, loss, learning_rate)) # logger.info("Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f," # % (training_iteration, loss, accuracy, exact_match, learning_rate)) del target_scores, target_batch # # Evaluate on test set. # if training_iteration % evaluate_every == 0: # with torch.no_grad(): # generator.eval() # logger.info("Evaluating..") # accuracy, exact_match, target_accuracy = evaluate( # test_set.get_data_iterator(batch_size=1), model=generator, # max_decoding_steps=max_decoding_steps, pad_idx=test_set.target_vocabulary.pad_idx, # sos_idx=test_set.target_vocabulary.sos_idx, # eos_idx=test_set.target_vocabulary.eos_idx, # max_examples_to_evaluate=kwargs["max_testing_examples"]) # logger.info(" Evaluation Accuracy: %5.2f Exact Match: %5.2f " # " Target Accuracy: %5.2f" % (accuracy, exact_match, target_accuracy)) # if exact_match > best_exact_match: # is_best = True # best_accuracy = accuracy # best_exact_match = exact_match # generator.update_state(accuracy=accuracy, exact_match=exact_match, is_best=is_best) # file_name = "checkpoint.pth.tar".format(str(training_iteration)) # if is_best: # generator.save_checkpoint(file_name=file_name, is_best=is_best, # optimizer_state_dict=optimizer.state_dict()) rollout.update_params() train_discriminator(training_set, discriminator, training_batch_size, generator, seed, epochs=1, name="training_discriminator") training_iteration += 1 if training_iteration > max_training_iterations: break del loss torch.save( generator.state_dict(), '{}/{}'.format(output_directory, 'gen_{}_{}.ckpt'.format(training_iteration, seed))) torch.save( discriminator.state_dict(), '{}/{}'.format(output_directory, 'dis_{}_{}.ckpt'.format(training_iteration, seed))) logger.info("Finished training.")
def train_and_evaluate(model, data_loader, train_data, val_data, test_data, optimizer, metrics, params, model_dir, data_encoder, label_encoder, restore_file=None, save_model=True, eval=True): from src.ner.utils import SummaryWriter, Label, plot # plotting tools train_summary_writer = SummaryWriter([*metrics] + ['loss'], name='train') val_summary_writer = SummaryWriter([*metrics] + ['loss'], name='val') test_summary_writer = SummaryWriter([*metrics] + ['loss'], name='test') writers = [train_summary_writer, val_summary_writer, test_summary_writer] labeller = Label(anchor_metric='f1_score', anchor_writer='val') plots_dir = os.path.join(model_dir, 'plots') if not os.path.exists(plots_dir): os.makedirs(plots_dir) start_epoch = -1 if restore_file is not None: logging.info("Restoring parameters from {}".format(restore_file)) checkpoint = utils.load_checkpoint(restore_file, model, optimizer) start_epoch = checkpoint['epoch'] # save the snapshot of parameters fro reproducibility utils.save_dict_to_json(params.dict, os.path.join(model_dir, 'train_snapshot.json')) # variable initialization best_val_score = 0.0 patience = 0 early_stopping_metric = 'f1_score' # set the Learning rate Scheduler lambda_lr = lambda epoch: 1 / (1 + (params.lr_decay_rate * epoch)) lr_scheduler = LambdaLR(optimizer, lr_lambda=lambda_lr, last_epoch=start_epoch) # train over epochs for epoch in range(start_epoch + 1, params.num_epochs): lr_scheduler.step() # Run one epoch logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs)) logging.info("Learning Rate : {}".format(lr_scheduler.get_lr())) # compute number of batches in one epoch (one full pass over the training set) # num_steps = (params.train_size + 1) // params.batch_size num_steps = (train_data['size'] + 1) // params.batch_size train_data_iterator = data_loader.batch_iterator(train_data, batch_size=params.batch_size, shuffle=True) train_metrics = train(model, optimizer, train_data_iterator, metrics, params, num_steps, data_encoder, label_encoder) val_score = train_metrics[early_stopping_metric] is_best = val_score >= best_val_score train_summary_writer.update(train_metrics) if eval: # Evaluate for one epoch on validation set # num_steps = (params.val_size + 1) // params.batch_size num_steps = (val_data['size'] + 1) // params.batch_size val_data_iterator = data_loader.batch_iterator(val_data, batch_size=params.batch_size, shuffle=False) val_metrics = evaluate(model, val_data_iterator, metrics, num_steps, label_encoder, mode='val') val_score = val_metrics[early_stopping_metric] is_best = val_score >= best_val_score val_summary_writer.update(val_metrics) ### TEST # num_steps = (params.test_size + 1) // params.batch_size num_steps = (test_data['size'] + 1) // params.batch_size test_data_iterator = data_loader.batch_iterator(test_data, batch_size=params.batch_size, shuffle=False) test_metrics = evaluate(model, test_data_iterator, metrics, num_steps, label_encoder, mode='test') test_summary_writer.update(test_metrics) labeller.update(writers=writers) plot(writers=writers, plot_dir=plots_dir, save=True) # Save weights if save_model: utils.save_checkpoint({'epoch': epoch, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict()}, is_best=is_best, checkpoint=model_dir) # save encoders only if they do not exist yet if not os.path.exists(os.path.join(model_dir, 'data_encoder.pkl')): utils.save_obj(data_encoder, os.path.join(model_dir, 'data_encoder.pkl')) if not os.path.exists(os.path.join(model_dir, 'label_encoder.pkl')): utils.save_obj(label_encoder, os.path.join(model_dir, 'label_encoder.pkl')) # If best_eval, best_save_path if is_best: patience = 0 logging.info("- Found new best F1 score") best_val_score = val_score # Save best metrics in a json file in the model directory if eval: utils.save_dict_to_json(val_metrics, os.path.join(model_dir, 'plots', "metrics_val_best_weights.json")) utils.save_dict_to_json(test_metrics, os.path.join(model_dir, 'plots', "metrics_test_best_weights.json")) utils.save_dict_to_json(train_metrics, os.path.join(model_dir, 'plots', "metrics_train_best_weights.json")) else: if eval: patience += 1 logging.info('current patience: {} ; max patience: {}'.format(patience, params.patience)) if patience == params.patience: logging.info('patience reached. Exiting at epoch: {}'.format(epoch + 1)) # Save latest metrics in a json file in the model directory before exiting if eval: utils.save_dict_to_json(val_metrics, os.path.join(model_dir, 'plots', "metrics_val_last_weights.json")) utils.save_dict_to_json(test_metrics, os.path.join(model_dir, 'plots', "metrics_test_last_weights.json")) utils.save_dict_to_json(train_metrics, os.path.join(model_dir, 'plots', "metrics_train_last_weights.json")) epoch = epoch - patience break # Save latest metrics in a json file in the model directory at end of epoch if eval: utils.save_dict_to_json(val_metrics, os.path.join(model_dir, 'plots', "metrics_val_last_weights.json")) utils.save_dict_to_json(test_metrics, os.path.join(model_dir, 'plots', "metrics_test_last_weights.json")) utils.save_dict_to_json(train_metrics, os.path.join(model_dir, 'plots', "metrics_train_last_weights.json")) return epoch
def train(model, tokenizer, train_data, valid_data, args, eos=False): model.train() train_dataset = TextDataset(train_data) train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=args.train_batch_size, num_workers=args.num_workers, collate_fn=lambda x: collate_fn(x, tokenizer, args.max_seq_length, eos=eos, tokenizer_type=args.tokenizer)) valid_dataset = TextDataset(valid_data) valid_dataloader = DataLoader(valid_dataset, sampler=SequentialSampler(valid_dataset), batch_size=args.eval_batch_size, num_workers=args.num_workers, collate_fn=lambda x: collate_fn(x, tokenizer, args.max_seq_length, eos=eos, tokenizer_type=args.tokenizer)) valid_noisy = [x['noisy'] for x in valid_data] valid_clean = [x['clean'] for x in valid_data] epochs = (args.max_steps - 1) // len(train_dataloader) + 1 optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=eval(args.adam_betas), eps=args.eps, weight_decay=args.weight_decay) lr_lambda = lambda x: x / args.num_warmup_steps if x <= args.num_warmup_steps else (x / args.num_warmup_steps) ** -0.5 scheduler = LambdaLR(optimizer, lr_lambda) step = 0 best_val_gleu = -float("inf") meter = Meter() for epoch in range(1, epochs + 1): print("===EPOCH: ", epoch) for batch in train_dataloader: step += 1 batch = tuple(t.to(args.device) for t in batch) loss, items = calc_loss(model, batch) meter.add(*items) loss.backward() if args.max_grad_norm > 0: nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() model.zero_grad() scheduler.step() if step % args.log_interval == 0: lr = scheduler.get_lr()[0] loss_sent, loss_token = meter.average() logger.info(f' [{step:5d}] lr {lr:.6f} | {meter.print_str(True)}') nsml.report(step=step, scope=locals(), summary=True, train__lr=lr, train__loss_sent=loss_sent, train__token_ppl=math.exp(loss_token)) meter.init() if step % args.eval_interval == 0: start_eval = time.time() (val_loss, val_loss_token), valid_str = evaluate(model, valid_dataloader, args) prediction = correct(model, tokenizer, valid_noisy, args, eos=eos, length_limit=0.1) val_em = em(prediction, valid_clean) cnt = 0 for noisy, pred, clean in zip(valid_noisy, prediction, valid_clean): print(f'[{noisy}], [{pred}], [{clean}]') # 10개만 출력하기 cnt += 1 if cnt == 20: break val_gleu = gleu(prediction, valid_clean) logger.info('-' * 89) logger.info(f' [{step:6d}] valid | {valid_str} | em {val_em:5.2f} | gleu {val_gleu:5.2f}') logger.info('-' * 89) nsml.report(step=step, scope=locals(), summary=True, valid__loss_sent=val_loss, valid__token_ppl=math.exp(val_loss_token), valid__em=val_em, valid__gleu=val_gleu) if val_gleu > best_val_gleu: best_val_gleu = val_gleu nsml.save("best") meter.start += time.time() - start_eval if step >= args.max_steps: break #nsml.save(epoch) if step >= args.max_steps: break
def train_and_valid_(net, criterion, optimizer, train_loader, valid_loader, cfg, is_lr_adjust=True, is_lr_warmup=False): # ------------------配置信息------------------------------ # 若检查点存在且容许使用检查点,则加载参数进行训练 if os.path.exists(cfg.checkpoints) and cfg.use_checkpoints: # 加载权重信息 net.load_state_dict(torch.load(cfg.checkpoints)) print('加载权重信息...') # 配置学习率衰减器(默认是按epoch衰减);两种类型的学习率衰减 if is_lr_adjust: # 按一定周期之后进行衰减<StepLR> lr_shcleduler_step = StepLR(optimizer=optimizer, step_size=cfg.lr_decay_step) elif is_lr_warmup: # 若True,则开启学习率预热 # 定义Lambda表达式 < LambdaLR > lr_lambda = lambda epoch: epoch / cfg.lr_warmup_step lr_shcleduler_warmup = LambdaLR(optimizer=optimizer, lr_lambda=lr_lambda) lr_shcleduler_warmup.step() # 获得记录日志信息的写入器 writer = SummaryWriter(cfg.log_dir) # ------------------定义训练、验证子函数-------------------- # 训练子函数 def _train(train_loader, num_step): print(' training stage....') # 将网络结构调成训练模式;初始化梯度张量 net.train() optimizer.zero_grad() # 定义准确率变量,损失值,批次数量,样本总数量 train_acc = 0.0 train_loss = 0.0 num_batch = 0 num_samples = 0 # 进行网络的训练 for index, data in enumerate(train_loader, start=0): # 获取每批次的训练数据、并将训练数据放入GPU中 images, labels = data # print(images.size(), labels) images = images.to(cfg.device) labels = labels.to(cfg.device) # 推理输出网络预测值,并使用softmax使预测值满足0-1概率范围;计算损失函数值 outputs = net(images) outputs = F.softmax(outputs, dim=1) loss = criterion(outputs, labels) # 计算每个预测值概率最大的索引(下标) preds = torch.argmax(outputs, dim=1) # 计算批次的准确率,预测值中预测正确的样本占总样本的比例 # 统计准确率、损失值、批次数量 acc = torch.sum(preds == labels).item() train_acc += acc train_loss += loss num_batch += 1 num_samples += images.size(0) # 判断是否使用梯度累积技巧(显存少的时候),否则,进行正常的反向传播(计算梯度)和梯度下降优化操作 if cfg.grad_accuml is True and cfg.batch_size < 128: # 累积损失,求累积损失的平均损失 loss = loss / cfg.batch_accumulate_size loss.backward() # 满足一定批次要求则进行梯度参数更新,重置梯度张量 if (index + 1) % cfg.batch_accumulate_size == 0: optimizer.step() optimizer.zero_grad() else: # 计算梯度、更新参数、重置梯度张量 loss.backward() optimizer.step() optimizer.zero_grad() # 输出一定次数的损失和精度情况 if (index + 1) % cfg.print_rate == 0: # 输出损失值和精度值 print(' batch:{}, batch_loss:{:.4f}, batch_acc:{:.4f}\n'. format(index, loss, acc / images.size(0))) # 记录训练批次的损失和准确率 # writer.add_scalar('Train/Loss', scalar_value=loss, global_step=index) # 单个标签 writer.add_scalars(main_tag='Train(batch)', tag_scalar_dict={ 'batch_loss': loss, 'batch_accuracy': acc / images.size(0) }, global_step=num_step) # 更新全局步骤 num_step += 1 # 计算训练的准确率和损失值 train_acc = train_acc / num_samples train_loss = train_loss / num_batch return train_acc, train_loss, num_step # 验证子函数 def _valid(valid_loader): print(' valid stage...') # 将网络结构调成验证模式;所有样本的准确率、损失值;统计批次数量; net.eval() valid_acc = 0.0 valid_loss = 0.0 num_batch = 0 num_samples = 0 # 进行测试集的测试 with torch.no_grad(): # 不使用梯度,减少内存占用 for index, data in enumerate(valid_loader, start=0): images, labels = data # 将测试数据放入GPU上 images, labels = images.to(cfg.device), labels.to(cfg.device) # 推理输出网络预测值,并使用softmax使预测值满足0-1概率范围 outputs = net(images) outputs = F.softmax(outputs, dim=1) # 计算每个预测值概率最大的索引(下标);计算损失值 pred = torch.argmax(outputs, dim=1) loss = criterion(outputs, labels) # 统计真实标签和预测标签的对应情况;计算损失 valid_acc += torch.sum((pred == labels)).item() valid_loss += loss num_batch += 1 num_samples += images.size(0) # 计算测试精度和损失值 valid_acc = valid_acc / num_samples valid_loss = valid_loss / num_batch return valid_acc, valid_loss # ----------------------------开始周期训练-------------------------------- # 定义训练开始时间、最好验证准确度(用于保存最好的模型)、统计训练步骤总数 start_time = time.time() best_acc = 0.0 num_step = 0 # 开始周期训练 for epoch in range(cfg.epochs): # 设定每周期开始时间点、周期信息 epoch_start_time = time.time() print('Epoch {}/{}'.format(epoch, cfg.epochs - 1)) print('-' * 20) # 训练 train_acc, train_loss, num_step = _train(train_loader, num_step) # 验证 valid_acc, valid_loss = _valid(valid_loader) # 调整学习率 # 在前几周期内,进行学习率预热 if is_lr_warmup is True and epoch < cfg.lr_warmup_step: lr_shcleduler_warmup.step() print(' epoch:{}/{}, learning rate warmup...{}'.format( epoch, cfg.lr_warmup_step - 1, lr_shcleduler_warmup.get_lr())) elif is_lr_adjust: # 在经过一定学习率预热后,学习率恢复成初始的值。或则直接进行周期下降。 lr_shcleduler_step.step() # 输出每周期的训练、验证的平均损失值、准确率 epoch_time = time.time() - epoch_start_time print(' epoch:{}/{}, time:{:.0f}m {:.0f}s'.format( epoch, cfg.epochs, epoch_time // 60, epoch_time % 60)) print( ' train_loss:{:.4f}, train_acc:{:.4f}\n valid_loss:{:.4f}, valid_acc:{:.4f}' .format(train_loss, train_acc, valid_loss, valid_acc)) # 记录测试结果 writer.add_scalars(main_tag='Train(epoch)', tag_scalar_dict={ 'train_loss': train_loss, 'train_acc': train_acc, 'valid_loss': valid_loss, 'valid_acc': valid_acc }, global_step=epoch) # 选出最好的模型参数 if valid_acc > best_acc: # 更新最好精度、保存最好的模型参数 best_acc = valid_acc torch.save(net.state_dict(), cfg.checkpoints) print(' epoch:{}, update model...'.format(epoch)) print() # 训练结束时间、输出最好的精度 end_time = time.time() - start_time print('Training complete in {:.0f}m {:.0f}s'.format( end_time // 60, end_time % 60)) print('Best val Acc: {:4f}'.format(best_acc)) # 关闭writer writer.close()
def train(self) -> None: r"""Main method for training PPO. Returns: None """ global lr_lambda logger.info(f"config: {self.config}") random.seed(self.config.SEED) np.random.seed(self.config.SEED) torch.manual_seed(self.config.SEED) self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME), auto_reset_done=False) ppo_cfg = self.config.RL.PPO self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if not os.path.isdir(self.config.CHECKPOINT_FOLDER): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.agent.parameters()))) rollouts = RolloutStorage(ppo_cfg.num_steps, self.envs.num_envs, self.envs.observation_spaces[0], self.envs.action_spaces[0], ppo_cfg.hidden_size) rollouts.to(self.device) observations = self.envs.reset() batch = batch_obs(observations) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None # episode_rewards and episode_counts accumulates over the entire training course episode_rewards = torch.zeros(self.envs.num_envs, 1) episode_spls = torch.zeros(self.envs.num_envs, 1) episode_steps = torch.zeros(self.envs.num_envs, 1) episode_counts = torch.zeros(self.envs.num_envs, 1) episode_distances = torch.zeros(self.envs.num_envs, 1) current_episode_reward = torch.zeros(self.envs.num_envs, 1) current_episode_step = torch.zeros(self.envs.num_envs, 1) window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) window_episode_spl = deque(maxlen=ppo_cfg.reward_window_size) window_episode_step = deque(maxlen=ppo_cfg.reward_window_size) window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) window_episode_distances = deque(maxlen=ppo_cfg.reward_window_size) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 if ppo_cfg.use_linear_lr_decay: def lr_lambda(x): return linear_decay(x, self.config.NUM_UPDATES) elif ppo_cfg.use_exponential_lr_decay: def lr_lambda(x): return exponential_decay(x, self.config.NUM_UPDATES, ppo_cfg.exp_decay_lambda) else: def lr_lambda(x): return 1 lr_scheduler = LambdaLR(optimizer=self.agent.optimizer, lr_lambda=lr_lambda) with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay or ppo_cfg.use_exponential_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) for step in range(ppo_cfg.num_steps): delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step( rollouts, current_episode_reward, current_episode_step, episode_rewards, episode_spls, episode_counts, episode_steps, episode_distances) pth_time += delta_pth_time env_time += delta_env_time count_steps += delta_steps delta_pth_time, value_loss, action_loss, dist_entropy = self._update_agent( ppo_cfg, rollouts) pth_time += delta_pth_time window_episode_reward.append(episode_rewards.clone()) window_episode_spl.append(episode_spls.clone()) window_episode_step.append(episode_steps.clone()) window_episode_counts.append(episode_counts.clone()) window_episode_distances.append(episode_distances.clone()) losses = [value_loss, action_loss, dist_entropy] stats = zip( ["count", "reward", "step", 'spl', 'distance'], [ window_episode_counts, window_episode_reward, window_episode_step, window_episode_spl, window_episode_distances ], ) deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in stats } deltas["count"] = max(deltas["count"], 1.0) # this reward is averaged over all the episodes happened during window_size updates # approximately number of steps is window_size * num_steps writer.add_scalar("Environment/Reward", deltas["reward"] / deltas["count"], count_steps) writer.add_scalar("Environment/SPL", deltas["spl"] / deltas["count"], count_steps) logging.debug('Number of steps: {}'.format(deltas["step"] / deltas["count"])) writer.add_scalar("Environment/Episode_length", deltas["step"] / deltas["count"], count_steps) writer.add_scalar("Environment/Distance_to_goal", deltas["distance"] / deltas["count"], count_steps) # writer.add_scalars( # "losses", # {k: l for l, k in zip(losses, ["value", "policy"])}, # count_steps, # ) writer.add_scalar('Policy/Value_Loss', value_loss, count_steps) writer.add_scalar('Policy/Action_Loss', action_loss, count_steps) writer.add_scalar('Policy/Entropy', dist_entropy, count_steps) writer.add_scalar('Policy/Learning_Rate', lr_scheduler.get_lr()[0], count_steps) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time))) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) window_rewards = (window_episode_reward[-1] - window_episode_reward[0]).sum() window_counts = (window_episode_counts[-1] - window_episode_counts[0]).sum() if window_counts > 0: logger.info( "Average window size {} reward: {:3f}".format( len(window_episode_reward), (window_rewards / window_counts).item(), )) else: logger.info("No episodes finish in current window") # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint(f"ckpt.{count_checkpoints}.pth") count_checkpoints += 1 self.envs.close()
def train(): train_data = ACNet_data.SUNRGBD(transform=transforms.Compose([ACNet_data.scaleNorm(), ACNet_data.RandomScale((1.0, 1.4)), ACNet_data.RandomHSV((0.9, 1.1), (0.9, 1.1), (25, 25)), ACNet_data.RandomCrop(image_h, image_w), ACNet_data.RandomFlip(), ACNet_data.ToTensor(), ACNet_data.Normalize()]), phase_train=True, data_dir=args.data_dir) train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False) num_train = len(train_data) if args.last_ckpt: model = ACNet_models_V1.ACNet(num_class=40, pretrained=False) else: model = ACNet_models_V1.ACNet(num_class=40, pretrained=True) if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model) CEL_weighted = utils.CrossEntropyLoss2d(weight=nyuv2_frq) model.train() model.to(device) CEL_weighted.to(device) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) global_step = 0 if args.last_ckpt: global_step, args.start_epoch = load_ckpt(model, optimizer, args.last_ckpt, device) lr_decay_lambda = lambda epoch: args.lr_decay_rate ** (epoch // args.lr_epoch_per_decay) scheduler = LambdaLR(optimizer, lr_lambda=lr_decay_lambda) writer = SummaryWriter(args.summary_dir) for epoch in range(int(args.start_epoch), args.epochs): scheduler.step(epoch) local_count = 0 last_count = 0 end_time = time.time() if epoch % args.save_epoch_freq == 0 and epoch != args.start_epoch: save_ckpt(args.ckpt_dir, model, optimizer, global_step, epoch, local_count, num_train) for batch_idx, sample in enumerate(train_loader): image = sample['image'].to(device) depth = sample['depth'].to(device) target_scales = [sample[s].to(device) for s in ['label', 'label2', 'label3', 'label4', 'label5']] optimizer.zero_grad() pred_scales = model(image, depth, args.checkpoint) loss = CEL_weighted(pred_scales, target_scales) loss.backward() optimizer.step() local_count += image.data.shape[0] global_step += 1 if global_step % args.print_freq == 0 or global_step == 1: time_inter = time.time() - end_time count_inter = local_count - last_count print_log(global_step, epoch, local_count, count_inter, num_train, loss, time_inter) end_time = time.time() for name, param in model.named_parameters(): writer.add_histogram(name, param.clone().cpu().data.numpy(), global_step, bins='doane') grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True) writer.add_image('image', grid_image, global_step) grid_image = make_grid(depth[:3].clone().cpu().data, 3, normalize=True) writer.add_image('depth', grid_image, global_step) grid_image = make_grid(utils.color_label(torch.max(pred_scales[0][:3], 1)[1] + 1), 3, normalize=False, range=(0, 255)) writer.add_image('Predicted label', grid_image, global_step) grid_image = make_grid(utils.color_label(target_scales[0][:3]), 3, normalize=False, range=(0, 255)) writer.add_image('Groundtruth label', grid_image, global_step) writer.add_scalar('CrossEntropyLoss', loss.data, global_step=global_step) writer.add_scalar('Learning rate', scheduler.get_lr()[0], global_step=global_step) last_count = local_count save_ckpt(args.ckpt_dir, model, optimizer, global_step, args.epochs, 0, num_train) print("Training completed ")
def main(args): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = T.Compose([ T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # define networks (both generators and discriminators) netG_S2T = cyclegan.generator.__dict__[args.netG]( ngf=args.ngf, norm=args.norm, use_dropout=False).to(device) netG_T2S = cyclegan.generator.__dict__[args.netG]( ngf=args.ngf, norm=args.norm, use_dropout=False).to(device) netD_S = cyclegan.discriminator.__dict__[args.netD]( ndf=args.ndf, norm=args.norm).to(device) netD_T = cyclegan.discriminator.__dict__[args.netD]( ndf=args.ndf, norm=args.norm).to(device) # create image buffer to store previously generated images fake_S_pool = ImagePool(args.pool_size) fake_T_pool = ImagePool(args.pool_size) # define optimizer and lr scheduler optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr, betas=(args.beta1, 0.999)) optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999)) lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs ) / float(args.epochs_decay) lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function) lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function) # optionally resume from a checkpoint if args.resume: print("Resume from", args.resume) checkpoint = torch.load(args.resume, map_location='cpu') netG_S2T.load_state_dict(checkpoint['netG_S2T']) netG_T2S.load_state_dict(checkpoint['netG_T2S']) netD_S.load_state_dict(checkpoint['netD_S']) netD_T.load_state_dict(checkpoint['netD_T']) optimizer_G.load_state_dict(checkpoint['optimizer_G']) optimizer_D.load_state_dict(checkpoint['optimizer_D']) lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G']) lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D']) args.start_epoch = checkpoint['epoch'] + 1 if args.phase == 'test': transform = T.Compose([ T.Resize(image_size=args.test_input_size), T.wrapper(cyclegan.transform.Translation)(netG_S2T, device), ]) train_source_dataset.translate(transform, args.translated_root) return # define loss function criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss() criterion_cycle = nn.L1Loss() criterion_identity = nn.L1Loss() # define visualization function tensor_to_image = Compose( [Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ToPILImage()]) def visualize(image, name): """ Args: image (tensor): image in shape 3 x H x W name: name of the saving image """ tensor_to_image(image).save( logger.get_image_path("{}.png".format(name))) # start training for epoch in range(args.start_epoch, args.epochs + args.epochs_decay): logger.set_epoch(epoch) print(lr_scheduler_G.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch, visualize, args) # update learning rates lr_scheduler_G.step() lr_scheduler_D.step() # save checkpoint torch.save( { 'netG_S2T': netG_S2T.state_dict(), 'netG_T2S': netG_T2S.state_dict(), 'netD_S': netD_S.state_dict(), 'netD_T': netD_T.state_dict(), 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D': optimizer_D.state_dict(), 'lr_scheduler_G': lr_scheduler_G.state_dict(), 'lr_scheduler_D': lr_scheduler_D.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch)) if args.translated_root is not None: transform = T.Compose([ T.Resize(image_size=args.test_input_size), T.wrapper(cyclegan.transform.Translation)(netG_S2T, device), ]) train_source_dataset.translate(transform, args.translated_root) logger.close()
criterion = CrossEntropyLoss2d() metrics = Metrics() if store.metrics: metrics.load_state_dict(store.metrics) if FAKE: print('STOP TRAINING') exit(0) # LOOP print(f'Starting ({now_str()})') iter_count = len(data_set) // BATCH_SIZE while epoch < first_epoch + EPOCH_COUNT: iter_metrics = Metrics() lr = scheduler.get_lr()[0] for i, (inputs, labels) in enumerate(data_loader): inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs).to(device) loss = criterion(outputs, labels) coef = Coef.calc(outputs, labels) iter_metrics.append_loss(loss.item()) iter_metrics.append_coef(coef) pp('epoch[{ep}]:{i}/{I} iou:{c.pjac:.4f} acc:{c.pdice:.4f} loss:{loss:.4f} lr:{lr:.4f} ({t})'.format( ep=epoch, i=i+1, I=iter_count, lr=lr, t=now_str(), loss=loss.item(), c=coef)) loss.backward() optimizer.step() pp('epoch[{ep}]:Done. iou:{c.pjac:.4f} acc:{c.pdice:.4f} gsi:{c.gsensi:.4f} gsp:{c.gspec:.4f} tsi:{c.tsensi:.4f} tsp:{c.tspec:.4f} loss:{loss:.4f} lr:{lr:.4f} ({t})'.format( ep=epoch, t=now_str(), lr=lr, loss=iter_metrics.avg('losses'), c=iter_metrics.avg_coef()
def train(self, model, train_loader, val_loader=None, num_epochs=10, log_nth=0, model_args={}): """ Train a given model with the provided data. Inputs: - model: model object initialized from a torch.nn.Module - train_loader: train data in torch.utils.data.DataLoader - val_loader: val data in torch.utils.data.DataLoader - num_epochs: total number of training epochs - log_nth: log training accuracy and loss every nth iteration """ self.writer = tb.SummaryWriter(self.tb_dir) self.val_writer = tb.SummaryWriter(self.tb_val_dir) # filter out frcnn if this is added to the module parameters = [ param for name, param in model.named_parameters() if 'frcnn' not in name ] optim = self.optim(parameters, **self.optim_args) if self.lr_scheduler_lambda: scheduler = LambdaLR(optim, lr_lambda=self.lr_scheduler_lambda) else: scheduler = None self._reset_histories() iter_per_epoch = len(train_loader) print('START TRAIN.') ############################################################################ # TODO: # # Write your own personal training method for our solver. In Each epoch # # iter_per_epoch shuffled training batches are processed. The loss for # # each batch is stored in self.train_loss_history. Every log_nth iteration # # the loss is logged. After one epoch the training accuracy of the last # # mini batch is logged and stored in self.train_acc_history. # # We validate at the end of each epoch, log the result and store the # # accuracy of the entire validation set in self.val_acc_history. # # # Your logging should like something like: # # ... # # [Iteration 700/4800] TRAIN loss: 1.452 # # [Iteration 800/4800] TRAIN loss: 1.409 # # [Iteration 900/4800] TRAIN loss: 1.374 # # [Epoch 1/5] TRAIN acc/loss: 0.560/1.374 # # [Epoch 1/5] VAL acc/loss: 0.539/1.310 # # ... # ############################################################################ for epoch in range(num_epochs): # TRAINING if scheduler: scheduler.step() print("[*] New learning rate(s): {}".format( scheduler.get_lr())) now = time.time() for i, batch in enumerate(train_loader, 1): #inputs, labels = Variable(batch[0]), Variable(batch[1]) optim.zero_grad() losses = model.sum_losses(batch, **model_args) losses['total_loss'].backward() optim.step() for k, v in losses.items(): if k not in self._losses.keys(): self._losses[k] = [] self._losses[k].append(v.data.cpu().numpy()) if log_nth and i % log_nth == 0: next_now = time.time() print('[Iteration %d/%d] %.3f s/it' % (i + epoch * iter_per_epoch, iter_per_epoch * num_epochs, (next_now - now) / log_nth)) now = next_now for k, v in self._losses.items(): last_log_nth_losses = self._losses[k][-log_nth:] train_loss = np.mean(last_log_nth_losses) print('%s: %.3f' % (k, train_loss)) self.writer.add_scalar(k, train_loss, i + epoch * iter_per_epoch) # VALIDATION if val_loader and log_nth: model.eval() for i, batch in enumerate(val_loader): losses = model.sum_losses(batch, **model_args) for k, v in losses.items(): if k not in self._val_losses.keys(): self._val_losses[k] = [] self._val_losses[k].append(v.data.cpu().numpy()) if i >= log_nth: break model.train() for k, v in self._losses.items(): last_log_nth_losses = self._val_losses[k][-log_nth:] val_loss = np.mean(last_log_nth_losses) self.val_writer.add_scalar(k, val_loss, (epoch + 1) * iter_per_epoch) #blobs_val = data_layer_val.forward() #tracks_val = model.val_predict(blobs_val) #im = plot_tracks(blobs_val, tracks_val) #self.val_writer.add_image('val_tracks', im, (epoch+1) * iter_per_epoch) self.snapshot(model, (epoch + 1) * iter_per_epoch) self._reset_histories() self.writer.close() self.val_writer.close() ############################################################################ # END OF YOUR CODE # ############################################################################ print('FINISH.')
def train(self) -> None: r"""Main method for DD-PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += (self.world_rank * self.config.NUM_PROCESSES) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO if (not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) self.agent.init_distributed(find_unused_params=True) if ppo_cfg.use_belief_predictor and ppo_cfg.BELIEF_PREDICTOR.online_training: self.belief_predictor.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) if ppo_cfg.use_belief_predictor: logger.info( "belief predictor number of trainable parameters: {}". format( sum(param.numel() for param in self.belief_predictor.parameters() if param.requires_grad))) logger.info(f"config: {self.config}") observations = self.envs.reset() batch = batch_obs(observations, device=self.device) obs_space = self.envs.observation_spaces[0] if ppo_cfg.use_external_memory: memory_dim = self.actor_critic.net.memory_dim else: memory_dim = None rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.action_space, ppo_cfg.hidden_size, ppo_cfg.use_external_memory, ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size + ppo_cfg.num_steps, ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size, memory_dim, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, ) rollouts.to(self.device) if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.update(batch, None) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) # Try to resume at previous checkpoint (independent of interrupted states) count_steps_start, count_checkpoints, start_update = self.try_to_resume_checkpoint( ) count_steps = count_steps_start interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.load_state_dict( interrupted_state["belief_predictor"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) state_dict = dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, ) if self.config.RL.PPO.use_belief_predictor: state_dict[ 'belief_predictor'] = self.belief_predictor.state_dict( ) save_interrupted_state(state_dict) requeue_job() return count_steps_delta = 0 self.agent.eval() if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self.config.RL.PPO.use_belief_predictor: self.belief_predictor.train() self.belief_predictor.set_eval_encoders() if self._static_smt_encoder: self.actor_critic.net.set_eval_encoders() if ppo_cfg.use_belief_predictor and ppo_cfg.BELIEF_PREDICTOR.online_training: location_predictor_loss, prediction_accuracy = self.train_belief_predictor( rollouts) else: location_predictor_loss = 0 prediction_accuracy = 0 ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [ value_loss, action_loss, dist_entropy, location_predictor_loss, prediction_accuracy, count_steps_delta ], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[5].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, stats[2].item() / self.world_size, stats[3].item() / self.world_size, stats[4].item() / self.world_size, ] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar("Metrics/reward", deltas["reward"] / deltas["count"], count_steps) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: for metric, value in metrics.items(): writer.add_scalar(f"Metrics/{metric}", value, count_steps) writer.add_scalar("Policy/value_loss", losses[0], count_steps) writer.add_scalar("Policy/policy_loss", losses[1], count_steps) writer.add_scalar("Policy/entropy_loss", losses[2], count_steps) writer.add_scalar("Policy/predictor_loss", losses[3], count_steps) writer.add_scalar("Policy/predictor_accuracy", losses[4], count_steps) writer.add_scalar('Policy/learning_rate', lr_scheduler.get_lr()[0], count_steps) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, (count_steps - count_steps_start) / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) count_checkpoints += 1 self.envs.close()
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset( root=args.source_root, transforms=T.Compose([ T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)), T.ColorJitter(brightness=0.3, contrast=0.3), T.RandomHorizontalFlip(), T.NormalizeAndTranspose(), ]), ) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset( root=args.target_root, transforms=T.Compose([ T.RandomResizedCrop(size=args.train_size, ratio=(2., 2.), scale=(0.5, 1.)), T.RandomHorizontalFlip(), T.NormalizeAndTranspose(), ]), ) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_target_dataset = target_dataset( root=args.target_root, split='val', transforms=T.Compose([ T.Resize(image_size=args.test_input_size, label_size=args.test_output_size), T.NormalizeAndTranspose(), ]), ) val_target_loader = DataLoader(val_target_dataset, batch_size=1, shuffle=False, pin_memory=True) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model num_classes = train_source_dataset.num_classes model = models.__dict__[args.arch](num_classes=num_classes).to(device) discriminator = Discriminator(num_classes=num_classes).to(device) # define optimizer and lr scheduler optimizer = SGD(model.get_parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) optimizer_d = Adam(discriminator.parameters(), lr=args.lr_d, betas=(0.9, 0.99)) lr_scheduler = LambdaLR( optimizer, lambda x: args.lr * (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power)) lr_scheduler_d = LambdaLR( optimizer_d, lambda x: (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power)) # optionally resume from a checkpoint if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model']) discriminator.load_state_dict(checkpoint['discriminator']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) optimizer_d.load_state_dict(checkpoint['optimizer_d']) lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d']) args.start_epoch = checkpoint['epoch'] + 1 # define loss function (criterion) criterion = torch.nn.CrossEntropyLoss( ignore_index=args.ignore_label).to(device) dann = DomainAdversarialEntropyLoss(discriminator) interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True) interp_val = nn.Upsample(size=args.test_output_size[::-1], mode='bilinear', align_corners=True) # define visualization function decode = train_source_dataset.decode_target def visualize(image, pred, label, prefix): """ Args: image (tensor): 3 x H x W pred (tensor): C x H x W label (tensor): H x W prefix: prefix of the saving image """ image = image.detach().cpu().numpy() pred = pred.detach().max(dim=0)[1].cpu().numpy() label = label.cpu().numpy() for tensor, name in [ (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))), "image"), (decode(label), "label"), (decode(pred), "pred") ]: tensor.save(logger.get_image_path("{}_{}.png".format(prefix, name))) if args.phase == 'test': confmat = validate(val_target_loader, model, interp_val, criterion, visualize, args) print(confmat) return # start training best_iou = 0. for epoch in range(args.start_epoch, args.epochs): logger.set_epoch(epoch) print(lr_scheduler.get_lr(), lr_scheduler_d.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, model, interp_train, criterion, dann, optimizer, lr_scheduler, optimizer_d, lr_scheduler_d, epoch, visualize if args.debug else None, args) # evaluate on validation set confmat = validate(val_target_loader, model, interp_val, criterion, None, args) print(confmat.format(train_source_dataset.classes)) acc_global, acc, iu = confmat.compute() # calculate the mean iou over partial classes indexes = [ train_source_dataset.classes.index(name) for name in train_source_dataset.evaluate_classes ] iu = iu[indexes] mean_iou = iu.mean() # remember best acc@1 and save checkpoint torch.save( { 'model': model.state_dict(), 'discriminator': discriminator.state_dict(), 'optimizer': optimizer.state_dict(), 'optimizer_d': optimizer_d.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'lr_scheduler_d': lr_scheduler_d.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch)) if mean_iou > best_iou: shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best')) best_iou = max(best_iou, mean_iou) print("Target: {} Best: {}".format(mean_iou, best_iou)) logger.close()