Exemplo n.º 1
0
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=False,
                                           sampler=train_sampler,
                                           drop_last=True)
val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=eval_batch_size,
                                         shuffle=False,
                                         drop_last=True)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=eval_batch_size,
                                          shuffle=False,
                                          drop_last=True)

if args.enable_gavel_iterator:
    train_loader = GavelIterator(train_loader, args.checkpoint_dir,
                                 load_checkpoint, save_checkpoint)

state = None
if args.checkpoint_dir is not None:
    if not os.path.isdir(args.checkpoint_dir):
        os.mkdir(args.checkpoint_dir)
    else:
        checkpoint_path = os.path.join(args.checkpoint_dir, 'model.chkpt')
        if os.path.exists(checkpoint_path):
            if args.enable_gavel_iterator:
                state = train_loader.load_checkpoint(args, checkpoint_path)
            else:
                state = load_checkpoint(args, checkpoint_path)
if state is not None:
    model = state['model'].to(device)
    if model is None:
Exemplo n.º 2
0
    mp.set_start_method('spawn')
    setup_json = read_config(args.env_config)
    env_conf = setup_json["Default"]
    for i in setup_json.keys():
        if i in args.env:
            env_conf = setup_json[i]
    env = atari_env(args.env, env_conf, args)
    shared_model = A3Clstm(env.observation_space.shape[0], env.action_space)

    iters = {}
    for rank in range(0, args.workers):
        iters[rank] = range(args.max_steps)
        if args.enable_gavel_iterator and rank == 0:
            iters[rank] = GavelIterator(iters[rank],
                                        args.checkpoint_dir,
                                        load_checkpoint,
                                        save_checkpoint,
                                        write_on_close=False)

    if not os.path.isdir(args.checkpoint_dir):
        os.mkdir(args.checkpoint_dir)
    checkpoint_path = os.path.join(args.checkpoint_dir, 'model.chkpt')
    if os.path.exists(checkpoint_path):
        if args.enable_gavel_iterator:
            saved_state = iters[0].load_checkpoint(args, checkpoint_path)
        else:
            saved_state = load_checkpoint(args, checkpoint_path)
        shared_model.load_state_dict(saved_state)
    shared_model.share_memory()

    if args.shared_optimizer:
