def create_model_and_optimizer(opt, texts): """Builds the model and related optimizer.""" print('Creating model and optimizer for', opt.model) if opt.model == 'imgonly': model = img_text_composition_models.SimpleModelImageOnly( texts, embed_dim=opt.embed_dim) elif opt.model == 'textonly': model = img_text_composition_models.SimpleModelTextOnly( texts, embed_dim=opt.embed_dim) elif opt.model == 'concat': model = img_text_composition_models.Concat(texts, embed_dim=opt.embed_dim) elif opt.model == 'tirg': model = img_text_composition_models.TIRG(texts, embed_dim=opt.embed_dim) elif opt.model == 'tirg_lastconv': model = img_text_composition_models.TIRGLastConv( texts, embed_dim=opt.embed_dim) else: print('Invalid model', opt.model) print('available: imgonly, textonly, concat, tirg or tirg_lastconv') sys.exit() if torch.cuda.is_available(): model = model.cuda() # create optimizer params = [] # low learning rate for pretrained layers on real image datasets if opt.dataset != 'css3d': params.append({ 'params': [p for p in model.img_model.fc.parameters()], 'lr': opt.learning_rate }) params.append({ 'params': [p for p in model.img_model.parameters()], 'lr': 0.1 * opt.learning_rate }) params.append({'params': [p for p in model.parameters()]}) for _, p1 in enumerate(params): # remove duplicated params for _, p2 in enumerate(params): if p1 is not p2: for p11 in p1['params']: for j, p22 in enumerate(p2['params']): if p11 is p22: p2['params'][j] = torch.tensor(0.0, requires_grad=True) optimizer = torch.optim.SGD(params, lr=opt.learning_rate, momentum=0.9, weight_decay=opt.weight_decay) return model, optimizer
def create_model_and_optimizer(opt, texts): """Builds the model and related optimizer.""" print('Creating model and optimizer for', opt.model) if opt.model == 'imgonly': model = img_text_composition_models.SimpleModelImageOnly( texts, opt) elif opt.model == 'textonly': model = img_text_composition_models.SimpleModelTextOnly( texts, opt) elif opt.model == 'add': model = img_text_composition_models.Add(texts, opt) elif opt.model == 'concat': model = img_text_composition_models.Concat(texts, opt) elif opt.model == 'tirg': model = img_text_composition_models.TIRG(texts, opt) elif opt.model == 'tirg_lastconv': model = img_text_composition_models.TIRGLastConv( texts, opt) else: print('Invalid model', opt.model) print('available: imgonly, textonly, add, concat, tirg, tirg_lastconv') sys.exit() model = model.cuda() # create optimizer params = [] per_params = [] for name, param in model.named_parameters(): per_params.append(param) params.append({'params': per_params}) if opt.optimizer == 'SGD': optimizer = torch.optim.SGD( params, lr=opt.learning_rate, momentum=0.9, weight_decay=opt.weight_decay) elif opt.optimizer == 'Adam': optimizer = torch.optim.Adam(params, lr=opt.learning_rate) return model, optimizer