Exemplo n.º 1
0
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # daita loader
    print('\n--- load dataset ---')
    dataset = dataset_unpair(opts)
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               num_workers=opts.nThreads)

    # model
    print('\n--- load model ---')
    model = UID(opts)
    model.setgpu(opts.gpu)
    if opts.resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 500000
    for ep in range(ep0, opts.n_ep):
        for it, (images_a, images_b) in enumerate(train_loader):
            if images_a.size(0) != opts.batch_size or images_b.size(
                    0) != opts.batch_size:
                continue
            images_a = images_a.cuda(opts.gpu).detach()
            images_b = images_b.cuda(opts.gpu).detach()

            # update model
            model.update_D(images_a, images_b)
            if (it + 1) % 2 != 0 and it != len(train_loader) - 1:
                continue
            model.update_EG()

            # save to display file
            if (it + 1) % 48 == 0:
                print('total_it: %d (ep %d, it %d), lr %08f' %
                      (total_it + 1, ep, it + 1,
                       model.gen_opt.param_groups[0]['lr']))
                print(
                    'Dis_I_loss: %04f, Dis_B_loss %04f, GAN_loss_I %04f, GAN_loss_B %04f'
                    % (model.disA_loss, model.disB_loss, model.gan_loss_i,
                       model.gan_loss_b))
                print('B_percp_loss %04f, Recon_II_loss %04f' %
                      (model.B_percp_loss, model.l1_recon_II_loss))
            if (it + 1) % 200 == 0:
                saver.write_img(ep * len(train_loader) + (it + 1), model)

            total_it += 1
            if total_it >= max_it:
                saver.write_img(-1, model)
                saver.write_model(-1, model)
                break

        # decay learning rate
        if opts.n_ep_decay > -1:
            model.update_lr()

        # Save network weights
        saver.write_model(ep, total_it + 1, model)

    return
Exemplo n.º 2
0
def main_worker(gpu, ngpus_per_node, args):
  ngpus_per_node = 1
  print(gpu, ngpus_per_node)
  #if args.multiprocessing_distributed and args.gpu != 0:
    #def print_pass(*args):
      #pass
    #builtins.print = print_pass
  ##if args.gpu is not None:
    #print("Use GPU: {} for training".format(args.gpu))
  args.dist_url = "tcp://127.0.0.1:2036"
  if args.distributed:
    if args.dist_url == "env://" and args.rank == -1:
      args.rank = int(os.environ["RANK"])
    if args.multiprocessing_distributed:
      #args.rank = args.rank * ngpus_per_node + gpu
      args.rank = gpu
    #args.rank = args.rank * ngpus_per_node + gpu
    #print("world size:", args.world_size)
    print("rank:", args.rank)
    #print("dist backend:", args.dist_backend)
    #print("dist url:", args.dist_url)
    print(args.world_size)
    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
  #print("define model")
  if args.mm == "c":
    m = E_content(args.input_dim_a, args.input_dim_b)
  else:
    m = E_attr()
  #print("building model")
  model = moco.builder.MoCo(
    m,
    args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp)
  #print(model)
  if args.distributed:
    if args.gpu is not None:
      torch.cuda.set_device(0) #original: args.gpu / gpu
      model.cuda(0) #args.gpu / gpu
      args.batch_size = int(args.batch_size / ngpus_per_node)
      args.workers = 8 #int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
      model = torch.nn.parallel.DistributedDataParallel(model, device_ids=args.gpu)
    else:
      model.cuda()
      model = torch.nn.parallel.DistributedDataParallel(model)
  elif args.gpu is not None:
    torch.cuda.set_device(args.gpu)
    model = model.cuda(args.gpu)
    raise NotImplementedError("Only DistributedDataParallel is supported.")
  else:
    raise NotImplementedError("Only DistributedDataParallel is supported.")
    
  criterion = nn.CrossEntropyLoss().cuda(args.gpu)
  optimizer = torch.optim.SGD(model.parameters(), args.lr,
              momentum=args.momentum,
              weight_decay=args.weight_decay)
  if args.resume:
    if os.path.isfile(args.resume):
      print("=> loading checkpoint '{}'".format(args.resume))
      if args.gpu is None:
        checkpoint = torch.load(args.resume)
      else:
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.resume, map_location=loc)
      args.start_epoch = checkpoint['epoch']
      model.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
    else:
      print("=> no checkpoint found at '{}'".format(args.resume))
  cudnn.benchmark = True

  # Data loading code
  print("loading dataset")
  train_dataset = dataset.dataset_unpair(args)
  print("loaded dataset")
  if args.distributed:
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
  else:
    train_sampler = None
  train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
    num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)

  for epoch in range(args.start_epoch, args.epochs):
    if args.distributed:
      train_sampler.set_epoch(epoch)
    adjust_learning_rate(optimizer, epoch, args)

    train(gpu, train_loader, model, criterion, optimizer, epoch, args)

    if not args.multiprocessing_distributed or (args.multiprocessing_distributed
      and args.rank % ngpus_per_node == 0):
      save_checkpoint(epoch, {
        'epoch': epoch + 1,
        #'arch': args.arch,
        'state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
      }, is_best=True, filename='checkpoint.pt')
