def train(trainloader, net, index, optimizer, epoch, use_cuda): losses = AverageMeter() print('\nIndex: %d \t Epoch: %d' % (index, epoch)) net.train() for batch_idx, (inputs, targets) in enumerate(trainloader): if use_cuda: inputs = inputs.cuda() optimizer.zero_grad() inputs_Var = Variable(inputs) outputs = net(inputs_Var, index) # import ipdb; ipdb.set_trace() # XXX BREAKPOINT # record loss losses.update(outputs.data[0], inputs.size(0)) outputs.backward() optimizer.step() print('train_loss_{}'.format(index), losses.avg, epoch) # log to TensorBoard if args.tensorboard: log_value('train_loss_{}'.format(index), losses.avg, epoch)
def train(trainloader, net, index, optimizer, epoch, use_cuda, logger): losses = AverageMeter() print('\nIndex: %d \t Epoch: %d' %(index,epoch)) net.train() for batch_idx, (inputs, targets) in enumerate(trainloader): if use_cuda: inputs = inputs.cuda() optimizer.zero_grad() inputs_Var = Variable(inputs) outputs = net(inputs_Var, index) # record loss losses.update(outputs.item(), inputs.size(0)) outputs.backward() ''' # gradient clipping for mlp arch ch = net.named_parameters() for c in ch: if 'enc' in c[0]: k = c[1] torch.nn.utils.clip_grad_norm_(c[1], c[1].mean(dtype=float)) d1 = c[1].view(-1) if torch.isnan(d1).any(): print('heree') exit(0) del (ch) ''' ''' # gradient clipping for convolutional arch ch = net.named_parameters() for c in ch: #if ('enc' in c[0] and 'benc' not in c[0]) or ('dec' in c[0] and 'bdec' not in c[0]): #print(c[0]) #print(c[1]) #torch.nn.utils.clip_grad_norm_(c[1], c[1].mean(dtype=float)) d1 = c[1].view(-1) if torch.isnan(d1).any(): print('heree') exit(0) del (ch) ''' optimizer.step() # log to TensorBoard if logger: logger.log_value('train_loss_{}'.format(index), losses.avg, epoch)
def test(testloader, net, index, epoch, use_cuda): losses = AverageMeter() net.eval() for batch_idx, (inputs, targets) in enumerate(testloader): if use_cuda: inputs = inputs.cuda() inputs_Var = Variable(inputs, volatile=True) outputs = net(inputs_Var, index) # measure accuracy and record loss losses.update(outputs.data[0], inputs.size(0)) # log to TensorBoard if args.tensorboard: log_value('val_loss_{}'.format(index), losses.avg, epoch)
def test(testloader, net, index, epoch, use_cuda, logger): losses = AverageMeter() net.eval() with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(testloader): if use_cuda: inputs = inputs.cuda() inputs_Var = Variable(inputs) outputs = net(inputs_Var, index) # measure accuracy and record loss losses.update(outputs.item(), inputs.size(0)) # log to TensorBoard if logger: logger.log_value('val_loss_{}'.format(index), losses.avg, epoch)
def train(trainloader, net, index, optimizer, epoch, use_cuda, logger): losses = AverageMeter() print('\nIndex: %d \t Epoch: %d' % (index, epoch)) net.train() for batch_idx, (inputs, targets) in enumerate(trainloader): if use_cuda: inputs = inputs.cuda() optimizer.zero_grad() inputs_Var = Variable(inputs) outputs = net(inputs_Var, index) # record loss losses.update(outputs.item(), inputs.size(0)) outputs.backward() optimizer.step() # log to TensorBoard if logger: logger.log_value('train_loss_{}'.format(index), losses.avg, epoch)
def train(trainloader, net, optimizer, criterion1, criterion2, epoch, use_cuda, _sigma1, _sigma2, _lambda, logger): losses = AverageMeter() losses1 = AverageMeter() losses2 = AverageMeter() print('\n Epoch: %d' % epoch) net.train() for batch_idx, (inputs, pairweights, sampweights, pairs, index) in enumerate(trainloader): inputs = torch.squeeze(inputs, 0) pairweights = torch.squeeze(pairweights) sampweights = torch.squeeze(sampweights) index = torch.squeeze(index) pairs = pairs.view(-1, 2) index = index.long() pairs = pairs.long() if use_cuda: inputs = inputs.cuda() pairweights = pairweights.cuda() sampweights = sampweights.cuda() index = index.cuda() pairs = pairs.cuda() optimizer.zero_grad() inputs_Var, sampweights, pairweights = Variable(inputs), Variable(sampweights, requires_grad=False), \ Variable(pairweights, requires_grad=False) enc, dec = net(inputs_Var) loss1 = criterion1(inputs_Var, dec, sampweights) loss2 = criterion2(enc, sampweights, pairweights, pairs, index, _sigma1, _sigma2, _lambda) loss = loss1 + loss2 # record loss losses1.update(loss1.item(), inputs.size(0)) losses2.update(loss2.item(), inputs.size(0)) losses.update(loss.item(), inputs.size(0)) loss.backward() optimizer.step() # log to TensorBoard if logger: logger.log_value('total_loss', losses.avg, epoch) logger.log_value('reconstruction_loss', losses1.avg, epoch) logger.log_value('dcc_loss', losses2.avg, epoch)
def train_one_epoch(dest_aug_mask_perm_dataloader, traj_encoder, dest_proj, aug_proj, mapemb_proj, mask_proj, perm_proj, optimizer, scheduler, criterion_ce, graphregion, config, log_f, log_error): traj_encoder.train() dest_proj.train() aug_proj.train() mask_proj.train() perm_proj.train() losses = AverageMeter() losses_dest = AverageMeter() losses_aug = AverageMeter() losses_mapemb = AverageMeter() losses_mask = AverageMeter() losses_perm = AverageMeter() train_runs = 0 sample_cnt = 0 losses_hist = [] losses_dest_hist = [] losses_aug_hist = [] losses_mapemb_hist = [] losses_mask_hist = [] losses_perm_hist = [] while True: train_runs += 1 # re-init loss to zero every iter loss = 0. try: train_batch = next(dest_aug_mask_perm_dataloader) except StopIteration as e: log_f.write( "All dataloader ran out, finishing {}-th epoch's training. \n". format(config.epoch)) print( "All dataloader ran out, finishing {}-th epoch's training. \n". format(config.epoch)) break if train_batch is None: # all filtered out: length<10 or [-1] train_runs -= 1 continue batch_dest, batch_aug, batch_mask, batch_perm = train_batch ####### Destination ############################################################### if 'dest' in config.del_tasks: loss_destination = None else: try: loss_destination, out_tm, h_t, w_uh_t, negs, neg_term = compute_destination_loss( batch_dest, traj_encoder, dest_proj, graphregion, config, ) #print("loss_destination", loss_destination) loss_destination = config.loss_dest_weight * loss_destination loss += loss_destination except Exception as e: traceback.print_exc() log_error.write(traceback.format_exc()) # print(e) loss_destination = None if batch_dest is not None: if batch_dest.traj_len.size( 0) == 1: # batchsize = 1, skip the iteration train_runs -= 1 continue pass #################################################################################### ####### Augmentation ############################################################### if ('aug' in config.del_tasks) and ('mapemb' in config.del_tasks): loss_augmentation = None loss_mapemb = None else: try: left_aug, right_aug = batch_aug is_mapemb = left_aug.traj_len.min() >= 40 loss_augmentation, loss_mapemb = compute_aug_loss( left_aug, right_aug, traj_encoder, aug_proj, mapemb_proj, graphregion, config, criterion_ce, is_mapemb) if ('aug' in config.del_tasks): # only count loss_mapemb loss_augmentation = None if is_mapemb: # loss_mapemb is not None loss += loss_mapemb if ('mapemb' in config.del_tasks): # only count loss_aug loss_mapemb = None loss += loss_augmentation if ('aug' not in config.del_tasks) and ('mapemb' not in config.del_tasks): loss += loss_augmentation if is_mapemb: # loss_mapemb is not None loss += loss_mapemb #print("loss_augmentation", loss_augmentation) except Exception as e: traceback.print_exc() log_error.write(traceback.format_exc()) # print(e) loss_augmentation = None loss_mapemb = None pass #################################################################################### ####### mask ####################################################################### if 'mask' in config.del_tasks: loss_mask = None else: try: loss_mask, batch_queries, h_t, w_uh_t, _neg_term, neg_term = compute_mask_loss( batch_mask, traj_encoder, mask_proj, graphregion, config, ) #print('loss_mask', loss_mask) loss_mask = config.loss_mask_weight * loss_mask loss += loss_mask except Exception as e: traceback.print_exc() log_error.write(traceback.format_exc()) loss_mask = None pass #################################################################################### ####### perm ####################################################################### if 'perm' in config.del_tasks: loss_perm = None else: try: anchor, pos, neg = batch_perm loss_perm, logits_perm, target_perm = compute_perm_loss( anchor, pos, neg, traj_encoder, perm_proj, graphregion, config, criterion_ce) loss_perm = config.loss_perm_weight * loss_perm loss += loss_perm #print("loss_perm", loss_perm) except Exception as e: traceback.print_exc() log_error.write(traceback.format_exc()) # print(e) loss_perm = None pass #################################################################################### if (loss_destination is None) and (loss_augmentation is None) and ( loss_mapemb is None) and (loss_mask is None) and (loss_perm is None): if ('dest' in config.del_tasks) and ('perm' in config.del_tasks) and \ ('aug' in config.del_tasks) and ('mask' in config.del_tasks): # model_with_mapemb train_runs -= 1 continue else: log_f.write( "All loss none, at {}-th epoch's training: check errordata_e{}_step{}.pkl \n" .format(config.epoch, config.epoch, train_runs)) print( "All loss none, at {}-th epoch's training: check errordata_e{}_step{}.pkl \n" .format(config.epoch, config.epoch, train_runs)) pickle.dump((batch_dest, batch_aug, batch_mask, batch_perm), open( 'errordata_e{}_step{}.pkl'.format( config.epoch, train_runs), 'wb')) train_runs -= 1 continue sample_cnted = False try: losses.update(loss.item(), ) except: print(loss, loss_destination, is_mapemb, loss_augmentation, loss_mapemb, loss_mask, loss_perm) if loss_destination is not None: losses_dest.update(loss_destination.item(), ) sample_cnt += batch_dest.tm_len.size(0) sample_cnted = True if loss_augmentation is not None: losses_aug.update(loss_augmentation.item(), ) if not sample_cnted: sample_cnt += left_aug.tm_len.size(0) sample_cnted = True if loss_mapemb is not None: losses_mapemb.update(loss_mapemb.item(), ) if not sample_cnted: sample_cnt += left_aug.tm_len.size(0) sample_cnted = True if loss_mask is not None: losses_mask.update(loss_mask.item(), ) if not sample_cnted: sample_cnt += batch_mask.tm_len.size(0) sample_cnted = True if loss_perm is not None: losses_perm.update(loss_perm.item(), ) if not sample_cnted: sample_cnt += pos.tm_len.size(0) sample_cnted = True if train_runs % 100 == 0: # print("logits_perm, target_perm: ", logits_perm, target_perm) # print('batch_queries', batch_queries) # print('h_t', h_t) # print('w_uh_t',w_uh_t) # print('_neg_term', _neg_term) # print('neg_term', neg_term) if 'perm' not in config.del_tasks: print("acc : {:.2f}".format( torch.sum( logits_perm.max(1, )[1].cpu() == target_perm.view( -1, ).cpu()).to(torch.float32) / target_perm.size(0))) losses_hist.append(losses.val) losses_dest_hist.append(losses_dest.val) losses_aug_hist.append(losses_aug.val) losses_mapemb_hist.append(losses_mapemb.val) losses_mask_hist.append(losses_mask.val) losses_perm_hist.append(losses_perm.val) log_f.write( 'Train Epoch:{} approx. [{}/{}] total_loss:{:.2f}({:.2f})\n'. format(config.epoch, sample_cnt, config.n_trains, losses.val, losses.avg)) log_f.write( 'loss_destination:{:.2f}({:.2f}) \nloss_augmentation:{:.2f}({:.2f}) \nloss_mapemb:{:.2f}({:.2f}) \nloss_mask:{:.2f}({:.2f}) \nloss_perm:{:.2f}({:.2f}) \n\n' .format( losses_dest.val, losses_dest.avg, losses_aug.val, losses_aug.avg, losses_mapemb.val, losses_mapemb.avg, losses_mask.val, losses_mask.avg, losses_perm.val, losses_perm.avg, )) print('Train Epoch:{} approx. [{}/{}] total_loss:{:.2f}({:.2f})'. format(config.epoch, sample_cnt, config.n_trains, losses.val, losses.avg)) print( 'loss_destination:{:.2f}({:.2f}) \nloss_augmentation:{:.2f}({:.2f}) \nloss_mapemb:{:.2f}({:.2f}) \nloss_mask:{:.2f}({:.2f}) \nloss_perm:{:.2f}({:.2f}) \n' .format( losses_dest.val, losses_dest.avg, losses_aug.val, losses_aug.avg, losses_mapemb.val, losses_mapemb.avg, losses_mask.val, losses_mask.avg, losses_perm.val, losses_perm.avg, )) log_f.flush() log_error.flush() if train_runs % 4500 == 0: log_f.write("At step 4500, save model {}.pt\n".format( config.name + '_num_hid_layer_' + str(config.num_hidden_layers) + '_step{}'.format(train_runs + 1))) print("At step 4500, save model {}.pt\n".format( config.name + '_num_hid_layer_' + str(config.num_hidden_layers) + '_step{}'.format(train_runs + 1))) ###### models_dict = { traj_encoder.__class__.__name__: traj_encoder.state_dict(), mask_proj.__class__.__name__: mask_proj.state_dict(), perm_proj.__class__.__name__: perm_proj.state_dict(), aug_proj.__class__.__name__: aug_proj.state_dict(), mapemb_proj.__class__.__name__: mapemb_proj.state_dict(), dest_proj.__class__.__name__: dest_proj.state_dict(), } torch.save( models_dict, os.path.join( 'models', config.name + '_num_hid_layer_' + str(config.num_hidden_layers) + '_step{}'.format(train_runs) + '.pt')) ###### optimizer.zero_grad() loss.backward() # every iter optimizer.step() torch.save((losses_hist, losses_dest_hist, losses_aug_hist, losses_mapemb_hist, losses_mask_hist, losses_perm_hist), os.path.join('train_hist', config.name+'_loss_hist'+\ '_hidlayer_'+str(config.num_hidden_layers)+\ 'e'+str(config.epoch)+'.pt'))
def train_step_2(trainloader, net_s, net_z, net_d, optimizer_zc, optimizer_d, criterion_rec, criterion_zc, criterion_d, epoch, use_cuda, _sigma1, _sigma2, _lambda): losses = AverageMeter() losses1 = AverageMeter() losses2 = AverageMeter() losses_d_rec = AverageMeter() losses_d = AverageMeter() print('\n Epoch: %d' % epoch) net_z.train() net_d.train() decoder_loss = 0.0 adversarial_loss = 0.0 for i, (inputs, pairweights, sampweights, pairs, index) in enumerate(trainloader): inputs = torch.squeeze(inputs,0) pairweights = torch.squeeze(pairweights) sampweights = torch.squeeze(sampweights) index = torch.squeeze(index) pairs = pairs.view(-1, 2) if use_cuda: inputs = inputs.cuda() pairweights = pairweights.cuda() sampweights = sampweights.cuda() index = index.cuda() pairs = pairs.cuda() inputs, sampweights, pairweights = Variable(inputs), Variable(sampweights, requires_grad=False), \ Variable(pairweights, requires_grad=False) # train z encoder and decoder if i % 3 == 0: # zero the parameter gradients optimizer_d.zero_grad() optimizer_zc.zero_grad() # forward + backward + optimize outputs_s, _ = net_s(inputs) outputs_z, dec_z = net_z(inputs) loss1 = criterion_rec(inputs, dec_z, sampweights) loss2 = criterion_zc(outputs_z, sampweights, pairweights, pairs, index, _sigma1, _sigma2, _lambda) loss_zc = loss1 + loss2 # record loss losses1.update(loss1.data[0], inputs.size(0)) losses2.update(loss2.data[0], inputs.size(0)) losses.update(loss_zc.data[0], inputs.size(0)) decoder_input = torch.cat((outputs_s, outputs_z),1) outputs_d = net_d(decoder_input) #beta = 1.985 # change? beta = 1.99 # change? loss_d_rec = criterion_d(outputs_d, inputs) loss_d = loss_d_rec - beta * loss_zc #record loss losses_d_rec.update(loss_d_rec.data[0], inputs.size(0)) losses_d.update(loss_d.data[0], inputs.size(0)) loss_d.backward() #loss_zc.backward() optimizer_d.step() optimizer_zc.step() decoder_loss += loss_d.data[0] print('dcc_reconstruction_loss', losses1.avg, epoch) print('dcc_clustering_loss', losses2.avg, epoch) print('dcc_loss', losses.avg, epoch) print('total_reconstruction_loss', losses_d_rec.avg, epoch) print('total_loss', losses_d.avg, epoch) # log to TensorBoard if args.tensorboard: log_value('dcc_reconstruction_loss', losses1.avg, epoch) log_value('dcc_clustering_loss', losses2.avg, epoch) log_value('dcc_loss', losses.avg, epoch) log_value('total_reconstruction_loss', losses_d_rec.avg, epoch) log_value('total_loss', losses_d.avg, epoch) # train adversarial clustering else: # zero the parameter gradients optimizer_zc.zero_grad() # forward + backward + optimize outputs_z, dec_z = net_z(inputs) loss1 = criterion_rec(inputs, dec_z, sampweights) loss2 = criterion_zc(outputs_z, sampweights, pairweights, pairs, index, _sigma1, _sigma2, _lambda) loss_zc = loss1 + loss2 # record loss losses1.update(loss1.data[0], inputs.size(0)) losses2.update(loss2.data[0], inputs.size(0)) losses.update(loss_zc.data[0], inputs.size(0)) loss_zc.backward() optimizer_zc.step() adversarial_loss += loss_zc.data[0] # print statistics if i % 2000 == 1999: # print every 2000 mini-batches print('[%d, %5d] decoder loss: %.3f, adversarial loss: %.3f' %(epoch + 1, i + 1, decoder_loss / 500, adversarial_loss / 1500)) decoder_loss = 0.0 adversarial_loss = 0.0