pred_ids = outputs.data.max(1)[1] valid_corrects += pred_ids.eq(targets.data).cpu().numpy().tolist() if opt.debug: break valid_acc = sum(valid_corrects) / float(len(valid_corrects)) valid_loss = sum(valid_loss) / float(len(valid_corrects)) return valid_acc, valid_loss if __name__ == "__main__": torch.manual_seed(2018) opt = BaseOptions().parse() writer = SummaryWriter(opt.results_dir) opt.writer = writer dset = TVQADataset(opt) opt.vocab_size = len(dset.word2idx) model = ABC(opt) if not opt.no_glove: model.load_embedding(dset.vocab_embedding) model.to(opt.device) cudnn.benchmark = True criterion = nn.CrossEntropyLoss(size_average=False).to(opt.device) optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.lr, weight_decay=opt.wd)
def main(): opt = BaseOptions().parse() torch.manual_seed(opt.seed) cudnn.benchmark = False cudnn.deterministic = True np.random.seed(opt.seed) writer = SummaryWriter(opt.results_dir) opt.writer = writer dset = TVQADataset(opt) opt.vocab_size = len(dset.word2idx) model = STAGE(opt) count_parameters(model) if opt.device.type == "cuda": print("CUDA enabled.") model.to(opt.device) if len(opt.device_ids) > 1: print("Use multi GPU", opt.device_ids) model = torch.nn.DataParallel( model, device_ids=opt.device_ids) # use multi GPU criterion = nn.CrossEntropyLoss(reduction="sum").to(opt.device) optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.lr, weight_decay=opt.wd) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=10, verbose=True) best_acc = 0. start_epoch = 0 early_stopping_cnt = 0 early_stopping_flag = False for epoch in range(start_epoch, opt.n_epoch): if not early_stopping_flag: use_hard_negatives = epoch + 1 > opt.hard_negative_start # whether to use hard negative sampling niter = epoch * np.ceil(len(dset) / float(opt.bsz)) opt.writer.add_scalar("learning_rate", float(optimizer.param_groups[0]["lr"]), niter) cur_acc = train(opt, dset, model, criterion, optimizer, epoch, best_acc, use_hard_negatives=use_hard_negatives) scheduler.step(cur_acc) # decrease lr when acc is not improving # remember best acc is_best = cur_acc > best_acc best_acc = max(cur_acc, best_acc) if not is_best: early_stopping_cnt += 1 if early_stopping_cnt >= opt.max_es_cnt: early_stopping_flag = True else: early_stopping_cnt = 0 else: print("=> early stop with valid acc %.4f" % best_acc) opt.writer.export_scalars_to_json( os.path.join(opt.results_dir, "all_scalars.json")) opt.writer.close() break # early stop break if opt.debug: break return opt.results_dir.split("/")[1], opt.debug