Exemplo n.º 3
0
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # data loader
    print('\n--- load dataset ---')

    if opts.multi_modal:
        dataset = dataset_unpair_multi(opts)
    else:
        dataset = dataset_unpair(opts)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               num_workers=opts.nThreads)

    # model
    print('\n--- load model ---')
    model = DRIT(opts)
    model.setgpu(opts.gpu)
    if opts.resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 500000
    for ep in range(ep0, opts.n_ep):
        for it, (images_a, images_b) in enumerate(train_loader):
            if images_a.size(0) != opts.batch_size or images_b.size(
                    0) != opts.batch_size:
                continue

            # input data
            images_a = images_a.cuda(opts.gpu).detach()
            images_b = images_b.cuda(opts.gpu).detach()

            # update model
            if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2:
                model.update_D_content(images_a, images_b)
                continue
            else:
                model.update_D(images_a, images_b)
                model.update_EG()

            # save to display file

            if not opts.no_display_img and not opts.multi_modal:
                saver.write_display(total_it, model)

            print('total_it: %d (ep %d, it %d), lr %08f' %
                  (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
            total_it += 1
            if total_it >= max_it:
                # saver.write_img(-1, model)
                saver.write_model(-1, model)
                break

        # decay learning rate
        if opts.n_ep_decay > -1:
            model.update_lr()

        # save result image
        if not opts.multi_modal:
            saver.write_img(ep, model)

        # Save network weights
        saver.write_model(ep, total_it, model)

    return
Exemplo n.º 4
0
def main():

    debug_mode=False

    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # daita loader
    print('\n--- load dataset ---')
    dataset = dataset_unpair(opts)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=opts.batch_size, shuffle=True,
                                               num_workers=opts.nThreads)
    '''
        通过检查dataset_unpair,我们发现:
            图像是先缩放到256,256,然后再随机裁剪出216,216的patch,(测试时是从中心裁剪)
    '''

    # model
    print('\n--- load model ---')
    model = DRIT(opts)
    if not debug_mode:
        model.setgpu(opts.gpu)
    if opts.resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 500000
    for ep in range(ep0, opts.n_ep):
        '''
            images_a,images_b: 2,3,216,216
        '''
        for it, (images_a, images_b) in enumerate(train_loader):
            #   假如正好拿到了残次的剩余的一两个样本,就跳过,重新取样
            if images_a.size(0) != opts.batch_size or images_b.size(0) != opts.batch_size:
                continue

            # input data
            if not debug_mode:
                images_a = images_a.cuda(opts.gpu).detach() #   这里进行detach,可能是为了避免计算不需要的梯度,节省显存
                images_b = images_b.cuda(opts.gpu).detach()

            # update model 按照默认设置,1/3的iter更新内容判别器,2/3的iter更新D和EG
            if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2:
                model.update_D_content(images_a, images_b)
                continue
            else:
                model.update_D(images_a, images_b)
                model.update_EG()

            # save to display file
            if not opts.no_display_img:
                saver.write_display(total_it, model)

            print('total_it: %d (ep %d, it %d), lr %08f' % (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
            sys.stdout.flush()
            total_it += 1
            if total_it >= max_it:
                saver.write_img(-1, model)
                saver.write_model(-1, model)
                break

        # decay learning rate
        if opts.n_ep_decay > -1:
            model.update_lr()

        # save result image
        saver.write_img(ep, model)

        # Save network weights
        saver.write_model(ep, total_it, model)

    return