Exemplo n.º 3
0
def main():
    ''' Main function '''
    parser = argparse.ArgumentParser()

    parser.add_argument('-data', required=True)

    parser.add_argument('-epoch', type=int, default=None)
    parser.add_argument('-step', type=int, default=None)
    parser.add_argument('-batch_size', type=int, default=64)

    #parser.add_argument('-d_word_vec', type=int, default=512)
    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-d_inner_hid', type=int, default=2048)
    parser.add_argument('-d_k', type=int, default=64)
    parser.add_argument('-d_v', type=int, default=64)

    parser.add_argument('-n_head', type=int, default=8)
    parser.add_argument('-n_layers', type=int, default=6)
    # NOTE(keshav2): This just refers to the learning rate schedule,
    #                nothing performance related.
    parser.add_argument('-n_warmup_steps', type=int, default=4000)

    parser.add_argument('-dropout', type=float, default=0.1)
    parser.add_argument('-embs_share_weight', action='store_true')
    parser.add_argument('-proj_share_weight', action='store_true')

    parser.add_argument('-log', default=None)
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='/lfs/1/keshav2/checkpoints/transformer')
    parser.add_argument('-save_mode',
                        type=str,
                        choices=['all', 'best'],
                        default='all')

    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-label_smoothing', action='store_true')

    parser.add_argument('--dist-url',
                        default='env://',
                        type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend',
                        default='nccl',
                        type=str,
                        help='Distributed backend')
    parser.add_argument('--local_rank', default=0, type=int, help='Local rank')
    parser.add_argument('--rank', default=None, type=int, help='Rank')
    parser.add_argument('--world_size',
                        default=None,
                        type=int,
                        help='World size')
    parser.add_argument('--master_addr',
                        default=None,
                        type=str,
                        help='Master address to use for distributed run')
    parser.add_argument('--master_port',
                        default=None,
                        type=int,
                        help='Master port to use for distributed run')

    parser.add_argument('--throughput_estimation_interval',
                        type=int,
                        default=None,
                        help='Steps between logging steps completed')
    parser.add_argument('--max_duration',
                        type=int,
                        default=None,
                        help='Maximum duration in seconds')
    parser.add_argument('--enable_gavel_iterator',
                        action='store_true',
                        default=False,
                        help='If set, use Gavel iterator')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    opt.d_word_vec = opt.d_model

    torch.cuda.set_device(opt.local_rank)

    if opt.epoch is not None and opt.step is not None:
        raise ValueError('Only one of epoch and step may be set')
    elif opt.epoch is None and opt.step is None:
        raise ValueError('One of epoch and step must be set')

    opt.distributed = False
    if opt.master_addr is not None:
        opt.distributed = True
        os.environ['MASTER_ADDR'] = opt.master_addr
        os.environ['MASTER_PORT'] = str(opt.master_port)
        dist.init_process_group(backend=opt.dist_backend,
                                init_method=opt.dist_url,
                                world_size=opt.world_size,
                                rank=opt.rank)

    #========= Loading Dataset =========#
    data = torch.load(opt.data)
    opt.max_token_seq_len = data['settings'].max_token_seq_len

    training_data, validation_data = prepare_dataloaders(
        data, opt, opt.master_addr is not None)

    opt.src_vocab_size = training_data.dataset.src_vocab_size
    opt.tgt_vocab_size = training_data.dataset.tgt_vocab_size

    #========= Preparing Model =========#
    if opt.embs_share_weight:
        assert training_data.dataset.src_word2idx == training_data.dataset.tgt_word2idx, \
            'The src/tgt word2idx table are different but asked to share word embedding.'

    print(opt)

    device = torch.device('cuda' if opt.cuda else 'cpu')
    transformer = Transformer(opt.src_vocab_size,
                              opt.tgt_vocab_size,
                              opt.max_token_seq_len,
                              tgt_emb_prj_weight_sharing=opt.proj_share_weight,
                              emb_src_tgt_weight_sharing=opt.embs_share_weight,
                              d_k=opt.d_k,
                              d_v=opt.d_v,
                              d_model=opt.d_model,
                              d_word_vec=opt.d_word_vec,
                              d_inner=opt.d_inner_hid,
                              n_layers=opt.n_layers,
                              n_head=opt.n_head,
                              dropout=opt.dropout).to(device)

    if opt.distributed:
        transformer = DDP(transformer,
                          device_ids=[opt.local_rank],
                          output_device=opt.local_rank)

    if opt.enable_gavel_iterator:
        training_data = GavelIterator(training_data, opt.checkpoint_dir,
                                      load_checkpoint, save_checkpoint)

    optimizer = ScheduledOptim(
        optim.Adam(filter(lambda x: x.requires_grad, transformer.parameters()),
                   betas=(0.9, 0.98),
                   eps=1e-09), opt.d_model, opt.n_warmup_steps)

    train(transformer, training_data, validation_data, optimizer, device, opt)
