train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, sampler=train_sampler, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=eval_batch_size, shuffle=False, drop_last=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, drop_last=True) if args.enable_gavel_iterator: train_loader = GavelIterator(train_loader, args.checkpoint_dir, load_checkpoint, save_checkpoint) state = None if args.checkpoint_dir is not None: if not os.path.isdir(args.checkpoint_dir): os.mkdir(args.checkpoint_dir) else: checkpoint_path = os.path.join(args.checkpoint_dir, 'model.chkpt') if os.path.exists(checkpoint_path): if args.enable_gavel_iterator: state = train_loader.load_checkpoint(args, checkpoint_path) else: state = load_checkpoint(args, checkpoint_path) if state is not None: model = state['model'].to(device) if model is None:
mp.set_start_method('spawn') setup_json = read_config(args.env_config) env_conf = setup_json["Default"] for i in setup_json.keys(): if i in args.env: env_conf = setup_json[i] env = atari_env(args.env, env_conf, args) shared_model = A3Clstm(env.observation_space.shape[0], env.action_space) iters = {} for rank in range(0, args.workers): iters[rank] = range(args.max_steps) if args.enable_gavel_iterator and rank == 0: iters[rank] = GavelIterator(iters[rank], args.checkpoint_dir, load_checkpoint, save_checkpoint, write_on_close=False) if not os.path.isdir(args.checkpoint_dir): os.mkdir(args.checkpoint_dir) checkpoint_path = os.path.join(args.checkpoint_dir, 'model.chkpt') if os.path.exists(checkpoint_path): if args.enable_gavel_iterator: saved_state = iters[0].load_checkpoint(args, checkpoint_path) else: saved_state = load_checkpoint(args, checkpoint_path) shared_model.load_state_dict(saved_state) shared_model.share_memory() if args.shared_optimizer:
def main(): ''' Main function ''' parser = argparse.ArgumentParser() parser.add_argument('-data', required=True) parser.add_argument('-epoch', type=int, default=None) parser.add_argument('-step', type=int, default=None) parser.add_argument('-batch_size', type=int, default=64) #parser.add_argument('-d_word_vec', type=int, default=512) parser.add_argument('-d_model', type=int, default=512) parser.add_argument('-d_inner_hid', type=int, default=2048) parser.add_argument('-d_k', type=int, default=64) parser.add_argument('-d_v', type=int, default=64) parser.add_argument('-n_head', type=int, default=8) parser.add_argument('-n_layers', type=int, default=6) # NOTE(keshav2): This just refers to the learning rate schedule, # nothing performance related. parser.add_argument('-n_warmup_steps', type=int, default=4000) parser.add_argument('-dropout', type=float, default=0.1) parser.add_argument('-embs_share_weight', action='store_true') parser.add_argument('-proj_share_weight', action='store_true') parser.add_argument('-log', default=None) parser.add_argument('--checkpoint_dir', type=str, default='/lfs/1/keshav2/checkpoints/transformer') parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='all') parser.add_argument('-no_cuda', action='store_true') parser.add_argument('-label_smoothing', action='store_true') parser.add_argument('--dist-url', default='env://', type=str, help='url used to set up distributed training') parser.add_argument('--dist-backend', default='nccl', type=str, help='Distributed backend') parser.add_argument('--local_rank', default=0, type=int, help='Local rank') parser.add_argument('--rank', default=None, type=int, help='Rank') parser.add_argument('--world_size', default=None, type=int, help='World size') parser.add_argument('--master_addr', default=None, type=str, help='Master address to use for distributed run') parser.add_argument('--master_port', default=None, type=int, help='Master port to use for distributed run') parser.add_argument('--throughput_estimation_interval', type=int, default=None, help='Steps between logging steps completed') parser.add_argument('--max_duration', type=int, default=None, help='Maximum duration in seconds') parser.add_argument('--enable_gavel_iterator', action='store_true', default=False, help='If set, use Gavel iterator') opt = parser.parse_args() opt.cuda = not opt.no_cuda opt.d_word_vec = opt.d_model torch.cuda.set_device(opt.local_rank) if opt.epoch is not None and opt.step is not None: raise ValueError('Only one of epoch and step may be set') elif opt.epoch is None and opt.step is None: raise ValueError('One of epoch and step must be set') opt.distributed = False if opt.master_addr is not None: opt.distributed = True os.environ['MASTER_ADDR'] = opt.master_addr os.environ['MASTER_PORT'] = str(opt.master_port) dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url, world_size=opt.world_size, rank=opt.rank) #========= Loading Dataset =========# data = torch.load(opt.data) opt.max_token_seq_len = data['settings'].max_token_seq_len training_data, validation_data = prepare_dataloaders( data, opt, opt.master_addr is not None) opt.src_vocab_size = training_data.dataset.src_vocab_size opt.tgt_vocab_size = training_data.dataset.tgt_vocab_size #========= Preparing Model =========# if opt.embs_share_weight: assert training_data.dataset.src_word2idx == training_data.dataset.tgt_word2idx, \ 'The src/tgt word2idx table are different but asked to share word embedding.' print(opt) device = torch.device('cuda' if opt.cuda else 'cpu') transformer = Transformer(opt.src_vocab_size, opt.tgt_vocab_size, opt.max_token_seq_len, tgt_emb_prj_weight_sharing=opt.proj_share_weight, emb_src_tgt_weight_sharing=opt.embs_share_weight, d_k=opt.d_k, d_v=opt.d_v, d_model=opt.d_model, d_word_vec=opt.d_word_vec, d_inner=opt.d_inner_hid, n_layers=opt.n_layers, n_head=opt.n_head, dropout=opt.dropout).to(device) if opt.distributed: transformer = DDP(transformer, device_ids=[opt.local_rank], output_device=opt.local_rank) if opt.enable_gavel_iterator: training_data = GavelIterator(training_data, opt.checkpoint_dir, load_checkpoint, save_checkpoint) optimizer = ScheduledOptim( optim.Adam(filter(lambda x: x.requires_grad, transformer.parameters()), betas=(0.9, 0.98), eps=1e-09), opt.d_model, opt.n_warmup_steps) train(transformer, training_data, validation_data, optimizer, device, opt)
def train(self, local_rank, train_dataset, val_dataset=None, lr=0.001, weight_decay=0, num_epochs=1, iters_per_epoch=None, batch_size=64, lr_milestones=None, negative_sampling=False, num_sampling_users=0, num_data_workers=0, model_checkpoint_prefix=None, checkpoint_freq=0, eval_freq=0, eval_num_recommendations=None, eval_num_users=None, metrics=None, eval_batch_size=None): """ Trains the model Args: train_dataset (RecommendationDataset): train dataset. val_dataset (RecommendationDataset, optional): validation dataset. lr (float, optional): learning rate. weight_decay (float, optional): weight decay (L2 normalization). num_epochs (int, optional): number of epochs to train the model. iters_per_epoch (int, optional): number of training iterations per training epoch. If None, one epoch is full number of training samples in the dataset batch_size (int, optional): batch size lr_milestones (list, optional): optimizer learning rate epochs milestones (0.1 decay). negative_sampling (bool, optional): whether to apply mini-batch based negative sampling or not. num_sampling_users (int, optional): number of users to consider for sampling items. This is useful for increasing the number of negative samples in mini-batch based negative sampling while keeping the batch-size small. If 0, then num_sampling_users will be equal to batch_size. num_data_workers (int, optional): number of data workers to use for building the mini-batches. checkpoint_freq (int, optional): epochs frequency of saving a checkpoint of the model model_checkpoint_prefix (str, optional): model checkpoint save path prefix eval_freq (int, optional): epochs frequency of doing an evaluation eval_num_recommendations (int, optional): num of recommendations to generate on evaluation eval_num_users (int, optional): number of users from the validation dataset to use for evaluation. If None, all users in the validation dataset are used for evaluation. metrics (list[Metric], optional): list of ``Metric`` used to evaluate the model eval_batch_size (int, optional): the size of the evaluation batch """ log.info('{} Mode'.format('CPU' if self.device.type == 'cpu' else 'GPU')) model_params = self.model.model_params() for param in model_params: log.info('Model {}: {}'.format(param, model_params[param])) log.info('Initial Learning Rate: {}'.format(lr)) log.info('Weight decay: {}'.format(weight_decay)) log.info('Batch Size: {}'.format(batch_size)) log.info('Optimizer: {}'.format(self.optimizer_type)) log.info('LR milestones: {}'.format(lr_milestones)) log.info('Loss Function: {}'.format(self.loss)) for param in self.loss_params: log.info('Loss {}: {}'.format(param, self.loss_params[param])) if num_sampling_users == 0: num_sampling_users = batch_size if eval_batch_size is None: eval_batch_size = batch_size assert num_sampling_users >= batch_size and num_sampling_users % batch_size == 0, \ "number of sampling users should be a multiple of the batch size" self.__init_training(train_dataset=train_dataset, lr=lr, weight_decay=weight_decay) train_dataloader = RecommendationDataLoader(train_dataset, batch_size=batch_size, negative_sampling=negative_sampling, num_sampling_users=num_sampling_users, num_workers=num_data_workers) if val_dataset is not None: val_dataloader = RecommendationDataLoader(val_dataset, batch_size=batch_size, negative_sampling=negative_sampling, num_sampling_users=num_sampling_users, num_workers=num_data_workers) else: val_dataloader = None if self._enable_gavel_iterator: train_dataloader = GavelIterator(train_dataloader, self._gavel_dir, self.load_checkpoint, self.save_checkpoint, synthetic_data=True) if os.path.exists(model_checkpoint_prefix): try: print('Loading checkpoint from %s...' % (model_checkpoint_prefix)) if self._enable_gavel_iterator: self.init_from_model_file(model_checkpoint_prefix, local_rank, train_dataloader) else: self.init_from_model_file(model_checkpoint_prefix, local_rank) except Exception as e: print('Could not load from checkpoint: %s' % (e)) else: print('Checkpoint does not exist at %s' % (model_checkpoint_prefix)) if lr_milestones is not None: _last_epoch = -1 if self.current_epoch == 1 else (self.current_epoch - 2) lr_scheduler = MultiStepLR(self.optimizer, milestones=lr_milestones, gamma=0.1, last_epoch=_last_epoch) else: lr_scheduler = None self._train(train_dataloader=train_dataloader, val_dataloader=val_dataloader, num_epochs=num_epochs, current_epoch=self.current_epoch, lr_scheduler=lr_scheduler, batch_size=batch_size, model_checkpoint_prefix=model_checkpoint_prefix, checkpoint_freq=checkpoint_freq, eval_freq=eval_freq, metrics=metrics, eval_num_recommendations=eval_num_recommendations, iters_per_epoch=iters_per_epoch, eval_num_users=eval_num_users, eval_batch_size=eval_batch_size) if self._enable_gavel_iterator: train_dataloader.complete() self.save_state(model_checkpoint_prefix, train_dataloader) else: self.save_state(model_checkpoint_prefix)
def main(): global args, best_acc1, total_minibatches, total_elapsed_time args = parser.parse_args() torch.cuda.set_device(args.local_rank) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') args.distributed = False if args.master_addr is not None: args.distributed = True os.environ['MASTER_ADDR'] = args.master_addr os.environ['MASTER_PORT'] = str(args.master_port) dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True) else: print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch]() model = model.cuda() if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) cudnn.benchmark = True # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.evaluate: validate(val_loader, model, criterion) return if args.enable_gavel_iterator: train_loader = GavelIterator(train_loader, args.checkpoint_dir, load_checkpoint, save_checkpoint) # Load from checkpoint. if not os.path.isdir(args.checkpoint_dir): os.mkdir(args.checkpoint_dir) checkpoint_path = os.path.join(args.checkpoint_dir, 'model.chkpt') if os.path.exists(checkpoint_path): if args.enable_gavel_iterator: checkpoint = train_loader.load_checkpoint(args, checkpoint_path) else: checkpoint = load_checkpoint(args, checkpoint_path) if checkpoint is not None: args.start_epoch = checkpoint['epoch'] # best_acc1 = checkpoint['best_acc1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(checkpoint_path, checkpoint['epoch'])) if args.num_minibatches is not None: args.epochs = math.ceil(float(args.num_minibatches) * args.batch_size / len(train_loader)) epoch = args.start_epoch for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch) # train for one epoch num_minibatches, elapsed_time, finished_epoch = \ train(train_loader, model, criterion, optimizer, epoch, total_minibatches, max_minibatches=args.num_minibatches, total_elapsed_time=total_elapsed_time, max_duration=args.max_duration) total_minibatches += num_minibatches total_elapsed_time += elapsed_time if args.enable_gavel_iterator and train_loader.done: break elif (args.num_minibatches is not None and total_minibatches >= args.num_minibatches): if args.enable_gavel_iterator: train_loader.complete() break elif(args.max_duration is not None and total_elapsed_time >= args.max_duration): if args.enable_gavel_iterator: train_loader.complete() break # evaluate on validation set # acc1 = validate(val_loader, model, criterion) # remember best acc@1 and save checkpoint #best_acc1 = max(acc1, best_acc1) if not args.distributed or args.rank == 0: state = { 'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), # 'best_acc1': best_acc1, 'optimizer' : optimizer.state_dict(), } if args.enable_gavel_iterator: train_loader.save_checkpoint(state, checkpoint_path) else: save_checkpoint(state, checkpoint_path)
return checkpoint except Exception as e: print('Error reading checkpoint: %s' % (e)) return None return None def save_checkpoint(checkpoint_path, state): print('==> Saving checkpoint at %s...' % (checkpoint_path)) torch.save(state, checkpoint_path) if args.checkpoint_dir is not None: checkpoint_path = os.path.join(args.checkpoint_dir, 'model.chkpt') if args.enable_gavel_iterator: trainloader = GavelIterator(trainloader, args.checkpoint_dir, load_checkpoint_func=load_checkpoint, save_checkpoint_func=save_checkpoint) checkpoint = trainloader.load_checkpoint(args, checkpoint_path) else: checkpoint = load_checkpoint(args, checkpoint_path) if checkpoint is not None: net.load_state_dict(checkpoint['net']) # best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) # Training def train(epoch, cumulative_steps=None, cumulative_time=None): print('\nEpoch: %d' % epoch)
def load_checkpoint(opt, checkpoint_path): try: print('Loading checkpoint from %s...' % (checkpoint_path)) checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(opt.local_rank)) return checkpoint except Exception as e: print('Could not load from checkpoint: %s' % (e)) return None def save_checkpoint(state, checkpoint_path): print('Saving checkpoint at %s...' % (checkpoint_path)) torch.save(state, checkpoint_path) if opt.enable_gavel_iterator: dataloader = GavelIterator(dataloader, opt.checkpoint_dir, load_checkpoint, save_checkpoint) checkpoint_path = os.path.join(opt.checkpoint_dir, "model.chkpt") checkpoint = None if os.path.exists(checkpoint_path): if opt.enable_gavel_iterator: checkpoint = dataloader.load_checkpoint(opt, checkpoint_path) else: checkpoint = load_checkpoint(opt, checkpoint_path) else: print('Could not load from checkpoint!') if checkpoint is not None: G_AB.load_state_dict(checkpoint['G_AB']) G_BA.load_state_dict(checkpoint['G_BA']) D_A.load_state_dict(checkpoint['D_A'])