GitHub

Official #TensorFlow implementation of Dense Transformer Networks

  • In this work, we propose Dense Transformer Networks to apply spatial transformation to semantic prediction tasks.
  • The third and fourth rows are the segmentation results of U-Net and DTN, respectively.
  • max_epoch: how many iterations or steps to train

    test_step: how many steps to perform a mini test or validation

    save_step: how many steps to save the model

    summary_step: how many steps to save the summary

    sampledir: where to store predicted samples, please add a / at the end for convinience

    model_name: the name prefix of saved models

    test_epoch: which step to test or predict

    network_depth: how deep of the U-Net including the bottom layer

    class_num: how many classes.

  • We have conv2d for standard convolutional layer, and ipixel_cl for input pixel convolutional layer proposed in our paper.
  • We have deconv for standard deconvolutional layer, ipixel_dcl for input pixel deconvolutional layer, and pixel_dcl for pixel deconvolutional layer proposed in our paper.

Contribute to dtn development by creating an account on GitHub.

@alxndrkalinin: Official #TensorFlow implementation of Dense Transformer Networks

This is the tensorflow implementation of our recent work, “Dense Transformer Networks”. Please check the paper for details.

In this work, we propose Dense Transformer Networks to apply spatial transformation to semantic prediction tasks. Dense Transformer Networks can extract features based on irregular areas, whose shapes and sizes are based on data. In the meantime, Dense Transformer Networks provide a method that efficiently restores spatial relations.

If using this code, please cite our paper.

Experimental results:

We perform our experiment on two datasets to compare the baseline U-Net model and the proposed DTN model.

Sample segmentation results on the PASCAL 2012 segmentation data set. The first and second rows are the original images and the corresponding ground truth, respectively. The third and fourth rows are the segmentation results of U-Net and DTN, respectively.

All network hyperparameters are configured in main.py.

max_epoch: how many iterations or steps to train

test_step: how many steps to perform a mini test or validation

save_step: how many steps to save the model

summary_step: how many steps to save the summary

keep_prob: dropout probability

valid_start_epoch: start step to test a model

valid_end_epoch: end step to test a model

valid_stride_of_epoch: stride to test a model

data_dir: data directory

train_data: h5 file for training

valid_data: h5 file for validation

test_data: h5 file for testing

batch: batch size

channel: input image channel number

height, width: height and width of input image

logdir: where to store log

modeldir: where to store saved models

sampledir: where to store predicted samples, please add a / at the end for convinience

model_name: the name prefix of saved models

reload_epoch: where to return training

test_epoch: which step to test or predict

random_seed: random seed for tensorflow

network_depth: how deep of the U-Net including the bottom layer

class_num: how many classes. Usually number of classes plus one for background

start_channel_num: the number of channel for the first conv layer

conv_name: use which convolutional layer in decoder. We have conv2d for standard convolutional layer, and ipixel_cl for input pixel convolutional layer proposed in our paper.

deconv_name: use which upsampling layer in decoder. We have deconv for standard deconvolutional layer, ipixel_dcl for input pixel deconvolutional layer, and pixel_dcl for pixel deconvolutional layer proposed in our paper.

add_dtn: add Dense Transformer Netwroks or not.

dtn_location: The Dense Transformer Networks location.

control_points_ratio: the ratio of control_points comparing with the Dense transformer networks input size.

After configure the network, we can start to train. Run

The training of a U-Net for semantic segmentation will start.

We employ tensorboard to visualize the training process.

The segmentation results including training and validation accuracies, and the prediction outputs are all available in tensorboard.

Select a good point to test your model based on validation result.

Fill the valid_start_epoch, valid_end_epoch and valid_stride_of_epoch in configure. Then run

It will show the accuracy, loss and mean_iou at each epoch.

If you want to make some predictions, run

The predicted segmentation results will be in sampledir set in main.py, colored.

If you want to use Dense Transformer Networks, just Fill the add_dtn, dtn_location and control_points_ratio in configure function.

GitHub