def train(args, controller, task, epoch_itr): # #revise-task 7 """Train the model for one epoch.""" # Update parameters every N batches, CORE scaling method update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf loop = enumerate(progress, start=epoch_itr.iterations_in_epoch) for i, samples in loop: log_output = controller.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(controller) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k or k == 'accuracy': extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second and updates-per-second calculation if i == 0: controller.get_meter('wps').reset() controller.get_meter('ups').reset() num_updates = controller.get_num_updates() if num_updates >= max_update: break
def init_meters(self, args): self.meters = OrderedDict() self.meters['train_loss'] = AverageMeter() self.meters['train_nll_loss'] = AverageMeter() self.meters['valid_loss'] = AverageMeter() self.meters['valid_nll_loss'] = AverageMeter() self.meters['wps'] = TimeMeter() # words per second self.meters['ups'] = TimeMeter() # updates per second self.meters['wpb'] = AverageMeter() # words per batch self.meters['bsz'] = AverageMeter() # sentences per batch self.meters['gnorm'] = AverageMeter() # gradient norm self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['wall'] = TimeMeter() # wall time in seconds self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
def main(): use_cuda = args.use_cuda train_data = UnlabeledContact(data=args.data_dir) print('Number of samples: {}'.format(len(train_data))) trainloader = DataLoader(train_data, batch_size=args.batch_size) # Contact matrices are 21x21 input_size = 441 img_height = 21 img_width = 21 vae = AutoEncoder(code_size=20, imgsize=input_size, height=img_height, width=img_width) criterion = nn.BCEWithLogitsLoss() if use_cuda: #vae = nn.DataParallel(vae) vae = vae.cuda() #.half() criterion = criterion.cuda() optimizer = optim.SGD(vae.parameters(), lr=0.01) clock = AverageMeter(name='clock32single', rank=0) epoch_loss = 0 total_loss = 0 end = time.time() for epoch in range(15): for batch_idx, data in enumerate(trainloader): inputs = data['cont_matrix'] inputs = inputs.resize_(args.batch_size, 1, 21, 21) inputs = inputs.float() if use_cuda: inputs = inputs.cuda() #.half() inputs = Variable(inputs) optimizer.zero_grad() output, code = vae(inputs) loss = criterion(output, inputs) loss.backward() optimizer.step() epoch_loss += loss.data[0] clock.update(time.time() - end) end = time.time() if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(trainloader.dataset), 100. * batch_idx / len(trainloader), loss.data[0])) clock.save( path= '/home/ygx/libraries/mds/molecules/molecules/conv_autoencoder/runtimes' )
def main(): use_cuda = args.use_cuda train_data = UnlabeledContact(data=args.data_dir) print('Number of samples: {}'.format(len(train_data))) trainloader = DataLoader(train_data, batch_size=args.batch_size) # Contact matrices are 21x21 input_size = 441 encoder = Encoder(input_size=input_size, latent_size=3) decoder = Decoder(latent_size=3, output_size=input_size) vae = VAE(encoder, decoder, use_cuda=use_cuda) criterion = nn.MSELoss() if use_cuda: encoder = nn.DataParallel(encoder) decoder = nn.DataParallel(decoder) encoder = encoder.cuda().half() decoder = decoder.cuda().half() vae = nn.DataParallel(vae) vae = vae.cuda().half() criterion = criterion.cuda().half() optimizer = optim.SGD(vae.parameters(), lr=0.01) clock = AverageMeter(name='clock16', rank=0) epoch_loss = 0 total_loss = 0 end = time.time() for epoch in range(15): for batch_idx, data in enumerate(trainloader): inputs = data['cont_matrix'] # inputs = inputs.resize_(args.batch_size, 1, 21, 21) inputs = inputs.float() if use_cuda: inputs = inputs.cuda().half() inputs = Variable(inputs) optimizer.zero_grad() dec = vae(inputs) ll = latent_loss(vae.z_mean, vae.z_sigma) loss = criterion(dec, inputs) + ll loss.backward() optimizer.step() epoch_loss += loss.data[0] clock.update(time.time() - end) end = time.time() if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(trainloader.dataset), 100. * batch_idx / len(trainloader), loss.data[0])) clock.save(path='/home/ygx/libraries/mds/molecules/molecules/linear_vae')
def __init__(self, args, model, criterion, optimizer=None, ae_criterion=None): self.args = args # copy model and criterion on current device self.model = model.to(self.args.device) self.criterion = criterion.to(self.args.device) self.ae_criterion = ae_criterion.to(self.args.device) # initialize meters self.meters = OrderedDict() self.meters['train_loss'] = AverageMeter() self.meters['train_nll_loss'] = AverageMeter() self.meters['valid_loss'] = AverageMeter() self.meters['valid_nll_loss'] = AverageMeter() self.meters['wps'] = TimeMeter() # words per second self.meters['ups'] = TimeMeter() # updates per second self.meters['wpb'] = AverageMeter() # words per batch self.meters['bsz'] = AverageMeter() # sentences per batch self.meters['gnorm'] = AverageMeter() # gradient norm self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['oom'] = AverageMeter() # out of memory self.meters['wall'] = TimeMeter() # wall time in seconds self._buffered_stats = defaultdict(lambda: []) self._flat_grads = None self._num_updates = 0 self._optim_history = None self._optimizer = None if optimizer is not None: self._optimizer = optimizer self.total_loss = 0.0 self.train_score = 0.0 self.total_norm = 0.0 self.count_norm = 0.0
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) else: dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 2 # 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 2 # 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Generator loaded successfully!") discriminator = Discriminator(args.decoder_embed_dim, args.discriminator_hidden_size, args.discriminator_linear_size, args.discriminator_lin_dropout, use_cuda=use_cuda) print("Discriminator loaded successfully!") def _calcualte_discriminator_loss(tf_scores, ar_scores): tf_loss = torch.log(tf_scores + 1e-6) * (-1) ar_loss = torch.log(1 - ar_scores + 1e-6) * (-1) return tf_loss + ar_loss if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/professorjp'): os.makedirs('checkpoints/professorjp') checkpoints_path = 'checkpoints/professorjp/' # define loss function g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum') # d_criterion = torch.nn.BCELoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator word embedding (as Wu et al. do) # for p in discriminator.embed_src_tokens.parameters(): # p.requires_grad = False # for p in discriminator.embed_trg_tokens.parameters(): # p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset trainloader = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() # set training mode generator.train() discriminator.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) for i, sample in enumerate(trainloader): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # print("Policy Gradient Training") sys_out_batch_PG, p_PG, hidden_list_PG = generator( 'PG', epoch_i, sample) # 64 X 50 X 6632 out_batch_PG = sys_out_batch_PG.contiguous().view( -1, sys_out_batch_PG.size(-1)) # (64 * 50) X 6632 _, prediction = out_batch_PG.topk(1) # prediction = prediction.squeeze(1) # 64*50 = 3200 # prediction = torch.reshape(prediction, sample['net_input']['src_tokens'].shape) # 64 X 50 with torch.no_grad(): reward = discriminator(hidden_list_PG) # 64 X 1 train_trg_batch_PG = sample['target'] # 64 x 50 pg_loss_PG = pg_criterion(sys_out_batch_PG, train_trg_batch_PG, reward, use_cuda) sample_size_PG = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] # 64 logging_loss_PG = pg_loss_PG / math.log(2) g_logging_meters['train_loss'].update(logging_loss_PG.item(), sample_size_PG) logging.debug( f"G policy gradient loss at batch {i}: {pg_loss_PG.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() pg_loss_PG.backward(retain_graph=True) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() # print("MLE Training") sys_out_batch_MLE, p_MLE, hidden_list_MLE = generator( "MLE", epoch_i, sample) out_batch_MLE = sys_out_batch_MLE.contiguous().view( -1, sys_out_batch_MLE.size(-1)) # (64 X 50) X 6632 train_trg_batch_MLE = sample['target'].view(-1) # 64*50 = 3200 loss_MLE = g_criterion(out_batch_MLE, train_trg_batch_MLE) sample_size_MLE = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss_MLE = loss_MLE.data / sample_size_MLE / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss_MLE, sample_size_MLE) logging.debug( f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) # g_optimizer.zero_grad() loss_MLE.backward(retain_graph=True) # all-reduce grads and rescale by grad_denom for p in generator.parameters(): # print(p.size()) if p.requires_grad: p.grad.data.div_(sample_size_MLE) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator d_MLE = discriminator(hidden_list_MLE) d_PG = discriminator(hidden_list_PG) d_loss = _calcualte_discriminator_loss(d_MLE, d_PG).sum() logging.debug(f"D training loss {d_loss} at batch {i}") d_optimizer.zero_grad() d_loss.backward() torch.nn.utils.clip_grad_norm_(discriminator.parameters(), args.clip_norm) d_optimizer.step() # validation # set validation mode generator.eval() discriminator.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) valloader = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(valloader): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) # generator validation sys_out_batch_test, p_test, hidden_list_test = generator( 'test', epoch_i, sample) out_batch_test = sys_out_batch_test.contiguous().view( -1, sys_out_batch_test.size(-1)) # (64 X 50) X 6632 dev_trg_batch = sample['target'].view(-1) # 64*50 = 3200 loss_test = g_criterion(out_batch_test, dev_trg_batch) sample_size_test = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss_test = loss_test / sample_size_test / math.log(2) g_logging_meters['valid_loss'].update(loss_test, sample_size_test) logging.debug( f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}" ) # # discriminator validation # bsz = sample['target'].size(0) # src_sentence = sample['net_input']['src_tokens'] # # train with half human-translation and half machine translation # # true_sentence = sample['target'] # true_labels = torch.ones(sample['target'].size(0)).float() # # with torch.no_grad(): # sys_out_batch_PG, p, hidden_list = generator('test', epoch_i, sample) # # out_batch = sys_out_batch_PG.contiguous().view(-1, sys_out_batch_PG.size(-1)) # (64 X 50) X 6632 # # _, prediction = out_batch.topk(1) # prediction = prediction.squeeze(1) # 64 * 50 = 6632 # # fake_labels = torch.zeros(sample['target'].size(0)).float() # # fake_sentence = torch.reshape(prediction, src_sentence.shape) # 64 X 50 # # if use_cuda: # fake_labels = fake_labels.cuda() # # disc_out = discriminator(src_sentence, fake_sentence) # d_loss = d_criterion(disc_out.squeeze(1), fake_labels) # acc = torch.sum(torch.round(disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) # d_logging_meters['valid_acc'].update(acc) # d_logging_meters['valid_loss'].update(d_loss) # logging.debug( # f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}") torch.save( generator, open( checkpoints_path + f"sampling_{g_logging_meters['valid_loss'].avg:.3f}.epoch_{epoch_i}.pt", 'wb'), pickle_module=dill) if g_logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = g_logging_meters['valid_loss'].avg torch.save(generator, open(checkpoints_path + "best_gmodel.pt", 'wb'), pickle_module=dill)
def create_meters(self): g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['train_bleu'] = AverageMeter() d_logging_meters['valid_bleu'] = AverageMeter() d_logging_meters['train_rouge'] = AverageMeter() d_logging_meters['valid_rouge'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch self.g_logging_meters = g_logging_meters self.d_logging_meters = d_logging_meters
writer.add_scalar(tag="lr", scalar_value=optimiser.param_groups[0]["lr"], global_step=e_i) for head_i in range(2): head = heads[head_i] if head == "A": dataloaders = dataloaders_head_A epoch_loss = config.epoch_loss_head_A epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_A elif head == "B": dataloaders = dataloaders_head_B epoch_loss = config.epoch_loss_head_B epoch_loss_no_lamb = config.epoch_loss_no_lamb_head_B else: raise NotImplemented(head) avg_loss_meter = AverageMeter("avg_loss") mi_meter = AverageMeter("standard_mi") for head_i_epoch in range(head_epochs[head]): sys.stdout.flush() with tqdm(enumerate(zip(*dataloaders))) as indicator: indicator.set_description(f"Head:{head}") for b_i, tup in indicator: optimiser.zero_grad() # one less because this is before sobel with autocast: data, label = zip(*tup) all_imgs = torch.cat([data[0] for _ in range(len(data) - 1)]).cuda() all_imgs_tf = torch.cat([data[i] for i in range(1, len(data))]).cuda()
def run_one_epoch(self, training): tic = time.time() batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() if training: amnt = self.num_train dataset = self.train_loader else: dataset = self.val_loader amnt = self.num_valid with tqdm(total=amnt) as pbar: for i, data in enumerate(dataset): x, y = data # segmentation task if self.classification: # assuming one-hot y = y.view(1, -1).expand(self.model.num_heads, -1) else: y = y.view(1, -1, 1, x.shape[-2], x.shape[-1]).expand(self.model.num_heads, -1, -1, -1, -1) if self.config.use_gpu: x, y = x.cuda(), y.cuda() output = self.model(x) if training: self.optimizer.zero_grad() loss = None for head in range(self.model.num_heads): if loss is None: loss = self.criterion(output[head], y[head]) else: loss = loss + self.criterion(output[head], y[head]) loss = loss / self.model.num_heads if training: loss.backward() self.optimizer.step() try: loss_data = loss.data[0] except IndexError: loss_data = loss.data.item() losses.update(loss_data) # measure elapsed time toc = time.time() batch_time.update(toc - tic) if self.classification: _, predicted = torch.max(output.data, -1) total = self.batch_size*self.model.num_heads correct = (predicted == y).sum().item() acc = correct/total accs.update(acc) pbar.set_description(f"{(toc - tic):.1f}s - loss: {loss_data:.3f} acc {accs.avg:.3f}") else: pbar.set_description(f"{(toc - tic):.1f}s - loss: {loss_data:.3f}") pbar.update(self.batch_size) if training and i % 2 == 0: self.model.log_illumination(self.curr_epoch, i) if not training and i == 0 and not self.classification: y_sample = y[0, 0].view(256, 256).detach().cpu().numpy() p_sample = output[0, 0].view(256, 256).detach().cpu().numpy() wandb.log({f"images_epoch{self.curr_epoch}": [ wandb.Image(np.round(p_sample * 255), caption="prediction"), wandb.Image(np.round(y_sample * 255), caption="label")]}, step=self.curr_epoch) return losses.avg, accs.avg
def main(): # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--resume', type=bool, default=False, help='Resumes training from savefile.') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument( '--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} train_loader = torch.utils.data.DataLoader(datasets.MNIST( '../data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader(datasets.MNIST( '../data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])), batch_size=args.test_batch_size, shuffle=True, **kwargs) encoder = Encoder2() savefile = './savepoints/checkpoint10.pth.tar' if args.resume: if os.path.isfile(savefile): print("=> loading checkpoint '{}'".format(savefile)) checkpoint = torch.load(savefile) encoder.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}'".format(savefile)) else: print("=> no checkpoint found at '{}'".format(savefile)) model = TransferNet(encoder).to(device) optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) train_meter = AverageMeter(name='trainacc') test_meter = AverageMeter(name='testacc') for epoch in range(1, args.epochs + 1): train(args, model, device, train_loader, optimizer, epoch) test(args, model, device, test_loader, test_meter) test_meter.save('./')
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) else: dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 2 # 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 2 # 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Generator loaded successfully!") discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Discriminator loaded successfully!") g_model_path = 'checkpoints/zhenwarm/generator.pt' assert os.path.exists(g_model_path) # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() model = torch.load(g_model_path) pretrained_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict generator.load_state_dict(model_dict) print("pre-trained Generator loaded successfully!") # # Load discriminator model d_model_path = 'checkpoints/zhenwarm/discri.pt' assert os.path.exists(d_model_path) # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) d_model_dict = discriminator.state_dict() d_model = torch.load(d_model_path) d_pretrained_dict = d_model.state_dict() # 1. filter out unnecessary keys d_pretrained_dict = { k: v for k, v in d_pretrained_dict.items() if k in d_model_dict } # 2. overwrite entries in the existing state dict d_model_dict.update(d_pretrained_dict) # 3. load the new state dict discriminator.load_state_dict(d_model_dict) print("pre-trained Discriminator loaded successfully!") if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/myzhencli5'): os.makedirs('checkpoints/myzhencli5') checkpoints_path = 'checkpoints/myzhencli5/' # define loss function g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum') d_criterion = torch.nn.BCELoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator word embedding (as Wu et al. do) for p in discriminator.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset trainloader = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(trainloader): # set training mode generator.train() discriminator.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # use policy gradient training when random.random() > 50% if random.random() >= 0.5: print("Policy Gradient Training") sys_out_batch = generator(sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 * 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64*50 = 3200 prediction = torch.reshape( prediction, sample['net_input']['src_tokens'].shape) # 64 X 50 with torch.no_grad(): reward = discriminator(sample['net_input']['src_tokens'], prediction) # 64 X 1 train_trg_batch = sample['target'] # 64 x 50 pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward, use_cuda) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] # 64 logging_loss = pg_loss / math.log(2) g_logging_meters['train_loss'].update(logging_loss.item(), sample_size) logging.debug( f"G policy gradient loss at batch {i}: {pg_loss.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() pg_loss.backward() torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() else: # MLE training print("MLE Training") sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 train_trg_batch = sample['target'].view(-1) # 64*50 = 3200 loss = g_criterion(out_batch, train_trg_batch) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator if num_update % 5 == 0: bsz = sample['target'].size(0) # batch_size = 64 src_sentence = sample['net_input'][ 'src_tokens'] # 64 x max-len i.e 64 X 50 # now train with machine translation output i.e generator output true_sentence = sample['target'].view(-1) # 64*50 = 3200 true_labels = Variable( torch.ones( sample['target'].size(0)).float()) # 64 length vector with torch.no_grad(): sys_out_batch = generator(sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64 * 50 = 6632 fake_labels = Variable( torch.zeros( sample['target'].size(0)).float()) # 64 length vector fake_sentence = torch.reshape(prediction, src_sentence.shape) # 64 X 50 true_sentence = torch.reshape(true_sentence, src_sentence.shape) if use_cuda: fake_labels = fake_labels.cuda() true_labels = true_labels.cuda() # fake_disc_out = discriminator(src_sentence, fake_sentence) # 64 X 1 # true_disc_out = discriminator(src_sentence, true_sentence) # # fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) # true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) # # fake_acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) # true_acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) # acc = (fake_acc + true_acc) / 2 # # d_loss = fake_d_loss + true_d_loss if random.random() > 0.5: fake_disc_out = discriminator(src_sentence, fake_sentence) fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) fake_acc = torch.sum( torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) d_loss = fake_d_loss acc = fake_acc else: true_disc_out = discriminator(src_sentence, true_sentence) true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) true_acc = torch.sum( torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) d_loss = true_d_loss acc = true_acc d_logging_meters['train_acc'].update(acc) d_logging_meters['train_loss'].update(d_loss) logging.debug( f"D training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}" ) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() if num_update % 10000 == 0: # validation # set validation mode generator.eval() discriminator.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) valloader = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending= True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(valloader): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) # generator validation sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 dev_trg_batch = sample['target'].view( -1) # 64*50 = 3200 loss = g_criterion(out_batch, dev_trg_batch) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss / sample_size / math.log(2) g_logging_meters['valid_loss'].update( loss, sample_size) logging.debug( f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}" ) # discriminator validation bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64 * 50 = 6632 fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) fake_sentence = torch.reshape( prediction, src_sentence.shape) # 64 X 50 true_sentence = torch.reshape(true_sentence, src_sentence.shape) if use_cuda: fake_labels = fake_labels.cuda() true_labels = true_labels.cuda() fake_disc_out = discriminator(src_sentence, fake_sentence) # 64 X 1 true_disc_out = discriminator(src_sentence, true_sentence) fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) d_loss = fake_d_loss + true_d_loss fake_acc = torch.sum( torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) true_acc = torch.sum( torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) acc = (fake_acc + true_acc) / 2 d_logging_meters['valid_acc'].update(acc) d_logging_meters['valid_loss'].update(d_loss) logging.debug( f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}" ) # torch.save(discriminator, # open(checkpoints_path + f"numupdate_{num_update/10000}k.discri_{d_logging_meters['valid_loss'].avg:.3f}.pt",'wb'), pickle_module=dill) # if d_logging_meters['valid_loss'].avg < best_dev_loss: # best_dev_loss = d_logging_meters['valid_loss'].avg # torch.save(discriminator, open(checkpoints_path + "best_dmodel.pt", 'wb'), pickle_module=dill) torch.save( generator, open( checkpoints_path + f"numupdate_{num_update/10000}k.joint_{g_logging_meters['valid_loss'].avg:.3f}.pt", 'wb'), pickle_module=dill)
def forward(data_loader, model, criterion, epoch, training, model_type, optimizer=None, writer=None): if training: model.train() else: model.eval() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() total_steps = len(data_loader) for i, (inputs, target) in enumerate(data_loader): # measure data loading time data_time.update(time.time() - end) inputs = inputs.to('cuda:0') target = target.to('cuda:0') # compute output output = model(inputs) if model_type == 'int': # omit the output exponent output, output_exp = output output = output.float() loss = criterion(output * (2**output_exp.float()), target) else: output_exp = 0 loss = criterion(output, target) # measure accuracy and record loss losses.update(float(loss), inputs.size(0)) prec1, prec5 = accuracy(output.detach(), target, topk=(1, 5)) top1.update(float(prec1), inputs.size(0)) top5.update(float(prec5), inputs.size(0)) if training: if model_type == 'int': model.backward(target) elif model_type == 'hybrid': # float backward optimizer.update(epoch, epoch * len(data_loader) + i) optimizer.zero_grad() loss.backward() optimizer.step() #int8 backward model.backward() else: optimizer.update(epoch, epoch * len(data_loader) + i) optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.log_interval == 0 and training: logging.info('{model_type} [{0}][{1}/{2}] ' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Data {data_time.val:.2f} ' 'loss {loss.val:.3f} ({loss.avg:.3f}) ' 'e {output_exp:d} ' '@1 {top1.val:.3f} ({top1.avg:.3f}) ' '@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(data_loader), model_type=model_type, batch_time=batch_time, data_time=data_time, loss=losses, output_exp=output_exp, top1=top1, top5=top5)) if args.grad_hist: if args.model_type == 'int': for idx, l in enumerate(model.forward_layers): if hasattr(l, 'weight'): grad = l.grad_int32acc writer.add_histogram( 'Grad/' + l.__class__.__name__ + '_' + str(idx), grad, epoch * total_steps + i) elif args.model_type == 'float': for idx, l in enumerate(model.layers): if hasattr(l, 'weight'): writer.add_histogram( 'Grad/' + l.__class__.__name__ + '_' + str(idx), l.weight.grad, epoch * total_steps + i) for idx, l in enumerate(model.classifier): if hasattr(l, 'weight'): writer.add_histogram( 'Grad/' + l.__class__.__name__ + '_' + str(idx), l.weight.grad, epoch * total_steps + i) return losses.avg, top1.avg, top5.avg
def main(args): use_cuda = (len(args.gpuid) >= 1) if args.gpuid: cuda.set_device(args.gpuid[0]) # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang) else: dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False logging_meters = OrderedDict() logging_meters['train_loss'] = AverageMeter() logging_meters['valid_loss'] = AverageMeter() logging_meters['bsz'] = AverageMeter() # sentences per batch # Build model generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) if use_cuda: generator.cuda() else: generator.cpu() optimizer = eval("torch.optim." + args.optimizer)(generator.parameters(), args.learning_rate) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=0, factor=args.lr_shrink) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf epoch_i = 1 best_dev_loss = math.inf lr = optimizer.param_groups[0]['lr'] # main training loop # added for write training loss f1 = open("train_loss", "a") while lr > args.min_lr and epoch_i <= max_epoch: logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (min(args.max_source_positions, generator.encoder.max_positions()), min(args.max_target_positions, generator.decoder.max_positions())) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions_train, seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # set training mode generator.train() # reset meters for key, val in logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(itr): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.data / sample_size / math.log(2) logging_meters['bsz'].update(nsentences) logging_meters['train_loss'].update(logging_loss, sample_size) f1.write("{0}\n".format(logging_meters['train_loss'].avg)) logging.debug( "loss at batch {0}: {1:.3f}, batch size: {2}, lr={3}".format( i, logging_meters['train_loss'].avg, round(logging_meters['bsz'].avg), optimizer.param_groups[0]['lr'])) optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) optimizer.step() # validation -- this is a crude estimation because there might be some padding at the end max_positions_valid = ( generator.encoder.max_positions(), generator.decoder.max_positions(), ) # Initialize dataloader itr = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # set validation mode generator.eval() # reset meters for key, val in logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(itr): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss / sample_size / math.log(2) logging_meters['valid_loss'].update(loss, sample_size) logging.debug("dev loss at batch {0}: {1:.3f}".format( i, logging_meters['valid_loss'].avg)) # update learning rate lr_scheduler.step(logging_meters['valid_loss'].avg) lr = optimizer.param_groups[0]['lr'] logging.info( "Average loss value per instance is {0} at the end of epoch {1}". format(logging_meters['valid_loss'].avg, epoch_i)) torch.save( generator.state_dict(), open( args.model_file + "data.nll_{0:.3f}.epoch_{1}.pt".format( logging_meters['valid_loss'].avg, epoch_i), 'wb')) if logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = logging_meters['valid_loss'].avg torch.save(generator.state_dict(), open(args.model_file + "best_gmodel.pt", 'wb')) epoch_i += 1 f1.close()
def train_d(args, dataset): logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG) use_cuda = (torch.cuda.device_count() >= 1) # check checkpoints saving path if not os.path.exists('checkpoints/discriminator'): os.makedirs('checkpoints/discriminator') checkpoints_path = 'checkpoints/discriminator/' logging_meters = OrderedDict() logging_meters['train_loss'] = AverageMeter() logging_meters['train_acc'] = AverageMeter() logging_meters['valid_loss'] = AverageMeter() logging_meters['valid_acc'] = AverageMeter() logging_meters['update_times'] = AverageMeter() # Build model discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) # Load generator assert os.path.exists('checkpoints/generator/best_gmodel.pt') generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() pretrained_dict = torch.load('checkpoints/generator/best_gmodel.pt') # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict generator.load_state_dict(model_dict) if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() # generator = torch.nn.DataParallel(generator).cuda() generator.cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() criterion = torch.nn.CrossEntropyLoss() # optimizer = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad, discriminator.parameters()), # args.d_learning_rate, momentum=args.momentum, nesterov=True) optimizer = torch.optim.RMSprop( filter(lambda x: x.requires_grad, discriminator.parameters()), 1e-4) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=0, factor=args.lr_shrink) # Train until the accuracy achieve the define value max_epoch = args.max_epoch or math.inf epoch_i = 1 trg_acc = 0.82 best_dev_loss = math.inf lr = optimizer.param_groups[0]['lr'] # validation set data loader (only prepare once) train = prepare_training_data(args, dataset, 'train', generator, epoch_i, use_cuda) valid = prepare_training_data(args, dataset, 'valid', generator, epoch_i, use_cuda) data_train = DatasetProcessing(data=train, maxlen=args.fixed_max_len) data_valid = DatasetProcessing(data=valid, maxlen=args.fixed_max_len) # main training loop while lr > args.min_d_lr and epoch_i <= max_epoch: logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) if args.sample_without_replacement > 0 and epoch_i > 1: train = prepare_training_data(args, dataset, 'train', generator, epoch_i, use_cuda) data_train = DatasetProcessing(data=train, maxlen=args.fixed_max_len) # discriminator training dataloader train_loader = train_dataloader(data_train, batch_size=args.joint_batch_size, seed=seed, epoch=epoch_i, sort_by_source_size=False) valid_loader = eval_dataloader(data_valid, num_workers=4, batch_size=args.joint_batch_size) # set training mode discriminator.train() # reset meters for key, val in logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(train_loader): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=use_cuda) disc_out = discriminator(sample['src_tokens'], sample['trg_tokens']) loss = criterion(disc_out, sample['labels']) _, prediction = F.softmax(disc_out, dim=1).topk(1) acc = torch.sum( prediction == sample['labels'].unsqueeze(1)).float() / len( sample['labels']) logging_meters['train_acc'].update(acc.item()) logging_meters['train_loss'].update(loss.item()) logging.debug("D training loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \ format(logging_meters['train_loss'].avg, acc, logging_meters['train_acc'].avg, optimizer.param_groups[0]['lr'], i)) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(discriminator.parameters(), args.clip_norm) optimizer.step() # del src_tokens, trg_tokens, loss, disc_out, labels, prediction, acc del disc_out, loss, prediction, acc # set validation mode discriminator.eval() for i, sample in enumerate(valid_loader): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=use_cuda) disc_out = discriminator(sample['src_tokens'], sample['trg_tokens']) loss = criterion(disc_out, sample['labels']) _, prediction = F.softmax(disc_out, dim=1).topk(1) acc = torch.sum( prediction == sample['labels'].unsqueeze(1)).float() / len( sample['labels']) logging_meters['valid_acc'].update(acc.item()) logging_meters['valid_loss'].update(loss.item()) logging.debug("D eval loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \ format(logging_meters['valid_loss'].avg, acc, logging_meters['valid_acc'].avg, optimizer.param_groups[0]['lr'], i)) del disc_out, loss, prediction, acc lr_scheduler.step(logging_meters['valid_loss'].avg) if logging_meters['valid_acc'].avg >= 0.70: torch.save(discriminator.state_dict(), checkpoints_path + "ce_{0:.3f}_acc_{1:.3f}.epoch_{2}.pt" \ .format(logging_meters['valid_loss'].avg, logging_meters['valid_acc'].avg, epoch_i)) if logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = logging_meters['valid_loss'].avg torch.save(discriminator.state_dict(), checkpoints_path + "best_dmodel.pt") # pretrain the discriminator to achieve accuracy 82% if logging_meters['valid_acc'].avg >= trg_acc: return epoch_i += 1
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) else: dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['MLE_train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['PG_train_loss'] = AverageMeter() g_logging_meters['MLE_train_acc'] = AverageMeter() g_logging_meters['PG_train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['D_h_train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['D_s_train_loss'] = AverageMeter() d_logging_meters['D_h_train_acc'] = AverageMeter() d_logging_meters['D_s_train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 2 # 4 args.encoder_dropout_out = 0.1 args.decoder_embed_dim = 1000 args.decoder_layers = 2 # 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0.1 args.bidirectional = False generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Generator loaded successfully!") discriminator_h = Discriminator_h(args.decoder_embed_dim, args.discriminator_hidden_size, args.discriminator_linear_size, args.discriminator_lin_dropout, use_cuda=use_cuda) print("Discriminator_h loaded successfully!") discriminator_s = Discriminator_s(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) print("Discriminator_s loaded successfully!") # Load generator model g_model_path = 'checkpoints/zhenwarm/genev.pt' assert os.path.exists(g_model_path) # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() model = torch.load(g_model_path) pretrained_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict generator.load_state_dict(model_dict) print("pre-trained Generator loaded successfully!") # Load discriminator model d_model_path = 'checkpoints/zhenwarm/discri.pt' assert os.path.exists(d_model_path) # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) d_model_dict = discriminator_s.state_dict() d_model = torch.load(d_model_path) d_pretrained_dict = d_model.state_dict() # 1. filter out unnecessary keys d_pretrained_dict = { k: v for k, v in d_pretrained_dict.items() if k in d_model_dict } # 2. overwrite entries in the existing state dict d_model_dict.update(d_pretrained_dict) # 3. load the new state dict discriminator_s.load_state_dict(d_model_dict) print("pre-trained Discriminator loaded successfully!") # # Load discriminatorH model # d_H_model_path = 'checkpoints/joint_warm/DH.pt' # assert os.path.exists(d_H_model_path) # # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) # d_H_model_dict = discriminator_h.state_dict() # d_H_model = torch.load(d_H_model_path) # d_H_pretrained_dict = d_H_model.state_dict() # # 1. filter out unnecessary keys # d_H_pretrained_dict = {k: v for k, v in d_H_pretrained_dict.items() if k in d_H_model_dict} # # 2. overwrite entries in the existing state dict # d_H_model_dict.update(d_H_pretrained_dict) # # 3. load the new state dict # discriminator_h.load_state_dict(d_H_model_dict) # print("pre-trained Discriminator_H loaded successfully!") def _calcualte_discriminator_loss(tf_scores, ar_scores): tf_loss = torch.log(tf_scores + 1e-9) * (-1) ar_loss = torch.log(1 - ar_scores + 1e-9) * (-1) return tf_loss + ar_loss if use_cuda: if torch.cuda.device_count() > 1: discriminator_h = torch.nn.DataParallel(discriminator_h).cuda() discriminator_s = torch.nn.DataParallel(discriminator_s).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator_h.cuda() discriminator_s.cuda() else: discriminator_h.cpu() discriminator_s.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/realmyzhenup10shrink07drop01new'): os.makedirs('checkpoints/realmyzhenup10shrink07drop01new') checkpoints_path = 'checkpoints/realmyzhenup10shrink07drop01new/' # define loss function g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum') d_criterion = torch.nn.BCELoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator_s word embedding (as Wu et al. do) for p in discriminator_s.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator_s.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer_h = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator_h.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) d_optimizer_s = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator_s.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset trainloader = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() # set training mode for i, sample in enumerate(trainloader): generator.train() discriminator_h.train() discriminator_s.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) train_MLE = 0 train_PG = 0 if random.random() > 0.5 and i != 0: train_MLE = 1 # print("MLE Training") sys_out_batch_MLE, p_MLE, hidden_list_MLE = generator( "MLE", epoch_i, sample) out_batch_MLE = sys_out_batch_MLE.contiguous().view( -1, sys_out_batch_MLE.size(-1)) # (64 X 50) X 6632 train_trg_batch_MLE = sample['target'].view(-1) # 64*50 = 3200 loss_MLE = g_criterion(out_batch_MLE, train_trg_batch_MLE) sample_size_MLE = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss_MLE = loss_MLE.data / sample_size_MLE / math.log( 2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['MLE_train_loss'].update( logging_loss_MLE, sample_size_MLE) logging.debug( f"G MLE loss at batch {i}: {g_logging_meters['MLE_train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() loss_MLE.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size_MLE) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() else: train_PG = 1 ## part I: use gradient policy method to train the generator # print("Policy Gradient Training") sys_out_batch_PG, p_PG, hidden_list_PG = generator( 'PG', epoch_i, sample) # 64 X 50 X 6632 out_batch_PG = sys_out_batch_PG.contiguous().view( -1, sys_out_batch_PG.size(-1)) # (64 * 50) X 6632 _, prediction = out_batch_PG.topk(1) prediction = prediction.squeeze(1) # 64*50 = 3200 prediction = torch.reshape( prediction, sample['net_input']['src_tokens'].shape) # 64 X 50 # if d_logging_meters['train_acc'].avg >= 0.75: with torch.no_grad(): reward = discriminator_s(sample['net_input']['src_tokens'], prediction) # 64 X 1 # else: # reward = torch.ones(args.joint_batch_size) train_trg_batch_PG = sample['target'] # 64 x 50 pg_loss_PG = pg_criterion(sys_out_batch_PG, train_trg_batch_PG, reward, use_cuda) sample_size_PG = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] # 64 logging_loss_PG = pg_loss_PG / math.log(2) g_logging_meters['PG_train_loss'].update( logging_loss_PG.item(), sample_size_PG) logging.debug( f"G policy gradient loss at batch {i}: {pg_loss_PG.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}" ) g_optimizer.zero_grad() if g_logging_meters['PG_train_loss'].val < 0.5: pg_loss_PG.backward() torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # if g_logging_meters["MLE_train_loss"].avg < 4: # part II: train the discriminator # discriminator_h if num_update % 10 == 0: if g_logging_meters["PG_train_loss"].val < 2: assert (train_MLE == 1) != (train_PG == 1) if train_MLE == 1: d_MLE = discriminator_h(hidden_list_MLE.detach()) M_loss = torch.log(d_MLE + 1e-9) * (-1) h_d_loss = M_loss.sum() elif train_PG == 1: d_PG = discriminator_h(hidden_list_PG.detach()) P_loss = torch.log(1 - d_PG + 1e-9) * (-1) h_d_loss = P_loss.sum() # d_loss = _calcualte_discriminator_loss(d_MLE, d_PG).sum() logging.debug(f"D_h training loss {h_d_loss} at batch {i}") d_optimizer_h.zero_grad() h_d_loss.backward() torch.nn.utils.clip_grad_norm_( discriminator_h.parameters(), args.clip_norm) d_optimizer_h.step() #discriminator_s bsz = sample['target'].size(0) # batch_size = 64 src_sentence = sample['net_input'][ 'src_tokens'] # 64 x max-len i.e 64 X 50 # now train with machine translation output i.e generator output with torch.no_grad(): sys_out_batch, p, hidden_list = generator( 'PG', epoch_i, sample) # 64 X 50 X 6632 out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) # (64 X 50) X 6632 _, prediction = out_batch.topk(1) prediction = prediction.squeeze(1) # 64 * 50 = 6632 fake_labels = torch.zeros( sample['target'].size(0)).float() # 64 length vector # true_labels = torch.ones(sample['target'].size(0)).float() # 64 length vector fake_sentence = torch.reshape(prediction, src_sentence.shape) # 64 X 50 if use_cuda: fake_labels = fake_labels.cuda() # true_labels = true_labels.cuda() # if random.random() > 0.5: fake_disc_out = discriminator_s(src_sentence, fake_sentence) # 64 X 1 fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels) acc = torch.sum( torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) d_loss = fake_d_loss # else: # # true_sentence = sample['target'].view(-1) # 64*50 = 3200 # # true_sentence = torch.reshape(true_sentence, src_sentence.shape) # true_disc_out = discriminator_s(src_sentence, true_sentence) # acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) # true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels) # d_loss = true_d_loss # acc_fake = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels) # acc_true = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels) # acc = (acc_fake + acc_true) / 2 # acc = acc_fake # d_loss = fake_d_loss + true_d_loss s_d_loss = d_loss d_logging_meters['D_s_train_acc'].update(acc) d_logging_meters['D_s_train_loss'].update(s_d_loss) logging.debug( f"D_s training loss {d_logging_meters['D_s_train_loss'].avg:.3f}, acc {d_logging_meters['D_s_train_acc'].avg:.3f} at batch {i}" ) d_optimizer_s.zero_grad() s_d_loss.backward() d_optimizer_s.step() if num_update % 10000 == 0: # validation # set validation mode print( 'validation and save+++++++++++++++++++++++++++++++++++++++++++++++' ) generator.eval() discriminator_h.eval() discriminator_s.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) valloader = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending= True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(valloader): with torch.no_grad(): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) # generator validation sys_out_batch_test, p_test, hidden_list_test = generator( 'test', epoch_i, sample) out_batch_test = sys_out_batch_test.contiguous().view( -1, sys_out_batch_test.size(-1)) # (64 X 50) X 6632 dev_trg_batch = sample['target'].view( -1) # 64*50 = 3200 loss_test = g_criterion(out_batch_test, dev_trg_batch) sample_size_test = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss_test = loss_test / sample_size_test / math.log(2) g_logging_meters['valid_loss'].update( loss_test, sample_size_test) logging.debug( f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}" ) torch.save( generator, open( checkpoints_path + f"numupdate_{num_update/10000}w.sampling_{g_logging_meters['valid_loss'].avg:.3f}.pt", 'wb'), pickle_module=dill)
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) print("======printing args========") print(args) print("=================================") # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): print("Loading bin dataset") dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) #args.data, splits, args.src_lang, args.trg_lang) else: print(f"Loading raw text dataset {args.data}") dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) #args.data, splits, args.src_lang, args.trg_lang) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) g_logging_meters = OrderedDict() g_logging_meters['train_loss'] = AverageMeter() g_logging_meters['valid_loss'] = AverageMeter() g_logging_meters['train_acc'] = AverageMeter() g_logging_meters['valid_acc'] = AverageMeter() g_logging_meters['bsz'] = AverageMeter() # sentences per batch d_logging_meters = OrderedDict() d_logging_meters['train_loss'] = AverageMeter() d_logging_meters['valid_loss'] = AverageMeter() d_logging_meters['train_acc'] = AverageMeter() d_logging_meters['valid_acc'] = AverageMeter() d_logging_meters['bsz'] = AverageMeter() # sentences per batch # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False # try to load generator model g_model_path = 'checkpoints/generator/best_gmodel.pt' if not os.path.exists(g_model_path): print("Start training generator!") train_g(args, dataset) assert os.path.exists(g_model_path) generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = generator.state_dict() pretrained_dict = torch.load(g_model_path) #print(f"First dict: {pretrained_dict}") # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } #print(f"Second dict: {pretrained_dict}") # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) #print(f"model dict: {model_dict}") # 3. load the new state dict generator.load_state_dict(model_dict) print("Generator has successfully loaded!") # try to load discriminator model d_model_path = 'checkpoints/discriminator/best_dmodel.pt' if not os.path.exists(d_model_path): print("Start training discriminator!") train_d(args, dataset) assert os.path.exists(d_model_path) discriminator = Discriminator(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) model_dict = discriminator.state_dict() pretrained_dict = torch.load(d_model_path) # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict discriminator.load_state_dict(model_dict) print("Discriminator has successfully loaded!") #return print("starting main training loop") torch.autograd.set_detect_anomaly(True) if use_cuda: if torch.cuda.device_count() > 1: discriminator = torch.nn.DataParallel(discriminator).cuda() generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() discriminator.cuda() else: discriminator.cpu() generator.cpu() # adversarial training checkpoints saving path if not os.path.exists('checkpoints/joint'): os.makedirs('checkpoints/joint') checkpoints_path = 'checkpoints/joint/' # define loss function g_criterion = torch.nn.NLLLoss(size_average=False, ignore_index=dataset.dst_dict.pad(), reduce=True) d_criterion = torch.nn.BCEWithLogitsLoss() pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True) # fix discriminator word embedding (as Wu et al. do) for p in discriminator.embed_src_tokens.parameters(): p.requires_grad = False for p in discriminator.embed_trg_tokens.parameters(): p.requires_grad = False # define optimizer g_optimizer = eval("torch.optim." + args.g_optimizer)(filter( lambda x: x.requires_grad, generator.parameters()), args.g_learning_rate) d_optimizer = eval("torch.optim." + args.d_optimizer)( filter(lambda x: x.requires_grad, discriminator.parameters()), args.d_learning_rate, momentum=args.momentum, nesterov=True) # start joint training best_dev_loss = math.inf num_update = 0 # main training loop for epoch_i in range(1, args.epochs + 1): logging.info("At {0}-th epoch.".format(epoch_i)) # seed = args.seed + epoch_i # torch.manual_seed(seed) max_positions_train = (args.fixed_max_len, args.fixed_max_len) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_train, # seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() # set training mode generator.train() discriminator.train() update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer) for i, sample in enumerate(itr): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) ## part I: use gradient policy method to train the generator # use policy gradient training when rand > 50% rand = random.random() if rand >= 0.5: # policy gradient training generator.decoder.is_testing = True sys_out_batch, prediction, _ = generator(sample) generator.decoder.is_testing = False with torch.no_grad(): n_i = sample['net_input']['src_tokens'] #print(f"net input:\n{n_i}, pred: \n{prediction}") reward = discriminator( sample['net_input']['src_tokens'], prediction) # dataset.dst_dict.pad()) train_trg_batch = sample['target'] #print(f"sys_out_batch: {sys_out_batch.shape}:\n{sys_out_batch}") pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward, use_cuda) # logging.debug("G policy gradient loss at batch {0}: {1:.3f}, lr={2}".format(i, pg_loss.item(), g_optimizer.param_groups[0]['lr'])) g_optimizer.zero_grad() pg_loss.backward() torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) g_optimizer.step() # oracle valid _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "G MLE loss at batch {0}: {1:.3f}, lr={2}".format( i, g_logging_meters['train_loss'].avg, g_optimizer.param_groups[0]['lr'])) else: # MLE training #print(f"printing sample: \n{sample}") _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.data / sample_size / math.log(2) g_logging_meters['bsz'].update(nsentences) g_logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "G MLE loss at batch {0}: {1:.3f}, lr={2}".format( i, g_logging_meters['train_loss'].avg, g_optimizer.param_groups[0]['lr'])) g_optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) g_optimizer.step() num_update += 1 # part II: train the discriminator bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): generator.decoder.is_testing = True _, prediction, _ = generator(sample) generator.decoder.is_testing = False fake_sentence = prediction fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0) labels = torch.cat([true_labels, fake_labels], dim=0) indices = np.random.permutation(2 * bsz) trg_sentence = trg_sentence[indices][:bsz] labels = labels[indices][:bsz] if use_cuda: labels = labels.cuda() disc_out = discriminator(src_sentence, trg_sentence) #, dataset.dst_dict.pad()) #print(f"disc out: {disc_out.shape}, labels: {labels.shape}") #print(f"labels: {labels}") d_loss = d_criterion(disc_out, labels.long()) acc = torch.sum(torch.Sigmoid() (disc_out).round() == labels).float() / len(labels) d_logging_meters['train_acc'].update(acc) d_logging_meters['train_loss'].update(d_loss) # logging.debug("D training loss {0:.3f}, acc {1:.3f} at batch {2}: ".format(d_logging_meters['train_loss'].avg, # d_logging_meters['train_acc'].avg, # i)) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() # validation # set validation mode generator.eval() discriminator.eval() # Initialize dataloader max_positions_valid = (args.fixed_max_len, args.fixed_max_len) itr = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.joint_batch_size, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=True, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # reset meters for key, val in g_logging_meters.items(): if val is not None: val.reset() for key, val in d_logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(itr): with torch.no_grad(): if use_cuda: sample['id'] = sample['id'].cuda() sample['net_input']['src_tokens'] = sample['net_input'][ 'src_tokens'].cuda() sample['net_input']['src_lengths'] = sample['net_input'][ 'src_lengths'].cuda() sample['net_input']['prev_output_tokens'] = sample[ 'net_input']['prev_output_tokens'].cuda() sample['target'] = sample['target'].cuda() # generator validation _, _, loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss / sample_size / math.log(2) g_logging_meters['valid_loss'].update(loss, sample_size) logging.debug("G dev loss at batch {0}: {1:.3f}".format( i, g_logging_meters['valid_loss'].avg)) # discriminator validation bsz = sample['target'].size(0) src_sentence = sample['net_input']['src_tokens'] # train with half human-translation and half machine translation true_sentence = sample['target'] true_labels = Variable( torch.ones(sample['target'].size(0)).float()) with torch.no_grad(): generator.decoder.is_testing = True _, prediction, _ = generator(sample) generator.decoder.is_testing = False fake_sentence = prediction fake_labels = Variable( torch.zeros(sample['target'].size(0)).float()) trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0) labels = torch.cat([true_labels, fake_labels], dim=0) indices = np.random.permutation(2 * bsz) trg_sentence = trg_sentence[indices][:bsz] labels = labels[indices][:bsz] if use_cuda: labels = labels.cuda() disc_out = discriminator(src_sentence, trg_sentence, dataset.dst_dict.pad()) d_loss = d_criterion(disc_out, labels) acc = torch.sum(torch.Sigmoid()(disc_out).round() == labels).float() / len(labels) d_logging_meters['valid_acc'].update(acc) d_logging_meters['valid_loss'].update(d_loss) # logging.debug("D dev loss {0:.3f}, acc {1:.3f} at batch {2}".format(d_logging_meters['valid_loss'].avg, # d_logging_meters['valid_acc'].avg, i)) torch.save(generator, open( checkpoints_path + "joint_{0:.3f}.epoch_{1}.pt".format( g_logging_meters['valid_loss'].avg, epoch_i), 'wb'), pickle_module=dill) if g_logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = g_logging_meters['valid_loss'].avg torch.save(generator, open(checkpoints_path + "best_gmodel.pt", 'wb'), pickle_module=dill)
def main(args): use_cuda = (len(args.gpuid) >= 1) print("{0} GPU(s) are available".format(cuda.device_count())) # Load dataset splits = ['train', 'valid'] if data.has_binary_files(args.data, splits): dataset = data.load_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) else: dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len) if args.src_lang is None or args.trg_lang is None: # record inferred languages in args, so that it's saved in checkpoints args.src_lang, args.trg_lang = dataset.src, dataset.dst print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) for split in splits: print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) # check checkpoints saving path if not os.path.exists('checkpoints/generator'): os.makedirs('checkpoints/generator') checkpoints_path = 'checkpoints/generator/' logging_meters = OrderedDict() logging_meters['train_loss'] = AverageMeter() logging_meters['valid_loss'] = AverageMeter() logging_meters['bsz'] = AverageMeter() # sentences per batch logging_meters['update_times'] = AverageMeter() # Set model parameters args.encoder_embed_dim = 1000 args.encoder_layers = 2 # 4 args.encoder_dropout_out = 0 args.decoder_embed_dim = 1000 args.decoder_layers = 2 # 4 args.decoder_out_embed_dim = 1000 args.decoder_dropout_out = 0 args.bidirectional = False # Build model generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) # g_model_path = 'checkpoints/generator/numupdate1.4180668458302803.data.nll_105000.000.pt' # assert os.path.exists(g_model_path) # # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) # model_dict = generator.state_dict() # model = torch.load(g_model_path) # pretrained_dict = model.state_dict() # # 1. filter out unnecessary keys # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # # 2. overwrite entries in the existing state dict # model_dict.update(pretrained_dict) # # 3. load the new state dict # generator.load_state_dict(model_dict) # print("pre-trained Generator loaded successfully!") if use_cuda: if len(args.gpuid) > 1: generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() else: generator.cpu() print("Training generator...") g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum') optimizer = eval("torch.optim." + args.optimizer)(generator.parameters(), args.learning_rate) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=0, factor=args.lr_shrink) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf epoch_i = 1 best_dev_loss = math.inf lr = optimizer.param_groups[0]['lr'] num_update = 0 # main training loop while lr > args.min_g_lr and epoch_i <= max_epoch: logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (min(args.max_source_positions, generator.encoder.max_positions()), min(args.max_target_positions, generator.decoder.max_positions())) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions_train, seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # set training mode # reset meters for key, val in logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(itr): generator.train() if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) train_trg_batch = sample['target'].view(-1) loss = g_criterion(out_batch, train_trg_batch) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.item() / sample_size / math.log(2) logging_meters['bsz'].update(nsentences) logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "g loss at batch {0}: {1:.3f}, batch size: {2}, lr={3}".format( i, logging_meters['train_loss'].avg, round(logging_meters['bsz'].avg), optimizer.param_groups[0]['lr'])) optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm) optimizer.step() num_update = num_update + 1 if num_update % 5000 == 0: # validation -- this is a crude estimation because there might be some padding at the end max_positions_valid = ( generator.encoder.max_positions(), generator.decoder.max_positions(), ) # Initialize dataloader itr = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test, descending= True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # set validation mode generator.eval() # reset meters for key, val in logging_meters.items(): if val is not None: val.reset() with torch.no_grad(): for i, sample in enumerate(itr): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) sys_out_batch = generator(sample) out_batch = sys_out_batch.contiguous().view( -1, sys_out_batch.size(-1)) val_trg_batch = sample['target'].view(-1) loss = g_criterion(out_batch, val_trg_batch) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss.item() / sample_size / math.log(2) logging_meters['valid_loss'].update(loss, sample_size) logging.debug( "g dev loss at batch {0}: {1:.3f}".format( i, logging_meters['valid_loss'].avg)) # update learning rate lr_scheduler.step(logging_meters['valid_loss'].avg) lr = optimizer.param_groups[0]['lr'] logging.info( "Average g loss value per instance is {0} at the end of epoch {1}" .format(logging_meters['valid_loss'].avg, epoch_i)) torch.save( generator, open( checkpoints_path + "numupdate{1}.data.nll_{0:.1f}.pt".format( num_update, logging_meters['valid_loss'].avg), 'wb')) # if logging_meters['valid_loss'].avg < best_dev_loss: # best_dev_loss = logging_meters['valid_loss'].avg # torch.save(generator.state_dict(), open( # checkpoints_path + "best_gmodel.pt", 'wb')) epoch_i += 1
def train_g(args, dataset): logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG) use_cuda = (cuda.device_count() >= 1) # check checkpoints saving path if not os.path.exists('checkpoints/generator'): os.makedirs('checkpoints/generator') checkpoints_path = 'checkpoints/generator/' logging_meters = OrderedDict() logging_meters['train_loss'] = AverageMeter() logging_meters['valid_loss'] = AverageMeter() logging_meters['bsz'] = AverageMeter() # sentences per batch logging_meters['update_times'] = AverageMeter() # Build model generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda) if use_cuda: if len(args.gpuid) > 1: generator = torch.nn.DataParallel(generator).cuda() else: generator.cuda() else: generator.cpu() optimizer = eval("torch.optim." + args.optimizer)(generator.parameters(), args.learning_rate) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=0, factor=args.lr_shrink) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf epoch_i = 1 best_dev_loss = math.inf lr = optimizer.param_groups[0]['lr'] # main training loop while lr > args.min_g_lr and epoch_i <= max_epoch: logging.info("At {0}-th epoch.".format(epoch_i)) seed = args.seed + epoch_i torch.manual_seed(seed) max_positions_train = (min(args.max_source_positions, generator.encoder.max_positions()), min(args.max_target_positions, generator.decoder.max_positions())) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( 'train', max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions_train, seed=seed, epoch=epoch_i, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch_i <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # set training mode generator.train() # reset meters for key, val in logging_meters.items(): if val is not None: val.reset() for i, sample in enumerate(itr): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] nsentences = sample['target'].size(0) logging_loss = loss.item() / sample_size / math.log(2) logging_meters['bsz'].update(nsentences) logging_meters['train_loss'].update(logging_loss, sample_size) logging.debug( "g loss at batch {0}: {1:.3f}, batch size: {2}, lr={3}".format( i, logging_meters['train_loss'].avg, round(logging_meters['bsz'].avg), optimizer.param_groups[0]['lr'])) optimizer.zero_grad() loss.backward() # all-reduce grads and rescale by grad_denom for p in generator.parameters(): if p.requires_grad: p.grad.data.div_(sample_size) torch.nn.utils.clip_grad_norm(generator.parameters(), args.clip_norm) optimizer.step() # validation -- this is a crude estimation because there might be some padding at the end max_positions_valid = ( generator.encoder.max_positions(), generator.decoder.max_positions(), ) # Initialize dataloader itr = dataset.eval_dataloader( 'valid', max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) # set validation mode generator.eval() # reset meters for key, val in logging_meters.items(): if val is not None: val.reset() with torch.no_grad(): for i, sample in enumerate(itr): if use_cuda: # wrap input tensors in cuda tensors sample = utils.make_variable(sample, cuda=cuda) loss = generator(sample) sample_size = sample['target'].size( 0) if args.sentence_avg else sample['ntokens'] loss = loss.item() / sample_size / math.log(2) logging_meters['valid_loss'].update(loss, sample_size) logging.debug("g dev loss at batch {0}: {1:.3f}".format( i, logging_meters['valid_loss'].avg)) # update learning rate lr_scheduler.step(logging_meters['valid_loss'].avg) lr = optimizer.param_groups[0]['lr'] logging.info( "Average g loss value per instance is {0} at the end of epoch {1}". format(logging_meters['valid_loss'].avg, epoch_i)) torch.save( generator.state_dict(), open( checkpoints_path + "data.nll_{0:.3f}.epoch_{1}.pt".format( logging_meters['valid_loss'].avg, epoch_i), 'wb')) if logging_meters['valid_loss'].avg < best_dev_loss: best_dev_loss = logging_meters['valid_loss'].avg torch.save(generator.state_dict(), open(checkpoints_path + "best_gmodel.pt", 'wb')) epoch_i += 1
def train(self, epoch, data_loader, opt_sn, opt_vn, mode, writer=None, print_freq=1): self.sn.train() self.vn.train() batch_time = AverageMeter() data_time = AverageMeter() losses_sn = AverageMeter() losses_vn = AverageMeter() ious = AverageMeter() end = time.time() for i, inputs in enumerate(data_loader): data_time.update(time.time() - end) img, lbl = self._parse_data(inputs) # train sn loss_sn, iou_, heat_map = self._forward_sn(img, lbl) losses_sn.update(loss_sn.data[0], lbl.size(0)) ious.update(iou_, lbl.size(0)) if mode == 'sn': # if opt_sn is None: # img.volatile = True # lbl.volatile = True # else: # img.volatile = False # lbl.volatile = False self.step(opt_sn, loss_sn) # train vn elif mode == 'vn': # heat_map = heat_map.detach() _, seg_pred = torch.max(heat_map, dim=1, keepdim=True) # seg_pred = onehot(seg_pred, 2) # heat_map = heat_map target_iou = iou(heat_map.data, lbl.data, average=False) loss_vn, iou_pred = self._forward_vn(img, heat_map, target_iou) losses_vn.update(loss_vn.data[0], lbl.size(0)) self.step(opt_vn, loss_vn) # bp % gd # if opt_sn is not None: # self.step(opt_sn, loss_sn) # if opt_vn is not None: # self.step(opt_vn, loss_vn) batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Epoch: [{}][{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t' 'Loss_sn {:.3f} ({:.3f})\t' 'Loss_vn {:.3f} ({:.3f})\t' 'Prec {:.2%} ({:.2%})\t'.format( epoch, i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg, losses_sn.val, losses_sn.avg, losses_vn.val, losses_vn.avg, ious.val, ious.avg)) if writer is not None: summary_output_lbl(seg_pred.data, lbl.data, writer, epoch)
def train_model(output_path, model, dataloaders, dataset_sizes, criterion, optimizer, num_epochs=5, scheduler=None): if not os.path.exists('iterations/' + str(output_path) + '/saved'): os.makedirs('iterations/' + str(output_path) + '/saved') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") losses = AverageMeter() accuracies = AverageMeter() all_preds = [] all_labels = [] val_auc_all = [] val_acc_all = [] test_auc_all = [] test_acc_all = [] TPFPFN0_all = [] TPFPFN1_all = [] best_val_auc = 0.0 best_epoch = 0 for epoch in range(1, num_epochs + 1): print('-' * 50) print('Epoch {}/{}'.format(epoch, num_epochs)) for phase in ['train', 'val']: if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode # tqdm_loader = tqdm(dataloaders[phase]) # for data in tqdm_loader: # inputs, labels = data for i, (inputs, labels) in enumerate(dataloaders[phase]): inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() # with torch.set_grad_enabled(True): outputs = model(inputs) _, preds = torch.max(outputs.data, 1) labels_onehot = torch.nn.functional.one_hot(labels, num_classes=2) labels_onehot = labels_onehot.type(torch.FloatTensor) # BCEloss = torch.nn.functional.binary_cross_entropy_with_logits(outputs.cpu(), labels_onehot, torch.FloatTensor([1.0, 1.0])) BCEloss = criterion(outputs.cpu(), labels_onehot) # print("BCEloss", BCEloss) BCEloss_rank = binary_crossentropy_with_ranking( outputs, labels_onehot) # print("BCEloss_rank", BCEloss_rank) # BCEloss_rank.requires_grad = True loss = BCEloss + 0 * BCEloss_rank # print("BCEloss, BCEloss_rank", BCEloss, BCEloss_rank) # loss = (BCEloss_rank + 1) * BCEloss loss.backward() optimizer.step() losses.update(loss.item(), inputs.size(0)) acc = float(torch.sum(preds == labels.data)) / preds.shape[0] accuracies.update(acc) all_preds += list( torch.nn.functional.softmax(outputs, dim=1)[:, 1].cpu().data.numpy()) all_labels += list(labels.cpu().data.numpy()) # tqdm_loader.set_postfix(loss=losses.avg, acc=accuracies.avg) auc = roc_auc_score(all_labels, all_preds) if phase == 'train': auc_t = auc loss_t = losses.avg acc_t = accuracies.avg if phase == 'val': auc_v = auc loss_v = losses.avg acc_v = accuracies.avg val_acc_all.append(acc_v) val_auc_all.append(auc_v) print('Train AUC: {:.8f} Loss: {:.8f} ACC: {:.8f} '.format( auc_t, loss_t, acc_t)) print('Val AUC: {:.8f} Loss: {:.8f} ACC: {:.8f} '.format( auc_v, loss_v, acc_v)) if auc_v > best_val_auc: best_val_auc = auc_v best_epoch = epoch # print(auc_v, best_val_auc) # print(best_epoch) best_model = copy.deepcopy(model) torch.save( model.module, './iterations/' + str(output_path) + '/saved/model_{}_epoch.pt'.format(epoch)) # ############################################################################################################# Test for phase in ['test']: model.eval() # Set model to evaluate mode for i, (inputs, labels) in enumerate(dataloaders[phase]): inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(False): outputs = model(inputs) _, preds = torch.max(outputs.data, 1) acc = float(torch.sum(preds == labels.data)) / preds.shape[0] accuracies.update(acc) all_preds += list( torch.nn.functional.softmax(outputs, dim=1)[:, 1].cpu().data.numpy()) all_labels += list(labels.cpu().data.numpy()) # tqdm_loader.set_postfix(loss=losses.avg, acc=accuracies.avg) auc = roc_auc_score(all_labels, all_preds) auc_test = auc loss_test = losses.avg acc_test = accuracies.avg test_acc_all.append(acc_test) test_auc_all.append(auc_test) print('Test AUC: {:.8f} Loss: {:.8f} ACC: {:.8f} '.format( auc_test, loss_test, acc_test)) nb_classes = 2 confusion_matrix = torch.zeros(nb_classes, nb_classes) with torch.no_grad(): TrueP0 = 0 FalseP0 = 0 FalseN0 = 0 TrueP1 = 0 FalseP1 = 0 FalseN1 = 0 for i, (inputs, classes) in enumerate(dataloaders[phase]): confusion_matrix = torch.zeros(nb_classes, nb_classes) input = inputs.to(device) target = classes.to(device) outputs = model(input) _, preds = torch.max(outputs, 1) for t, p in zip(target.view(-1), preds.view(-1)): confusion_matrix[t, p] += 1 this_class = 0 col = confusion_matrix[:, this_class] row = confusion_matrix[this_class, :] TP = row[this_class] FN = sum(row) - TP FP = sum(col) - TP # print("TP, FP, FN: ", TP, FP, FN) TrueP0 = TrueP0 + TP FalseP0 = FalseP0 + FP FalseN0 = FalseN0 + FN this_class = 1 col = confusion_matrix[:, this_class] row = confusion_matrix[this_class, :] TP = row[this_class] FN = sum(row) - TP FP = sum(col) - TP # print("TP, FP, FN: ", TP, FP, FN) TrueP1 = TrueP1 + TP FalseP1 = FalseP1 + FP FalseN1 = FalseN1 + FN TPFPFN0 = [TrueP0, FalseP0, FalseN0] TPFPFN1 = [TrueP1, FalseP1, FalseN1] TPFPFN0_all.append(TPFPFN0) TPFPFN1_all.append(TPFPFN1) print("overall_TP, FP, FN for 0: ", TrueP0, FalseP0, FalseN0) print("overall_TP, FP, FN for 1: ", TrueP1, FalseP1, FalseN1) print("best_ValidationEpoch:", best_epoch) # print(TPFPFN0_all, val_auc_all, test_auc_all) TPFPFN0_best = TPFPFN0_all[best_epoch - 1][0] TPFPFN1_best = TPFPFN1_all[best_epoch - 1][0] val_auc_best = val_auc_all[best_epoch - 1] val_acc_best = val_acc_all[best_epoch - 1] test_auc_best = test_auc_all[best_epoch - 1] test_acc_best = test_acc_all[best_epoch - 1] # #################### save only the best, delete others file_path = './iterations/' + str(output_path) + '/saved/model_' + str( best_epoch) + '_epoch.pt' if os.path.isfile(file_path): for CleanUp in glob.glob('./iterations/' + str(output_path) + '/saved/*.pt'): if 'model_' + str(best_epoch) + '_epoch.pt' not in CleanUp: os.remove(CleanUp) # # ###################################################### return best_epoch, best_model, TPFPFN0_all[best_epoch - 1], TPFPFN1_all[ best_epoch - 1], test_acc_best, test_auc_best # def binary_crossentropy_with_ranking(y_true, y_pred): # """ Trying to combine ranking loss with numeric precision""" # # first get the log loss like normal # logloss = K.mean(K.binary_crossentropy(y_pred, y_true), axis=-1) # # # next, build a rank loss # # # clip the probabilities to keep stability # y_pred_clipped = K.clip(y_pred, K.epsilon(), 1 - K.epsilon()) # # # translate into the raw scores before the logit # y_pred_score = K.log(y_pred_clipped / (1 - y_pred_clipped)) # # # determine what the maximum score for a zero outcome is # y_pred_score_zerooutcome_max = K.max(y_pred_score * (y_true < 1)) # # # determine how much each score is above or below it # rankloss = y_pred_score - y_pred_score_zerooutcome_max # # # only keep losses for positive outcomes # rankloss = rankloss * y_true # # # only keep losses where the score is below the max # rankloss = K.square(K.clip(rankloss, -100, 0)) # # # average the loss for just the positive outcomes # rankloss = K.sum(rankloss, axis=-1) / (K.sum(y_true > 0) + 1) # # # return (rankloss + 1) * logloss - an alternative to try # return rankloss + logloss