def create_model_and_optimizer(opt, texts):
    """Builds the model and related optimizer."""

    print('Creating model and optimizer for', opt.model)

    if opt.model == 'imgonly':
        model = img_text_composition_models.SimpleModelImageOnly(
            texts, embed_dim=opt.embed_dim)
    elif opt.model == 'textonly':
        model = img_text_composition_models.SimpleModelTextOnly(
            texts, embed_dim=opt.embed_dim)
    elif opt.model == 'concat':
        model = img_text_composition_models.Concat(texts,
                                                   embed_dim=opt.embed_dim)
    elif opt.model == 'tirg':
        model = img_text_composition_models.TIRG(texts,
                                                 embed_dim=opt.embed_dim)
    elif opt.model == 'tirg_lastconv':
        model = img_text_composition_models.TIRGLastConv(
            texts, embed_dim=opt.embed_dim)
    else:
        print('Invalid model', opt.model)
        print('available: imgonly, textonly, concat, tirg or tirg_lastconv')
        sys.exit()

    if torch.cuda.is_available():
        model = model.cuda()

    # create optimizer
    params = []
    # low learning rate for pretrained layers on real image datasets
    if opt.dataset != 'css3d':
        params.append({
            'params': [p for p in model.img_model.fc.parameters()],
            'lr': opt.learning_rate
        })
        params.append({
            'params': [p for p in model.img_model.parameters()],
            'lr': 0.1 * opt.learning_rate
        })
    params.append({'params': [p for p in model.parameters()]})
    for _, p1 in enumerate(params):  # remove duplicated params
        for _, p2 in enumerate(params):
            if p1 is not p2:
                for p11 in p1['params']:
                    for j, p22 in enumerate(p2['params']):
                        if p11 is p22:
                            p2['params'][j] = torch.tensor(0.0,
                                                           requires_grad=True)
    optimizer = torch.optim.SGD(params,
                                lr=opt.learning_rate,
                                momentum=0.9,
                                weight_decay=opt.weight_decay)
    return model, optimizer
Ejemplo n.º 2
0
def create_model_and_optimizer(opt, texts):
  """Builds the model and related optimizer."""
  print('Creating model and optimizer for', opt.model)
  if opt.model == 'imgonly':
    model = img_text_composition_models.SimpleModelImageOnly(
        texts, opt)
  elif opt.model == 'textonly':
    model = img_text_composition_models.SimpleModelTextOnly(
        texts, opt)
  elif opt.model == 'add':
    model = img_text_composition_models.Add(texts, opt)
  elif opt.model == 'concat':
    model = img_text_composition_models.Concat(texts, opt)
  elif opt.model == 'tirg':
    model = img_text_composition_models.TIRG(texts, opt)
  elif opt.model == 'tirg_lastconv':
    model = img_text_composition_models.TIRGLastConv(
        texts, opt)
  else:
    print('Invalid model', opt.model)
    print('available: imgonly, textonly, add, concat, tirg, tirg_lastconv')
    sys.exit()
  model = model.cuda()
  
  # create optimizer
  params = []
  per_params = []
  for name, param in model.named_parameters():
    per_params.append(param)
  params.append({'params': per_params})
  if opt.optimizer == 'SGD':
    optimizer = torch.optim.SGD(
        params, lr=opt.learning_rate, momentum=0.9, weight_decay=opt.weight_decay)
  elif opt.optimizer == 'Adam':
    optimizer = torch.optim.Adam(params, lr=opt.learning_rate)
  return model, optimizer