Exemplo n.º 4
0
  def train(self, local_rank, train_dataset, val_dataset=None,
            lr=0.001, weight_decay=0, num_epochs=1,
            iters_per_epoch=None, batch_size=64, lr_milestones=None,
            negative_sampling=False, num_sampling_users=0, num_data_workers=0,
            model_checkpoint_prefix=None, checkpoint_freq=0,
            eval_freq=0, eval_num_recommendations=None,
            eval_num_users=None, metrics=None, eval_batch_size=None):
    """
    Trains the model

    Args:
      train_dataset (RecommendationDataset): train dataset.
      val_dataset (RecommendationDataset, optional): validation dataset.
      lr (float, optional): learning rate.
      weight_decay (float, optional): weight decay (L2 normalization).
      num_epochs (int, optional): number of epochs to train the model.
      iters_per_epoch (int, optional): number of training iterations per training epoch. If None,
        one epoch is full number of training samples in the dataset
      batch_size (int, optional): batch size
      lr_milestones (list, optional): optimizer learning rate epochs milestones (0.1 decay).
      negative_sampling (bool, optional): whether to apply mini-batch based negative sampling or not.
      num_sampling_users (int, optional): number of users to consider for sampling items.
        This is useful for increasing the number of negative samples in mini-batch based negative
        sampling while keeping the batch-size small. If 0, then num_sampling_users will
        be equal to batch_size.
      num_data_workers (int, optional): number of data workers to use for building the mini-batches.
      checkpoint_freq (int, optional): epochs frequency of saving a checkpoint of the model
      model_checkpoint_prefix (str, optional): model checkpoint save path prefix
      eval_freq (int, optional): epochs frequency of doing an evaluation
      eval_num_recommendations (int, optional): num of recommendations to generate on evaluation
      eval_num_users (int, optional): number of users from the validation dataset to use for evaluation.
        If None, all users in the validation dataset are used for evaluation.
      metrics (list[Metric], optional): list of ``Metric`` used to evaluate the model
      eval_batch_size (int, optional): the size of the evaluation batch
    """
    log.info('{} Mode'.format('CPU' if self.device.type == 'cpu' else 'GPU'))
    model_params = self.model.model_params()
    for param in model_params:
      log.info('Model {}: {}'.format(param, model_params[param]))
    log.info('Initial Learning Rate: {}'.format(lr))
    log.info('Weight decay: {}'.format(weight_decay))
    log.info('Batch Size: {}'.format(batch_size))
    log.info('Optimizer: {}'.format(self.optimizer_type))
    log.info('LR milestones: {}'.format(lr_milestones))
    log.info('Loss Function: {}'.format(self.loss))
    for param in self.loss_params:
      log.info('Loss {}: {}'.format(param, self.loss_params[param]))

    if num_sampling_users == 0:
      num_sampling_users = batch_size

    if eval_batch_size is None:
      eval_batch_size = batch_size

    assert num_sampling_users >= batch_size and num_sampling_users % batch_size == 0, \
      "number of sampling users should be a multiple of the batch size"

    self.__init_training(train_dataset=train_dataset, lr=lr, weight_decay=weight_decay)

    train_dataloader = RecommendationDataLoader(train_dataset, batch_size=batch_size,
                                                negative_sampling=negative_sampling,
                                                num_sampling_users=num_sampling_users,
                                                num_workers=num_data_workers)
    if val_dataset is not None:
      val_dataloader = RecommendationDataLoader(val_dataset, batch_size=batch_size,
                                                negative_sampling=negative_sampling,
                                                num_sampling_users=num_sampling_users,
                                                num_workers=num_data_workers)
    else:
      val_dataloader = None

    if self._enable_gavel_iterator:
      train_dataloader = GavelIterator(train_dataloader, self._gavel_dir,
                                       self.load_checkpoint, self.save_checkpoint,
                                       synthetic_data=True)

    if os.path.exists(model_checkpoint_prefix):
      try:
        print('Loading checkpoint from %s...' % (model_checkpoint_prefix))
        if self._enable_gavel_iterator:
          self.init_from_model_file(model_checkpoint_prefix, local_rank,
                                    train_dataloader)
        else:
          self.init_from_model_file(model_checkpoint_prefix, local_rank)
      except Exception as e:
        print('Could not load from checkpoint: %s' % (e))
    else:
      print('Checkpoint does not exist at %s' % (model_checkpoint_prefix))

    if lr_milestones is not None:
      _last_epoch = -1 if self.current_epoch == 1 else (self.current_epoch - 2)
      lr_scheduler = MultiStepLR(self.optimizer, milestones=lr_milestones,
                                 gamma=0.1, last_epoch=_last_epoch)
    else:
      lr_scheduler = None

    self._train(train_dataloader=train_dataloader,
                val_dataloader=val_dataloader,
                num_epochs=num_epochs,
                current_epoch=self.current_epoch,
                lr_scheduler=lr_scheduler,
                batch_size=batch_size,
                model_checkpoint_prefix=model_checkpoint_prefix,
                checkpoint_freq=checkpoint_freq,
                eval_freq=eval_freq,
                metrics=metrics,
                eval_num_recommendations=eval_num_recommendations,
                iters_per_epoch=iters_per_epoch,
                eval_num_users=eval_num_users,
                eval_batch_size=eval_batch_size)

    if self._enable_gavel_iterator:
      train_dataloader.complete()
      self.save_state(model_checkpoint_prefix, train_dataloader)
    else:
      self.save_state(model_checkpoint_prefix)
Exemplo n.º 5
0
Arquivo: main.py Projeto: xieydd/gavel
def main():
    global args, best_acc1, total_minibatches, total_elapsed_time
    args = parser.parse_args()
    torch.cuda.set_device(args.local_rank)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    args.distributed = False
    if args.master_addr is not None:
        args.distributed = True
        os.environ['MASTER_ADDR'] = args.master_addr
        os.environ['MASTER_PORT'] = str(args.master_port)
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model = model.cuda()
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.local_rank],
                output_device=args.local_rank)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    if args.enable_gavel_iterator:
        train_loader = GavelIterator(train_loader, args.checkpoint_dir,
                                     load_checkpoint, save_checkpoint)

    # Load from checkpoint.
    if not os.path.isdir(args.checkpoint_dir):
        os.mkdir(args.checkpoint_dir)
    checkpoint_path = os.path.join(args.checkpoint_dir, 'model.chkpt')
    if os.path.exists(checkpoint_path):
        if args.enable_gavel_iterator:
            checkpoint = train_loader.load_checkpoint(args, checkpoint_path)
        else:
            checkpoint = load_checkpoint(args, checkpoint_path)

        if checkpoint is not None:
            args.start_epoch = checkpoint['epoch']
            # best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(checkpoint_path, checkpoint['epoch']))

    if args.num_minibatches is not None:
        args.epochs = math.ceil(float(args.num_minibatches) *
                                args.batch_size / len(train_loader))

    epoch = args.start_epoch
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        num_minibatches, elapsed_time, finished_epoch = \
            train(train_loader, model, criterion, optimizer,
                  epoch, total_minibatches,
                  max_minibatches=args.num_minibatches,
                  total_elapsed_time=total_elapsed_time,
                  max_duration=args.max_duration)
        total_minibatches += num_minibatches
        total_elapsed_time += elapsed_time

        if args.enable_gavel_iterator and train_loader.done:
            break
        elif (args.num_minibatches is not None and
            total_minibatches >= args.num_minibatches):
            if args.enable_gavel_iterator:
                train_loader.complete()
            break
        elif(args.max_duration is not None and
             total_elapsed_time >= args.max_duration):
            if args.enable_gavel_iterator:
                train_loader.complete()
            break

        # evaluate on validation set
        # acc1 = validate(val_loader, model, criterion)

        # remember best acc@1 and save checkpoint
        #best_acc1 = max(acc1, best_acc1)
    if not args.distributed or args.rank == 0:
        state = {
            'epoch': epoch,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            # 'best_acc1': best_acc1,
            'optimizer' : optimizer.state_dict(),
        }
        if args.enable_gavel_iterator:
            train_loader.save_checkpoint(state, checkpoint_path)
        else:
            save_checkpoint(state, checkpoint_path)
Exemplo n.º 6
0
            return checkpoint
        except Exception as e:
            print('Error reading checkpoint: %s' % (e))
            return None
    return None

def save_checkpoint(checkpoint_path, state):
    print('==> Saving checkpoint at %s...' % (checkpoint_path))
    torch.save(state, checkpoint_path)

if args.checkpoint_dir is not None:
    checkpoint_path = os.path.join(args.checkpoint_dir, 'model.chkpt')

if args.enable_gavel_iterator:
    trainloader = GavelIterator(trainloader, args.checkpoint_dir,
                                load_checkpoint_func=load_checkpoint,
                                save_checkpoint_func=save_checkpoint)
    checkpoint = trainloader.load_checkpoint(args, checkpoint_path)
else:
    checkpoint = load_checkpoint(args, checkpoint_path)
if checkpoint is not None:
    net.load_state_dict(checkpoint['net'])
    # best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

# Training
def train(epoch, cumulative_steps=None, cumulative_time=None):
    print('\nEpoch: %d' % epoch)
Exemplo n.º 7
0
def load_checkpoint(opt, checkpoint_path):
    try:
        print('Loading checkpoint from %s...' % (checkpoint_path))
        checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(opt.local_rank))
        return checkpoint
    except Exception as e:
        print('Could not load from checkpoint: %s' % (e))
        return None

def save_checkpoint(state, checkpoint_path):
    print('Saving checkpoint at %s...' % (checkpoint_path))
    torch.save(state, checkpoint_path)

if opt.enable_gavel_iterator:
    dataloader = GavelIterator(dataloader, opt.checkpoint_dir, load_checkpoint, save_checkpoint)

checkpoint_path = os.path.join(opt.checkpoint_dir, "model.chkpt")
checkpoint = None
if os.path.exists(checkpoint_path):
    if opt.enable_gavel_iterator:
        checkpoint = dataloader.load_checkpoint(opt, checkpoint_path)
    else:
        checkpoint = load_checkpoint(opt, checkpoint_path)
else:
    print('Could not load from checkpoint!')

if checkpoint is not None:
    G_AB.load_state_dict(checkpoint['G_AB'])
    G_BA.load_state_dict(checkpoint['G_BA'])
    D_A.load_state_dict(checkpoint['D_A'])